Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VTLP to the demo script #1

Open
wants to merge 3 commits into
base: feature/vtlp
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 171 additions & 137 deletions scripts/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
BandStopFilter,
TimeInversion,
Padding,
VTLP,
)
from torch_audiomentations.augmentations.shuffle_channels import ShuffleChannels
from torch_audiomentations.core.transforms_interface import ModeNotSupportedException
from torch_audiomentations.core.transforms_interface import (
ModeNotSupportedException,
MultichannelAudioNotSupportedException,
)
from torch_audiomentations.utils.object_dict import ObjectDict

SAMPLE_RATE = 44100

BASE_DIR = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
SCRIPTS_DIR = BASE_DIR / "scripts"
TEST_FIXTURES_DIR = BASE_DIR / "test_fixtures"
Expand Down Expand Up @@ -85,141 +87,173 @@ def __exit__(self, type, value, traceback):
np.random.seed(43)
random.seed(43)

filenames = ["perfect-alley1.ogg", "perfect-alley2.ogg"]
samples1, _ = librosa.load(
os.path.join(TEST_FIXTURES_DIR, filenames[0]), sr=SAMPLE_RATE, mono=False
)
samples2, _ = librosa.load(
os.path.join(TEST_FIXTURES_DIR, filenames[1]), sr=SAMPLE_RATE, mono=False
)
samples = np.stack((samples1, samples2), axis=0)
samples = torch.from_numpy(samples)

modes = ["per_batch", "per_example", "per_channel"]
for mode in modes:
transforms = [
{
"get_instance": lambda: AddBackgroundNoise(
background_paths=TEST_FIXTURES_DIR / "bg", mode=mode, p=1.0
),
"num_runs": 5,
},
{"get_instance": lambda: AddColoredNoise(mode=mode, p=1.0), "num_runs": 5},
{
"get_instance": lambda: ApplyImpulseResponse(
ir_paths=TEST_FIXTURES_DIR / "ir", mode=mode, p=1.0
),
"num_runs": 1,
},
{
"get_instance": lambda: ApplyImpulseResponse(
ir_paths=TEST_FIXTURES_DIR / "ir",
compensate_for_propagation_delay=True,
mode=mode,
p=1.0,
),
"name": "ApplyImpulseResponse with compensate_for_propagation_delay set to True",
"num_runs": 1,
},
{"get_instance": lambda: BandPassFilter(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: BandStopFilter(mode=mode, p=1.0), "num_runs": 5},
{
"get_instance": lambda: Compose(
transforms=[
Gain(
min_gain_in_db=-18.0, max_gain_in_db=-16.0, mode=mode, p=1.0
),
PeakNormalization(mode=mode, p=1.0),
],
shuffle=True,
),
"name": "Shuffled Compose with Gain and PeakNormalization",
"num_runs": 5,
},
{
"get_instance": lambda: Compose(
transforms=[
Gain(
min_gain_in_db=-18.0, max_gain_in_db=-16.0, mode=mode, p=0.5
),
PolarityInversion(mode=mode, p=0.5),
],
shuffle=True,
),
"name": "Compose with Gain and PolarityInversion",
"num_runs": 5,
},
{"get_instance": lambda: Gain(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: HighPassFilter(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: LowPassFilter(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: Padding(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: PeakNormalization(mode=mode, p=1.0), "num_runs": 1},
{
"get_instance": lambda: PitchShift(
sample_rate=SAMPLE_RATE, mode=mode, p=1.0
),
"num_runs": 5,
},
{"get_instance": lambda: PolarityInversion(mode=mode, p=1.0), "num_runs": 1},
{"get_instance": lambda: Shift(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: ShuffleChannels(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: TimeInversion(mode=mode, p=1.0), "num_runs": 1},
]

execution_times = {}

for transform in transforms:
try:
augmenter = transform["get_instance"]()
except ModeNotSupportedException:
continue
transform_name = (
transform.get("name")
if transform.get("name")
else augmenter.__class__.__name__
batch = [["p286_011.wav"], ["perfect-alley1.ogg", "perfect-alley2.ogg"]]

for filenames in batch:
audios = []
batch_sample_rate = None
for filename in filenames:
samples1, batch_sample_rate = librosa.load(
os.path.join(TEST_FIXTURES_DIR, filename), sr=None, mono=False
)
execution_times[transform_name] = []
for i in range(transform["num_runs"]):
with timer() as t:
augmented_samples = augmenter(
samples=samples, sample_rate=SAMPLE_RATE
)
print(
augmenter.__class__.__name__,
"is output ObjectDict:",
type(augmented_samples) is ObjectDict,
)
augmented_samples = (
augmented_samples.samples.numpy()
if type(augmented_samples) is ObjectDict
else augmented_samples.numpy()
)
execution_times[transform_name].append(t.execution_time)
for example_idx, original_filename in enumerate(filenames):
output_file_path = os.path.join(
output_dir,
"{}_{}_{:03d}_{}.wav".format(
transform_name, mode, i, Path(original_filename).stem
),
)
wavfile.write(
output_file_path,
rate=SAMPLE_RATE,
data=augmented_samples[example_idx].transpose(),
)
if samples1.ndim == 1:
samples1 = samples1.reshape((1, -1))
audios.append(samples1)
samples = np.stack(audios, axis=0)
samples = torch.from_numpy(samples)

for transform_name in execution_times:
if len(execution_times[transform_name]) > 1:
print(
"{:<52} {:.3f} s (std: {:.3f} s)".format(
transform_name,
np.mean(execution_times[transform_name]),
np.std(execution_times[transform_name]),
)
modes = ["per_example"]
if samples.shape[0] > 1:
modes.append("per_batch")
if samples.shape[1] > 1:
modes.append("per_channel")
for mode in modes:
transforms = [
{
"get_instance": lambda: AddBackgroundNoise(
background_paths=TEST_FIXTURES_DIR / "bg", mode=mode, p=1.0
),
"num_runs": 5,
},
{
"get_instance": lambda: AddColoredNoise(mode=mode, p=1.0),
"num_runs": 5,
},
{
"get_instance": lambda: ApplyImpulseResponse(
ir_paths=TEST_FIXTURES_DIR / "ir", mode=mode, p=1.0
),
"num_runs": 1,
},
{
"get_instance": lambda: ApplyImpulseResponse(
ir_paths=TEST_FIXTURES_DIR / "ir",
compensate_for_propagation_delay=True,
mode=mode,
p=1.0,
),
"name": "ApplyImpulseResponse with compensate_for_propagation_delay set to True",
"num_runs": 1,
},
{"get_instance": lambda: BandPassFilter(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: BandStopFilter(mode=mode, p=1.0), "num_runs": 5},
{
"get_instance": lambda: Compose(
transforms=[
Gain(
min_gain_in_db=-18.0,
max_gain_in_db=-16.0,
mode=mode,
p=1.0,
),
PeakNormalization(mode=mode, p=1.0),
],
shuffle=True,
),
"name": "Shuffled Compose with Gain and PeakNormalization",
"num_runs": 5,
},
{
"get_instance": lambda: Compose(
transforms=[
Gain(
min_gain_in_db=-18.0,
max_gain_in_db=-16.0,
mode=mode,
p=0.5,
),
PolarityInversion(mode=mode, p=0.5),
],
shuffle=True,
),
"name": "Compose with Gain and PolarityInversion",
"num_runs": 5,
},
{"get_instance": lambda: Gain(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: HighPassFilter(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: LowPassFilter(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: Padding(mode=mode, p=1.0), "num_runs": 5},
{
"get_instance": lambda: PeakNormalization(mode=mode, p=1.0),
"num_runs": 1,
},
{
"get_instance": lambda: PitchShift(
sample_rate=batch_sample_rate, mode=mode, p=1.0
),
"num_runs": 5,
},
{
"get_instance": lambda: PolarityInversion(mode=mode, p=1.0),
"num_runs": 1,
},
{"get_instance": lambda: Shift(mode=mode, p=1.0), "num_runs": 5},
{
"get_instance": lambda: ShuffleChannels(mode=mode, p=1.0),
"num_runs": 5,
},
{"get_instance": lambda: VTLP(mode=mode, p=1.0), "num_runs": 5},
{"get_instance": lambda: TimeInversion(mode=mode, p=1.0), "num_runs": 1},
]

execution_times = {}

for transform in transforms:
try:
augmenter = transform["get_instance"]()
except ModeNotSupportedException:
continue
transform_name = (
transform.get("name")
if transform.get("name")
else augmenter.__class__.__name__
)
else:
print(
"{:<52} {:.3f} s".format(
transform_name, np.mean(execution_times[transform_name])
execution_times[transform_name] = []
for i in range(transform["num_runs"]):
with timer() as t:
try:
augmented_samples = augmenter(
samples=samples, sample_rate=batch_sample_rate
)
except MultichannelAudioNotSupportedException as e:
print(e)
continue
print(
augmenter.__class__.__name__,
"is output ObjectDict:",
type(augmented_samples) is ObjectDict,
)
augmented_samples = (
augmented_samples.samples.numpy()
if type(augmented_samples) is ObjectDict
else augmented_samples.numpy()
)
execution_times[transform_name].append(t.execution_time)
for example_idx, original_filename in enumerate(filenames):
output_file_path = os.path.join(
output_dir,
"{}_{}_{:03d}_{}.wav".format(
transform_name, mode, i, Path(original_filename).stem
),
)
wavfile.write(
output_file_path,
rate=batch_sample_rate,
data=augmented_samples[example_idx].transpose(),
)

for transform_name in execution_times:
if len(execution_times[transform_name]) > 1:
print(
"{:<52} {:.3f} s (std: {:.3f} s)".format(
transform_name,
np.mean(execution_times[transform_name]),
np.std(execution_times[transform_name]),
)
)
else:
print(
"{:<52} {:.3f} s".format(
transform_name, np.mean(execution_times[transform_name])
)
)
)
Binary file added test_fixtures/p286_011.wav
Binary file not shown.
1 change: 1 addition & 0 deletions test_fixtures/p286_011_license
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
p286_011.wav comes from "Noisy speech database for training speech enhancement algorithms and TTS models", published by University of Edinburgh. School of Informatics. Centre for Speech Technology Research (CSTR). The license is Creative Commons License: Attribution 4.0 International. For more info, including a link to the original license file, see https://datashare.ed.ac.uk/handle/10283/2791
2 changes: 0 additions & 2 deletions torch_audiomentations/augmentations/vtlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def __init__(
)
if min_warp_factor >= max_warp_factor:
raise ValueError("max_warp_factor must be > min_warp_factor")
if not sample_rate:
raise ValueError("sample_rate is invalid.")

self.min_warp_factor = min_warp_factor
self.max_warp_factor = max_warp_factor
Expand Down