-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
import torch | ||
import re | ||
import argparse | ||
|
||
|
||
def main(): | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--filename', dest='filename', type=str, help='Name of input file') | ||
parser.add_argument('--outfile', dest='outfile', type=str, help='Name of output file') | ||
|
||
args = parser.parse_args() | ||
filename = args.filename | ||
outfile = args.outfile | ||
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
print("Device is ", device) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("filco306/gpt2-shakespeare-paraphraser") | ||
model = AutoModelForCausalLM.from_pretrained("filco306/gpt2-shakespeare-paraphraser", pad_token_id=tokenizer.eos_token_id) | ||
model.to(device) | ||
|
||
base_input_path = "./shakespeare/sparknotes/merged/" | ||
base_output_path = "./output/" | ||
|
||
self.generate(base_input_path + filename, base_output_path + outfile, tokenizer, model, device) | ||
|
||
|
||
def generate(input_path, output_path, tokenizer, model, device): | ||
|
||
input = None | ||
for filename in files: | ||
f = open(input_path, "r") | ||
input = f.read() | ||
|
||
output_file = open(output_path, "w") | ||
torch.manual_seed(0) | ||
sentences = re.split('\n', input) | ||
for sentence in sentences: | ||
input_ids = tokenizer.encode(sentence, return_tensors='pt').to(device) | ||
greedy_output = model.generate(input_ids, max_length=input_ids.shape[1]*1.5, early_stopping = True, top_p=0.90) | ||
output_file.write('\n') | ||
output_file.write(tokenizer.decode(greedy_output[0], skip_special_tokens=True)) | ||
|
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
python3 baseline.py --filename antony-and-cleopatra_original.snt.aligned --outfile antony-and-cleopatra_translated.txt | ||
python3 baseline.py --filename asyoulikeit_original.snt.aligned --outfile asyoulikeit_translated.txt | ||
python3 baseline.py --filename errors_original.snt.aligned --outfile errors_translated.txt | ||
python3 baseline.py --filename hamlet_original.snt.aligned --outfile hamlet_translated.txt | ||
python3 baseline.py --filename henryv_original.snt.aligned --outfile henryv_translated.txt | ||
python3 baseline.py --filename juliuscaesar_original.snt.aligned --outfile juliuscaeser_translated.txt | ||
python3 baseline.py --filename lear_original.snt.aligned --outfile lear_translated.txt | ||
python3 baseline.py --filename macbeth_original.snt.aligned --outfile macbeth_translated.txt | ||
python3 baseline.py --filename merchant_original.snt.aligned --outfile merchant_translated.txt | ||
python3 baseline.py --filename msnd_original.snt.aligned --outfile msnd_translated.txt | ||
python3 baseline.py --filename muchado_original.snt.aligned --outfile muchado_translated.txt | ||
python3 baseline.py --filename othello_original.snt.aligned --outfile othello_translated.txt | ||
python3 baseline.py --filename richardiii_original.snt.aligned --outfile richardiii_translated.txt | ||
python3 baseline.py --filename romeojuliet_original.snt.aligned --outfile romeojuliet_translated.txt | ||
python3 baseline.py --filename shrew_original.snt.aligned --outfile shrew_translated.txt | ||
python3 baseline.py --filename tempest_original.snt.aligned --outfile tempest_translated.txt | ||
python3 baseline.py --filename twelfthnight_original.snt.aligned --outfile twelfthnight_translated.txt |