Skip to content

Commit

Permalink
Fix UpSampling2D TF issue
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 5, 2024
1 parent 89b8acb commit d84c6ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 3 additions & 1 deletion keras/src/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def serving_fn(x):

if input_signature:
if backend.backend() == "tensorflow":
decorated_fn = tf.function(fn, input_signature=input_signature)
decorated_fn = tf.function(
fn, input_signature=input_signature, autograph=False
)
else: # JAX backend

# 1. Create a stateless wrapper for `fn`
Expand Down
7 changes: 4 additions & 3 deletions keras/src/layers/reshaping/up_sampling2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _resize_images(
if data_format == "channels_first":
x = ops.transpose(x, [0, 2, 3, 1])
# https://github.com/keras-team/keras/issues/294
# Use `ops.repeat` for `nearest` interpolation
# Use `ops.repeat` for `nearest` interpolation to enable XLA
if interpolation == "nearest":
x = ops.repeat(x, height_factor, axis=1)
x = ops.repeat(x, width_factor, axis=2)
Expand All @@ -158,9 +158,10 @@ def _resize_images(
# since when running under torchdynamo, `new_shape`
# will be traced as a symbolic variable (specifically
# a `FakeTensor`) which does not have a `tolist()` method.
shape = ops.shape(x)
new_shape = (
x.shape[1] * height_factor,
x.shape[2] * width_factor,
shape[1] * height_factor,
shape[2] * width_factor,
)
x = ops.image.resize(x, new_shape, interpolation=interpolation)
if data_format == "channels_first":
Expand Down

0 comments on commit d84c6ee

Please sign in to comment.