diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer.py b/keras_hub/src/layers/preprocessing/multi_segment_packer.py index a4e0ba7ef4..e6a963c0a2 100644 --- a/keras_hub/src/layers/preprocessing/multi_segment_packer.py +++ b/keras_hub/src/layers/preprocessing/multi_segment_packer.py @@ -292,7 +292,7 @@ def call( ) # Pad to dense tensor output. sequence_length = sequence_length or self.sequence_length - shape = tf.cast([-1, sequence_length], "int64") + shape = [-1, sequence_length] token_ids = pad( token_ids, shape=shape, diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py index 7ffba1dd7e..bb156b8063 100644 --- a/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py +++ b/keras_hub/src/layers/preprocessing/multi_segment_packer_test.py @@ -235,9 +235,9 @@ def test_pad_inputs(self): token_ids, segment_ids = packer((seq1, seq2)) self.assertAllEqual( token_ids, - ["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"], + ["[CLS]", "a", "[SEP]", "x", "[SEP]", "[PAD]"], ) - self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1]) + self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0]) def test_pad_batched_inputs(self): # right padding @@ -277,15 +277,15 @@ def test_pad_batched_inputs(self): self.assertAllEqual( token_ids, [ - ["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"], - ["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"], + ["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]", "[PAD]"], + ["[CLS]", "a", "[SEP]", "x", "y", "[SEP]", "[PAD]"], ], ) self.assertAllEqual( segment_ids, [ - [0, 0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 1, 1, 1, 0], ], ) diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py index efe10a4585..2fe65357ab 100644 --- a/keras_hub/src/layers/preprocessing/start_end_packer.py +++ b/keras_hub/src/layers/preprocessing/start_end_packer.py @@ -141,7 +141,7 @@ def check_special_value_type(value, value_name): self.start_value = start_value self.end_value = end_value - self.pad_value = pad_value + self.pad_value = pad_value or 0 self.return_padding_mask = return_padding_mask self.padding_side = padding_side diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py index 78f65405f0..44bdf36522 100644 --- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py +++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py @@ -17,7 +17,7 @@ def test_dense_input(self): sequence_length=5, padding_side="left" ) output = start_end_packer(input_data) - expected_output = [0, 0, 5, 6, 7] + expected_output = [5, 6, 7, 0, 0] self.assertAllEqual(output, expected_output) def test_bfloat16_dtype(self): @@ -40,7 +40,7 @@ def test_dense_2D_input(self): sequence_length=5, padding_side="left" ) output = start_end_packer(input_data) - expected_output = [[0, 0, 5, 6, 7]] + expected_output = [[5, 6, 7, 0, 0]] self.assertAllEqual(output, expected_output) def test_ragged_input(self): @@ -55,7 +55,7 @@ def test_ragged_input(self): sequence_length=5, padding_side="left" ) output = start_end_packer(input_data) - expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]] + expected_output = [[0, 5, 6, 7, 0], [8, 9, 10, 11, 0]] self.assertAllEqual(output, expected_output) def test_start_end_token(self): @@ -119,7 +119,7 @@ def test_start_end_padding_value(self): padding_side="left", ) output = start_end_packer(input_data) - expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]] + expected_output = [[3, 1, 5, 6, 7, 2, 3], [1, 8, 9, 10, 11, 2, 3]] self.assertAllEqual(output, expected_output) def test_truncation(self): diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index 47305a3f01..4dee4a7d02 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -21,17 +21,24 @@ NO_CONVERT_COUNTER = threading.local() -def pad(x, shape, padding_side, pad_value): - if padding_side == "left": +def pad(x, shape, padding_side, pad_value, axis=-1): + if padding_side == "left" and pad_value is not None: x = x[..., ::-1] - - outputs = x.to_tensor( - default_value=pad_value, - shape=shape, - ) - - if padding_side == "left": + outputs = x.to_tensor( + default_value=pad_value, + ) outputs = outputs[..., ::-1] + padding_shape = [tf.shape(outputs)[0]] + [1] * (len(outputs.shape) - 1) + padding_shape[axis] = shape[axis] - tf.shape(outputs)[axis] + padding_shape = tf.cast(padding_shape, "int64") + padding = tf.fill(padding_shape, pad_value) + padding = tf.cast(padding, outputs.dtype) + outputs = tf.concat([outputs, padding], axis=axis) + else: + outputs = x.to_tensor( + default_value=pad_value, + shape=tf.cast(shape, "int64"), + ) return outputs