diff --git a/train.py b/train.py old mode 100644 new mode 100755 index 7310903b7..99d646407 --- a/train.py +++ b/train.py @@ -19,6 +19,7 @@ from wavenet import WaveNetModel, AudioReader, optimizer_factory +NUM_GPUS = 1 BATCH_SIZE = 1 DATA_DIRECTORY = './VCTK-Corpus' LOGDIR_ROOT = './logdir' @@ -94,6 +95,8 @@ def _str_to_bool(s): default=MOMENTUM, help='Specify the momentum to be ' 'used by sgd or rmsprop optimizer. Ignored by the ' 'adam optimizer.') + parser.add_argument('--num_gpus', type=int, default=NUM_GPUS, + help='number of gpus to use') parser.add_argument('--histograms', type=_str_to_bool, default=False, help='Whether to store histogram summaries.') return parser.parse_args() @@ -177,6 +180,30 @@ def validate_directories(args): 'restore_from': restore_from } +def make_net(args,wavenet_params,audio_batch,reuse_variables): + # Create network. + net = WaveNetModel( + batch_size=args.batch_size, + dilations=wavenet_params["dilations"], + filter_width=wavenet_params["filter_width"], + residual_channels=wavenet_params["residual_channels"], + dilation_channels=wavenet_params["dilation_channels"], + skip_channels=wavenet_params["skip_channels"], + quantization_channels=wavenet_params["quantization_channels"], + use_biases=wavenet_params["use_biases"], + scalar_input=wavenet_params["scalar_input"], + initial_filter_width=wavenet_params["initial_filter_width"], + reuse_variables=reuse_variables, + histograms=args.histograms) + if args.l2_regularization_strength == 0: + args.l2_regularization_strength = None + loss = net.loss(audio_batch, args.l2_regularization_strength) + optimizer = optimizer_factory[args.optimizer]( + learning_rate=args.learning_rate, + momentum=args.momentum) + trainable = tf.trainable_variables() + return loss, optimizer, trainable + def main(): args = get_arguments() @@ -214,38 +241,52 @@ def main(): sample_rate=wavenet_params['sample_rate'], sample_size=args.sample_size, silence_threshold=args.silence_threshold) - audio_batch = reader.dequeue(args.batch_size) - # Create network. - net = WaveNetModel( - batch_size=args.batch_size, - dilations=wavenet_params["dilations"], - filter_width=wavenet_params["filter_width"], - residual_channels=wavenet_params["residual_channels"], - dilation_channels=wavenet_params["dilation_channels"], - skip_channels=wavenet_params["skip_channels"], - quantization_channels=wavenet_params["quantization_channels"], - use_biases=wavenet_params["use_biases"], - scalar_input=wavenet_params["scalar_input"], - initial_filter_width=wavenet_params["initial_filter_width"], - histograms=args.histograms) - if args.l2_regularization_strength == 0: - args.l2_regularization_strength = None - loss = net.loss(audio_batch, args.l2_regularization_strength) - optimizer = optimizer_factory[args.optimizer]( - learning_rate=args.learning_rate, - momentum=args.momentum) - trainable = tf.trainable_variables() - optim = optimizer.minimize(loss, var_list=trainable) + tower_grads = [] + tower_losses = [] + for device_index in xrange(args.num_gpus): + with tf.device('/gpu:%d' % device_index), tf.name_scope('tower_%d' % device_index) as scope: + audio_batch = reader.dequeue(args.batch_size) + loss, optimizer, trainable = make_net(args,wavenet_params,audio_batch,reuse_variables=True) + grads = optimizer.compute_gradients(loss, var_list=trainable) + tower_losses.append(loss) + tower_grads.append(grads) + summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) + tf.get_variable_scope().reuse_variables() + + if args.num_gpus == 1: + optim = optimizer.minimize(loss, var_list=trainable) + else: + loss = tf.reduce_mean(tower_losses) + average_grads = [] + for grad_and_vars in zip(*tower_grads): + grads = [] + for g,_ in grad_and_vars: + if g is None: + continue + expanded_g = tf.expand_dims(g,0) + grads.append(expanded_g) + + if len(grads) == 0: + average_grads.append((None,v)) + continue + grad = tf.concat(0,grads) + grad = tf.reduce_mean(grad,0) + + v = grad_and_vars[0][1] + grad_and_var = (grad,v) + average_grads.append(grad_and_var) + optim = optimizer.apply_gradients(average_grads) # Set up logging for TensorBoard. writer = tf.train.SummaryWriter(logdir) writer.add_graph(tf.get_default_graph()) run_metadata = tf.RunMetadata() - summaries = tf.merge_all_summaries() + summaries = tf.merge_summary(summaries) + #summaries = tf.merge_all_summaries() # Set up session - sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) + sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) init = tf.initialize_all_variables() sess.run(init) diff --git a/wavenet/audio_reader.py b/wavenet/audio_reader.py old mode 100644 new mode 100755 diff --git a/wavenet/model.py b/wavenet/model.py index 071e4738f..5d02aa663 100644 --- a/wavenet/model.py +++ b/wavenet/model.py @@ -3,7 +3,7 @@ from .ops import causal_conv, mu_law_encode -def create_variable(name, shape): +def _create_variable(name, shape): '''Create a convolution filter variable with the specified name and shape, and initialize it using Xavier initialition.''' initializer = tf.contrib.layers.xavier_initializer_conv2d() @@ -11,13 +11,28 @@ def create_variable(name, shape): return variable -def create_bias_variable(name, shape): +def _create_bias_variable(name, shape): '''Create a bias variable with the specified name and shape and initialize it to zero.''' initializer = tf.constant_initializer(value=0.0, dtype=tf.float32) return tf.Variable(initializer(shape=shape), name) +def _get_variable(name, shape): + '''Create a convolution filter variable with the specified name and shape, + and initialize it using Xavier initialition.''' + initializer = tf.contrib.layers.xavier_initializer_conv2d() + variable = tf.get_variable(name, initializer=initializer(shape=shape)) + return variable + + +def _get_bias_variable(name, shape): + '''Create a bias variable with the specified name and shape and initialize + it to zero.''' + initializer = tf.constant_initializer(value=0.0, dtype=tf.float32) + return tf.get_variable(name, initializer=initializer(shape=shape)) + + class WaveNetModel(object): '''Implements the WaveNet network for generative audio. @@ -43,6 +58,7 @@ def __init__(self, quantization_channels=2**8, use_biases=False, scalar_input=False, + reuse_variables=False, initial_filter_width=32, histograms=False): '''Initializes the WaveNet model. @@ -83,6 +99,7 @@ def __init__(self, self.scalar_input = scalar_input self.initial_filter_width = initial_filter_width self.histograms = histograms + self.reuse_variables = reuse_variables self.variables = self._create_variables() @@ -93,6 +110,13 @@ def _create_variables(self): var = dict() + if self.reuse_variables: + create_variable = _get_variable + create_bias_variable = _get_bias_variable + else: + create_variable = _create_variable + create_bias_variable = _create_bias_variable + with tf.variable_scope('wavenet'): with tf.variable_scope('causal_layer'): layer = dict()