diff --git a/markovify.py b/markovify.py index d36d21f..bf661f2 100755 --- a/markovify.py +++ b/markovify.py @@ -215,7 +215,7 @@ def process(urls): convert_to_probabilities(bigrams) print(generate_text(bigrams)) -def process_mlx(urls): +def process_mlx(urls, prompt): """process_mlx(list(string)) -> None Prompt the Microsoft Phi-2 LLM with all the text inside
tags and @@ -224,16 +224,17 @@ def process_mlx(urls): all_text = '' for url in urls: text, final_url = fetch_text(url) - print('FETCHED TEXT FROM: %s\n' % final_url) - all_text += text - all_text = all_text.replace('\n', '') - all_text = re.sub('[^a-zA-Z\.\s]', '', all_text) - all_text = re.sub('\s+', ' ', all_text) - all_text = re.sub(' \.', '.', all_text) - prompt = f'Summarize the following. {all_text}.' + print(f'FETCHED TEXT FROM: {final_url}\n') + text = text.replace('\n', '') + text = re.sub('[^a-zA-Z\.\s]', '', text) + text = re.sub('\s+', ' ', text) + text = re.sub(' \.', '.', text) + all_text += text + '\n\n' + prompt = f'{all_text}{prompt}' + print(f'PROMPT: {prompt}\n') from mlx_lm import load, generate model, tokenizer = load('microsoft/phi-2') - response = generate(model, tokenizer, max_tokens=2048, prompt=all_text, \ + response = generate(model, tokenizer, max_tokens=2048, prompt=prompt, \ verbose=True, temp=0.5) if __name__ == '__main__': @@ -241,16 +242,27 @@ def process_mlx(urls): print('usage: %s [list of urls to learn markov chains from]' % sys.argv[0]) print('To use the phi-2 LLM on MacBook with MLX:') + print('%s --mlx' % sys.argv[0]) print('%s --mlx [list of urls to learn from]' % sys.argv[0]) + print('%s --mlx --prompt "use only this prompt"' % sys.argv[0]) + print('%s --mlx --prompt "use this prompt after the web pages" [list of urls to learn from]' % sys.argv[0]) sys.exit(1) pages = [] random_page = 'https://en.wikipedia.org/wiki/Special:Random' if len(sys.argv) > 1 and sys.argv[1].lower() == '--mlx': - if len(sys.argv) < 3: - pages = [random_page, random_page] + prompt = 'Can you summarize the previous text?' + if len(sys.argv) >= 3 and sys.argv[2].lower() == '--prompt': + prompt = sys.argv[3] + if len(sys.argv) == 3: + pages = [] + else: + pages = sys.argv[4:] else: - pages = sys.argv[2:] - process_mlx(pages) + if len(sys.argv) < 3: + pages = [random_page, random_page] + else: + pages = sys.argv[2:] + process_mlx(pages, prompt) else: if len(sys.argv) < 2: pages = [random_page, random_page]