-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathstreamlit_app.py
40 lines (26 loc) · 1.26 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import streamlit as st
import json
import torch
from collections import Counter
import generate_text
#===========================================#
# Loads Model and word_to_id #
#===========================================#
with open('trained_model/word_to_id.json') as json_file:
word_to_id = Counter(json.load(json_file))
with open('trained_model/always_capitalized.json') as json_file:
always_capitalized = json.load(json_file)
id_to_word = ["<Unknown>"] + [word for word, index in word_to_id.items()]
net = torch.load('trained_model/trained_model.pt')
net.eval()
#===========================================#
# Streamlit Code #
#===========================================#
desc = "Uses an LSTM neural network trained on *The Lord of the Rings*. Check out the code [here](https://github.com/christian-doucette/tolkein_text)!"
st.title('Lord of the Rings Text Generator')
st.write(desc)
num_sentences = st.number_input('Number of Sentences', min_value=1, max_value=20, value=5)
user_input = st.text_input('Seed Text (can leave blank)')
if st.button('Generate Text'):
generated_text = generate_text.prediction(net, word_to_id, id_to_word, always_capitalized, user_input, 9, num_sentences)
st.write(generated_text)