-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathrun_sample.py
35 lines (26 loc) · 1.05 KB
/
run_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import argparse
from load_model import load_model
from transformers import GPT2TokenizerFast
import torch.nn.functional as F
import sampling
def main():
parser = argparse.ArgumentParser(description="Generate some samples")
parser.add_argument("--model_path", default="louaaron/sedd-medium", type=str)
parser.add_argument("--dataset", default="wikitext103", type=str)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--steps", type=int, default=1024)
args = parser.parse_args()
device = torch.device('cuda')
model, graph, noise = load_model(args.model_path, device)
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
sampling_fn = sampling.get_pc_sampler(
graph, noise, (args.batch_size, 1024), 'analytic', args.steps, device=device
)
samples = sampling_fn(model)
text_samples = tokenizer.batch_decode(samples)
for i in text_samples:
print(i)
print("=================================================")
if __name__=="__main__":
main()