Skip to content

Commit

Permalink
refactor(encoder): simplify code and run test from python
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Jan 29, 2025
1 parent b8ede07 commit 7036788
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 153 deletions.
152 changes: 122 additions & 30 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
#!/usr/bin/env python3

# TODO: use Trainer

import logging
import os
import sys
import time
import traceback
import warnings
from dataclasses import dataclass, field

import torch
from torch.utils.data import DataLoader
from trainer.generic_utils import count_parameters, remove_experiment_folder
from trainer.io import copy_model_files, save_best_model, save_checkpoint
from trainer import TrainerArgs, TrainerConfig
from trainer.generic_utils import count_parameters, get_experiment_folder_path, get_git_branch
from trainer.io import copy_model_files, get_last_checkpoint, save_best_model, save_checkpoint
from trainer.logging import BaseDashboardLogger, ConsoleLogger, logger_factory
from trainer.torch import NoamLR
from trainer.trainer_utils import get_optimizer

from TTS.config import load_config, register_config
from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import setup_encoder_model
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.text.characters import parse_symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
from TTS.utils.samplers import PerfectBatchSampler
Expand All @@ -33,7 +39,77 @@
print(" > Number of GPUs: ", num_gpus)


def setup_loader(ap: AudioProcessor, is_val: bool = False):
@dataclass
class TrainArgs(TrainerArgs):
config_path: str | None = field(default=None, metadata={"help": "Path to the config file."})


def process_args(
args, config: BaseEncoderConfig | None = None
) -> tuple[BaseEncoderConfig, str, str, ConsoleLogger, BaseDashboardLogger | None]:
"""Process parsed comand line arguments and initialize the config if not provided.
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
Returns:
c (Coqpit): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
TODO:
- Interactive config definition.
"""
coqpit_overrides = None
if isinstance(args, tuple):
args, coqpit_overrides = args
if args.continue_path:
# continue a previous training from its output folder
experiment_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path:
args.best_path = best_model
# init config if not already defined
if config is None:
if args.config_path:
# init from a file
config = load_config(args.config_path)
else:
# init from console args
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel

config_base = BaseTrainingConfig()
config_base.parse_known_args(coqpit_overrides)
config = register_config(config_base.model)()
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
experiment_path = args.continue_path
if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
audio_path = os.path.join(experiment_path, "test_audios")
config.output_log_path = experiment_path
# setup rank 0 process in distributed training
dashboard_logger = None
if args.rank == 0:
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if config.has("characters") and config.characters is None:
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = logger_factory(config, experiment_path)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger


def setup_loader(c: TrainerConfig, ap: AudioProcessor, is_val: bool = False):
num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class
num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch

Expand Down Expand Up @@ -83,7 +159,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False):
return loader, classes, dataset.get_map_classid_to_classname()


def evaluation(model, criterion, data_loader, global_step):
def evaluation(c: BaseEncoderConfig, model, criterion, data_loader, global_step, dashboard_logger: BaseDashboardLogger):
eval_loss = 0
for _, data in enumerate(data_loader):
with torch.inference_mode():
Expand Down Expand Up @@ -127,7 +203,17 @@ def evaluation(model, criterion, data_loader, global_step):
return eval_avg_loss


def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
def train(
c: BaseEncoderConfig,
model,
optimizer,
scheduler,
criterion,
data_loader,
eval_data_loader,
global_step,
dashboard_logger: BaseDashboardLogger,
):
model.train()
best_loss = {"train_loss": None, "eval_loss": float("inf")}
avg_loader_time = 0
Expand Down Expand Up @@ -226,7 +312,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
if global_step % c.save_step == 0:
# save model
save_checkpoint(
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
c, model, optimizer, None, global_step, epoch, c.output_log_path, criterion=criterion.state_dict()
)

end_time = time.time()
Expand All @@ -240,7 +326,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
# evaluation
if c.run_eval:
model.eval()
eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
eval_loss = evaluation(c, model, criterion, eval_data_loader, global_step, dashboard_logger)
print("\n\n")
print("--> EVAL PERFORMANCE")
print(
Expand All @@ -257,15 +343,21 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
None,
global_step,
epoch,
OUT_PATH,
c.output_log_path,
criterion=criterion.state_dict(),
)
model.train()

return best_loss, global_step


def main(args): # pylint: disable=redefined-outer-name
def main(arg_list: list[str] | None = None):
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())

train_config = TrainArgs()
parser = train_config.init_argparse(arg_prefix="")
args, overrides = parser.parse_known_args(arg_list)
c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args((args, overrides))
# pylint: disable=global-variable-undefined
global meta_data_train
global meta_data_eval
Expand All @@ -279,9 +371,9 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True)

train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False)
train_data_loader, train_classes, map_classid_to_classname = setup_loader(c, ap, is_val=False)
if c.run_eval:
eval_data_loader, _, _ = setup_loader(ap, is_val=True)
eval_data_loader, _, _ = setup_loader(c, ap, is_val=True)
else:
eval_data_loader = None

Expand Down Expand Up @@ -313,23 +405,23 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda()

global_step = args.restore_step
_, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step)
_, global_step = train(
c, model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step, dashboard_logger
)
sys.exit(0)


if __name__ == "__main__":
setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter())

args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()

try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)
main()
# try:
# main()
# except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)
# try:
# sys.exit(0)
# except SystemExit:
# os._exit(0) # pylint: disable=protected-access
# except Exception: # pylint: disable=broad-except
# remove_experiment_folder(OUT_PATH)
# traceback.print_exc()
# sys.exit(1)
6 changes: 5 additions & 1 deletion TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scipy import signal

from TTS.encoder.models.base_encoder import BaseEncoder
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder

Expand Down Expand Up @@ -120,7 +121,7 @@ def apply_one(self, audio):
return self.additive_noise(noise_type, audio)


def setup_encoder_model(config: "Coqpit"):
def setup_encoder_model(config: "Coqpit") -> BaseEncoder:
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
config.model_params["input_dim"],
Expand All @@ -138,4 +139,7 @@ def setup_encoder_model(config: "Coqpit"):
use_torch_spec=config.model_params.get("use_torch_spec", False),
audio_config=config.audio,
)
else:
msg = f"Model not supported: {config.model_params['model_name']}"
raise ValueError(msg)
return model
99 changes: 0 additions & 99 deletions TTS/encoder/utils/training.py

This file was deleted.

5 changes: 0 additions & 5 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ def get_tests_output_path():
return path


def run_cli(command):
exit_status = os.system(command)
assert exit_status == 0, f" [!] command `{command}` failed."


def run_main(main_func: Callable, args: list[str] | None = None, expected_code: int = 0):
with pytest.raises(SystemExit) as exc_info:
main_func(args)
Expand Down
Loading

0 comments on commit 7036788

Please sign in to comment.