Skip to content

Commit

Permalink
Clean up the sample repeat code
Browse files Browse the repository at this point in the history
  • Loading branch information
liPatrick committed Sep 13, 2024
1 parent 9cdc08b commit 600fe70
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,23 @@ def _map_sample_repeat(self, sample):
sentence = sample[self.asr_column_name]
translation = sample[self.translation_column_name]

if isinstance(audio, dict):
audio_data = audio["array"]
else:
if not isinstance(audio, dict) or "array" not in audio:
raise ValueError(f"Unsupported audio format: {type(audio)}")

audio_data = audio["array"]
repeated_audio = np.tile(audio_data, self.multiplier)
repeated_sentence = " ".join([sentence] * self.multiplier)
repeated_translation = " ".join([translation] * self.multiplier)

new_sample = {}
new_sample[self.audio_column_name]["array"] = repeated_audio
new_sample[self.audio_column_name].pop("path")
new_sample[self.asr_column_name] = repeated_sentence
new_sample[self.translation_column_name] = repeated_translation
new_sample[self.id_column_name] = sample[self.id_column_name]
new_audio = {key: value for key, value in audio.items() if key != "path"}
new_audio["array"] = repeated_audio

new_sample = {
self.audio_column_name: new_audio,
self.asr_column_name: repeated_sentence,
self.translation_column_name: repeated_translation,
self.id_column_name: sample[self.id_column_name],
}

return new_sample

Expand Down

0 comments on commit 600fe70

Please sign in to comment.