Skip to content

Commit

Permalink
Merge branch 'main' of github.com:keras-team/keras-core
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jun 29, 2023
2 parents f9edaeb + 2c6f658 commit 4d45cd7
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 16 deletions.
8 changes: 6 additions & 2 deletions keras_core/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def append(


def arange(start, stop=None, step=1, dtype=None):
return tfnp.arange(start, stop, step=step, dtype=dtype)
# tfnp.arange has trouble with dynamic Tensors in compiled function.
# tf.range does not.
return tf.range(start, stop, delta=step, dtype=dtype)


def arccos(x):
Expand Down Expand Up @@ -444,7 +446,9 @@ def reciprocal(x):


def repeat(x, repeats, axis=None):
return tfnp.repeat(x, repeats, axis=axis)
# tfnp.repeat has trouble with dynamic Tensors in compiled function.
# tf.repeat does not.
return tf.repeat(x, repeats, axis=axis)


def reshape(x, new_shape):
Expand Down
9 changes: 8 additions & 1 deletion keras_core/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,14 @@ def take(x, indices, axis=None):
# This case is equivalent to embedding lookup.
return torch.nn.functional.embedding(indices, x)
if axis is not None:
return torch.index_select(x, dim=axis, index=indices).squeeze(axis)
# make sure axis is non-negative
axis = len(x.shape) + axis if axis < 0 else axis
shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
# ravel the `indices` since `index_select` expects `indices`
# to be a vector (1-D tensor).
indices = indices.ravel()
out = torch.index_select(x, dim=axis, index=indices).squeeze(axis)
return out.reshape(shape)
return torch.take(x, index=indices)


Expand Down
23 changes: 20 additions & 3 deletions keras_core/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras_core import ops
from keras_core import regularizers
from keras_core.api_export import keras_core_export
from keras_core.backend import standardize_dtype
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer

Expand Down Expand Up @@ -188,6 +189,12 @@ def compute_output_shape(self, input_shape):
return input_shape

def call(self, inputs, training=None, mask=None):
input_dtype = standardize_dtype(inputs.dtype)
if input_dtype in ("float16", "bfloat16"):
# BN is prone to overflowing for float16/bfloat16 inputs, so we opt
# out BN for mixed precision.
inputs = ops.cast(inputs, "float32")

broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[self.axis] = inputs.shape[self.axis]
if training and self.trainable:
Expand All @@ -198,12 +205,20 @@ def call(self, inputs, training=None, mask=None):
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
mean = ops.squeeze(mean, self._reduction_axes)
variance = ops.squeeze(variance, self._reduction_axes)
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
self.moving_mean.assign(
self.moving_mean * self.momentum + mean * (1.0 - self.momentum)
ops.cast(
moving_mean * self.momentum + mean * (1.0 - self.momentum),
inputs.dtype,
)
)
self.moving_variance.assign(
self.moving_variance * self.momentum
+ variance * (1.0 - self.momentum)
ops.cast(
moving_variance * self.momentum
+ variance * (1.0 - self.momentum),
inputs.dtype,
)
)
else:
moving_mean = ops.reshape(self.moving_mean, broadcast_shape)
Expand All @@ -213,9 +228,11 @@ def call(self, inputs, training=None, mask=None):
)
if self.scale:
gamma = ops.reshape(self.gamma, broadcast_shape)
gamma = ops.cast(gamma, outputs.dtype)
outputs = outputs * gamma
if self.center:
beta = ops.reshape(self.beta, broadcast_shape)
beta = ops.cast(beta, outputs.dtype)
outputs = outputs + beta
return outputs

Expand Down
7 changes: 5 additions & 2 deletions keras_core/layers/normalization/layer_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ def _broadcast(v):
# Compute the batch normalization.
inv = 1 / ops.sqrt(variance + self.epsilon)
if scale is not None:
scale = ops.cast(scale, inputs.dtype)
inv = inv * scale

x = offset - mean * inv if offset is not None else -mean * inv
x = -mean * inv
if offset is not None:
offset = ops.cast(offset, inputs.dtype)
x = offset + x
outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
x, inputs.dtype
)
Expand Down
7 changes: 6 additions & 1 deletion keras_core/mixed_precision/dtype_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,12 @@ def from_config(cls, config):
return cls(**config)


@keras_core_export("keras_core.mixed_precision.set_dtype_policy")
@keras_core_export(
[
"keras_core.mixed_precision.set_dtype_policy",
"keras_core.mixed_precision.set_global_policy",
]
)
def set_dtype_policy(policy):
if not isinstance(policy, DTypePolicy):
if isinstance(policy, str):
Expand Down
14 changes: 7 additions & 7 deletions keras_core/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,16 +2814,16 @@ def call(self, x, indices):

def compute_output_spec(self, x, indices):
x_shape = list(x.shape)
indices_shape = list(getattr(np.array(indices), "shape", []))
if isinstance(indices, KerasTensor):
indices_shape = list(indices.shape)
else:
indices_shape = list(getattr(np.array(indices), "shape", []))
if self.axis is None:
return KerasTensor(indices_shape, dtype=x.dtype)

if self.axis == -1:
output_shape = x_shape[:-1] + indices_shape
else:
output_shape = (
x_shape[: self.axis] + indices_shape + x_shape[self.axis + 1 :]
)
# make sure axis is non-negative
axis = len(x_shape) + self.axis if self.axis < 0 else self.axis
output_shape = x_shape[:axis] + indices_shape + x_shape[axis + 1 :]
return KerasTensor(output_shape, dtype=x.dtype)


Expand Down
28 changes: 28 additions & 0 deletions keras_core/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def test_take(self):
knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2, 3)
)

# test with negative axis
self.assertEqual(knp.take(x, 1, axis=-2).shape, (None, 3))

# test with multi-dimensional indices
x = KerasTensor([None, 3, None, 5])
indices = KerasTensor([6, 7])
self.assertEqual(knp.take(x, indices, axis=2).shape, (None, 3, 6, 7, 5))

def test_take_along_axis(self):
x = KerasTensor([None, 3])
indices = KerasTensor([1, 3])
Expand Down Expand Up @@ -586,6 +594,11 @@ def test_take(self):
self.assertEqual(knp.take(x, [1, 2]).shape, (2,))
self.assertEqual(knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (2, 2, 2))

# test with multi-dimensional indices
x = KerasTensor([2, 3, 4, 5])
indices = KerasTensor([6, 7])
self.assertEqual(knp.take(x, indices, axis=2).shape, (2, 3, 6, 7, 5))

def test_take_along_axis(self):
x = KerasTensor([2, 3])
indices = KerasTensor([1, 3])
Expand Down Expand Up @@ -1980,6 +1993,21 @@ def test_take(self):
self.assertAllClose(knp.Take()(x, 0), np.take(x, 0))
self.assertAllClose(knp.Take(axis=1)(x, 0), np.take(x, 0, axis=1))

# test with multi-dimensional indices
rng = np.random.default_rng(0)
x = rng.standard_normal((2, 3, 4, 5))
indices = rng.integers(0, 4, (6, 7))
self.assertAllClose(
knp.take(x, indices, axis=2),
np.take(x, indices, axis=2),
)

# test with negative axis
self.assertAllClose(
knp.take(x, indices, axis=-2),
np.take(x, indices, axis=-2),
)

def test_take_along_axis(self):
x = np.arange(24).reshape([1, 2, 3, 4])
indices = np.ones([1, 4, 1, 1], dtype=np.int32)
Expand Down

0 comments on commit 4d45cd7

Please sign in to comment.