Skip to content

Commit ac2bbd6

Browse files
authored
Merge pull request coqui-ai#1868 from JRMeyer/data-augmentation-cleaning
Add logging and clean up some augmentation code
2 parents d2c5f97 + 7bec52c commit ac2bbd6

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

training/coqui_stt_training/util/augmentations.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import os
32
import re
43
import math
@@ -10,6 +9,7 @@
109
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
1110
from .helpers import LimitingPool, int_range, float_range, pick_value_from_range, tf_pick_value_from_range, MEGABYTE
1211
from .sample_collections import samples_from_source, unpack_maybe
12+
from .logging import log_info
1313

1414
BUFFER_SIZE = 1 * MEGABYTE
1515
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z_]+)(\[(?P<params>.*)\])?$')
@@ -90,6 +90,7 @@ def parse_augmentation(augmentation_spec):
9090
kwargs[pair[0]] = pair[1]
9191
else:
9292
raise ValueError('Unable to parse augmentation value assignment')
93+
log_info('Processed augmentation type: [{}] with parameter settings: {}'.format(augmentation_cls.__name__, kwargs))
9394
return augmentation_cls(*args, **kwargs)
9495

9596

@@ -106,7 +107,7 @@ def parse_augmentations(augmentation_specs):
106107
-------
107108
List of augmentation class instances from util.augmentations.*.
108109
"""
109-
return [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
110+
return list(map(parse_augmentation, augmentation_specs or []))
110111

111112

112113
def apply_graph_augmentations(domain, tensor, augmentations, transcript=None, clock=0.0):

training/coqui_stt_training/util/helpers.py

+39-19
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,41 @@ def do_iterate():
163163

164164

165165
def get_value_range(value, target_type):
166+
"""
167+
This function converts all possible supplied values for augmentation
168+
into the [start,end,r] ValueRange type. The expected inputs are of the form:
169+
170+
<number>
171+
<number>~<number>
172+
<number>:<number>~<number>
173+
174+
Any "missing" values are filled so that ValueRange always includes [start,end,r].
175+
"""
166176
if isinstance(value, str):
167-
r = target_type(0)
168-
parts = value.split('~')
169-
if len(parts) == 2:
177+
if '~' in value:
178+
parts = value.split('~')
179+
if len(parts) != 2:
180+
raise ValueError('Cannot parse value range')
170181
value = parts[0]
171-
r = target_type(parts[1])
172-
elif len(parts) > 2:
173-
raise ValueError('Cannot parse value range')
182+
r = parts[1]
183+
else:
184+
r = 0 # if no <r> supplied, use 0
174185
parts = value.split(':')
175186
if len(parts) == 1:
176-
parts.append(parts[0])
177-
elif len(parts) > 2:
187+
parts.append(parts[0]) # only one <value> given, so double it
188+
if len(parts) != 2:
178189
raise ValueError('Cannot parse value range')
179-
return ValueRange(target_type(parts[0]), target_type(parts[1]), r)
190+
return ValueRange(target_type(parts[0]), target_type(parts[1]), target_type(r))
180191
if isinstance(value, tuple):
181192
if len(value) == 2:
182-
return ValueRange(target_type(value[0]), target_type(value[1]), 0)
193+
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(0))
183194
if len(value) == 3:
184195
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[2]))
185-
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
186-
return ValueRange(target_type(value), target_type(value), 0)
196+
else:
197+
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
198+
if isinstance(value, int) or isinstance(value, float):
199+
return ValueRange(target_type(value), target_type(value), target_type(0))
200+
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
187201

188202

189203
def int_range(value):
@@ -203,14 +217,20 @@ def pick_value_from_range(value_range, clock=None):
203217

204218
def tf_pick_value_from_range(value_range, clock=None, double_precision=False):
205219
import tensorflow as tf # pylint: disable=import-outside-toplevel
206-
clock = (tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64) if clock is None
207-
else tf.maximum(tf.constant(0.0, dtype=tf.float64), tf.minimum(tf.constant(1.0, dtype=tf.float64), clock)))
220+
if clock is None:
221+
clock = tf.random.stateless_uniform([], seed=(-1, 1), dtype=tf.float64)
222+
else:
223+
clock = tf.maximum(tf.constant(0.0, dtype=tf.float64),
224+
tf.minimum(tf.constant(1.0, dtype=tf.float64), clock))
208225
value = value_range.start + clock * (value_range.end - value_range.start)
209-
value = tf.random.stateless_uniform([],
210-
minval=value - value_range.r,
211-
maxval=value + value_range.r,
212-
seed=(clock * tf.int32.min, clock * tf.int32.max),
213-
dtype=tf.float64)
226+
if value_range.r:
227+
# if the option <r> (<value>~<r>, randomization radius) is supplied,
228+
# sample the value from a uniform distribution with "radius" <r>
229+
value = tf.random.stateless_uniform([],
230+
minval=value - value_range.r,
231+
maxval=value + value_range.r,
232+
seed=(clock * tf.int32.min, clock * tf.int32.max),
233+
dtype=tf.float64)
214234
if isinstance(value_range.start, int):
215235
return tf.cast(tf.math.round(value), tf.int64 if double_precision else tf.int32)
216236
return tf.cast(value, tf.float64 if double_precision else tf.float32)

0 commit comments

Comments
 (0)