Skip to content

fix left-padding #2278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_hub/src/layers/preprocessing/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def call(
)
# Pad to dense tensor output.
sequence_length = sequence_length or self.sequence_length
shape = tf.cast([-1, sequence_length], "int64")
shape = [-1, sequence_length]
token_ids = pad(
token_ids,
shape=shape,
Expand Down
12 changes: 6 additions & 6 deletions keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def test_pad_inputs(self):
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
["[CLS]", "a", "[SEP]", "x", "[SEP]", "[PAD]"],
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1])
self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0])

def test_pad_batched_inputs(self):
# right padding
Expand Down Expand Up @@ -277,15 +277,15 @@ def test_pad_batched_inputs(self):
self.assertAllEqual(
token_ids,
[
["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"],
["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]", "[PAD]"],
["[CLS]", "a", "[SEP]", "x", "y", "[SEP]", "[PAD]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
],
)

Expand Down
2 changes: 1 addition & 1 deletion keras_hub/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def check_special_value_type(value, value_name):
self.start_value = start_value
self.end_value = end_value

self.pad_value = pad_value
self.pad_value = pad_value or 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have this line for start end packer but not multi-segment packer? what's the difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have this line for start end packer but not multi-segment packer? what's the difference?

If you delete the test, an error will be reported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! But that's not the question. We want the implementation of start end packer and multi segment packer to look similar where we can. Having "none" pad value be different between these layers could lead to subtle bugs for end users. Is there a technical reason why we need this line for StartEndPacker and not MultiSegmentPacker? If this is just for tests, let's rework the tests. Let's try to keep the layers working roughly the same.

self.return_padding_mask = return_padding_mask
self.padding_side = padding_side

Expand Down
8 changes: 4 additions & 4 deletions keras_hub/src/layers/preprocessing/start_end_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_dense_input(self):
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [0, 0, 5, 6, 7]
expected_output = [5, 6, 7, 0, 0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. Why is left padding the same as right padding in this case? The test case before looks correct, this looks wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense. Why is left padding the same as right padding in this case? The test case before looks correct, this looks wrong.

No, it was obviously wrong before.
Because left padding is for casualLM. If expected_output = [0, 0, 5, 6, 7], how to generate output?

self.assertAllEqual(output, expected_output)

def test_bfloat16_dtype(self):
Expand All @@ -40,7 +40,7 @@ def test_dense_2D_input(self):
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7]]
expected_output = [[5, 6, 7, 0, 0]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto, we are showing a lot of right hand side padding for the left padding option. I think we have introduced a bug.

self.assertAllEqual(output, expected_output)

def test_ragged_input(self):
Expand All @@ -55,7 +55,7 @@ def test_ragged_input(self):
sequence_length=5, padding_side="left"
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]]
expected_output = [[0, 5, 6, 7, 0], [8, 9, 10, 11, 0]]
self.assertAllEqual(output, expected_output)

def test_start_end_token(self):
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_start_end_padding_value(self):
padding_side="left",
)
output = start_end_packer(input_data)
expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
expected_output = [[3, 1, 5, 6, 7, 2, 3], [1, 8, 9, 10, 11, 2, 3]]
self.assertAllEqual(output, expected_output)

def test_truncation(self):
Expand Down
25 changes: 16 additions & 9 deletions keras_hub/src/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,24 @@
NO_CONVERT_COUNTER = threading.local()


def pad(x, shape, padding_side, pad_value):
if padding_side == "left":
def pad(x, shape, padding_side, pad_value, axis=-1):
if padding_side == "left" and pad_value is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand what we are trying to do here, but I think it is buggy, we should go back to the reverse and to_tensor approach and avoid this manual padding.

x = x[..., ::-1]

outputs = x.to_tensor(
default_value=pad_value,
shape=shape,
)

if padding_side == "left":
outputs = x.to_tensor(
default_value=pad_value,
)
outputs = outputs[..., ::-1]
padding_shape = [tf.shape(outputs)[0]] + [1] * (len(outputs.shape) - 1)
padding_shape[axis] = shape[axis] - tf.shape(outputs)[axis]
padding_shape = tf.cast(padding_shape, "int64")
padding = tf.fill(padding_shape, pad_value)
padding = tf.cast(padding, outputs.dtype)
outputs = tf.concat([outputs, padding], axis=axis)
else:
outputs = x.to_tensor(
default_value=pad_value,
shape=tf.cast(shape, "int64"),
)
return outputs


Expand Down
Loading