Skip to content

Commit

Permalink
add instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
kuleshov committed Jun 29, 2017
1 parent 807de6f commit 968d313
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 40 deletions.
30 changes: 20 additions & 10 deletions data/vctk/prep_vctk.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,27 @@

parser = argparse.ArgumentParser()

parser.add_argument('--file-list')
parser.add_argument('--out')
parser.add_argument('--in-dir', default='~/')
parser.add_argument('--interpolate', action='store_true')
parser.add_argument('--low-pass', action='store_true')
parser.add_argument('--dimension', type=int, default=6400)
parser.add_argument('--stride', type=int, default=3200)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--file-list',
help='list of input wav files to process')
parser.add_argument('--in-dir', default='~/',
help='folder where input files are located')
parser.add_argument('--out',
help='path to output h5 archive')
parser.add_argument('--scale', type=int, default=2,
help='scaling factor')
parser.add_argument('--dimension', type=int, default=8192,
help='dimension of patches')
parser.add_argument('--stride', type=int, default=3200,
help='stride when extracting patches')
parser.add_argument('--interpolate', action='store_true',
help='interpolate low-res patches with cubic splines')
parser.add_argument('--low-pass', action='store_true',
help='apply low-pass filter when generating low-res patches')
parser.add_argument('--batch-size', type=int, default=128,
help='we produce # of patches that is a multiple of batch size')
parser.add_argument('--sr', type=int, default=16000, help='audio sampling rate')
parser.add_argument('--sam', type=float, default=1.)
parser.add_argument('--sam', type=float, default=1.,
help='subsampling factor for the data')

args = parser.parse_args()

Expand Down
55 changes: 25 additions & 30 deletions src/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,36 @@ def make_parser():
train_parser = subparsers.add_parser('train')
train_parser.set_defaults(func=train)

train_parser.add_argument('--train', required=True)
train_parser.add_argument('--val', required=True)
train_parser.add_argument('-e', '--epochs', type=int, default=100)
train_parser.add_argument('--batch-size', type=int, default=128)
train_parser.add_argument('--load', help='preload-existing model')
train_parser.add_argument('--logname', default='tmp-run')
train_parser.add_argument('--layers', default=4, type=int)
train_parser.add_argument('--alg', default='adam')
train_parser.add_argument('--lr', default=1e-3, type=float)
train_parser.add_argument('--sr', help='sampling rate for the wav',
type=int, default=16000)
train_parser.add_argument('--train', required=True,
help='path to h5 archive of training patches')
train_parser.add_argument('--val', required=True,
help='path to h5 archive of validation set patches')
train_parser.add_argument('-e', '--epochs', type=int, default=100,
help='number of epochs to train')
train_parser.add_argument('--batch-size', type=int, default=128,
help='training batch size')
train_parser.add_argument('--logname', default='tmp-run',
help='folder where logs will be stored')
train_parser.add_argument('--layers', default=4, type=int,
help='number of layers in each of the D and U halves of the network')
train_parser.add_argument('--alg', default='adam',
help='optimization algorithm')
train_parser.add_argument('--lr', default=1e-3, type=float,
help='learning rate')

# eval

eval_parser = subparsers.add_parser('eval')
eval_parser.set_defaults(func=eval)

eval_parser.add_argument('--logname', required=True)
eval_parser.add_argument('--out-label', default='')
eval_parser.add_argument('--wav-file-list', help='list of audio files')
eval_parser.add_argument('--logname', required=True,
help='path to training checkpoint')
eval_parser.add_argument('--out-label', default='',
help='append label to output samples')
eval_parser.add_argument('--wav-file-list',
help='list of audio files for evaluation')
eval_parser.add_argument('--r', help='upscaling factor', type=int)
eval_parser.add_argument('--sr', help='sampling rate',
eval_parser.add_argument('--sr', help='high-res sampling rate',
type=int, default=16000)

return parser
Expand All @@ -61,21 +69,8 @@ def train(args):
r = Y_train[0].shape[1] / X_train[0].shape[1]
assert n_chan == 1

# # determine super-resolution level
# n_chan, n_dim = Y_train[0].shape
# r = Y_train[0].shape[1] / X_train[0].shape[1]
# assert n_chan == 1

# # transponse to match tensorflow (batch, height, width, chan) format
# X_train, X_val = X_train.transpose([0,2,1]), X_val.transpose([0,2,1])
# Y_train, Y_val = Y_train.transpose([0,2,1]), Y_val.transpose([0,2,1])

# load model
from_ckpt = True if args.load is not None else False
model = get_model(args, n_dim, r, from_ckpt=from_ckpt, train=True)

# load checkpoint
if from_ckpt: model.load(args.load)
# create model
model = get_model(args, n_dim, r, from_ckpt=False, train=True)

# train model
model.fit(X_train, Y_train, X_val, Y_val, n_epoch=args.epochs)
Expand Down

0 comments on commit 968d313

Please sign in to comment.