-
Notifications
You must be signed in to change notification settings - Fork 5
/
textgen.py
82 lines (70 loc) · 2.96 KB
/
textgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import random
import argparse
from collections import defaultdict
import sys
if sys.version_info >= (3,):
import queue
else:
import Queue as queue
def get_next_state(markov_chain, state):
next_state_items = list(markov_chain[state].items())
next_states = [x[0] for x in next_state_items]
next_state_counts = [x[1] for x in next_state_items]
total_count = sum(next_state_counts)
next_state_probabilities = []
probability_total = 0
for next_state_count in next_state_counts:
probability = float(next_state_count) / total_count
probability_total += probability
next_state_probabilities.append(probability_total)
sample = random.random()
for index, next_state_probability in enumerate(next_state_probabilities):
if sample <= next_state_probability:
return next_states[index]
return None
def tokenise_text_file(file_name):
with open(file_name, 'r') as file:
return ' '.join(file).split()
def create_markov_chain(tokens, order):
if order > len(tokens):
raise Exception('Order greater than number of tokens.')
markov_chain = defaultdict(lambda: defaultdict(int))
current_state_queue = queue.Queue()
for index, token in enumerate(tokens):
if index < order:
current_state_queue.put(token)
if index == order - 1:
current_state = ' '.join(list(current_state_queue.queue))
elif index < len(tokens):
current_state_queue.get()
current_state_queue.put(token)
next_state = ' '.join(list(current_state_queue.queue))
markov_chain[current_state][next_state] += 1
current_state = next_state
return markov_chain
def get_random_state(markov_chain):
uppercase_states = [state for state in markov_chain.keys() if state[0].isupper()]
if len(uppercase_states) == 0:
return random.choice(list(markov_chain.keys()))
return random.choice(uppercase_states)
def generate_text(markov_chain, words):
state = get_random_state(markov_chain)
text = state.split()[:words]
while len(text) < words:
state = get_next_state(markov_chain, state)
if state is None:
state = get_random_state(markov_chain)
text.append(state.split()[-1])
return ' '.join(text)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Markov Chain Text Generator')
parser.add_argument('-f', '--file', required=True,
help='Name of file to read text from.')
parser.add_argument('-o', '--order', default=1, type=int,
help='Number of past states each state depends on.')
parser.add_argument('-w', '--words', default=100, type=int,
help='Number of words to generate.')
pargs = parser.parse_args()
tokens = tokenise_text_file(pargs.file)
markov_chain = create_markov_chain(tokens, order=pargs.order)
print(generate_text(markov_chain, pargs.words))