Skip to content

Commit 4d45cd7

Browse files
committed
Merge branch 'main' of github.com:keras-team/keras-core
2 parents f9edaeb + 2c6f658 commit 4d45cd7

File tree

7 files changed

+80
-16
lines changed

7 files changed

+80
-16
lines changed

keras_core/backend/tensorflow/numpy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def append(
9696

9797

9898
def arange(start, stop=None, step=1, dtype=None):
99-
return tfnp.arange(start, stop, step=step, dtype=dtype)
99+
# tfnp.arange has trouble with dynamic Tensors in compiled function.
100+
# tf.range does not.
101+
return tf.range(start, stop, delta=step, dtype=dtype)
100102

101103

102104
def arccos(x):
@@ -444,7 +446,9 @@ def reciprocal(x):
444446

445447

446448
def repeat(x, repeats, axis=None):
447-
return tfnp.repeat(x, repeats, axis=axis)
449+
# tfnp.repeat has trouble with dynamic Tensors in compiled function.
450+
# tf.repeat does not.
451+
return tf.repeat(x, repeats, axis=axis)
448452

449453

450454
def reshape(x, new_shape):

keras_core/backend/torch/numpy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,14 @@ def take(x, indices, axis=None):
750750
# This case is equivalent to embedding lookup.
751751
return torch.nn.functional.embedding(indices, x)
752752
if axis is not None:
753-
return torch.index_select(x, dim=axis, index=indices).squeeze(axis)
753+
# make sure axis is non-negative
754+
axis = len(x.shape) + axis if axis < 0 else axis
755+
shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]
756+
# ravel the `indices` since `index_select` expects `indices`
757+
# to be a vector (1-D tensor).
758+
indices = indices.ravel()
759+
out = torch.index_select(x, dim=axis, index=indices).squeeze(axis)
760+
return out.reshape(shape)
754761
return torch.take(x, index=indices)
755762

756763

keras_core/layers/normalization/batch_normalization.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from keras_core import ops
44
from keras_core import regularizers
55
from keras_core.api_export import keras_core_export
6+
from keras_core.backend import standardize_dtype
67
from keras_core.layers.input_spec import InputSpec
78
from keras_core.layers.layer import Layer
89

@@ -188,6 +189,12 @@ def compute_output_shape(self, input_shape):
188189
return input_shape
189190

190191
def call(self, inputs, training=None, mask=None):
192+
input_dtype = standardize_dtype(inputs.dtype)
193+
if input_dtype in ("float16", "bfloat16"):
194+
# BN is prone to overflowing for float16/bfloat16 inputs, so we opt
195+
# out BN for mixed precision.
196+
inputs = ops.cast(inputs, "float32")
197+
191198
broadcast_shape = [1] * len(inputs.shape)
192199
broadcast_shape[self.axis] = inputs.shape[self.axis]
193200
if training and self.trainable:
@@ -198,12 +205,20 @@ def call(self, inputs, training=None, mask=None):
198205
outputs = (inputs - mean) / ops.sqrt(variance + self.epsilon)
199206
mean = ops.squeeze(mean, self._reduction_axes)
200207
variance = ops.squeeze(variance, self._reduction_axes)
208+
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
209+
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
201210
self.moving_mean.assign(
202-
self.moving_mean * self.momentum + mean * (1.0 - self.momentum)
211+
ops.cast(
212+
moving_mean * self.momentum + mean * (1.0 - self.momentum),
213+
inputs.dtype,
214+
)
203215
)
204216
self.moving_variance.assign(
205-
self.moving_variance * self.momentum
206-
+ variance * (1.0 - self.momentum)
217+
ops.cast(
218+
moving_variance * self.momentum
219+
+ variance * (1.0 - self.momentum),
220+
inputs.dtype,
221+
)
207222
)
208223
else:
209224
moving_mean = ops.reshape(self.moving_mean, broadcast_shape)
@@ -213,9 +228,11 @@ def call(self, inputs, training=None, mask=None):
213228
)
214229
if self.scale:
215230
gamma = ops.reshape(self.gamma, broadcast_shape)
231+
gamma = ops.cast(gamma, outputs.dtype)
216232
outputs = outputs * gamma
217233
if self.center:
218234
beta = ops.reshape(self.beta, broadcast_shape)
235+
beta = ops.cast(beta, outputs.dtype)
219236
outputs = outputs + beta
220237
return outputs
221238

keras_core/layers/normalization/layer_normalization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,12 @@ def _broadcast(v):
205205
# Compute the batch normalization.
206206
inv = 1 / ops.sqrt(variance + self.epsilon)
207207
if scale is not None:
208+
scale = ops.cast(scale, inputs.dtype)
208209
inv = inv * scale
209-
210-
x = offset - mean * inv if offset is not None else -mean * inv
210+
x = -mean * inv
211+
if offset is not None:
212+
offset = ops.cast(offset, inputs.dtype)
213+
x = offset + x
211214
outputs = inputs * ops.cast(inv, inputs.dtype) + ops.cast(
212215
x, inputs.dtype
213216
)

keras_core/mixed_precision/dtype_policy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,12 @@ def from_config(cls, config):
9393
return cls(**config)
9494

9595

96-
@keras_core_export("keras_core.mixed_precision.set_dtype_policy")
96+
@keras_core_export(
97+
[
98+
"keras_core.mixed_precision.set_dtype_policy",
99+
"keras_core.mixed_precision.set_global_policy",
100+
]
101+
)
97102
def set_dtype_policy(policy):
98103
if not isinstance(policy, DTypePolicy):
99104
if isinstance(policy, str):

keras_core/ops/numpy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2814,16 +2814,16 @@ def call(self, x, indices):
28142814

28152815
def compute_output_spec(self, x, indices):
28162816
x_shape = list(x.shape)
2817-
indices_shape = list(getattr(np.array(indices), "shape", []))
2817+
if isinstance(indices, KerasTensor):
2818+
indices_shape = list(indices.shape)
2819+
else:
2820+
indices_shape = list(getattr(np.array(indices), "shape", []))
28182821
if self.axis is None:
28192822
return KerasTensor(indices_shape, dtype=x.dtype)
28202823

2821-
if self.axis == -1:
2822-
output_shape = x_shape[:-1] + indices_shape
2823-
else:
2824-
output_shape = (
2825-
x_shape[: self.axis] + indices_shape + x_shape[self.axis + 1 :]
2826-
)
2824+
# make sure axis is non-negative
2825+
axis = len(x_shape) + self.axis if self.axis < 0 else self.axis
2826+
output_shape = x_shape[:axis] + indices_shape + x_shape[axis + 1 :]
28272827
return KerasTensor(output_shape, dtype=x.dtype)
28282828

28292829

keras_core/ops/numpy_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def test_take(self):
203203
knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2, 3)
204204
)
205205

206+
# test with negative axis
207+
self.assertEqual(knp.take(x, 1, axis=-2).shape, (None, 3))
208+
209+
# test with multi-dimensional indices
210+
x = KerasTensor([None, 3, None, 5])
211+
indices = KerasTensor([6, 7])
212+
self.assertEqual(knp.take(x, indices, axis=2).shape, (None, 3, 6, 7, 5))
213+
206214
def test_take_along_axis(self):
207215
x = KerasTensor([None, 3])
208216
indices = KerasTensor([1, 3])
@@ -586,6 +594,11 @@ def test_take(self):
586594
self.assertEqual(knp.take(x, [1, 2]).shape, (2,))
587595
self.assertEqual(knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (2, 2, 2))
588596

597+
# test with multi-dimensional indices
598+
x = KerasTensor([2, 3, 4, 5])
599+
indices = KerasTensor([6, 7])
600+
self.assertEqual(knp.take(x, indices, axis=2).shape, (2, 3, 6, 7, 5))
601+
589602
def test_take_along_axis(self):
590603
x = KerasTensor([2, 3])
591604
indices = KerasTensor([1, 3])
@@ -1980,6 +1993,21 @@ def test_take(self):
19801993
self.assertAllClose(knp.Take()(x, 0), np.take(x, 0))
19811994
self.assertAllClose(knp.Take(axis=1)(x, 0), np.take(x, 0, axis=1))
19821995

1996+
# test with multi-dimensional indices
1997+
rng = np.random.default_rng(0)
1998+
x = rng.standard_normal((2, 3, 4, 5))
1999+
indices = rng.integers(0, 4, (6, 7))
2000+
self.assertAllClose(
2001+
knp.take(x, indices, axis=2),
2002+
np.take(x, indices, axis=2),
2003+
)
2004+
2005+
# test with negative axis
2006+
self.assertAllClose(
2007+
knp.take(x, indices, axis=-2),
2008+
np.take(x, indices, axis=-2),
2009+
)
2010+
19832011
def test_take_along_axis(self):
19842012
x = np.arange(24).reshape([1, 2, 3, 4])
19852013
indices = np.ones([1, 4, 1, 1], dtype=np.int32)

0 commit comments

Comments
 (0)