Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Dec 6, 2024
1 parent ce73643 commit 2504036
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 3 additions & 4 deletions egs/wenetspeech4tts/TTS/valle/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Mingshuang Luo)
# Copyright 2023 (authors: Feiteng Li)
# Copyright 2024 (authors: Yuekai Zhang)
# Copyright 2024 Tsinghua University (authors: Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -48,10 +49,8 @@
import argparse
import copy
import logging
import os
import random
import warnings
from contextlib import nullcontext
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -686,9 +685,9 @@ def compute_validation_loss(
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
output_dir.mkdir(parents=True, exist_ok=True)
if isinstance(model, DDP):
model.module.visualize(predicts, batch, output_dir=output_dir)
model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir)
else:
model.visualize(predicts, batch, output_dir=output_dir)
model.visualize(predicts, batch, tokenizer, output_dir=output_dir)

return tot_loss

Expand Down
7 changes: 5 additions & 2 deletions egs/wenetspeech4tts/TTS/valle/valle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import torch
import torch.nn as nn
from tokenizer import TextTokenCollater
from torch import Tensor
from torch.nn import Linear, Module
from torch.nn import functional as F
Expand Down Expand Up @@ -1664,13 +1665,15 @@ def visualize(
self,
predicts: Tuple[torch.Tensor],
batch: Dict[str, Union[List, torch.Tensor]],
tokenizer: TextTokenCollater,
output_dir: str,
limit: int = 4,
) -> None:
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
audio_features = batch["audio_features"].to("cpu").detach().numpy()
audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()

tokens = batch["tokens"]
text_tokens, text_tokens_lens = tokenizer(tokens)
assert text_tokens.ndim == 2

utt_ids, texts = batch["utt_id"], batch["text"]
Expand Down

0 comments on commit 2504036

Please sign in to comment.