Skip to content

Commit 042abd0

Browse files
committed
🧀 Use drop_remainder=True.
1 parent b301fbb commit 042abd0

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

examples/fastspeech/fastspeech_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ def create(
246246
# define padded shapes
247247
padded_shapes = {"utt_ids": [], "input_ids": [None]}
248248

249-
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
249+
datasets = datasets.padded_batch(
250+
batch_size, padded_shapes=padded_shapes, drop_remainder=True
251+
)
250252
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
251253
return datasets
252254

examples/fastspeech2/fastspeech2_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def create(
227227
"mel_lengths": [],
228228
}
229229

230-
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
230+
datasets = datasets.padded_batch(
231+
batch_size, padded_shapes=padded_shapes, drop_remainder=True
232+
)
231233
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
232234
return datasets
233235

examples/fastspeech2_libritts/fastspeech2_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def create(
231231
"mel_lengths": [],
232232
}
233233

234-
datasets = datasets.padded_batch(batch_size, padded_shapes=padded_shapes)
234+
datasets = datasets.padded_batch(
235+
batch_size, padded_shapes=padded_shapes, drop_remainder=True
236+
)
235237
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
236238
return datasets
237239

examples/melgan/audio_mel_dataset.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ def generator(self, utt_ids):
8080
for i, utt_id in enumerate(utt_ids):
8181
audio_file = self.audio_files[i]
8282
mel_file = self.mel_files[i]
83-
83+
8484
items = {
8585
"utt_ids": utt_id,
8686
"audio_files": audio_file,
87-
"mel_files": mel_file
87+
"mel_files": mel_file,
8888
}
8989

9090
yield items
91-
91+
9292
@tf.function
9393
def _load_data(self, items):
9494
audio = tf.numpy_function(np.load, [items["audio_files"]], tf.float32)
@@ -101,7 +101,7 @@ def _load_data(self, items):
101101
"mel_lengths": len(mel),
102102
"audio_lengths": len(audio),
103103
}
104-
104+
105105
return items
106106

107107
def create(
@@ -120,8 +120,7 @@ def create(
120120

121121
# load dataset
122122
datasets = datasets.map(
123-
lambda items: self._load_data(items),
124-
tf.data.experimental.AUTOTUNE
123+
lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE
125124
)
126125

127126
datasets = datasets.filter(
@@ -165,17 +164,19 @@ def create(
165164
}
166165

167166
datasets = datasets.padded_batch(
168-
batch_size, padded_shapes=padded_shapes, padding_values=padding_values
167+
batch_size,
168+
padded_shapes=padded_shapes,
169+
padding_values=padding_values,
170+
drop_remainder=True,
169171
)
170172
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
171-
172173
return datasets
173174

174175
def get_output_dtypes(self):
175176
output_types = {
176177
"utt_ids": tf.string,
177178
"audio_files": tf.string,
178-
"mel_files": tf.string
179+
"mel_files": tf.string,
179180
}
180181
return output_types
181182

examples/tacotron2/tacotron_dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ def create(
235235
}
236236

237237
datasets = datasets.padded_batch(
238-
batch_size, padded_shapes=padded_shapes, padding_values=padding_values
238+
batch_size,
239+
padded_shapes=padded_shapes,
240+
padding_values=padding_values,
241+
drop_remainder=True,
239242
)
240243
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
241244
return datasets

0 commit comments

Comments
 (0)