Skip to content

Commit

Permalink
fix: get feat_lens for webdataset
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaschhh committed Jun 11, 2024
1 parent d93e6e5 commit 9528bb6
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def forward(
inputs: Dict[str, tuple[torch.Tensor, torch.Tensor]],
) -> Dict[str, tuple[torch.Tensor, torch.Tensor]]:
x = inputs[self.sample_key] # x.shape: B, T, D
attn_key_mask = self._get_attn_key_mask(inputs["feats_len"])
feat_lens = [log_mel_spec.shape[-1] // 4 for log_mel_spec in inputs["audio"]]
attn_key_mask = self._get_attn_key_mask(feat_lens, inputs["attention_mask"][-1])
# x.shape: B, T, D
x = self.project(x.transpose(1, 2)) # x.shape: B, D, T
x = self.subsampler(x) # x.shape: B, D, T/4
Expand All @@ -211,7 +212,8 @@ def forward(

def _get_attn_key_mask(
self,
lengths: torch.Tensor,
lengths: list[int],
device: str,
):
return (
torch.nn.utils.rnn.pad_sequence(
Expand All @@ -221,4 +223,4 @@ def _get_attn_key_mask(
)
.transpose(1, 2)[:-1]
.unsqueeze_(1)
).to(lengths.device)
).to(device)

0 comments on commit 9528bb6

Please sign in to comment.