Skip to content

Commit

Permalink
Make _GetSourceAndQuerySegmentIds return the correct shaped zero tens…
Browse files Browse the repository at this point in the history
…or of query_segment_id when inputs.query_vec has different shape.

PiperOrigin-RevId: 718045538
  • Loading branch information
lingvo-bot authored and copybara-github committed Jan 21, 2025
1 parent a4fbd33 commit 03fd809
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ def _GetSourceAndQuerySegmentIds(
Returns:
A tuple of 2 elements.
- The source segment id tensor: [time, batch_size].
- The query segment id tensor: [batch_size].
- The source segment id tensor: [time, source_batch].
- The query segment id tensor: [target_batch].
"""
p = self.params
if p.packed_input:
Expand All @@ -518,9 +518,13 @@ def _GetSourceAndQuerySegmentIds(
' a default all-zero value instead, packed_input will be'
' ineffective.'
)
if source_padding is not None:
if source_padding is not None and inputs.query_vec is not None:
# query_vec.shape could be different from [target_batch, query_dim]
# because of potential reshape,e.g. reshaped to
# [1, target_batch/source_batch, source_batch, hidden_dims].
target_batch = inputs.query_vec.shape.num_elements() // p.hidden_dim
query_segment_id = tf.zeros(
tf.shape(inputs.query_vec)[0], dtype=source_padding.dtype
[target_batch], dtype=source_padding.dtype
)
else:
query_segment_id = None
Expand Down

0 comments on commit 03fd809

Please sign in to comment.