-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeneratetweet.py
88 lines (65 loc) · 2.78 KB
/
generatetweet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import os, time, functools, pickle
songs = pickle.load( open( "training_data/positive.pkl", "rb" ) )
# Join our list of song strings into a single string containing all songs
songs_joined = "\n\n".join(songs)
# Find all unique characters in the joined string
vocab = sorted(set(songs_joined))
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
def LSTM(rnn_units):
return tf.keras.layers.LSTM(
rnn_units,
return_sequences=True,
recurrent_initializer='glorot_uniform',
recurrent_activation='sigmoid',
stateful=True,
)
### Defining the RNN Model ###
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = tf.keras.Sequential([
# Layer 1: Embedding layer to transform indices into dense vectors of a fixed embedding size
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
# Layer 2: LSTM with `rnn_units` number of units.
LSTM(rnn_units),
# Layer 3: Dense (fully-connected) layer that transforms the LSTM output into the vocabulary size.
tf.keras.layers.Dense(vocab_size)
])
return model
### Hyperparameter setting and optimization ###
# Model parameters:
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024 # Experiment between 1 and 2048
# Checkpoint location:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "my_ckpt")
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
# Restore the model weights for the last checkpoint after training
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
model.summary()
### Prediction of a generated song ###
def generate_text(model, start_string, generation_length=1000):
# Evaluation step (generating ABC text using the learned RNN model)
input_eval = [char2idx[s] for s in start_string]
input_eval = tf.expand_dims(input_eval, 0)
print(input_eval)
# Empty string to store our results
text_generated = []
# Here batch size == 1
model.reset_states()
for _ in tqdm(range(generation_length)):
predictions = model(input_eval)
# Remove the batch dimension
predictions = tf.squeeze(predictions, 0)
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
# Pass the prediction along with the previous hidden state as the next inputs to the model
input_eval = tf.expand_dims([predicted_id], 0)
# Hint: consider what format the prediction is in vs. the output
text_generated.append(idx2char[predicted_id])
return (start_string + ''.join(text_generated))
generated_text = generate_text(model, start_string="c", generation_length=1000)
print(generated_text)