Skip to content

Commit

Permalink
patch for pytest to run with mlx
Browse files Browse the repository at this point in the history
  • Loading branch information
acsweet committed Jan 29, 2025
1 parent 179ebeb commit dea0fca
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 1 deletion.
3 changes: 2 additions & 1 deletion keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
distribution_lib = None
elif backend() == "mlx":
from keras.src.backend.mlx import * # noqa: F403

from keras.src.backend.mlx.core import Variable as BackendVariable

distribution_lib = None
else:
raise ValueError(f"Unable to import backend : {backend()}")
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/mlx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""MLX backend APIs."""

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.mlx import core
from keras.src.backend.mlx import image
from keras.src.backend.mlx import linalg
Expand Down
10 changes: 10 additions & 0 deletions keras/src/backend/mlx/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class MlxExportArchive:
def track(self, resource):
raise NotImplementedError(
"`track` is not implemented in the mlx backend."
)

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
raise NotImplementedError(
"`add_endpoint` is not implemented in the mlx backend."
)
5 changes: 5 additions & 0 deletions keras/src/backend/mlx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def gelu_tanh_approx(x):
return f(x)


def celu(x, alpha=1.0):
x = convert_to_tensor(x)
return nn.celu(x, alpha=alpha)


def softmax(x, axis=-1):
x = convert_to_tensor(x)
return mx.softmax(x, axis=axis)
Expand Down
4 changes: 4 additions & 0 deletions keras/src/export/saved_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from keras.src.backend.openvino.export import (
OpenvinoExportArchive as BackendExportArchive,
)
elif backend.backend() == "mlx":
from keras.src.backend.mlx.export import (
MlxExportArchive as BackendExportArchive,
)
else:
raise RuntimeError(
f"Backend '{backend.backend()}' must implement a layer mixin class."
Expand Down

0 comments on commit dea0fca

Please sign in to comment.