diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index da3a49ad..a8510ff4 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -107,9 +107,14 @@ class TimestampGenerationTask: default="timestamps", alias="-tsc" ) sample_rate: int = simple_parsing.field(default=16000, alias="-r") - temp_dir: Optional[str] = None aligned_ratio_check: float = simple_parsing.field(default=0.95, alias="-ar") + def _get_new_temp_dir(self) -> str: + if self._temp_dir is not None: + self._temp_dir.cleanup() + self._temp_dir = tempfile.TemporaryDirectory() + self.temp_dir = self._temp_dir.name + def __post_init__(self): self._temp_dir: Optional[tempfile.TemporaryDirectory] = None @@ -175,13 +180,7 @@ def get_id(sample): raise ValueError("Could not find an ID in the sample") def _preprocess(self, ds_split: datasets.Dataset) -> datasets.Dataset: - if self.temp_dir is None: - if self._temp_dir is not None: - self._temp_dir.cleanup() - self._temp_dir = tempfile.TemporaryDirectory() - self.temp_dir = self._temp_dir.name - - os.makedirs(self.temp_dir, exist_ok=True) + self._get_new_temp_dir() # 1. copy all audio-text pairs into a temp directory for sample in ds_split: @@ -537,9 +536,6 @@ def main(args: DatasetToolArgs): if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - # if hasattr(args.task, "preprocess"): - # args.task.preprocess(ds_split) - ds_chunk_proc.process_and_upload_split_rescursive( split_name, ds_split, 0, len(ds_split) )