diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index bcd419fb7a..ab46e233be 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -158,7 +158,7 @@ def forward( if not is_jit_tracing(): assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths, x.size(0)) if self.dynamic_chunk_training: assert ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 5b75b8d352..cbde2a2e4d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -281,7 +281,7 @@ def forward( lengths = (x_lens - 7) >> 1 assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) + mask = make_pad_mask(lengths, x.size(0)) outputs = [] feature_masks = self.get_feature_masks(x)