-
Notifications
You must be signed in to change notification settings - Fork 87
/
Copy pathsample.py
39 lines (30 loc) · 1.24 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import os
import numpy as np
import torch as t
from utils.batch_loader import BatchLoader
from utils.parameters import Parameters
from model.rvae import RVAE
if __name__ == '__main__':
assert os.path.exists('trained_RVAE'), \
'trained model not found'
parser = argparse.ArgumentParser(description='Sampler')
parser.add_argument('--use-cuda', type=bool, default=True, metavar='CUDA',
help='use cuda (default: True)')
parser.add_argument('--num-sample', type=int, default=10, metavar='NS',
help='num samplings (default: 10)')
args = parser.parse_args()
batch_loader = BatchLoader('')
parameters = Parameters(batch_loader.max_word_len,
batch_loader.max_seq_len,
batch_loader.words_vocab_size,
batch_loader.chars_vocab_size)
rvae = RVAE(parameters)
rvae.load_state_dict(t.load('trained_RVAE'))
if args.use_cuda:
rvae = rvae.cuda()
for iteration in range(args.num_sample):
seed = np.random.normal(size=[1, parameters.latent_variable_size])
result = rvae.sample(batch_loader, 50, seed, args.use_cuda)
print(result)
print()