Skip to content

Commit

Permalink
fix error with loading models with h5 and core while_loop
Browse files Browse the repository at this point in the history
  • Loading branch information
acsweet committed Jan 28, 2025
1 parent 2bc4baf commit 315cd25
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions keras/src/backend/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope

try:
import h5py
except ImportError:
h5py = None

SUPPORTS_SPARSE_TENSORS = False

MLX_DTYPES = {
Expand Down Expand Up @@ -55,6 +60,13 @@ def __array__(self, dtype=None):
return value


def _is_h5py_dataset(obj):
return (
type(obj).__module__.startswith("h5py.")
and type(obj).__name__ == "Dataset"
)


def convert_to_tensor(x, dtype=None, sparse=None):
if sparse:
raise ValueError("`sparse=True` is not supported with mlx backend")
Expand Down Expand Up @@ -89,6 +101,14 @@ def to_scalar_list(x):

return mx.array(to_scalar_list(x), dtype=mlx_dtype)

if _is_h5py_dataset(x):
if h5py is None:
raise ImportError(
"h5py must be installed in order to load a model."
)
# load h5py._hl.dataset.Dataset object with numpy
x = np.array(x)

return mx.array(x, dtype=mlx_dtype)


Expand Down Expand Up @@ -279,18 +299,32 @@ def while_loop(
loop_vars,
maximum_iterations=None,
):
# TODO: How should we avoid evaluating cond when tracing?
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
)
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
while cond(*loop_vars) and iteration_check(current_iter):
loop_vars = body(*loop_vars)
if not isinstance(loop_vars, (list, tuple)):
loop_vars = (loop_vars,)
loop_vars = tuple(loop_vars)

is_sequence = isinstance(loop_vars, (tuple, list))

if is_sequence:
loop_vars = tuple(convert_to_tensor(v) for v in loop_vars)
else:
loop_vars = tree.map_structure(convert_to_tensor, loop_vars)

while (
cond(*loop_vars) if is_sequence else cond(loop_vars)
) and iteration_check(current_iter):
new_vars = body(*loop_vars) if is_sequence else body(loop_vars)

if is_sequence:
if not isinstance(new_vars, (tuple, list)):
new_vars = (new_vars,)
loop_vars = tuple(convert_to_tensor(v) for v in new_vars)
else:
loop_vars = tree.map_structure(convert_to_tensor, new_vars)

current_iter += 1

return loop_vars


Expand Down

0 comments on commit 315cd25

Please sign in to comment.