Skip to content

Commit

Permalink
applied text norm to valid & test cuts
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Oct 7, 2024
1 parent f074487 commit 156af46
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions egs/libritts/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
modified_beam_search_LODR,
)
from lhotse import set_caching_enabled
from train import add_model_arguments, get_model, get_params
from train import add_model_arguments, get_model, get_params, normalize_text

from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
Expand Down Expand Up @@ -1043,8 +1043,8 @@ def main():
args.return_cuts = True
libritts = LibriTTSAsrDataModule(args)

test_clean_cuts = libritts.test_clean_cuts()
test_other_cuts = libritts.test_other_cuts()
test_clean_cuts = libritts.test_clean_cuts().map(normalize_text)
test_other_cuts = libritts.test_other_cuts().map(normalize_text)

test_clean_dl = libritts.test_dataloaders(test_clean_cuts)
test_other_dl = libritts.test_dataloaders(test_other_cuts)
Expand Down
28 changes: 14 additions & 14 deletions egs/libritts/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,13 +603,18 @@ def _to_int_tuple(s: str):
return tuple(map(int, s.split(",")))


def remove_punc_to_upper(text: str) -> str:
text = text.replace("‘", "'")
text = text.replace("’", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s
def normalize_text(c: Cut):
def remove_punc_to_upper(text: str) -> str:
text = text.replace("‘", "'")
text = text.replace("’", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s

text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].text = text
return c


def get_encoder_embed(params: AttributeDict) -> nn.Module:
Expand Down Expand Up @@ -1309,11 +1314,6 @@ def run(rank, world_size, args):
else:
train_cuts = libritts.train_clean_100_cuts()

def normalize_text(c: Cut):
text = remove_punc_to_upper(c.supervisions[0].text)
c.supervisions[0].text = text
return c

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
Expand Down Expand Up @@ -1365,8 +1365,8 @@ def remove_short_and_long_utt(c: Cut):
train_cuts, sampler_state_dict=sampler_state_dict
)

valid_cuts = libritts.dev_clean_cuts()
valid_cuts += libritts.dev_other_cuts()
valid_cuts = libritts.dev_clean_cuts().map(normalize_text)
valid_cuts += libritts.dev_other_cuts().map(normalize_text)
valid_dl = libritts.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
Expand Down

0 comments on commit 156af46

Please sign in to comment.