Skip to content

Commit

Permalink
Use the exported keras.Variable class throughout code.
Browse files Browse the repository at this point in the history
The exported `keras.Variable` class is a subclass of the backend specific implementation. However, that public class was not used internally to create variables, for instance via `add_weight()`.

There was no convenient way to detect that a tensor really is a variable. This code would print `False`:
```python
v = self.add_weight(x)
print(isinstance(v, keras.Variable)
```

This change makes `keras.Variable` the actual class used internally. The only exception is in `SeedGenerator` with an overridden backend.
  • Loading branch information
hertschuh committed Oct 7, 2024
1 parent f52f9f5 commit 6129b2b
Show file tree
Hide file tree
Showing 15 changed files with 29 additions and 43 deletions.
2 changes: 1 addition & 1 deletion keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from keras.api import saving
from keras.api import tree
from keras.api import utils
from keras.src.backend import Variable
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.exports import Variable
from keras.src.backend.exports import device
from keras.src.backend.exports import name_scope
from keras.src.dtype_policies.dtype_policy import DTypePolicy
Expand Down
2 changes: 1 addition & 1 deletion keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from keras.api._tf_keras.keras import losses
from keras.api._tf_keras.keras import metrics
from keras.api._tf_keras.keras import preprocessing
from keras.src.backend import Variable
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.exports import Variable
from keras.src.backend.exports import device
from keras.src.backend.exports import name_scope
from keras.src.dtype_policies.dtype_policy import DTypePolicy
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# upon import.
import torch

from keras.src.api_export import keras_export
from keras.src.backend.common.dtypes import result_type
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.keras_tensor import any_symbolic_tensors
Expand Down Expand Up @@ -47,3 +48,8 @@
distribution_lib = None
else:
raise ValueError(f"Unable to import backend : {backend()}")


@keras_export("keras.Variable")
class Variable(BackendVariable): # noqa: F405
pass
12 changes: 0 additions & 12 deletions keras/src/backend/exports.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.backend.common import KerasVariable

if backend.backend() == "tensorflow":
BackendVariable = backend.tensorflow.core.Variable
backend_name_scope = backend.tensorflow.core.name_scope
elif backend.backend() == "jax":
BackendVariable = backend.jax.core.Variable
backend_name_scope = backend.common.name_scope.name_scope
elif backend.backend() == "torch":
BackendVariable = backend.torch.core.Variable
backend_name_scope = backend.common.name_scope.name_scope
elif backend.backend() == "numpy":
from keras.src.backend.numpy.core import Variable as NumpyVariable

BackendVariable = NumpyVariable
backend_name_scope = backend.common.name_scope.name_scope
else:
raise RuntimeError(f"Invalid backend: {backend.backend()}")


@keras_export("keras.Variable")
class Variable(BackendVariable, KerasVariable):
pass


@keras_export("keras.name_scope")
class name_scope(backend_name_scope):
pass
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from keras.src.backend.jax import numpy
from keras.src.backend.jax import random
from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.jax.core import Variable
from keras.src.backend.jax.core import JaxVariable as BackendVariable
from keras.src.backend.jax.core import cast
from keras.src.backend.jax.core import compute_output_spec
from keras.src.backend.jax.core import cond
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SUPPORTS_SPARSE_TENSORS = True


class Variable(KerasVariable):
class JaxVariable(KerasVariable):
def _initialize(self, value):
value = jnp.array(value, dtype=self._dtype)
# Note that variable.shape is needed by distribution_lib
Expand Down Expand Up @@ -56,7 +56,7 @@ def convert_to_tensor(x, dtype=None, sparse=True):
# an existing distributed jax array will raise error.
return x

if isinstance(x, Variable):
if isinstance(x, KerasVariable):
if dtype is not None and x.dtype != dtype:
return x.value.astype(dtype)
return x.value
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from keras.src.backend.numpy import numpy
from keras.src.backend.numpy import random
from keras.src.backend.numpy.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.numpy.core import Variable
from keras.src.backend.numpy.core import NumpyVariable as BackendVariable
from keras.src.backend.numpy.core import cast
from keras.src.backend.numpy.core import compute_output_spec
from keras.src.backend.numpy.core import cond
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SUPPORTS_SPARSE_TENSORS = False


class Variable(KerasVariable):
class NumpyVariable(KerasVariable):
def _initialize(self, value):
self._value = np.array(value, dtype=self._dtype)

Expand All @@ -36,7 +36,7 @@ def convert_to_tensor(x, dtype=None, sparse=None):
raise ValueError("`sparse=True` is not supported with numpy backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
if isinstance(x, KerasVariable):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from keras.src.backend.tensorflow import random
from keras.src.backend.tensorflow import tensorboard
from keras.src.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.tensorflow.core import Variable
from keras.src.backend.tensorflow.core import TFVariable as BackendVariable
from keras.src.backend.tensorflow.core import cast
from keras.src.backend.tensorflow.core import compute_output_spec
from keras.src.backend.tensorflow.core import cond
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SUPPORTS_SPARSE_TENSORS = True


class Variable(
class TFVariable(
KerasVariable,
tf.__internal__.types.Tensor,
tf.__internal__.tracking.Trackable,
Expand Down
17 changes: 7 additions & 10 deletions keras/src/backend/tensorflow/rnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf

from keras.src import backend
from keras.src import tree


Expand Down Expand Up @@ -471,13 +472,11 @@ def gru(
if not cudnn_supported:
raise NotImplementedError

from keras.src.backend.tensorflow import Variable

if isinstance(kernel, Variable):
if isinstance(kernel, backend.Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
if isinstance(recurrent_kernel, backend.Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
if isinstance(bias, backend.Variable):
bias = bias.value

try:
Expand Down Expand Up @@ -828,13 +827,11 @@ def lstm(
if not cudnn_supported:
raise NotImplementedError

from keras.src.backend.tensorflow import Variable

if isinstance(kernel, Variable):
if isinstance(kernel, backend.Variable):
kernel = kernel.value
if isinstance(recurrent_kernel, Variable):
if isinstance(recurrent_kernel, backend.Variable):
recurrent_kernel = recurrent_kernel.value
if isinstance(bias, Variable):
if isinstance(bias, backend.Variable):
bias = bias.value

try:
Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from keras.src.backend.torch import numpy
from keras.src.backend.torch import random
from keras.src.backend.torch.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.torch.core import Variable
from keras.src.backend.torch.core import TorchVariable as BackendVariable
from keras.src.backend.torch.core import cast
from keras.src.backend.torch.core import compute_output_spec
from keras.src.backend.torch.core import cond
Expand Down
8 changes: 2 additions & 6 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def to_torch_dtype(dtype):
return standardized_dtype


class Variable(KerasVariable):
class TorchVariable(KerasVariable):
def _initialize(self, value):
if isinstance(value, torch.nn.Parameter):
# Reuse same parameter
Expand Down Expand Up @@ -187,11 +187,7 @@ def __eq__(self, other):
def convert_to_tensor(x, dtype=None, sparse=None):
if sparse:
raise ValueError("`sparse=True` is not supported with torch backend")
if type(x) is Variable:
# We cannot use `isinstance(x, Variable)` due to the failure of
# TorchDynamo.
# torch._dynamo.exc.InternalTorchDynamoError:
# GetAttrVariable(SuperVariable(), value) has no type.
if isinstance(x, KerasVariable):
# TorchDynamo has bugs supporting nn.Parameter type check.
# Return it directly instead of pass it to the rest of the logic in the
# function.
Expand Down
2 changes: 1 addition & 1 deletion keras/src/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def seed_initializer(*args, **kwargs):
return self.backend.convert_to_tensor([seed, 0], dtype=dtype)

with backend.name_scope(self.name, caller=self):
self.state = self.backend.Variable(
self.state = self.backend.BackendVariable(
seed_initializer,
shape=(2,),
dtype=self.backend.random_seed_dtype(),
Expand Down
5 changes: 2 additions & 3 deletions keras/src/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from packaging.version import parse

from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.layers import Layer
from keras.src.ops import convert_to_numpy
Expand Down Expand Up @@ -101,12 +102,10 @@ def parameters(self, recurse=True):
return self.module.parameters(recurse=recurse)

def _track_module_parameters(self):
from keras.src.backend.torch import Variable

for param in self.module.parameters():
# The Variable will reuse the raw `param`
# and simply wrap it.
variable = Variable(
variable = backend.Variable(
initializer=param, trainable=param.requires_grad
)
self._track_variable(variable)
Expand Down

0 comments on commit 6129b2b

Please sign in to comment.