From 2b3c5d799f3a585dc22071a9148424ff77aefd47 Mon Sep 17 00:00:00 2001 From: Wen Ding Date: Wed, 11 Oct 2023 16:58:00 +0800 Subject: [PATCH] Fix padding issues (#1303) --- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bcd419fb7a..ab46e233be 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -158,7 +158,7 @@ def forward( if not is_jit_tracing(): assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths, x.size(0)) if self.dynamic_chunk_training: assert ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5b75b8d352..cbde2a2e4d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -281,7 +281,7 @@ def forward( lengths = (x_lens - 7) >> 1 assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) + mask = make_pad_mask(lengths, x.size(0)) outputs = [] feature_masks = self.get_feature_masks(x)