-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathextract_baseline.py
31 lines (22 loc) · 1.1 KB
/
extract_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import argparse
def main():
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device is ", device)
tokenizer = AutoTokenizer.from_pretrained("filco306/gpt2-shakespeare-paraphraser")
model = AutoModelForCausalLM.from_pretrained("filco306/gpt2-shakespeare-paraphraser", pad_token_id=tokenizer.eos_token_id)
model.to(device)
generate("../data/extract.txt", "extract_baseline3.txt", tokenizer, model, device)
def generate(input_path, output_path, tokenizer, model, device):
f = open(input_path, "r")
input = f.readlines()
output_file = open(output_path, "w")
print(input)
for sentence in input:
input_ids = tokenizer.encode(sentence.strip(), return_tensors='pt').to(device)
greedy_output = model.generate(no_repeat_ngram_size=2, input_ids=input_ids, max_length=80, early_stopping = True, num_beams=5)
output_file.write('\n')
output_file.write(tokenizer.decode(greedy_output[0], skip_special_tokens=True))
main()