Skip to content

mlx - test updates and signbit implementation #21180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.10]
python-version: ['3.10']
backend: [tensorflow, jax, torch, numpy, openvino, mlx]
name: Run tests
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from keras.src.backend.common.remat import remat
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.exports import Variable
from keras.src.backend.exports import device
from keras.src.backend.exports import name_scope
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.initializers.initializer import Initializer
Expand Down
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from keras.src.backend.common.remat import remat
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.exports import Variable
from keras.src.backend.exports import device
from keras.src.backend.exports import name_scope
from keras.src.dtype_policies.dtype_policy import DTypePolicy
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
from keras.src.initializers.initializer import Initializer
Expand Down
28 changes: 28 additions & 0 deletions keras/src/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,11 @@ def test_add(self, dtypes):
x2_jax = jnp.ones((1,), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)

if backend.backend() == "mlx":
if expected_dtype == "complex128":
# mlx backend does not support complex128
expected_dtype = "complex64"

self.assertDType(x1 + x2, expected_dtype)
self.assertDType(x1.__radd__(x2), expected_dtype)

Expand All @@ -959,6 +964,11 @@ def test_sub(self, dtypes):
x2_jax = jnp.ones((1,), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)

if backend.backend() == "mlx":
if expected_dtype == "complex128":
# mlx backend does not support complex128
expected_dtype = "complex64"

self.assertDType(x1 - x2, expected_dtype)
self.assertDType(x1.__rsub__(x2), expected_dtype)

Expand All @@ -975,6 +985,11 @@ def test_mul(self, dtypes):
x2_jax = jnp.ones((1,), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)

if backend.backend() == "mlx":
if expected_dtype == "complex128":
# mlx backend does not have complex128
expected_dtype = "complex64"

self.assertDType(x1 * x2, expected_dtype)
self.assertDType(x1.__rmul__(x2), expected_dtype)

Expand Down Expand Up @@ -1059,13 +1074,26 @@ def test_pow(self, dtypes):
x2_jax = jnp.ones((1,), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.power(x1_jax, x2_jax).dtype)

if backend.backend() == "mlx":
if expected_dtype == "complex128":
# mlx backend does not support complex128
expected_dtype = "complex64"

self.assertDType(x1**x2, expected_dtype)
self.assertDType(x1.__rpow__(x2), expected_dtype)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
def test_matmul(self, dtypes):
if backend.backend() == "mlx":
result_dtype = backend.result_type(*dtypes)
if "float" not in result_dtype:
self.skipTest(
"mlx backend only supports matmul for real floating point "
"types"
)

import jax.numpy as jnp

dtype1, dtype2 = dtypes
Expand Down
14 changes: 10 additions & 4 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import warnings

import ml_dtypes
import mlx.core as mx
import numpy as np

Expand Down Expand Up @@ -97,10 +98,14 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if x.dtype == np.float64:
# mlx backend does not support float64
x = x.astype(np.float32)
if standardize_dtype(x.dtype) == "bfloat16" and mlx_dtype is None:
if standardize_dtype(x.dtype) == "bfloat16":
# mlx currently fails to load a numpy array with dtype=bfloat16
# upcast to float32 to avoid error
x = x.astype(np.float32)
# if a bfloat16 np.ndarray is passed to mx.array with dtype=None
# it casts the output to complex64, so we force cast to bfloat16
mlx_dtype = mx.bfloat16
# (but by upcasting we avoid x.dtype=bfloat16 and mlx_dtype=None)
mlx_dtype = mx.bfloat16 if mlx_dtype is None else mlx_dtype
return mx.array(x, dtype=mlx_dtype)

if isinstance(x, list):
Expand Down Expand Up @@ -154,8 +159,9 @@ def convert_to_tensors(*xs):
def convert_to_numpy(x):
# Performs a copy. If we want 0-copy we can pass copy=False
if isinstance(x, mx.array) and x.dtype == mx.bfloat16:
# mlx currently has an error passing bloat16 array to numpy
return np.array(x.astype(mx.float32))
# mlx currently has an error passing bfloat16 array to numpy
# upcast to float32 then downcast to bfloat16
return np.array(x.astype(mx.float32)).astype(ml_dtypes.bfloat16)
return np.array(x)


Expand Down
32 changes: 32 additions & 0 deletions keras/src/backend/mlx/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,3 +655,35 @@ def _compute_weight_mat(
weights,
0,
)


def elastic_transform(
images,
alpha=20.0,
sigma=5.0,
interpolation="bilinear",
fill_mode="reflect",
fill_value=0.0,
seed=None,
data_format=None,
):
raise NotImplementedError("elastic_transform not yet implemented in mlx.")


def perspective_transform(
images,
start_points,
end_points,
interpolation="bilinear",
fill_value=0,
data_format=None,
):
raise NotImplementedError(
"perspective_transform not yet implemented in mlx."
)


def gaussian_blur(
images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None
):
raise NotImplementedError("gaussian_blur not yet implemented in mlx.")
9 changes: 8 additions & 1 deletion keras/src/backend/mlx/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,24 @@ def segment_max(data, segment_ids, num_segments=None, sorted=False):


def top_k(x, k, sorted=True):
# default to sorted=True to match other backends
x = convert_to_tensor(x)
indices = mx.argpartition(mx.negative(x), k, axis=-1)[..., :k]
values = mx.take_along_axis(x, indices, axis=-1)

if sorted:
sort_indices = mx.argsort(mx.negative(values), axis=-1)
values = mx.take_along_axis(values, sort_indices, axis=-1)
indices = mx.take_along_axis(indices, sort_indices, axis=-1)

return values, indices


def in_top_k(targets, predictions, k):
targets = convert_to_tensor(targets)
predictions = convert_to_tensor(predictions)
targets = targets[..., None]
topk_values = top_k(predictions, k)[0]
topk_values = top_k(predictions, k, sorted=False)[0]
targets_values = mx.take_along_axis(predictions, targets, axis=-1)
mask = targets_values >= topk_values
return mx.any(mask, axis=-1)
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/mlx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,7 @@ def dot_product_attention(
scale=None,
is_causal=False,
flash_attention=None,
attn_logits_soft_cap=None,
):
if flash_attention is None:
flash_attention = False
Expand Down
64 changes: 60 additions & 4 deletions keras/src/backend/mlx/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src.backend.mlx.core import cast
from keras.src.backend.mlx.core import convert_to_tensor
from keras.src.backend.mlx.core import convert_to_tensors
from keras.src.backend.mlx.core import is_tensor
from keras.src.backend.mlx.core import slice
from keras.src.backend.mlx.core import to_mlx_dtype

Expand Down Expand Up @@ -272,8 +273,20 @@ def bitwise_xor(x, y):

def bitwise_left_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return mx.left_shift(x, y)
if not isinstance(y, int):
y = convert_to_tensor(y)

# handle result dtype to match other backends
types = [x.dtype]
if is_tensor(y):
types.append(y.dtype)
result_dtype = result_type(*types)
mlx_result_dtype = to_mlx_dtype(result_dtype)

result = mx.left_shift(x, y)
if result.dtype != mlx_result_dtype:
return result.astype(mlx_result_dtype)
return result


def left_shift(x, y):
Expand All @@ -282,8 +295,20 @@ def left_shift(x, y):

def bitwise_right_shift(x, y):
x = convert_to_tensor(x)
y = convert_to_tensor(y)
return mx.right_shift(x, y)
if not isinstance(y, int):
y = convert_to_tensor(y)

# handle result dtype to match other backends
types = [x.dtype]
if is_tensor(y):
types.append(y.dtype)
result_dtype = result_type(*types)
mlx_result_dtype = to_mlx_dtype(result_dtype)

result = mx.right_shift(x, y)
if result.dtype != mlx_result_dtype:
return result.astype(mlx_result_dtype)
return result


def right_shift(x, y):
Expand Down Expand Up @@ -1567,3 +1592,34 @@ def rot90(array, k=1, axes=(0, 1)):
array = array[tuple(slices)]

return array


def signbit(x):
x = convert_to_tensor(x)

if x.dtype in (
mx.float16,
mx.float32,
mx.float64,
mx.bfloat16,
mx.complex64,
):
if x.dtype == mx.complex64:
# check sign of real part for complex numbers
real_part = mx.real(x)
return signbit(real_part)
zeros = x == 0
# this works because in mlx 1/0=inf and 1/-0=-inf
neg_zeros = (1 / x == mx.array(float("-inf"))) & zeros
return mx.where(zeros, neg_zeros, x < 0)
elif x.dtype in (mx.uint8, mx.uint16, mx.uint32, mx.uint64):
# unsigned integers never negative
return mx.zeros_like(x).astype(mx.bool_)
elif x.dtype in (mx.int8, mx.int16, mx.int32, mx.int64):
# for integers, simple negative check
return x < 0
elif x.dtype == mx.bool_:
# for boolean array, return false
return mx.zeros_like(x).astype(mx.bool_)
else:
raise ValueError(f"Unsupported dtype in `signbit`: {x.dtype}")
3 changes: 2 additions & 1 deletion keras/src/initializers/constant_initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def test_stft_initializer(self):
args = -2 * pi * time_range * freq_range

tol_kwargs = {}
if backend.backend() == "jax":
if backend.backend() == "jax" or backend.backend() == "mlx":
# TODO(mostafa-mahmoud): investigate the cases
# of non-small error in jax and torch
# for mlx, minor precision differences with float64 on linux
tol_kwargs = {"atol": 1e-4, "rtol": 1e-6}

initializer = initializers.STFT("real", None)
Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/activations/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, activation, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
self.activation = activations.get(activation)
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/activations/elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, alpha=1.0, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.supports_masking = True
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/core/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class Identity(Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.supports_masking = True
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/pooling/base_global_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def __init__(
self.data_format = backend.standardize_data_format(data_format)
self.keepdims = keepdims
self.input_spec = InputSpec(ndim=pool_dimensions + 2)
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/pooling/base_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
self.data_format = backend.standardize_data_format(data_format)

self.input_spec = InputSpec(ndim=pool_dimensions + 2)
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/regularization/activity_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(self, l1=0.0, l2=0.0, **kwargs):
self.supports_masking = True
self.l1 = l1
self.l2 = l2
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/regularization/alpha_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/regularization/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/regularization/gaussian_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, rate, seed=None, **kwargs):
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

self._build_at_init()

Expand Down
1 change: 0 additions & 1 deletion keras/src/layers/regularization/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self, stddev, seed=None, **kwargs):
if stddev > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

self._build_at_init()

Expand Down
Loading
Loading