Skip to content

Fix sample generation for scalar inputs #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Fix sample generation for scalar inputs #292

wants to merge 1 commit into from

Conversation

nsfinkelstein
Copy link

Sample generation with scalar inputs is broken. When it is reshaped in predict_proba, it is given too few dimensions. This pull requests corrects the number of dimensions.

Here is a simple reproducible example.

Set the value of wavenet_params.json to the following:

{
    "filter_width": 2,
    "sample_rate": 16000,
    "dilations": [1, 2, 4, 8],
    "residual_channels": 32,
    "dilation_channels": 32,
    "quantization_channels": 256,
    "skip_channels": 512,
    "use_biases": true,
    "scalar_input": true,
    "initial_filter_width": 2
}

Training:

$ python train.py --data_dir train_data --silence_threshold 0.0001
Using default logdir: ./logdir/train/2017-09-18T13-42-07
files length: 100
step 0 - loss = 5.537, (1.306 sec/step)
Storing checkpoint to ./logdir/train/2017-09-18T13-42-07 ...files length: 100

Generating (before the fix):

$ python generate.py logdir/train/2017-09-18T13-42-07/model.ckpt-0 --fast_generation false

Traceback (most recent call last):
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 671, in _call_cpp_shape_fn_impl
    input_tensors_as_shapes, status)
  File "/home/noam/code/anaconda/lib/python3.6/contextlib.py", line 89, in __exit__
    next(self.gen)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 4 but is rank 3 for 'wavenet_1/causal_layer/causal_conv/conv1d/Conv2D' (op: 'Conv2D') with input shapes: [?,1,1], [1,2,1,32].

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "generate.py", line 279, in <module>
    main()
  File "generate.py", line 163, in main
    next_sample = net.predict_proba(samples, args.gc_id)
  File "/home/noam/code/heartrate-wavenet/wavenet/model.py", line 586, in predict_proba
    raw_output = self._create_network(encoded, gc_embedding)
  File "/home/noam/code/heartrate-wavenet/wavenet/model.py", line 406, in _create_network
    current_layer = self._create_causal_layer(current_layer)
  File "/home/noam/code/heartrate-wavenet/wavenet/model.py", line 243, in _create_causal_layer
    return causal_conv(input_batch, weights_filter, 1)
  File "/home/noam/code/heartrate-wavenet/wavenet/ops.py", line 55, in causal_conv
    restored = tf.nn.conv1d(value, filter_, stride=1, padding='VALID')
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/ops/nn_ops.py", line 2010, in conv1d
    data_format=data_format)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 399, in conv2d
    data_format=data_format, name=name)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op
    set_shapes_for_outputs(ret)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)
  File "/home/noam/code/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shape must be rank 4 but is rank 3 for 'wavenet_1/causal_layer/causal_conv/conv1d/Conv2D' (op: 'Conv2D') with input shapes: [?,1,1], [1,2,1,32].

Generating (after the fix):

$ python generate.py logdir/train/2017-09-18T13-42-07/model.ckpt-0 --fast_generation false
Restoring model from logdir/train/2017-09-18T13-42-07/model.ckpt-0
Sample 15506/16000
Finished generating. The result can be viewed in TensorBoard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant