Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced resample function by new implementation (2) #1412

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 98 additions & 40 deletions tensorflow_io/python/ops/audio_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sys

import tensorflow as tf

import math
from tensorflow_io.python.ops import core_ops


Expand Down Expand Up @@ -372,55 +372,113 @@ def fade(input, fade_in, fade_out, mode, name=None):
return factor_in * factor_out * input


def resample(input, rate_in, rate_out, name=None):
"""Resample audio.
def _get_sinc_resample_kernel(rate_in, rate_out, lowpass_filter_width):
assert lowpass_filter_width > 0
rate_in=tf.cast(rate_in,tf.float32)
rate_out=tf.cast(rate_out,tf.float32)
base_freq = tf.minimum(rate_in, rate_out)
# This will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
# which will add high freq artifacts.
base_freq *= 0.99

# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * rate_in * (i / rate_in - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / rate_out)
# or,
# y[j] = sum_i x[i] sinc(pi * rate_in * (i / rate_in - j / rate_out))

# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + rate_out].
# Indeed:
# y[j + rate_out] = sum_i x[i] sinc(pi * rate_in * ((i / rate_in - (j + rate_out) / rate_out))
# = sum_i x[i] sinc(pi * rate_in * ((i - rate_in) / rate_in - j / rate_out))
# = sum_i x[i + rate_in] sinc(pi * rate_in * (i / rate_in - j / rate_out))
# so y[j+rate_out] uses the same filter as y[j], but on a shifted version of x by `rate_in`.
# This will explain the F.conv1d after, with a stride of rate_in.
width = tf.experimental.numpy.ceil(lowpass_filter_width * rate_in / base_freq)
# If rate_in is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = tf.range(-width, width + rate_in, dtype=tf.float32)
idx = tf.repeat(tf.expand_dims(idx, axis=-1), tf.cast(rate_out,tf.int32), axis=-1)
aux_i = tf.expand_dims(tf.range(rate_out, dtype=tf.float32), axis=0)
kernels = (-aux_i / rate_out + idx / rate_in) * base_freq

kernels = tf.clip_by_value(kernels, -lowpass_filter_width, lowpass_filter_width)
kernels *= math.pi

window = tf.math.cos(kernels / lowpass_filter_width / 2) ** 2
kernels = tf.where(
kernels == 0, tf.ones_like(kernels), tf.math.sin(kernels) / kernels
)
kernels *= window

scale = base_freq / rate_in
return tf.expand_dims(kernels, axis=1) * scale, width


def resample(input, rate_in, rate_out, lowpass_filter_width=6):
"""Resamples the input at the new frequency. This matches Kaldi’s OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e the output signal has a frequency of rate_out). It uses sinc/bandlimited interpolation to upsample/downsample the signal.

Args:
input: A 1-D (`[samples]`) or 2-D (`[samples, channels]`) or 3-D
(`[batch, samples, channels]`) `Tensor` of type
`int16` or `float`. Audio input.
input: A 1-D (`[samples]`) or 2-D (`[samples, channels]`) or 3-D (`[batch, samples, channels]`) `Tensor` of type `float`. Audio input.
rate_in: The rate of the audio input.
rate_out: The rate of the audio output.
name: A name for the operation (optional).
lowpass_filter_width: Controls the sharpness of the filter, more == sharper but less efficient. We suggest around 4 to 10 for normal use. (Default: 6)

Returns:
output: Resampled audio.
"""
rank = tf.rank(input)

def f1():
return tf.expand_dims(tf.expand_dims(input, -1), 0)

def f2():
return tf.expand_dims(input, 0)

def f3():
return input

input = tf.case(
[(tf.math.equal(rank, 1), f1), (tf.math.equal(rank, 2), f2)], default=f3
)

def f(i):
return core_ops.io_audio_resample(
i, rate_in=rate_in, rate_out=rate_out, name=name
waveform = input

if rate_in == rate_out:
return waveform
rate_in = tf.cast(rate_in,tf.int32)
rate_out = tf.cast(rate_out,tf.int32)
gcd = tf.experimental.numpy.gcd(rate_in, rate_out)
rate_in = rate_in // gcd
rate_out = rate_out // gcd

kernel, width = _get_sinc_resample_kernel(rate_in, rate_out, lowpass_filter_width)
width=tf.cast(width,tf.int32)

ori_shape = waveform.shape
ori_shape_len = len(ori_shape)
if ori_shape_len == 1:
waveform = tf.expand_dims(waveform, axis=0)
elif ori_shape_len == 2:
waveform = tf.transpose(waveform, [1, 0])
elif ori_shape_len == 3:
waveform = tf.transpose(waveform, [0, 2, 1])
waveform = tf.reshape(waveform, [ori_shape[0] * ori_shape[2], ori_shape[1]])

waveform = tf.expand_dims(waveform, axis=-1)

num_wavs, length, _ = waveform.shape

waveform = tf.pad(waveform, [[0, 0], [width, width + rate_in], [0, 0]])
resampled = tf.nn.conv1d(waveform, kernel, stride=tf.reshape(rate_in,[1,]), padding="VALID")
resampled = tf.reshape(resampled, [num_wavs, -1])
target_length = tf.cast(tf.experimental.numpy.ceil(rate_out * length / rate_in),tf.int32)
if ori_shape_len == 1:
return resampled[0, :target_length]
elif ori_shape_len == 2:
return tf.transpose(resampled[:, :target_length], [1, 0])
elif ori_shape_len == 3:
return tf.transpose(
tf.reshape(
resampled[:, :target_length],
[ori_shape[0], ori_shape[2], target_length],
),
[0, 2, 1],
)

value = tf.vectorized_map(f, input)

def g1():
return tf.squeeze(value, [0, -1])

def g2():
return tf.squeeze(value, [0])

def g3():
return value

return tf.case(
[(tf.math.equal(rank, 1), g1), (tf.math.equal(rank, 2), g2)], default=g3
)


def decode_wav(
input, shape=None, dtype=None, name=None
Expand Down