Skip to content

Commit

Permalink
Add signbit and fix argmin and argmax (#20821)
Browse files Browse the repository at this point in the history
* Add `signbit` op and fix `argmin` and `argmax`.

* Add APIs

* Fix CI

* Fix torch CI

* Simplify the logic

* Fix TF GPU CI
  • Loading branch information
james77777778 authored Jan 30, 2025
1 parent af0c2d2 commit 738c313
Show file tree
Hide file tree
Showing 16 changed files with 212 additions and 83 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
from keras.src.ops.numpy import searchsorted
from keras.src.ops.numpy import select
from keras.src.ops.numpy import sign
from keras.src.ops.numpy import signbit
from keras.src.ops.numpy import sin
from keras.src.ops.numpy import sinh
from keras.src.ops.numpy import size
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from keras.src.ops.numpy import round
from keras.src.ops.numpy import select
from keras.src.ops.numpy import sign
from keras.src.ops.numpy import signbit
from keras.src.ops.numpy import sin
from keras.src.ops.numpy import sinh
from keras.src.ops.numpy import size
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@
from keras.src.ops.numpy import searchsorted
from keras.src.ops.numpy import select
from keras.src.ops.numpy import sign
from keras.src.ops.numpy import signbit
from keras.src.ops.numpy import sin
from keras.src.ops.numpy import sinh
from keras.src.ops.numpy import size
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from keras.src.ops.numpy import round
from keras.src.ops.numpy import select
from keras.src.ops.numpy import sign
from keras.src.ops.numpy import signbit
from keras.src.ops.numpy import sin
from keras.src.ops.numpy import sinh
from keras.src.ops.numpy import size
Expand Down
54 changes: 33 additions & 21 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,31 +353,37 @@ def arctanh(x):


def argmax(x, axis=None, keepdims=False):
if x.ndim == 0:
from keras.src.testing.test_case import uses_cpu

x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or not uses_cpu() or x.ndim == 0:
return jnp.argmax(x, axis=axis, keepdims=keepdims)
x_float = x.astype(jnp.float32)
is_negative_zero = (x_float == 0.0) & jnp.signbit(x_float)
x_adjusted = jnp.where(
is_negative_zero, -jnp.finfo(x_float.dtype).tiny, x_float
)
return jnp.argmax(x_adjusted, axis=axis, keepdims=keepdims)

# Fix the flush-to-zero (FTZ) issue based on this issue:
# https://github.com/jax-ml/jax/issues/24280
dtype = dtypes.result_type(dtype, "float32")
x = cast(x, dtype)
is_negative_zero = (x == 0.0) & jnp.signbit(x)
x = jnp.where(is_negative_zero, -jnp.finfo(x.dtype).tiny, x)
return jnp.argmax(x, axis=axis, keepdims=keepdims)


def argmin(x, axis=None, keepdims=False):
x_64 = jnp.asarray(x, dtype=jnp.float64)
if axis is not None:
min_mask = x_64 == jnp.min(x_64, axis=axis, keepdims=True)
indices = jnp.argmin(
jnp.where(min_mask, x_64, jnp.inf), axis=axis, keepdims=keepdims
).astype("int32")
else:
min_mask = (x_64 < x_64.min()) | (
(x_64 == x_64.min()) & (jnp.signbit(x_64))
)
indices = jnp.argmin(
jnp.where(min_mask, x_64, jnp.inf), axis=axis, keepdims=keepdims
).astype("int32")
return indices
from keras.src.testing.test_case import uses_cpu

x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or not uses_cpu() or x.ndim == 0:
return jnp.argmin(x, axis=axis, keepdims=keepdims)

# Fix the flush-to-zero (FTZ) issue based on this issue:
# https://github.com/jax-ml/jax/issues/24280
dtype = dtypes.result_type(dtype, "float32")
x = cast(x, dtype)
is_negative_zero = (x == 0.0) & jnp.signbit(x)
x = jnp.where(is_negative_zero, -jnp.finfo(x.dtype).tiny, x)
return jnp.argmin(x, axis=axis, keepdims=keepdims)


def argsort(x, axis=-1):
Expand Down Expand Up @@ -996,6 +1002,12 @@ def sign(x):
return jnp.sign(x)


@sparse.elementwise_unary(linear=False)
def signbit(x):
x = convert_to_tensor(x)
return jnp.signbit(x)


@sparse.elementwise_unary(linear=False)
def sin(x):
x = convert_to_tensor(x)
Expand Down
45 changes: 24 additions & 21 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,32 +245,31 @@ def arctanh(x):


def argmax(x, axis=None, keepdims=False):
if x.ndim == 0:
x = convert_to_tensor(x)
axis = standardize_axis_for_numpy(axis)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or x.ndim == 0:
return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32")
x_float = x.astype(np.float32)
is_negative_zero = (x_float == 0.0) & np.signbit(x_float)
x_adjusted = np.where(
is_negative_zero, -np.finfo(x_float.dtype).tiny, x_float
)
return np.argmax(x_adjusted, axis=axis, keepdims=keepdims).astype("int32")

dtype = dtypes.result_type(dtype, "float32")
x = x.astype(dtype)
is_negative_zero = (x == 0.0) & np.signbit(x)
x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x)
return np.argmax(x, axis=axis, keepdims=keepdims).astype("int32")


def argmin(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
axis = standardize_axis_for_numpy(axis)
x_64 = np.asarray(x, dtype=np.float64)
if axis is not None:
min_mask = x_64 == np.min(x_64, axis=axis, keepdims=True)
indices = np.argmin(
np.where(min_mask, x_64, np.inf), axis=axis, keepdims=keepdims
).astype("int32")
else:
min_mask = (x_64 < x_64.min()) | (
(x_64 == x_64.min()) & (np.signbit(x_64))
)
indices = np.argmin(
np.where(min_mask, x_64, np.inf), axis=axis, keepdims=keepdims
).astype("int32")
return indices
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or x.ndim == 0:
return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32")

dtype = dtypes.result_type(dtype, "float32")
x = x.astype(dtype)
is_negative_zero = (x == 0.0) & np.signbit(x)
x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x)
return np.argmin(x, axis=axis, keepdims=keepdims).astype("int32")


def argsort(x, axis=-1):
Expand Down Expand Up @@ -907,6 +906,10 @@ def sign(x):
return np.sign(x)


def signbit(x):
return np.signbit(x)


def sin(x):
x = convert_to_tensor(x)
if standardize_dtype(x.dtype) == "int64":
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,12 @@ def sign(x):
return OpenVINOKerasTensor(ov_opset.sign(x).output(0))


def signbit(x):
raise NotImplementedError(
"`signbit` is not supported with openvino backend"
)


def sin(x):
x = get_ov_output(x)
x_type = x.get_element_type()
Expand Down
95 changes: 64 additions & 31 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,43 +837,56 @@ def _keepdims(x, y, axis):


def argmax(x, axis=None, keepdims=False):
x_float = tf.cast(x, tf.float32)
is_negative_zero = tf.logical_and(
tf.equal(x_float, 0.0),
tf.less(
tf.bitwise.bitwise_and(
tf.bitcast(x_float, tf.int32),
# tf.float32 sign bit
tf.constant(0x80000000, dtype=tf.int32),
),
0,
),
)
non_zero_mask = tf.not_equal(x_float, 0.0)
masked_abs = (
tf.abs(x_float)
+ (1.0 - tf.cast(non_zero_mask, tf.float32)) * tf.float32.max
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or x.ndim == 0:
_x = x
if axis is None:
x = tf.reshape(x, [-1])
y = tf.argmax(x, axis=axis, output_type="int32")
if keepdims:
y = _keepdims(_x, y, axis)
return y

# Fix the flush-to-zero (FTZ) issue based on this issue:
# https://github.com/jax-ml/jax/issues/24280
dtype = dtypes.result_type(dtype, "float32")
x = cast(x, dtype)
is_negative_zero = tf.logical_and(tf.equal(x, 0.0), signbit(x))
x = tf.where(
is_negative_zero, -np.finfo(standardize_dtype(x.dtype)).tiny, x
)
min_non_zero = tf.reduce_min(masked_abs) - 1e-9
x_adjusted = tf.where(is_negative_zero, -min_non_zero, x_float)
_x = x
if axis is None:
x_adjusted = tf.reshape(x_adjusted, [-1])
y = tf.argmax(x_adjusted, axis=0, output_type=tf.int32)
if keepdims:
y = tf.reshape(y, [1, 1])
else:
rank = tf.rank(x_adjusted)
axis_tensor = tf.convert_to_tensor(axis, dtype=tf.int32)
positive_axis = tf.cond(
axis_tensor < 0, lambda: axis_tensor + rank, lambda: axis_tensor
)
y = tf.argmax(x_adjusted, axis=positive_axis, output_type=tf.int32)
if keepdims:
y = tf.expand_dims(y, axis=positive_axis)
x = tf.reshape(x, [-1])
y = tf.argmax(x, axis=axis, output_type="int32")
if keepdims:
y = _keepdims(_x, y, axis)
return y


def argmin(x, axis=None, keepdims=False):
from keras.src.testing.test_case import uses_cpu

x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if "float" not in dtype or not uses_cpu() or x.ndim == 0:
_x = x
if axis is None:
x = tf.reshape(x, [-1])
y = tf.argmin(x, axis=axis, output_type="int32")
if keepdims:
y = _keepdims(_x, y, axis)
return y

# Fix the flush-to-zero (FTZ) issue based on this issue:
# https://github.com/jax-ml/jax/issues/24280
dtype = dtypes.result_type(dtype, "float32")
x = cast(x, dtype)
is_negative_zero = tf.logical_and(tf.equal(x, 0.0), signbit(x))
x = tf.where(
is_negative_zero, -np.finfo(standardize_dtype(x.dtype)).tiny, x
)
_x = x
if axis is None:
x = tf.reshape(x, [-1])
Expand Down Expand Up @@ -2027,6 +2040,26 @@ def sign(x):
return tf.sign(x)


@sparse.elementwise_unary
def signbit(x):
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "bool":
return tf.fill(tf.shape(x), False)
elif "int" in ori_dtype:
return x < 0
else:
x = cast(x, "float32")
return tf.less(
tf.bitwise.bitwise_and(
tf.bitcast(x, tf.int32),
# tf.float32 sign bit
tf.constant(0x80000000, dtype=tf.int32),
),
0,
)


@sparse.elementwise_unary
def sin(x):
x = convert_to_tensor(x)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,11 @@ def sign(x):
return torch.sign(x)


def signbit(x):
x = convert_to_tensor(x)
return torch.signbit(x)


def sin(x):
x = convert_to_tensor(x)
return torch.sin(x)
Expand Down
27 changes: 27 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5067,6 +5067,33 @@ def sign(x):
return backend.numpy.sign(x)


class Signbit(Operation):
def call(self, x):
return backend.numpy.signbit(x)

def compute_output_spec(self, x):
sparse = getattr(x, "sparse", False)
return KerasTensor(x.shape, dtype="bool", sparse=sparse)


@keras_export(["keras.ops.signbit", "keras.ops.numpy.signbit"])
def signbit(x):
"""Return the sign bit of the elements of `x`.
The output boolean tensor contains `True` where the sign of `x` is negative,
and `False` otherwise.
Args:
x: Input tensor.
Returns:
Output boolean tensor of same shape as `x`.
"""
if any_symbolic_tensors((x,)):
return Signbit().symbolic_call(x)
return backend.numpy.signbit(x)


class Sin(Operation):
def call(self, x):
return backend.numpy.sin(x)
Expand Down
30 changes: 30 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,10 @@ def test_sign(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.sign(x).shape, (None, 3))

def test_signbit(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.signbit(x).shape, (None, 3))

def test_sin(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.sin(x).shape, (None, 3))
Expand Down Expand Up @@ -2161,6 +2165,10 @@ def test_sign(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.sign(x).shape, (2, 3))

def test_signbit(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.signbit(x).shape, (2, 3))

def test_sin(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.sin(x).shape, (2, 3))
Expand Down Expand Up @@ -4323,6 +4331,11 @@ def test_sign(self):
self.assertAllClose(knp.sign(x), np.sign(x))
self.assertAllClose(knp.Sign()(x), np.sign(x))

def test_signbit(self):
x = np.array([[0.0, -0.0, -1.1e-45], [1.1e-38, 2, -1]])
self.assertAllClose(knp.signbit(x), np.signbit(x))
self.assertAllClose(knp.Signbit()(x), np.signbit(x))

def test_sin(self):
x = np.array([[1, -2, 3], [-3, 2, -1]])
self.assertAllClose(knp.sin(x), np.sin(x))
Expand Down Expand Up @@ -8039,6 +8052,23 @@ def test_sign(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_signbit(self, dtype):
import jax.numpy as jnp

x = knp.ones((), dtype=dtype)
x_jax = jnp.ones((), dtype=dtype)
expected_dtype = standardize_dtype(jnp.signbit(x_jax).dtype)

self.assertEqual(
standardize_dtype(knp.signbit(x).dtype),
expected_dtype,
)
self.assertEqual(
standardize_dtype(knp.Signbit().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_sin(self, dtype):
import jax.numpy as jnp
Expand Down
Loading

0 comments on commit 738c313

Please sign in to comment.