Skip to content

Commit

Permalink
Merge pull request #2040 from comodoro/patch-1
Browse files Browse the repository at this point in the history
Update docs to the new augmentation format
  • Loading branch information
reuben authored Dec 7, 2021
2 parents f132edc + 017de51 commit 36d3f7b
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 204 deletions.
143 changes: 84 additions & 59 deletions bin/data_set_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
Tool for building a combined SDB or CSV sample-set from other sets
Use 'python3 data_set_tool.py -h' for help
"""
import argparse
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import progressbar
from coqui_stt_training.util.audio import (
Expand All @@ -19,6 +19,11 @@
apply_sample_augmentations,
parse_augmentations,
)
from coqui_stt_training.util.config import (
BaseSttConfig,
Config,
initialize_globals_from_instance,
)
from coqui_stt_training.util.downloader import SIMPLE_BAR
from coqui_stt_training.util.sample_collections import (
CSVWriter,
Expand All @@ -31,39 +36,36 @@


def build_data_set():
audio_type = AUDIO_TYPE_LOOKUP[CLI_ARGS.audio_type]
augmentations = parse_augmentations(CLI_ARGS.augment)
audio_type = AUDIO_TYPE_LOOKUP[Config.audio_type]
augmentations = parse_augmentations(Config.augment)
print(f"Parsed augmentations from flags: {augmentations}")
if any(not isinstance(a, SampleAugmentation) for a in augmentations):
print(
"Warning: Some of the specified augmentations will not get applied, as this tool only supports "
"overlay, codec, reverb, resample and volume."
)
extension = Path(CLI_ARGS.target).suffix.lower()
labeled = not CLI_ARGS.unlabeled
extension = "".join(Path(Config.target).suffixes).lower()
labeled = not Config.unlabeled
if extension == ".csv":
writer = CSVWriter(
CLI_ARGS.target, absolute_paths=CLI_ARGS.absolute_paths, labeled=labeled
Config.target, absolute_paths=Config.absolute_paths, labeled=labeled
)
elif extension == ".sdb":
writer = DirectSDBWriter(
CLI_ARGS.target, audio_type=audio_type, labeled=labeled
)
writer = DirectSDBWriter(Config.target, audio_type=audio_type, labeled=labeled)
elif extension == ".tar":
writer = TarWriter(
CLI_ARGS.target, labeled=labeled, gz=False, include=CLI_ARGS.include
Config.target, labeled=labeled, gz=False, include=Config.include
)
elif extension == ".tgz" or CLI_ARGS.target.lower().endswith(".tar.gz"):
elif extension in (".tgz", ".tar.gz"):
writer = TarWriter(
CLI_ARGS.target, labeled=labeled, gz=True, include=CLI_ARGS.include
Config.target, labeled=labeled, gz=True, include=Config.include
)
else:
print(
raise RuntimeError(
"Unknown extension of target file - has to be either .csv, .sdb, .tar, .tar.gz or .tgz"
)
sys.exit(1)
with writer:
samples = samples_from_sources(CLI_ARGS.sources, labeled=not CLI_ARGS.unlabeled)
samples = samples_from_sources(Config.sources, labeled=not Config.unlabeled)
num_samples = len(samples)
if augmentations:
samples = apply_sample_augmentations(
Expand All @@ -74,63 +76,86 @@ def build_data_set():
change_audio_types(
samples,
audio_type=audio_type,
bitrate=CLI_ARGS.bitrate,
processes=CLI_ARGS.workers,
bitrate=Config.bitrate,
processes=Config.workers,
)
):
writer.add(sample)


def handle_args():
parser = argparse.ArgumentParser(
description="Tool for building a combined SDB or CSV sample-set from other sets"
@dataclass
class DatasetToolConfig(BaseSttConfig):
sources: List[str] = field(
default_factory=list,
metadata=dict(
help="Source CSV and/or SDB files - "
"Note: For getting a correctly ordered target set, source SDBs have to have their samples "
"already ordered from shortest to longest.",
),
)
parser.add_argument(
"sources",
nargs="+",
help="Source CSV and/or SDB files - "
"Note: For getting a correctly ordered target set, source SDBs have to have their samples "
"already ordered from shortest to longest.",
target: str = field(
default="",
metadata=dict(
help="SDB, CSV or TAR(.gz) file to create",
),
)
parser.add_argument("target", help="SDB, CSV or TAR(.gz) file to create")
parser.add_argument(
"--audio-type",
audio_type: str = field(
default="opus",
choices=AUDIO_TYPE_LOOKUP.keys(),
help="Audio representation inside target SDB",
)
parser.add_argument(
"--bitrate",
type=int,
help="Bitrate for lossy compressed SDB samples like in case of --audio-type opus",
metadata=dict(
help="Audio representation inside target SDB",
),
)
parser.add_argument(
"--workers", type=int, default=None, help="Number of encoding SDB workers"
bitrate: int = field(
default=16000,
metadata=dict(
help="Bitrate for lossy compressed SDB samples like in case of --audio-type opus",
),
)
parser.add_argument(
"--unlabeled",
action="store_true",
help="If to build an data-set with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora",
workers: Optional[int] = field(
default=None,
metadata=dict(
help="Number of encoding SDB workers",
),
)
parser.add_argument(
"--absolute-paths",
action="store_true",
help="If to reference samples by their absolute paths when writing CSV files",
unlabeled: bool = field(
default=False,
metadata=dict(
help="If to build an data-set with unlabeled (audio only) samples - "
"typically used for building noise augmentation corpora",
),
)
parser.add_argument(
"--augment",
action="append",
help="Add an augmentation operation",
absolute_paths: bool = field(
default=False,
metadata=dict(
help="If to reference samples by their absolute paths when writing CSV files",
),
)
parser.add_argument(
"--include",
action="append",
help="Adds a file to the root directory of .tar(.gz) targets",
include: List[str] = field(
default_factory=list,
metadata=dict(
help="Adds files to the root directory of .tar(.gz) targets",
),
)
return parser.parse_args()

def __post_init__(self):
if self.audio_type not in AUDIO_TYPE_LOOKUP.keys():
raise RuntimeError(
f"--audio_type must be one of {tuple(AUDIO_TYPE_LOOKUP.keys())}"
)

if not self.sources:
raise RuntimeError("No source specified with --sources")

if not self.target:
raise RuntimeError("No target specified with --target")


def main():
config = DatasetToolConfig.init_from_argparse(arg_prefix="")
initialize_globals_from_instance(config)

if __name__ == "__main__":
CLI_ARGS = handle_args()
build_data_set()


if __name__ == "__main__":
main()
Loading

0 comments on commit 36d3f7b

Please sign in to comment.