diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_BatchNorm2d_output/run_0/Translated__BatchNorm.py b/ivy/compiler/_cache/Translated_Outputs/Translated_BatchNorm2d_output/run_0/Translated__BatchNorm.py index 9d931afa6f0f..c77efa6e2d4f 100644 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_BatchNorm2d_output/run_0/Translated__BatchNorm.py +++ b/ivy/compiler/_cache/Translated_Outputs/Translated_BatchNorm2d_output/run_0/Translated__BatchNorm.py @@ -47,11 +47,9 @@ def forward(self, input): """ return torch.nn.functional.batch_norm( input, - ( - self.running_mean - if not self.training or self.track_running_stats - else None - ), + self.running_mean + if not self.training or self.track_running_stats + else None, self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/Translated_Conv2d.py b/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/Translated_Conv2d.py deleted file mode 100644 index 3bb9d2a8ab24..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/Translated_Conv2d.py +++ /dev/null @@ -1,62 +0,0 @@ -import ivy.functional.frontends.torch as torch - -from .Translated__ConvNd import Translated__ConvNd -from .helpers import Translated__ntuple_parse - -_pair = Translated__ntuple_parse(2, "_pair") - - -class Translated_Conv2d(Translated__ConvNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode="zeros", - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - kernel_size_ = _pair(kernel_size) - stride_ = _pair(stride) - padding_ = padding if isinstance(padding, str) else _pair(padding) - dilation_ = _pair(dilation) - super().__init__( - in_channels, - out_channels, - kernel_size_, - stride_, - padding_, - dilation_, - False, - _pair(0), - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def _conv_forward(self, input, weight, bias): - if self.padding_mode != "zeros": - return torch.nn.functional.conv2d( - torch.nn.functional.pad( - input, self._reversed_padding_repeated_twice, mode=self.padding_mode - ), - weight, - bias, - self.stride, - _pair(0), - self.dilation, - self.groups, - ) - return torch.nn.functional.conv2d( - input, weight, bias, self.stride, self.padding, self.dilation, self.groups - ) - - def forward(self, input): - return self._conv_forward(input, self.weight, self.bias) diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/Translated__ConvNd.py b/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/Translated__ConvNd.py deleted file mode 100644 index b177126d465e..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/Translated__ConvNd.py +++ /dev/null @@ -1,158 +0,0 @@ -import ivy.functional.frontends.torch as torch -import ivy.functional.frontends.torch.nn as nn - -import typing -import math -from typing import Optional - -from .helpers import Translated__calculate_fan_in_and_fan_out -from .helpers import Translated__reverse_repeat_tuple -from .helpers import Translated_kaiming_uniform_ -from .helpers import Translated_uniform_ - - -class Translated__ConvNd(nn.Module): - __constants__ = [ - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - "output_padding", - "in_channels", - "out_channels", - "kernel_size", - ] - __annotations__ = {"bias": Optional[torch.Tensor]} - - def _conv_forward(self, input, weight, bias): ... - - in_channels: typing.Any - _reversed_padding_repeated_twice: typing.Any - out_channels: typing.Any - kernel_size: typing.Any - stride: typing.Any - padding: typing.Any - dilation: typing.Any - transposed: typing.Any - output_padding: typing.Any - groups: typing.Any - padding_mode: typing.Any - weight: typing.Any - bias: typing.Any - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if groups <= 0: - raise ValueError("groups must be a positive integer") - if in_channels % groups != 0: - raise ValueError("in_channels must be divisible by groups") - if out_channels % groups != 0: - raise ValueError("out_channels must be divisible by groups") - valid_padding_strings = {"same", "valid"} - if isinstance(padding, str): - if padding not in valid_padding_strings: - raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" - ) - if padding == "same" and any(s != 1 for s in stride): - raise ValueError( - "padding='same' is not supported for strided convolutions" - ) - valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} - if padding_mode not in valid_padding_modes: - raise ValueError( - f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" - ) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.padding_mode = padding_mode - if isinstance(self.padding, str): - self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = Translated__reverse_repeat_tuple( - self.padding, 2 - ) - if transposed: - self.weight = torch.nn.parameter.Parameter( - torch.empty( - (in_channels, out_channels // groups, *kernel_size), - **factory_kwargs, - ) - ) - else: - self.weight = torch.nn.parameter.Parameter( - torch.empty( - (out_channels, in_channels // groups, *kernel_size), - **factory_kwargs, - ) - ) - if bias: - self.bias = torch.nn.parameter.Parameter( - torch.empty(out_channels, **factory_kwargs) - ) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - Translated_kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = Translated__calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - Translated_uniform_(self.bias, -bound, bound) - - def extra_repr(self): - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" - if self.padding != (0,) * len(self.padding): - s += ", padding={padding}" - if self.dilation != (1,) * len(self.dilation): - s += ", dilation={dilation}" - if self.output_padding != (0,) * len(self.output_padding): - s += ", output_padding={output_padding}" - if self.groups != 1: - s += ", groups={groups}" - if self.bias is None: - s += ", bias=False" - if self.padding_mode != "zeros": - s += ", padding_mode={padding_mode}" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/helpers.py b/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/helpers.py deleted file mode 100644 index fc2435d1dcbf..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_Conv2d_output/run_0/helpers.py +++ /dev/null @@ -1,99 +0,0 @@ -from itertools import repeat -import collections -import math -import warnings - - -def Translated__ntuple_parse(n, name="parse"): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - parse.__name__ = name - return parse - - -def Translated__reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def Translated__calculate_fan_in_and_fan_out(tensor): - dimensions = tensor.dim() - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - num_input_fmaps = tensor.size(1) - num_output_fmaps = tensor.size(0) - receptive_field_size = 1 - if tensor.dim() > 2: - for s in tensor.shape[2:]: - receptive_field_size *= s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - return fan_in, fan_out - - -def Translated__calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ["fan_in", "fan_out"] - if mode not in valid_modes: - raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") - fan_in, fan_out = Translated__calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == "fan_in" else fan_out - - -def Translated_calculate_gain(nonlinearity, param=None): - linear_fns = [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - ] - if nonlinearity in linear_fns or nonlinearity == "sigmoid": - return 1 - elif nonlinearity == "tanh": - return 5.0 / 3 - elif nonlinearity == "relu": - return math.sqrt(2.0) - elif nonlinearity == "leaky_relu": - if param is None: - negative_slope = 0.01 - elif ( - not isinstance(param, bool) - and isinstance(param, int) - or isinstance(param, float) - ): - negative_slope = param - else: - raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == "selu": - return 3.0 / 4 - else: - raise ValueError(f"Unsupported nonlinearity {nonlinearity}") - - -def Translated_kaiming_uniform_( - tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None -): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = Translated__calculate_correct_fan(tensor, mode) - gain = Translated_calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std - return tensor.uniform_(-bound, bound, generator=generator) - - -def Translated__no_grad_uniform_(tensor, a, b, generator=None): - return tensor.uniform_(a, b, generator=generator) - - -def Translated_uniform_(tensor, a=0.0, b=1.0, generator=None): - return Translated__no_grad_uniform_(tensor, a, b, generator) diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated_ConvTranspose2d.py b/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated_ConvTranspose2d.py deleted file mode 100644 index 98d894ad0d3b..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated_ConvTranspose2d.py +++ /dev/null @@ -1,71 +0,0 @@ -import ivy.functional.frontends.torch as torch - -from .Translated__ConvTransposeNd import Translated__ConvTransposeNd -from .helpers import Translated__ntuple_parse - -_pair = Translated__ntuple_parse(2, "_pair") - - -class Translated_ConvTranspose2d(Translated__ConvTransposeNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - bias=True, - dilation=1, - padding_mode="zeros", - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - kernel_size = _pair(kernel_size) - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - output_padding = _pair(output_padding) - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - True, - output_padding, - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def forward(self, input, output_size=None): - if self.padding_mode != "zeros": - raise ValueError( - "Only `zeros` padding mode is supported for ConvTranspose2d" - ) - assert isinstance(self.padding, tuple) - num_spatial_dims = 2 - output_padding = self._output_padding( - input, - output_size, - self.stride, - self.padding, - self.kernel_size, - num_spatial_dims, - self.dilation, - ) - return torch.nn.functional.conv_transpose2d( - input, - self.weight, - self.bias, - self.stride, - self.padding, - output_padding, - self.groups, - self.dilation, - ) diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated__ConvNd.py b/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated__ConvNd.py deleted file mode 100644 index b177126d465e..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated__ConvNd.py +++ /dev/null @@ -1,158 +0,0 @@ -import ivy.functional.frontends.torch as torch -import ivy.functional.frontends.torch.nn as nn - -import typing -import math -from typing import Optional - -from .helpers import Translated__calculate_fan_in_and_fan_out -from .helpers import Translated__reverse_repeat_tuple -from .helpers import Translated_kaiming_uniform_ -from .helpers import Translated_uniform_ - - -class Translated__ConvNd(nn.Module): - __constants__ = [ - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - "output_padding", - "in_channels", - "out_channels", - "kernel_size", - ] - __annotations__ = {"bias": Optional[torch.Tensor]} - - def _conv_forward(self, input, weight, bias): ... - - in_channels: typing.Any - _reversed_padding_repeated_twice: typing.Any - out_channels: typing.Any - kernel_size: typing.Any - stride: typing.Any - padding: typing.Any - dilation: typing.Any - transposed: typing.Any - output_padding: typing.Any - groups: typing.Any - padding_mode: typing.Any - weight: typing.Any - bias: typing.Any - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if groups <= 0: - raise ValueError("groups must be a positive integer") - if in_channels % groups != 0: - raise ValueError("in_channels must be divisible by groups") - if out_channels % groups != 0: - raise ValueError("out_channels must be divisible by groups") - valid_padding_strings = {"same", "valid"} - if isinstance(padding, str): - if padding not in valid_padding_strings: - raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" - ) - if padding == "same" and any(s != 1 for s in stride): - raise ValueError( - "padding='same' is not supported for strided convolutions" - ) - valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} - if padding_mode not in valid_padding_modes: - raise ValueError( - f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" - ) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.padding_mode = padding_mode - if isinstance(self.padding, str): - self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = Translated__reverse_repeat_tuple( - self.padding, 2 - ) - if transposed: - self.weight = torch.nn.parameter.Parameter( - torch.empty( - (in_channels, out_channels // groups, *kernel_size), - **factory_kwargs, - ) - ) - else: - self.weight = torch.nn.parameter.Parameter( - torch.empty( - (out_channels, in_channels // groups, *kernel_size), - **factory_kwargs, - ) - ) - if bias: - self.bias = torch.nn.parameter.Parameter( - torch.empty(out_channels, **factory_kwargs) - ) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - Translated_kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = Translated__calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - Translated_uniform_(self.bias, -bound, bound) - - def extra_repr(self): - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" - if self.padding != (0,) * len(self.padding): - s += ", padding={padding}" - if self.dilation != (1,) * len(self.dilation): - s += ", dilation={dilation}" - if self.output_padding != (0,) * len(self.output_padding): - s += ", output_padding={output_padding}" - if self.groups != 1: - s += ", groups={groups}" - if self.bias is None: - s += ", bias=False" - if self.padding_mode != "zeros": - s += ", padding_mode={padding_mode}" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated__ConvTransposeNd.py b/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated__ConvTransposeNd.py deleted file mode 100644 index a3e86a228a94..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/Translated__ConvTransposeNd.py +++ /dev/null @@ -1,89 +0,0 @@ -from .Translated__ConvNd import Translated__ConvNd -from .helpers import Translated__ntuple_parse - -_single = Translated__ntuple_parse(1, "_single") - - -class Translated__ConvTransposeNd(Translated__ConvNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - if padding_mode != "zeros": - raise ValueError( - f'Only "zeros" padding mode is supported for {self.__class__.__name__}' - ) - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def _output_padding( - self, - input, - output_size, - stride, - padding, - kernel_size, - num_spatial_dims, - dilation=None, - ): - if output_size is None: - ret = _single(self.output_padding) - else: - has_batch_dim = input.dim() == num_spatial_dims + 2 - num_non_spatial_dims = 2 if has_batch_dim else 1 - if len(output_size) == num_non_spatial_dims + num_spatial_dims: - output_size = output_size[num_non_spatial_dims:] - if len(output_size) != num_spatial_dims: - raise ValueError( - f"ConvTranspose{num_spatial_dims}D: for {input.dim()}D input, output_size must have {num_spatial_dims} or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" - ) - min_sizes = [] - max_sizes = [] - for d in range(num_spatial_dims): - dim_size = ( - (input.size(d + num_non_spatial_dims) - 1) * stride[d] - - 2 * padding[d] - + (dilation[d] if dilation is not None else 1) - * (kernel_size[d] - 1) - + 1 - ) - min_sizes.append(dim_size) - max_sizes.append(min_sizes[d] + stride[d] - 1) - for i in range(len(output_size)): - size = output_size[i] - min_size = min_sizes[i] - max_size = max_sizes[i] - if size < min_size or size > max_size: - raise ValueError( - f"requested an output size of {output_size}, but valid sizes range from {min_sizes} to {max_sizes} (for an input of {input.size()[2:]})" - ) - res = [] - for d in range(num_spatial_dims): - res.append(output_size[d] - min_sizes[d]) - ret = res - return ret diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/helpers.py b/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/helpers.py deleted file mode 100644 index fc2435d1dcbf..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_ConvTranspose2d_output/run_0/helpers.py +++ /dev/null @@ -1,99 +0,0 @@ -from itertools import repeat -import collections -import math -import warnings - - -def Translated__ntuple_parse(n, name="parse"): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - parse.__name__ = name - return parse - - -def Translated__reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def Translated__calculate_fan_in_and_fan_out(tensor): - dimensions = tensor.dim() - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - num_input_fmaps = tensor.size(1) - num_output_fmaps = tensor.size(0) - receptive_field_size = 1 - if tensor.dim() > 2: - for s in tensor.shape[2:]: - receptive_field_size *= s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - return fan_in, fan_out - - -def Translated__calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ["fan_in", "fan_out"] - if mode not in valid_modes: - raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") - fan_in, fan_out = Translated__calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == "fan_in" else fan_out - - -def Translated_calculate_gain(nonlinearity, param=None): - linear_fns = [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - ] - if nonlinearity in linear_fns or nonlinearity == "sigmoid": - return 1 - elif nonlinearity == "tanh": - return 5.0 / 3 - elif nonlinearity == "relu": - return math.sqrt(2.0) - elif nonlinearity == "leaky_relu": - if param is None: - negative_slope = 0.01 - elif ( - not isinstance(param, bool) - and isinstance(param, int) - or isinstance(param, float) - ): - negative_slope = param - else: - raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == "selu": - return 3.0 / 4 - else: - raise ValueError(f"Unsupported nonlinearity {nonlinearity}") - - -def Translated_kaiming_uniform_( - tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None -): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = Translated__calculate_correct_fan(tensor, mode) - gain = Translated_calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std - return tensor.uniform_(-bound, bound, generator=generator) - - -def Translated__no_grad_uniform_(tensor, a, b, generator=None): - return tensor.uniform_(a, b, generator=generator) - - -def Translated_uniform_(tensor, a=0.0, b=1.0, generator=None): - return Translated__no_grad_uniform_(tensor, a, b, generator) diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_Linear_output/run_0/Translated_Linear.py b/ivy/compiler/_cache/Translated_Outputs/Translated_Linear_output/run_0/Translated_Linear.py index 4db3999a197f..b78e8e89821b 100644 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_Linear_output/run_0/Translated_Linear.py +++ b/ivy/compiler/_cache/Translated_Outputs/Translated_Linear_output/run_0/Translated_Linear.py @@ -1,8 +1,8 @@ import ivy.functional.frontends.torch as torch import ivy.functional.frontends.torch.nn as nn -import math import typing +import math from .helpers import Translated__calculate_fan_in_and_fan_out from .helpers import Translated_kaiming_uniform_ diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_ModuleList_output/run_0/Translated_ModuleList.py b/ivy/compiler/_cache/Translated_Outputs/Translated_ModuleList_output/run_0/Translated_ModuleList.py index fd0baa4b683b..615c3acbd676 100644 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_ModuleList_output/run_0/Translated_ModuleList.py +++ b/ivy/compiler/_cache/Translated_Outputs/Translated_ModuleList_output/run_0/Translated_ModuleList.py @@ -2,9 +2,9 @@ import typing import operator -from collections import abc as container_abcs from itertools import chain from collections import OrderedDict +from collections import abc as container_abcs from .helpers import Translated__addindent diff --git a/ivy/compiler/_cache/Translated_Outputs/Translated_Sequential_output/run_0/Translated_Sequential.py b/ivy/compiler/_cache/Translated_Outputs/Translated_Sequential_output/run_0/Translated_Sequential.py index 83a8904d9f64..ef48c39cb2f7 100644 --- a/ivy/compiler/_cache/Translated_Outputs/Translated_Sequential_output/run_0/Translated_Sequential.py +++ b/ivy/compiler/_cache/Translated_Outputs/Translated_Sequential_output/run_0/Translated_Sequential.py @@ -1,11 +1,11 @@ import ivy.functional.frontends.torch as torch import ivy.functional.frontends.torch.nn as nn -import typing import operator +import typing +from collections import OrderedDict from typing import overload from itertools import islice -from collections import OrderedDict class Translated_Sequential(nn.Module): diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__BatchNorm.py b/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__BatchNorm.py index 808b96cb82c9..fd0f6561cbe4 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__BatchNorm.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__BatchNorm.py @@ -47,11 +47,9 @@ def forward(self, input): """ normalized, self.running_mean, self.running_var = ivy_batch_norm_frnt( input, - ( - self.running_mean - if not self.training or self.track_running_stats - else None - ), + self.running_mean + if not self.training or self.track_running_stats + else None, self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__helpers.py b/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__helpers.py index f37e5f02c994..23eccb2f7882 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_BatchNorm2d_output/run_0/ivy__helpers.py @@ -3,6 +3,25 @@ import re +def ivy_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if ivy.is_array(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + def ivy_empty_frnt( *args, size=None, @@ -120,25 +139,6 @@ def ivy_device_frnt(dev): return ivy.default_device(dev) -def ivy_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if ivy.is_array(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @ivy_handle_methods def ivy_split_frnt(tensor, split_size_or_sections, dim=0): if isinstance(split_size_or_sections, int): diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy_Conv2d.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy_Conv2d.py deleted file mode 100644 index 92b307ea4907..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy_Conv2d.py +++ /dev/null @@ -1,62 +0,0 @@ -from .ivy__ConvNd import ivy__ConvNd -from .ivy__helpers import ivy__ntuple_parse -from .ivy__helpers import ivy_conv2d_frnt -from .ivy__helpers import ivy_pad_frnt - -_pair = ivy__ntuple_parse(2, "_pair") - - -class ivy_Conv2d(ivy__ConvNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode="zeros", - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - kernel_size_ = _pair(kernel_size) - stride_ = _pair(stride) - padding_ = padding if isinstance(padding, str) else _pair(padding) - dilation_ = _pair(dilation) - super().__init__( - in_channels, - out_channels, - kernel_size_, - stride_, - padding_, - dilation_, - False, - _pair(0), - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def _conv_forward(self, input, weight, bias): - if self.padding_mode != "zeros": - return ivy_conv2d_frnt( - ivy_pad_frnt( - input, self._reversed_padding_repeated_twice, mode=self.padding_mode - ), - weight, - bias, - self.stride, - _pair(0), - self.dilation, - self.groups, - ) - return ivy_conv2d_frnt( - input, weight, bias, self.stride, self.padding, self.dilation, self.groups - ) - - def forward(self, input): - return self._conv_forward(input, self.weight, self.bias) diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy__ConvNd.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy__ConvNd.py deleted file mode 100644 index 6bca3fd80178..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy__ConvNd.py +++ /dev/null @@ -1,508 +0,0 @@ -import ivy -from collections import OrderedDict - -import typing -import math -from typing import Optional - -from .ivy__helpers import ivy__calculate_fan_in_and_fan_out -from .ivy__helpers import ivy__reverse_repeat_tuple -from .ivy__helpers import ivy_add_frnt_ -from .ivy__helpers import ivy_empty_frnt -from .ivy__helpers import ivy_kaiming_uniform_ -from .ivy__helpers import ivy_split_frnt_ -from .ivy__helpers import ivy_uniform_ - - -class ivy__ConvNd(ivy.Module): - __constants__ = [ - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - "output_padding", - "in_channels", - "out_channels", - "kernel_size", - ] - __annotations__ = {"bias": Optional[ivy.Array]} - - def _conv_forward(self, input, weight, bias): ... - - in_channels: typing.Any - _reversed_padding_repeated_twice: typing.Any - out_channels: typing.Any - kernel_size: typing.Any - stride: typing.Any - padding: typing.Any - dilation: typing.Any - transposed: typing.Any - output_padding: typing.Any - groups: typing.Any - padding_mode: typing.Any - weight: typing.Any - bias: typing.Any - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - self.super___init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=device, - dtype=dtype, - v=getattr(self, "_v", None), - buffers=getattr(self, "_buffers", None), - module_dict=getattr(self, "_module_dict", None), - ) - if groups <= 0: - raise ValueError("groups must be a positive integer") - if in_channels % groups != 0: - raise ValueError("in_channels must be divisible by groups") - if out_channels % groups != 0: - raise ValueError("out_channels must be divisible by groups") - valid_padding_strings = {"same", "valid"} - if isinstance(padding, str): - if padding not in valid_padding_strings: - raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" - ) - if padding == "same" and any(s != 1 for s in stride): - raise ValueError( - "padding='same' is not supported for strided convolutions" - ) - valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} - if padding_mode not in valid_padding_modes: - raise ValueError( - f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" - ) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.padding_mode = padding_mode - if isinstance(self.padding, str): - self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = ivy__reverse_repeat_tuple( - self.padding, 2 - ) - if transposed: - self.weight = ivy.Array( - ivy_empty_frnt( - (*kernel_size, out_channels // groups, in_channels), - **factory_kwargs, - ) - ) - else: - self.weight = ivy.Array( - ivy_empty_frnt( - (*kernel_size, in_channels // groups, out_channels), - **factory_kwargs, - ) - ) - if bias: - self.bias = ivy.Array(ivy_empty_frnt(out_channels, **factory_kwargs)) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - ivy_kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = ivy__calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - ivy_uniform_(self.bias, -bound, bound) - - def extra_repr(self): - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" - if self.padding != (0,) * len(self.padding): - s += ", padding={padding}" - if self.dilation != (1,) * len(self.dilation): - s += ", dilation={dilation}" - if self.output_padding != (0,) * len(self.output_padding): - s += ", output_padding={output_padding}" - if self.groups != 1: - s += ", groups={groups}" - if self.bias is None: - s += ", bias=False" - if self.padding_mode != "zeros": - s += ", padding_mode={padding_mode}" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" - - def super___init__(self, *args, device=None, devices=None, **kwargs): - super().__init__( - *args, - device=device, - devices=devices, - training=True, - build_mode="explicit", - dynamic_backend=True, - **kwargs, - ) - super().__setattr__("_frontend_module", True) - super().__setattr__( - "_attr_mapping", {"_parameters": "v", "_modules": "module_dict"} - ) - - def __dir__(self): - module_attrs = dir(self.__class__) - attrs = list(self.__dict__.keys()) - parameters = list(self._v.keys()) - modules = list(self._module_dict.keys()) - buffers = list(self._buffers.keys()) - keys = module_attrs + attrs + parameters + modules + buffers - keys = [key for key in keys if not key[0].isdigit()] - return sorted(keys) - - def __getattribute__(self, name): - if name == "__dict__": - return super().__getattribute__(name) - if "_module_dict" in self.__dict__: - modules = self.__dict__["_module_dict"] - if name in modules: - return modules[name] - if "_buffers" in self.__dict__: - buffers = self.__dict__["_buffers"] - if name in buffers: - return buffers[name] - if "_v" in self.__dict__: - v = self.__dict__["_v"] - if name in v: - return v[name] - if "_attr_mapping" in self.__dict__: - mapping = self.__dict__["_attr_mapping"] - if name in mapping: - return super().__getattribute__(mapping[name]) - return super().__getattribute__(name) - - def __getstate__(self): - state = self.__dict__.copy() - state.pop("_compiled_call_impl", None) - state.pop("_thread_local", None) - state.pop("_metrics_lock", None) - return state - - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = ivy_split_frnt_(extra_repr, "\n") - child_lines = [] - for key, module in self._module_dict.items(): - mod_str = repr(module) - mod_str = self._addindent(mod_str, 2) - child_lines.append("(" + key + "): " + mod_str) - lines = extra_lines + child_lines - main_str = self._get_name() + "(" - if lines: - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ")" - return main_str - - def __setattr__(self, name, value): - def remove_from(*dicts_or_sets): - for d in dicts_or_sets: - if name in d: - if isinstance(d, dict): - del d[name] - else: - d.discard(name) - - params = self.__dict__.get("_v") - if params is not None and name in params and isinstance(value, ivy.Array): - remove_from(self.__dict__, self._buffers, self._module_dict) - self.register_parameter(name, value) - super().__setattr__(name, value) - else: - super().__setattr__(name, value) - - def _build(self, *args, **kwargs): - for module in self.__dict__.values(): - if isinstance(module, ivy.Module) and module is not self: - if not module._built: - module.build( - *module._args, - dynamic_backend=module._dynamic_backend, - **module._kwargs, - ) - return True - - def _call_impl(self, *args, **kwargs): - return self.call(*args, **kwargs) - - def _create_variables(self, device=None, dtype=None): - v = ivy.Container( - OrderedDict( - [ - (k.replace(".", "/"), v) - for k, v in self.__dict__.items() - if isinstance(v, ivy.Array) and not k.startswith("_") - ] - ) - ) - v = ( - ivy.Container( - OrderedDict( - { - _k.replace(".", "/"): _v - for _k, _v in self._v.items() - if _k.replace(".", "/") not in v - and not isinstance(_v, ivy.Container) - }, - **v, - ) - ) - if self._v - else v - ) - return v - - def _extra_repr(self): - return "" - - def _forward(self, *a, **kw): - ret = self._call_impl(*a, **kw) - return ret - - def _get_name(self): - return self.__class__.__name__ - - def _named_members( - self, get_members_fn, prefix="", recurse=True, remove_duplicate=True - ): - """Helper method for yielding various names + members of modules.""" - memo = set() - modules = ( - self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) - if recurse - else [(prefix, self)] - ) - for module_prefix, module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or id(v) in memo: - continue - if remove_duplicate: - ivy_add_frnt_(memo, id(v)) - name = module_prefix + ("." if module_prefix else "") + k - yield name, v - - def _replace_update_v(self, new_v, native=None): - from ivy.functional.ivy.gradients import _is_variable - - native = ivy.default(native, self) - for k, v in new_v.items(): - if isinstance(v, ivy.Container): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, ivy.Array): - native.__setattr__(k, v) - elif _is_variable(v): - native.__setattr__(k, ivy.Array(v)) - elif isinstance(v, ivy.Array): - native.__setattr__(k, ivy.Array(v)) - else: - raise ivy.utils.exceptions.IvyException( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def _update_v(self, new_v, native=None): - from ivy.functional.ivy.gradients import _is_variable - - native = ivy.default(native, self) - for k, v in new_v.items(): - if isinstance(v, ivy.Container): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, ivy.Array): - native.__setattr__(k, v) - elif _is_variable(v): - native.__setattr__(k, ivy.Array(v)) - elif isinstance(v, ivy.Array): - native.__setattr__(k, ivy.Array(v)) - else: - raise ivy.utils.exceptions.IvyException( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def add_module(self, name, module): - if not isinstance(module, ivy.Module) and module is not None: - raise TypeError(f"{type(module)} is not a Module subclass") - elif not isinstance(name, str): - raise TypeError(f"module name should be a string. Got {type(name)}") - elif hasattr(self, name) and name not in self._modules: - raise KeyError(f"attribute '{name}' already exists") - elif "." in name: - raise KeyError(f'module name can\'t contain ".", got: {name}') - elif name == "": - raise KeyError('module name can\'t be empty string ""') - self._modules[name] = module - super().__setattr__(name, module) - - def apply(self, fn): - for module in self.children(): - if hasattr(module, "apply"): - module.apply(fn) - else: - fn(module) - fn(self) - return self - - def children(self): - for _, module in self.named_children(): - yield module - - def forward(self, *input): - raise NotImplementedError( - f'Module [{type(self).__name__}] is missing the required "forward" function' - ) - - def get_parameter(self, target): - target = target.replace(".", "/") - return self.v[target] - - def get_submodule(self, target): - if target == "": - return self - atoms: typing.Any = ivy_split_frnt_(target, ".") - mod: typing.Any = self - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, ivy.Module): - raise TypeError("`" + item + "` is not an nn.Module") - return mod - - def modules(self): - for _, module in self.named_modules(): - yield module - - def named_buffers(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.buffers.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def named_children(self): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - memo = set() - for name, module in self._module_dict.items(): - if module is not None and id(module) not in memo: - ivy_add_frnt_(memo, id(module)) - yield name, module - - def named_modules(self, memo=None, prefix="", remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - if memo is None: - memo = set() - if id(self) not in memo: - if remove_duplicate: - ivy_add_frnt_(memo, id(self)) - yield prefix, self - for name, module in self._module_dict.items(): - if module is None: - continue - submodule_prefix = prefix + ("." if prefix else "") + name - if not hasattr(module, "named_modules"): - yield submodule_prefix, self - else: - yield from module.named_modules( - memo, submodule_prefix, remove_duplicate - ) - - def named_parameters(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.v.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def parameters(self, recurse=True): - for _, param in self.named_parameters(recurse=recurse): - yield param - - def register_buffer(self, name, value, persistent=False): - super().register_buffer(name, value) - - def register_module(self, name, module): - """Alias for :func:`add_module`.""" - self.add_module(name, module) - - def register_parameter(self, name, value): - super().register_parameter(name, value) - - def requires_grad_(self, requires_grad=True): - for p in self.parameters(): - p.requires_grad_(requires_grad) - return self diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy__helpers.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy__helpers.py deleted file mode 100644 index 73290c1879c1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_Conv2d_output/run_0/ivy__helpers.py +++ /dev/null @@ -1,287 +0,0 @@ -from itertools import repeat -import collections -import functools -import ivy -import math -import re -import warnings - - -def ivy__ntuple_parse(n, name="parse"): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - parse.__name__ = name - return parse - - -def ivy__reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def ivy_empty_frnt( - *args, - size=None, - out=None, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, - memory_format=None, -): - if args and size: - raise TypeError("empty() got multiple values for argument 'shape'") - if size is None: - size = ( - args[0] - if isinstance(args[0], (tuple, list, ivy.Shape, ivy.NativeShape)) - else args - ) - if isinstance(size, (tuple, list)): - size = tuple(s.to_scalar() if ivy.is_array(s) else s for s in size) - return ivy.empty(shape=size, dtype=dtype, device=device, out=out) - - -def ivy_dim_frnt_(arr): - return arr.ndim - - -def ivy_size_frnt_(arr, dim=None): - shape = arr.shape - if dim is None: - return shape - try: - return shape[dim] - except IndexError as e: - raise IndexError( - f"Dimension out of range (expected to be in range of [{len(shape)}, {len(shape) - 1}], but got {dim}" - ) from e - - -def ivy__calculate_fan_in_and_fan_out(tensor): - dimensions = ivy_dim_frnt_(tensor) - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - num_input_fmaps = ivy_size_frnt_(tensor, 1) - num_output_fmaps = ivy_size_frnt_(tensor, 0) - receptive_field_size = 1 - if ivy_dim_frnt_(tensor) > 2: - for s in tensor.shape[2:]: - receptive_field_size *= s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - return fan_in, fan_out - - -def ivy__calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ["fan_in", "fan_out"] - if mode not in valid_modes: - raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") - fan_in, fan_out = ivy__calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == "fan_in" else fan_out - - -def ivy_calculate_gain(nonlinearity, param=None): - linear_fns = [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - ] - if nonlinearity in linear_fns or nonlinearity == "sigmoid": - return 1 - elif nonlinearity == "tanh": - return 5.0 / 3 - elif nonlinearity == "relu": - return math.sqrt(2.0) - elif nonlinearity == "leaky_relu": - if param is None: - negative_slope = 0.01 - elif ( - not isinstance(param, bool) - and isinstance(param, int) - or isinstance(param, float) - ): - negative_slope = param - else: - raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == "selu": - return 3.0 / 4 - else: - raise ValueError(f"Unsupported nonlinearity {nonlinearity}") - - -def ivy_uniform__frnt_(arr, from_=0, to=1, *, generator=None): - ret = ivy.random_uniform( - low=from_, high=to, shape=arr.shape, dtype=arr.dtype, seed=generator - ) - arr = ivy.inplace_update(arr, ivy.astype(ret, arr.dtype)).data - return arr - - -def ivy_kaiming_uniform_( - tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None -): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = ivy__calculate_correct_fan(tensor, mode) - gain = ivy_calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std - return ivy_uniform__frnt_(tensor, -bound, bound, generator=generator) - - -def ivy__no_grad_uniform_(tensor, a, b, generator=None): - return ivy_uniform__frnt_(tensor, a, b, generator=generator) - - -def ivy_uniform_(tensor, a=0.0, b=1.0, generator=None): - return ivy__no_grad_uniform_(tensor, a, b, generator) - - -def ivy_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if ivy.is_array(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@ivy_handle_methods -def ivy_split_frnt(tensor, split_size_or_sections, dim=0): - if isinstance(split_size_or_sections, int): - split_size = split_size_or_sections - split_size_or_sections = [split_size] * (tensor.shape[dim] // split_size) - if tensor.shape[dim] % split_size: - split_size_or_sections.append(tensor.shape[dim] % split_size) - return tuple( - ivy.split( - tensor, - num_or_size_splits=split_size_or_sections, - axis=dim, - with_remainder=True, - ) - ) - - -@ivy_handle_methods -def ivy_split_frnt_(arr, split_size, dim=0): - return ivy_split_frnt(arr, split_size, dim) - - -@ivy_handle_methods -def ivy_add_frnt(input, other, *, alpha=1, out=None): - return ivy.add(input, other, alpha=alpha, out=out) - - -@ivy_handle_methods -def ivy_add_frnt_(arr, other, *, alpha=1): - return ivy_add_frnt(arr, other, alpha=alpha) - - -def ivy__conv_frnt(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): - dims = len(input.shape) - 2 - if isinstance(padding, str): - padding = padding.upper() - elif isinstance(padding, int): - padding = [*[(padding, padding) for _ in range(dims)]] - else: - padding = [*[(p, p) for p in padding]] - ret = ivy.conv_general_dilated( - input, - weight, - stride, - padding, - dims=dims, - data_format="channel_last", - filter_format="channel_last", - dilations=dilation, - feature_group_count=groups, - bias=bias, - ) - return ret - - -def ivy_conv2d_frnt( - input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1 -): - return ivy__conv_frnt( - input, - weight, - bias=bias, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - -def ivy__handle_padding_shape_frnt(padding, n, mode): - padding = tuple( - [ - (padding[i * 2], padding[i * 2 + 1]) - for i in range(int(len(padding) / 2) - 1, -1, -1) - ] - ) - if mode == "circular": - padding = padding + ((0, 0),) * (n - len(padding)) - else: - padding = ((0, 0),) * (n - len(padding)) + padding - if mode == "circular": - padding = tuple(list(padding)[::-1]) - return padding - - -def ivy_pad_frnt(input, pad, mode="constant", value=0): - if any([(pad_value < 0) for pad_value in pad]): - pad = list(pad) - slices = [] - for n in reversed(range(len(pad) // 2)): - i = n * 2 - j = i + 1 - start = None - stop = None - if pad[i] < 0: - start = -pad[i] - pad[i] = 0 - if pad[j] < 0: - stop = pad[j] - pad[j] = 0 - slices.append(slice(start, stop)) - ndim = len(input.shape) - while len(slices) < ndim: - slices.insert(0, slice(None)) - input = input[tuple(slices)] - value = 0 if value is None else value - mode_dict = { - "constant": "constant", - "reflect": "reflect", - "replicate": "edge", - "circular": "wrap", - } - if mode not in mode_dict: - raise ValueError(f"Unsupported padding mode: {mode}") - pad = ivy__handle_padding_shape_frnt(pad, len(input.shape), mode) - return ivy.pad(input, pad, mode=mode_dict[mode], constant_values=value) diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy_ConvTranspose2d.py b/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy_ConvTranspose2d.py deleted file mode 100644 index 42d88209d492..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy_ConvTranspose2d.py +++ /dev/null @@ -1,70 +0,0 @@ -from .ivy__ConvTransposeNd import ivy__ConvTransposeNd -from .ivy__helpers import ivy__ntuple_parse -from .ivy__helpers import ivy_conv_transpose2d_frnt - -_pair = ivy__ntuple_parse(2, "_pair") - - -class ivy_ConvTranspose2d(ivy__ConvTransposeNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - bias=True, - dilation=1, - padding_mode="zeros", - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - kernel_size = _pair(kernel_size) - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - output_padding = _pair(output_padding) - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - True, - output_padding, - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def forward(self, input, output_size=None): - if self.padding_mode != "zeros": - raise ValueError( - "Only `zeros` padding mode is supported for ConvTranspose2d" - ) - assert isinstance(self.padding, tuple) - num_spatial_dims = 2 - output_padding = self._output_padding( - input, - output_size, - self.stride, - self.padding, - self.kernel_size, - num_spatial_dims, - self.dilation, - ) - return ivy_conv_transpose2d_frnt( - input, - self.weight, - self.bias, - self.stride, - self.padding, - output_padding, - self.groups, - self.dilation, - ) diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__ConvNd.py b/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__ConvNd.py deleted file mode 100644 index 6bca3fd80178..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__ConvNd.py +++ /dev/null @@ -1,508 +0,0 @@ -import ivy -from collections import OrderedDict - -import typing -import math -from typing import Optional - -from .ivy__helpers import ivy__calculate_fan_in_and_fan_out -from .ivy__helpers import ivy__reverse_repeat_tuple -from .ivy__helpers import ivy_add_frnt_ -from .ivy__helpers import ivy_empty_frnt -from .ivy__helpers import ivy_kaiming_uniform_ -from .ivy__helpers import ivy_split_frnt_ -from .ivy__helpers import ivy_uniform_ - - -class ivy__ConvNd(ivy.Module): - __constants__ = [ - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - "output_padding", - "in_channels", - "out_channels", - "kernel_size", - ] - __annotations__ = {"bias": Optional[ivy.Array]} - - def _conv_forward(self, input, weight, bias): ... - - in_channels: typing.Any - _reversed_padding_repeated_twice: typing.Any - out_channels: typing.Any - kernel_size: typing.Any - stride: typing.Any - padding: typing.Any - dilation: typing.Any - transposed: typing.Any - output_padding: typing.Any - groups: typing.Any - padding_mode: typing.Any - weight: typing.Any - bias: typing.Any - - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - self.super___init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=device, - dtype=dtype, - v=getattr(self, "_v", None), - buffers=getattr(self, "_buffers", None), - module_dict=getattr(self, "_module_dict", None), - ) - if groups <= 0: - raise ValueError("groups must be a positive integer") - if in_channels % groups != 0: - raise ValueError("in_channels must be divisible by groups") - if out_channels % groups != 0: - raise ValueError("out_channels must be divisible by groups") - valid_padding_strings = {"same", "valid"} - if isinstance(padding, str): - if padding not in valid_padding_strings: - raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" - ) - if padding == "same" and any(s != 1 for s in stride): - raise ValueError( - "padding='same' is not supported for strided convolutions" - ) - valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} - if padding_mode not in valid_padding_modes: - raise ValueError( - f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" - ) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.padding_mode = padding_mode - if isinstance(self.padding, str): - self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = ivy__reverse_repeat_tuple( - self.padding, 2 - ) - if transposed: - self.weight = ivy.Array( - ivy_empty_frnt( - (*kernel_size, out_channels // groups, in_channels), - **factory_kwargs, - ) - ) - else: - self.weight = ivy.Array( - ivy_empty_frnt( - (*kernel_size, in_channels // groups, out_channels), - **factory_kwargs, - ) - ) - if bias: - self.bias = ivy.Array(ivy_empty_frnt(out_channels, **factory_kwargs)) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - ivy_kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = ivy__calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - ivy_uniform_(self.bias, -bound, bound) - - def extra_repr(self): - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" - if self.padding != (0,) * len(self.padding): - s += ", padding={padding}" - if self.dilation != (1,) * len(self.dilation): - s += ", dilation={dilation}" - if self.output_padding != (0,) * len(self.output_padding): - s += ", output_padding={output_padding}" - if self.groups != 1: - s += ", groups={groups}" - if self.bias is None: - s += ", bias=False" - if self.padding_mode != "zeros": - s += ", padding_mode={padding_mode}" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" - - def super___init__(self, *args, device=None, devices=None, **kwargs): - super().__init__( - *args, - device=device, - devices=devices, - training=True, - build_mode="explicit", - dynamic_backend=True, - **kwargs, - ) - super().__setattr__("_frontend_module", True) - super().__setattr__( - "_attr_mapping", {"_parameters": "v", "_modules": "module_dict"} - ) - - def __dir__(self): - module_attrs = dir(self.__class__) - attrs = list(self.__dict__.keys()) - parameters = list(self._v.keys()) - modules = list(self._module_dict.keys()) - buffers = list(self._buffers.keys()) - keys = module_attrs + attrs + parameters + modules + buffers - keys = [key for key in keys if not key[0].isdigit()] - return sorted(keys) - - def __getattribute__(self, name): - if name == "__dict__": - return super().__getattribute__(name) - if "_module_dict" in self.__dict__: - modules = self.__dict__["_module_dict"] - if name in modules: - return modules[name] - if "_buffers" in self.__dict__: - buffers = self.__dict__["_buffers"] - if name in buffers: - return buffers[name] - if "_v" in self.__dict__: - v = self.__dict__["_v"] - if name in v: - return v[name] - if "_attr_mapping" in self.__dict__: - mapping = self.__dict__["_attr_mapping"] - if name in mapping: - return super().__getattribute__(mapping[name]) - return super().__getattribute__(name) - - def __getstate__(self): - state = self.__dict__.copy() - state.pop("_compiled_call_impl", None) - state.pop("_thread_local", None) - state.pop("_metrics_lock", None) - return state - - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = ivy_split_frnt_(extra_repr, "\n") - child_lines = [] - for key, module in self._module_dict.items(): - mod_str = repr(module) - mod_str = self._addindent(mod_str, 2) - child_lines.append("(" + key + "): " + mod_str) - lines = extra_lines + child_lines - main_str = self._get_name() + "(" - if lines: - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ")" - return main_str - - def __setattr__(self, name, value): - def remove_from(*dicts_or_sets): - for d in dicts_or_sets: - if name in d: - if isinstance(d, dict): - del d[name] - else: - d.discard(name) - - params = self.__dict__.get("_v") - if params is not None and name in params and isinstance(value, ivy.Array): - remove_from(self.__dict__, self._buffers, self._module_dict) - self.register_parameter(name, value) - super().__setattr__(name, value) - else: - super().__setattr__(name, value) - - def _build(self, *args, **kwargs): - for module in self.__dict__.values(): - if isinstance(module, ivy.Module) and module is not self: - if not module._built: - module.build( - *module._args, - dynamic_backend=module._dynamic_backend, - **module._kwargs, - ) - return True - - def _call_impl(self, *args, **kwargs): - return self.call(*args, **kwargs) - - def _create_variables(self, device=None, dtype=None): - v = ivy.Container( - OrderedDict( - [ - (k.replace(".", "/"), v) - for k, v in self.__dict__.items() - if isinstance(v, ivy.Array) and not k.startswith("_") - ] - ) - ) - v = ( - ivy.Container( - OrderedDict( - { - _k.replace(".", "/"): _v - for _k, _v in self._v.items() - if _k.replace(".", "/") not in v - and not isinstance(_v, ivy.Container) - }, - **v, - ) - ) - if self._v - else v - ) - return v - - def _extra_repr(self): - return "" - - def _forward(self, *a, **kw): - ret = self._call_impl(*a, **kw) - return ret - - def _get_name(self): - return self.__class__.__name__ - - def _named_members( - self, get_members_fn, prefix="", recurse=True, remove_duplicate=True - ): - """Helper method for yielding various names + members of modules.""" - memo = set() - modules = ( - self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) - if recurse - else [(prefix, self)] - ) - for module_prefix, module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or id(v) in memo: - continue - if remove_duplicate: - ivy_add_frnt_(memo, id(v)) - name = module_prefix + ("." if module_prefix else "") + k - yield name, v - - def _replace_update_v(self, new_v, native=None): - from ivy.functional.ivy.gradients import _is_variable - - native = ivy.default(native, self) - for k, v in new_v.items(): - if isinstance(v, ivy.Container): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, ivy.Array): - native.__setattr__(k, v) - elif _is_variable(v): - native.__setattr__(k, ivy.Array(v)) - elif isinstance(v, ivy.Array): - native.__setattr__(k, ivy.Array(v)) - else: - raise ivy.utils.exceptions.IvyException( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def _update_v(self, new_v, native=None): - from ivy.functional.ivy.gradients import _is_variable - - native = ivy.default(native, self) - for k, v in new_v.items(): - if isinstance(v, ivy.Container): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, ivy.Array): - native.__setattr__(k, v) - elif _is_variable(v): - native.__setattr__(k, ivy.Array(v)) - elif isinstance(v, ivy.Array): - native.__setattr__(k, ivy.Array(v)) - else: - raise ivy.utils.exceptions.IvyException( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def add_module(self, name, module): - if not isinstance(module, ivy.Module) and module is not None: - raise TypeError(f"{type(module)} is not a Module subclass") - elif not isinstance(name, str): - raise TypeError(f"module name should be a string. Got {type(name)}") - elif hasattr(self, name) and name not in self._modules: - raise KeyError(f"attribute '{name}' already exists") - elif "." in name: - raise KeyError(f'module name can\'t contain ".", got: {name}') - elif name == "": - raise KeyError('module name can\'t be empty string ""') - self._modules[name] = module - super().__setattr__(name, module) - - def apply(self, fn): - for module in self.children(): - if hasattr(module, "apply"): - module.apply(fn) - else: - fn(module) - fn(self) - return self - - def children(self): - for _, module in self.named_children(): - yield module - - def forward(self, *input): - raise NotImplementedError( - f'Module [{type(self).__name__}] is missing the required "forward" function' - ) - - def get_parameter(self, target): - target = target.replace(".", "/") - return self.v[target] - - def get_submodule(self, target): - if target == "": - return self - atoms: typing.Any = ivy_split_frnt_(target, ".") - mod: typing.Any = self - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, ivy.Module): - raise TypeError("`" + item + "` is not an nn.Module") - return mod - - def modules(self): - for _, module in self.named_modules(): - yield module - - def named_buffers(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.buffers.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def named_children(self): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - memo = set() - for name, module in self._module_dict.items(): - if module is not None and id(module) not in memo: - ivy_add_frnt_(memo, id(module)) - yield name, module - - def named_modules(self, memo=None, prefix="", remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - if memo is None: - memo = set() - if id(self) not in memo: - if remove_duplicate: - ivy_add_frnt_(memo, id(self)) - yield prefix, self - for name, module in self._module_dict.items(): - if module is None: - continue - submodule_prefix = prefix + ("." if prefix else "") + name - if not hasattr(module, "named_modules"): - yield submodule_prefix, self - else: - yield from module.named_modules( - memo, submodule_prefix, remove_duplicate - ) - - def named_parameters(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.v.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def parameters(self, recurse=True): - for _, param in self.named_parameters(recurse=recurse): - yield param - - def register_buffer(self, name, value, persistent=False): - super().register_buffer(name, value) - - def register_module(self, name, module): - """Alias for :func:`add_module`.""" - self.add_module(name, module) - - def register_parameter(self, name, value): - super().register_parameter(name, value) - - def requires_grad_(self, requires_grad=True): - for p in self.parameters(): - p.requires_grad_(requires_grad) - return self diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__ConvTransposeNd.py b/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__ConvTransposeNd.py deleted file mode 100644 index 973f9331e830..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__ConvTransposeNd.py +++ /dev/null @@ -1,91 +0,0 @@ -from .ivy__ConvNd import ivy__ConvNd -from .ivy__helpers import ivy__ntuple_parse -from .ivy__helpers import ivy_dim_frnt_ -from .ivy__helpers import ivy_size_frnt_ - -_single = ivy__ntuple_parse(1, "_single") - - -class ivy__ConvTransposeNd(ivy__ConvNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - if padding_mode != "zeros": - raise ValueError( - f'Only "zeros" padding mode is supported for {self.__class__.__name__}' - ) - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def _output_padding( - self, - input, - output_size, - stride, - padding, - kernel_size, - num_spatial_dims, - dilation=None, - ): - if output_size is None: - ret = _single(self.output_padding) - else: - has_batch_dim = ivy_dim_frnt_(input) == num_spatial_dims + 2 - num_non_spatial_dims = 2 if has_batch_dim else 1 - if len(output_size) == num_non_spatial_dims + num_spatial_dims: - output_size = output_size[num_non_spatial_dims:] - if len(output_size) != num_spatial_dims: - raise ValueError( - f"ConvTranspose{num_spatial_dims}D: for {ivy_dim_frnt_(input)}D input, output_size must have {num_spatial_dims} or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" - ) - min_sizes = [] - max_sizes = [] - for d in range(num_spatial_dims): - dim_size = ( - (ivy_size_frnt_(input, d + num_non_spatial_dims) - 1) * stride[d] - - 2 * padding[d] - + (dilation[d] if dilation is not None else 1) - * (kernel_size[d] - 1) - + 1 - ) - min_sizes.append(dim_size) - max_sizes.append(min_sizes[d] + stride[d] - 1) - for i in range(len(output_size)): - size = output_size[i] - min_size = min_sizes[i] - max_size = max_sizes[i] - if size < min_size or size > max_size: - raise ValueError( - f"requested an output size of {output_size}, but valid sizes range from {min_sizes} to {max_sizes} (for an input of {ivy_size_frnt_(input)[2:]})" - ) - res = [] - for d in range(num_spatial_dims): - res.append(output_size[d] - min_sizes[d]) - ret = res - return ret diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__helpers.py b/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__helpers.py deleted file mode 100644 index 4ff6534bb286..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_ConvTranspose2d_output/run_0/ivy__helpers.py +++ /dev/null @@ -1,300 +0,0 @@ -from itertools import repeat -import collections -import functools -import ivy -import math -import re -import warnings - - -def ivy__ntuple_parse(n, name="parse"): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - parse.__name__ = name - return parse - - -def ivy__reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def ivy_empty_frnt( - *args, - size=None, - out=None, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, - memory_format=None, -): - if args and size: - raise TypeError("empty() got multiple values for argument 'shape'") - if size is None: - size = ( - args[0] - if isinstance(args[0], (tuple, list, ivy.Shape, ivy.NativeShape)) - else args - ) - if isinstance(size, (tuple, list)): - size = tuple(s.to_scalar() if ivy.is_array(s) else s for s in size) - return ivy.empty(shape=size, dtype=dtype, device=device, out=out) - - -def ivy_dim_frnt_(arr): - return arr.ndim - - -def ivy_size_frnt_(arr, dim=None): - shape = arr.shape - if dim is None: - return shape - try: - return shape[dim] - except IndexError as e: - raise IndexError( - f"Dimension out of range (expected to be in range of [{len(shape)}, {len(shape) - 1}], but got {dim}" - ) from e - - -def ivy__calculate_fan_in_and_fan_out(tensor): - dimensions = ivy_dim_frnt_(tensor) - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - num_input_fmaps = ivy_size_frnt_(tensor, 1) - num_output_fmaps = ivy_size_frnt_(tensor, 0) - receptive_field_size = 1 - if ivy_dim_frnt_(tensor) > 2: - for s in tensor.shape[2:]: - receptive_field_size *= s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - return fan_in, fan_out - - -def ivy__calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ["fan_in", "fan_out"] - if mode not in valid_modes: - raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") - fan_in, fan_out = ivy__calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == "fan_in" else fan_out - - -def ivy_calculate_gain(nonlinearity, param=None): - linear_fns = [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - ] - if nonlinearity in linear_fns or nonlinearity == "sigmoid": - return 1 - elif nonlinearity == "tanh": - return 5.0 / 3 - elif nonlinearity == "relu": - return math.sqrt(2.0) - elif nonlinearity == "leaky_relu": - if param is None: - negative_slope = 0.01 - elif ( - not isinstance(param, bool) - and isinstance(param, int) - or isinstance(param, float) - ): - negative_slope = param - else: - raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == "selu": - return 3.0 / 4 - else: - raise ValueError(f"Unsupported nonlinearity {nonlinearity}") - - -def ivy_uniform__frnt_(arr, from_=0, to=1, *, generator=None): - ret = ivy.random_uniform( - low=from_, high=to, shape=arr.shape, dtype=arr.dtype, seed=generator - ) - arr = ivy.inplace_update(arr, ivy.astype(ret, arr.dtype)).data - return arr - - -def ivy_kaiming_uniform_( - tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None -): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = ivy__calculate_correct_fan(tensor, mode) - gain = ivy_calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std - return ivy_uniform__frnt_(tensor, -bound, bound, generator=generator) - - -def ivy__no_grad_uniform_(tensor, a, b, generator=None): - return ivy_uniform__frnt_(tensor, a, b, generator=generator) - - -def ivy_uniform_(tensor, a=0.0, b=1.0, generator=None): - return ivy__no_grad_uniform_(tensor, a, b, generator) - - -def ivy_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if ivy.is_array(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@ivy_handle_methods -def ivy_split_frnt(tensor, split_size_or_sections, dim=0): - if isinstance(split_size_or_sections, int): - split_size = split_size_or_sections - split_size_or_sections = [split_size] * (tensor.shape[dim] // split_size) - if tensor.shape[dim] % split_size: - split_size_or_sections.append(tensor.shape[dim] % split_size) - return tuple( - ivy.split( - tensor, - num_or_size_splits=split_size_or_sections, - axis=dim, - with_remainder=True, - ) - ) - - -@ivy_handle_methods -def ivy_split_frnt_(arr, split_size, dim=0): - return ivy_split_frnt(arr, split_size, dim) - - -@ivy_handle_methods -def ivy_add_frnt(input, other, *, alpha=1, out=None): - return ivy.add(input, other, alpha=alpha, out=out) - - -@ivy_handle_methods -def ivy_add_frnt_(arr, other, *, alpha=1): - return ivy_add_frnt(arr, other, alpha=alpha) - - -def ivy__get_transpose_pad_frnt(padding, output_padding, dims): - padding, output_padding = map( - lambda x: [x] * dims if isinstance(x, int) else x, [padding, output_padding] - ) - asymmetric_padding = [ - [pad, pad - output_pad] for pad, output_pad in zip(padding, output_padding) - ] - return asymmetric_padding - - -def ivy__conv_transpose_frnt( - input, - weight, - bias=None, - stride=1, - padding=0, - output_padding=0, - groups=1, - dilation=1, -): - dims = len(input.shape) - 2 - weight = ivy.permute_dims(weight, axes=(*range(2, dims + 2), 0, 1)) - for i in range(dims): - weight = ivy.flip(weight, axis=i) - padding, output_padding, stride, dilation = map( - lambda x: [x] * dims if isinstance(x, int) else x, - [padding, output_padding, stride, dilation], - ) - pad_widths = [ - ( - ( - (weight.shape[i] - 1) * dilation[i] - + max([output_padding[i] - padding[i], 0]), - ) - * 2 - ) - for i in range(dims) - ] - ret = ivy.conv_general_dilated( - input, - weight, - 1, - pad_widths, - dims=dims, - data_format="channel_last", - feature_group_count=groups, - x_dilations=stride, - dilations=dilation, - bias=bias, - ) - unpad_slice = (slice(None),) * 2 - for i in range(dims): - unpad_slice += ( - slice( - max([padding[i] - dilation[i] // 2, padding[i], output_padding[i]]), - ret.shape[2 + i] - padding[i] + output_padding[i] + dilation[i] // 2, - 1, - ), - ) - ret = ret[unpad_slice] - return ret - - -def ivy_conv_transpose2d_frnt( - input, - weight, - bias=None, - stride=1, - padding=0, - output_padding=0, - groups=1, - dilation=1, -): - if ivy.current_backend_str() in ["torch", "tensorflow"]: - return ivy.conv_general_transpose( - input, - weight, - stride, - ivy__get_transpose_pad_frnt(padding, output_padding, 2), - dims=2, - filter_format="channel_last", - data_format="channel_last", - dilations=dilation, - feature_group_count=groups, - bias=bias, - ) - else: - return ivy__conv_transpose_frnt( - input, - weight, - bias=bias, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation, - ) diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_LayerNorm_output/run_0/ivy__helpers.py b/ivy/compiler/_cache/Translated_Outputs/ivy_LayerNorm_output/run_0/ivy__helpers.py index 02ed25dbe90f..967300c0ff7e 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_LayerNorm_output/run_0/ivy__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_LayerNorm_output/run_0/ivy__helpers.py @@ -3,6 +3,25 @@ import re +def ivy_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if ivy.is_array(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + def ivy_empty_frnt( *args, size=None, @@ -92,25 +111,6 @@ def ivy_layer_norm_frnt(input, normalized_shape, weight=None, bias=None, eps=1e- return ivy.layer_norm(input, axis, scale=weight, offset=bias, eps=eps) -def ivy_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if ivy.is_array(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @ivy_handle_methods def ivy_split_frnt(tensor, split_size_or_sections, dim=0): if isinstance(split_size_or_sections, int): diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy_Linear.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy_Linear.py index 89183d65c71a..d50e3b67a281 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy_Linear.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy_Linear.py @@ -2,8 +2,8 @@ from collections import OrderedDict import threading -import math import typing +import math from .ivy__helpers import ivy__calculate_fan_in_and_fan_out from .ivy__helpers import ivy_add_frnt_ diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy__helpers.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy__helpers.py index d10135d22e5c..1c10c6d2fad5 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_Linear_output/run_0/ivy__helpers.py @@ -5,6 +5,25 @@ import warnings +def ivy_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if ivy.is_array(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + def ivy_empty_frnt( *args, size=None, @@ -138,25 +157,6 @@ def ivy_linear_frnt(input, weight, bias=None): return ivy.linear(input, weight, bias=bias) -def ivy_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if ivy.is_array(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @ivy_handle_methods def ivy_split_frnt(tensor, split_size_or_sections, dim=0): if isinstance(split_size_or_sections, int): diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_MaxPool2d_output/run_0/ivy__helpers.py b/ivy/compiler/_cache/Translated_Outputs/ivy_MaxPool2d_output/run_0/ivy__helpers.py index eb8855af0920..490457727ad5 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_MaxPool2d_output/run_0/ivy__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_MaxPool2d_output/run_0/ivy__helpers.py @@ -4,6 +4,22 @@ import re +def ivy__handle_padding_shape_frnt(padding, n, mode): + padding = tuple( + [ + (padding[i * 2], padding[i * 2 + 1]) + for i in range(int(len(padding) / 2) - 1, -1, -1) + ] + ) + if mode == "circular": + padding = padding + ((0, 0),) * (n - len(padding)) + else: + padding = ((0, 0),) * (n - len(padding)) + padding + if mode == "circular": + padding = tuple(list(padding)[::-1]) + return padding + + def ivy_handle_methods(fn): def extract_function_name(s): match = re.search("_(.+?)(?:_\\d+)?$", s) @@ -96,22 +112,6 @@ def ivy_reshape_frnt_(arr, *args, shape=None): raise ValueError("reshape() got no values for argument 'shape'") -def ivy__handle_padding_shape_frnt(padding, n, mode): - padding = tuple( - [ - (padding[i * 2], padding[i * 2 + 1]) - for i in range(int(len(padding) / 2) - 1, -1, -1) - ] - ) - if mode == "circular": - padding = padding + ((0, 0),) * (n - len(padding)) - else: - padding = ((0, 0),) * (n - len(padding)) + padding - if mode == "circular": - padding = tuple(list(padding)[::-1]) - return padding - - def ivy_pad_frnt(input, pad, mode="constant", value=0): if any([(pad_value < 0) for pad_value in pad]): pad = list(pad) diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_ModuleList_output/run_0/ivy_ModuleList.py b/ivy/compiler/_cache/Translated_Outputs/ivy_ModuleList_output/run_0/ivy_ModuleList.py index ef905e7ac56d..87b0d7697aa3 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_ModuleList_output/run_0/ivy_ModuleList.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_ModuleList_output/run_0/ivy_ModuleList.py @@ -4,8 +4,8 @@ import typing import operator -from collections import abc as container_abcs from itertools import chain +from collections import abc as container_abcs from .ivy__helpers import ivy__addindent from .ivy__helpers import ivy_add_frnt_ diff --git a/ivy/compiler/_cache/Translated_Outputs/ivy_Sequential_output/run_0/ivy_Sequential.py b/ivy/compiler/_cache/Translated_Outputs/ivy_Sequential_output/run_0/ivy_Sequential.py index 1e2ae79d13e7..d0c33a9026f7 100644 --- a/ivy/compiler/_cache/Translated_Outputs/ivy_Sequential_output/run_0/ivy_Sequential.py +++ b/ivy/compiler/_cache/Translated_Outputs/ivy_Sequential_output/run_0/ivy_Sequential.py @@ -2,8 +2,8 @@ from collections import OrderedDict import threading -import typing import operator +import typing from typing import overload from itertools import islice diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow__helpers.py index e504254cfc3d..eacc25bd4de1 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_AdaptiveAvgPool2d_output/run_0/tensorflow__helpers.py @@ -27,6 +27,371 @@ import tensorflow as tf +CONV_FUNCS = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] +NORM_FUNCS = [ + "_BatchNorm", + "_InstanceNorm", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "GroupNorm", + "SyncBatchNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LocalResponseNorm", +] +POOL_FUNCS = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "LPPool1d", + "LPPool2d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] +KERAS_CONV_FUNCS = [ + "KerasConv1D", + "KerasConv2D", + "KerasConv3D", + "KerasDepthwiseConv2D", + "KerasConv1DTranspose", + "KerasConv2DTranspose", + "KerasConv3DTranspose", +] +KERAS_NORM_FUNCS = [ + "KerasBatchNorm1D", + "KerasBatchNorm2D", + "KerasBatchNorm3D", + "KerasLayerNormalization", + "KerasGroupNormalization", + "KerasUnitNorm1D", + "KerasUnitNorm2D", + "KerasUnitNorm3D", +] +KERAS_POOL_FUNCS = [ + "KerasAveragePooling1D", + "KerasAveragePooling2D", + "KerasAveragePooling3D", + "KerasMaxPool1D", + "KerasMaxPool2D", + "KerasMaxPool3D", +] +PADDING_FUNCS = [ + "ReflectionPad1d", + "ReflectionPad2d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad2d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", +] +KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] +ACTIVATION_FUNCS = [ + "ELU", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "LeakyReLU", + "PReLU", + "ReLU", + "ReLU6", + "RReLU", + "SELU", + "CELU", + "GELU", + "Sigmoid", + "Softplus", + "Softshrink", + "Softsign", + "Tanh", + "Tanhshrink", + "Threshold", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", + "AdaptiveLogSoftmaxWithLoss", +] +KERAS_ACTIVATION_FUNCS = [ + "KerasReLU", + "KerasPReLU", + "KerasLeakyReLU", + "KerasThresholdedReLU", + "KerasELU", + "KerasSoftmax", +] +DROPOUT_FUNCS = [ + "Dropout", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] +KERAS_DROPOUT_FUNCS = ["KerasDropout"] +CONV_BLOCK_FNS = [ + *CONV_FUNCS, + *KERAS_CONV_FUNCS, + *POOL_FUNCS, + *KERAS_POOL_FUNCS, + *PADDING_FUNCS, + *KERAS_PADDING_FUNCS, + *ACTIVATION_FUNCS, + *KERAS_ACTIVATION_FUNCS, + *NORM_FUNCS, + *KERAS_NORM_FUNCS, + *DROPOUT_FUNCS, + *KERAS_DROPOUT_FUNCS, +] +DATA_FORMAT = "PT" + + +def tensorflow_handle_transpose_in_input_and_output(fn): + from .tensorflow_TransposeType import tensorflow_TransposeType + + original_signature = inspect.signature(fn) + + @functools.wraps(fn) + def transpose_wrapper(self, *args, **kwargs): + global DATA_FORMAT + kwargs_call = { + key: val + for key, val in kwargs.items() + if key not in dict(original_signature.parameters) + } + fn_args_and_kwargs = { + key: val for key, val in kwargs.items() if key not in kwargs_call + } + fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) + conv_block_start = lambda f: any( + substr in f.__qualname__ + for substr in CONV_FUNCS + + NORM_FUNCS + + POOL_FUNCS + + KERAS_CONV_FUNCS + + KERAS_NORM_FUNCS + + KERAS_POOL_FUNCS + ) + next_call_in_seq = tensorflow_get_next_func(self) + name_of_next_call = ( + next_call_in_seq.__class__.__name__ + if hasattr(next_call_in_seq, "__class__") + else "" + ) + conv_block_continued = next_call_in_seq and any( + substr in name_of_next_call for substr in CONV_BLOCK_FNS + ) + if DATA_FORMAT == "PT" and conv_block_start(self.__class__): + input = fn_args_and_kwargs["input"] + if len(input.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(input.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(input.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + fn_args_and_kwargs = tensorflow_set_item_bknd( + fn_args_and_kwargs, + "input", + tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), + ) + DATA_FORMAT = "TF" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_last" + ) + res = fn(self, **fn_args_and_kwargs) + if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": + return res + if len(res.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(res.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(res.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) + DATA_FORMAT = "PT" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_first" + ) + return res + + tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature + return transpose_wrapper + + +def tensorflow__handle_manual_pad_avg_pool( + x, kernel, strides, padding, ceil_mode, dims +): + if isinstance(padding, str): + pad_specific = [ + tensorflow__handle_padding_bknd( + x.shape[i + 1], strides[i], kernel[i], padding + ) + for i in range(dims) + ] + padding = [ + (pad_specific[i] // 2, pad_specific[i] - pad_specific[i] // 2) + for i in range(dims) + ] + else: + if isinstance(padding, int): + padding = [(padding,) * 2] * dims + pad_specific = [sum(padding[i]) for i in range(dims)] + c = [] + if ceil_mode: + for i in range(dims): + padding[i], c_i = tensorflow__padding_ceil_mode_bknd( + x.shape[i + 1], kernel[i], padding[i], strides[i], True + ) + c.append(c_i) + pad_specific[i] = sum(padding[i]) + return padding, pad_specific, c + + +def tensorflow__handle_padding_bknd(x, strides, filters, padding): + if isinstance(padding, str) and padding.upper() == "SAME": + if x % strides == 0: + pad = max(filters - strides, 0) + else: + pad = max(filters - x % strides, 0) + else: + pad = 0 + return pad + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -151,6 +516,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -187,234 +553,13 @@ ("uint64", "uint64"): "uint64", ("float16", "float16"): "float16", ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -CONV_FUNCS = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", -] -NORM_FUNCS = [ - "_BatchNorm", - "_InstanceNorm", - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "GroupNorm", - "SyncBatchNorm", - "InstanceNorm1d", - "InstanceNorm2d", - "InstanceNorm3d", - "LocalResponseNorm", -] -POOL_FUNCS = [ - "MaxPool1d", - "MaxPool2d", - "MaxPool3d", - "AvgPool1d", - "AvgPool2d", - "AvgPool3d", - "FractionalMaxPool2d", - "LPPool1d", - "LPPool2d", - "AdaptiveMaxPool1d", - "AdaptiveMaxPool2d", - "AdaptiveMaxPool3d", - "AdaptiveAvgPool1d", - "AdaptiveAvgPool2d", - "AdaptiveAvgPool3d", -] -KERAS_CONV_FUNCS = [ - "KerasConv1D", - "KerasConv2D", - "KerasConv3D", - "KerasDepthwiseConv2D", - "KerasConv1DTranspose", - "KerasConv2DTranspose", - "KerasConv3DTranspose", -] -KERAS_NORM_FUNCS = [ - "KerasBatchNorm1D", - "KerasBatchNorm2D", - "KerasBatchNorm3D", - "KerasLayerNormalization", - "KerasGroupNormalization", - "KerasUnitNorm1D", - "KerasUnitNorm2D", - "KerasUnitNorm3D", -] -KERAS_POOL_FUNCS = [ - "KerasAveragePooling1D", - "KerasAveragePooling2D", - "KerasAveragePooling3D", - "KerasMaxPool1D", - "KerasMaxPool2D", - "KerasMaxPool3D", -] -PADDING_FUNCS = [ - "ReflectionPad1d", - "ReflectionPad2d", - "ReplicationPad1d", - "ReplicationPad2d", - "ReplicationPad3d", - "ZeroPad2d", - "ConstantPad1d", - "ConstantPad2d", - "ConstantPad3d", -] -KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] -ACTIVATION_FUNCS = [ - "ELU", - "Hardshrink", - "Hardsigmoid", - "Hardswish", - "Hardtanh", - "LeakyReLU", - "PReLU", - "ReLU", - "ReLU6", - "RReLU", - "SELU", - "CELU", - "GELU", - "Sigmoid", - "Softplus", - "Softshrink", - "Softsign", - "Tanh", - "Tanhshrink", - "Threshold", - "Softmin", - "Softmax", - "Softmax2d", - "LogSoftmax", - "AdaptiveLogSoftmaxWithLoss", -] -KERAS_ACTIVATION_FUNCS = [ - "KerasReLU", - "KerasPReLU", - "KerasLeakyReLU", - "KerasThresholdedReLU", - "KerasELU", - "KerasSoftmax", -] -DROPOUT_FUNCS = [ - "Dropout", - "Dropout2d", - "Dropout3d", - "AlphaDropout", - "FeatureAlphaDropout", -] -KERAS_DROPOUT_FUNCS = ["KerasDropout"] -CONV_BLOCK_FNS = [ - *CONV_FUNCS, - *KERAS_CONV_FUNCS, - *POOL_FUNCS, - *KERAS_POOL_FUNCS, - *PADDING_FUNCS, - *KERAS_PADDING_FUNCS, - *ACTIVATION_FUNCS, - *KERAS_ACTIVATION_FUNCS, - *NORM_FUNCS, - *KERAS_NORM_FUNCS, - *DROPOUT_FUNCS, - *KERAS_DROPOUT_FUNCS, -] -DATA_FORMAT = "PT" - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", +} - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_store_config_info(fn): @@ -460,25 +605,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -511,7 +637,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -673,26 +801,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -810,6 +920,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -916,27 +1029,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1105,6 +1212,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1361,7 +1471,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1773,7 +1885,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1941,6 +2055,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1965,11 +2082,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2203,7 +2318,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2363,11 +2480,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2407,21 +2522,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2502,6 +2602,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2558,6 +2661,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2602,6 +2708,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2656,6 +2765,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2692,6 +2820,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2714,21 +2846,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2766,6 +2894,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2817,20 +2964,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2889,7 +3022,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2931,17 +3066,6 @@ def tensorflow_permute_dims( return tensorflow.transpose(x, perm=axes) -def tensorflow__handle_padding_bknd(x, strides, filters, padding): - if isinstance(padding, str) and padding.upper() == "SAME": - if x % strides == 0: - pad = max(filters - strides, 0) - else: - pad = max(filters - x % strides, 0) - else: - pad = 0 - return pad - - def tensorflow__output_ceil_shape_bknd(w, f, p, s): return math.ceil((w - f + p) / s) + 1 @@ -2966,35 +3090,6 @@ def tensorflow__padding_ceil_mode_bknd( return p -def tensorflow__handle_manual_pad_avg_pool( - x, kernel, strides, padding, ceil_mode, dims -): - if isinstance(padding, str): - pad_specific = [ - tensorflow__handle_padding_bknd( - x.shape[i + 1], strides[i], kernel[i], padding - ) - for i in range(dims) - ] - padding = [ - (pad_specific[i] // 2, pad_specific[i] - pad_specific[i] // 2) - for i in range(dims) - ] - else: - if isinstance(padding, int): - padding = [(padding,) * 2] * dims - pad_specific = [sum(padding[i]) for i in range(dims)] - c = [] - if ceil_mode: - for i in range(dims): - padding[i], c_i = tensorflow__padding_ceil_mode_bknd( - x.shape[i + 1], kernel[i], padding[i], strides[i], True - ) - c.append(c_i) - pad_specific[i] = sum(padding[i]) - return padding, pad_specific, c - - def tensorflow_map_bknd( fn: Callable, constant: Optional[Dict[str, Any]] = None, @@ -3217,7 +3312,9 @@ def tensorflow_divide( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -3277,7 +3374,9 @@ def tensorflow_minimum( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -3354,7 +3453,9 @@ def tensorflow_greater_equal( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -3543,79 +3644,3 @@ def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3) input = tensorflow_permute_dims(input, axes=axes) return input - - -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_AvgPool2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_AvgPool2d_output/run_0/tensorflow__helpers.py index e92d2af4d280..48d4e807029c 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_AvgPool2d_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_AvgPool2d_output/run_0/tensorflow__helpers.py @@ -27,6 +27,371 @@ import tensorflow as tf +CONV_FUNCS = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] +NORM_FUNCS = [ + "_BatchNorm", + "_InstanceNorm", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "GroupNorm", + "SyncBatchNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LocalResponseNorm", +] +POOL_FUNCS = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "LPPool1d", + "LPPool2d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] +KERAS_CONV_FUNCS = [ + "KerasConv1D", + "KerasConv2D", + "KerasConv3D", + "KerasDepthwiseConv2D", + "KerasConv1DTranspose", + "KerasConv2DTranspose", + "KerasConv3DTranspose", +] +KERAS_NORM_FUNCS = [ + "KerasBatchNorm1D", + "KerasBatchNorm2D", + "KerasBatchNorm3D", + "KerasLayerNormalization", + "KerasGroupNormalization", + "KerasUnitNorm1D", + "KerasUnitNorm2D", + "KerasUnitNorm3D", +] +KERAS_POOL_FUNCS = [ + "KerasAveragePooling1D", + "KerasAveragePooling2D", + "KerasAveragePooling3D", + "KerasMaxPool1D", + "KerasMaxPool2D", + "KerasMaxPool3D", +] +PADDING_FUNCS = [ + "ReflectionPad1d", + "ReflectionPad2d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad2d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", +] +KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] +ACTIVATION_FUNCS = [ + "ELU", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "LeakyReLU", + "PReLU", + "ReLU", + "ReLU6", + "RReLU", + "SELU", + "CELU", + "GELU", + "Sigmoid", + "Softplus", + "Softshrink", + "Softsign", + "Tanh", + "Tanhshrink", + "Threshold", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", + "AdaptiveLogSoftmaxWithLoss", +] +KERAS_ACTIVATION_FUNCS = [ + "KerasReLU", + "KerasPReLU", + "KerasLeakyReLU", + "KerasThresholdedReLU", + "KerasELU", + "KerasSoftmax", +] +DROPOUT_FUNCS = [ + "Dropout", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] +KERAS_DROPOUT_FUNCS = ["KerasDropout"] +CONV_BLOCK_FNS = [ + *CONV_FUNCS, + *KERAS_CONV_FUNCS, + *POOL_FUNCS, + *KERAS_POOL_FUNCS, + *PADDING_FUNCS, + *KERAS_PADDING_FUNCS, + *ACTIVATION_FUNCS, + *KERAS_ACTIVATION_FUNCS, + *NORM_FUNCS, + *KERAS_NORM_FUNCS, + *DROPOUT_FUNCS, + *KERAS_DROPOUT_FUNCS, +] +DATA_FORMAT = "PT" + + +def tensorflow_handle_transpose_in_input_and_output(fn): + from .tensorflow_TransposeType import tensorflow_TransposeType + + original_signature = inspect.signature(fn) + + @functools.wraps(fn) + def transpose_wrapper(self, *args, **kwargs): + global DATA_FORMAT + kwargs_call = { + key: val + for key, val in kwargs.items() + if key not in dict(original_signature.parameters) + } + fn_args_and_kwargs = { + key: val for key, val in kwargs.items() if key not in kwargs_call + } + fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) + conv_block_start = lambda f: any( + substr in f.__qualname__ + for substr in CONV_FUNCS + + NORM_FUNCS + + POOL_FUNCS + + KERAS_CONV_FUNCS + + KERAS_NORM_FUNCS + + KERAS_POOL_FUNCS + ) + next_call_in_seq = tensorflow_get_next_func(self) + name_of_next_call = ( + next_call_in_seq.__class__.__name__ + if hasattr(next_call_in_seq, "__class__") + else "" + ) + conv_block_continued = next_call_in_seq and any( + substr in name_of_next_call for substr in CONV_BLOCK_FNS + ) + if DATA_FORMAT == "PT" and conv_block_start(self.__class__): + input = fn_args_and_kwargs["input"] + if len(input.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(input.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(input.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + fn_args_and_kwargs = tensorflow_set_item_bknd( + fn_args_and_kwargs, + "input", + tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), + ) + DATA_FORMAT = "TF" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_last" + ) + res = fn(self, **fn_args_and_kwargs) + if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": + return res + if len(res.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(res.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(res.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) + DATA_FORMAT = "PT" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_first" + ) + return res + + tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature + return transpose_wrapper + + +def tensorflow__handle_manual_pad_avg_pool( + x, kernel, strides, padding, ceil_mode, dims +): + if isinstance(padding, str): + pad_specific = [ + tensorflow__handle_padding_bknd( + x.shape[i + 1], strides[i], kernel[i], padding + ) + for i in range(dims) + ] + padding = [ + (pad_specific[i] // 2, pad_specific[i] - pad_specific[i] // 2) + for i in range(dims) + ] + else: + if isinstance(padding, int): + padding = [(padding,) * 2] * dims + pad_specific = [sum(padding[i]) for i in range(dims)] + c = [] + if ceil_mode: + for i in range(dims): + padding[i], c_i = tensorflow__padding_ceil_mode_bknd( + x.shape[i + 1], kernel[i], padding[i], strides[i], True + ) + c.append(c_i) + pad_specific[i] = sum(padding[i]) + return padding, pad_specific, c + + +def tensorflow__handle_padding_bknd(x, strides, filters, padding): + if isinstance(padding, str) and padding.upper() == "SAME": + if x % strides == 0: + pad = max(filters - strides, 0) + else: + pad = max(filters - x % strides, 0) + else: + pad = 0 + return pad + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -151,6 +516,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -187,234 +553,13 @@ ("uint64", "uint64"): "uint64", ("float16", "float16"): "float16", ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -CONV_FUNCS = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", -] -NORM_FUNCS = [ - "_BatchNorm", - "_InstanceNorm", - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "GroupNorm", - "SyncBatchNorm", - "InstanceNorm1d", - "InstanceNorm2d", - "InstanceNorm3d", - "LocalResponseNorm", -] -POOL_FUNCS = [ - "MaxPool1d", - "MaxPool2d", - "MaxPool3d", - "AvgPool1d", - "AvgPool2d", - "AvgPool3d", - "FractionalMaxPool2d", - "LPPool1d", - "LPPool2d", - "AdaptiveMaxPool1d", - "AdaptiveMaxPool2d", - "AdaptiveMaxPool3d", - "AdaptiveAvgPool1d", - "AdaptiveAvgPool2d", - "AdaptiveAvgPool3d", -] -KERAS_CONV_FUNCS = [ - "KerasConv1D", - "KerasConv2D", - "KerasConv3D", - "KerasDepthwiseConv2D", - "KerasConv1DTranspose", - "KerasConv2DTranspose", - "KerasConv3DTranspose", -] -KERAS_NORM_FUNCS = [ - "KerasBatchNorm1D", - "KerasBatchNorm2D", - "KerasBatchNorm3D", - "KerasLayerNormalization", - "KerasGroupNormalization", - "KerasUnitNorm1D", - "KerasUnitNorm2D", - "KerasUnitNorm3D", -] -KERAS_POOL_FUNCS = [ - "KerasAveragePooling1D", - "KerasAveragePooling2D", - "KerasAveragePooling3D", - "KerasMaxPool1D", - "KerasMaxPool2D", - "KerasMaxPool3D", -] -PADDING_FUNCS = [ - "ReflectionPad1d", - "ReflectionPad2d", - "ReplicationPad1d", - "ReplicationPad2d", - "ReplicationPad3d", - "ZeroPad2d", - "ConstantPad1d", - "ConstantPad2d", - "ConstantPad3d", -] -KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] -ACTIVATION_FUNCS = [ - "ELU", - "Hardshrink", - "Hardsigmoid", - "Hardswish", - "Hardtanh", - "LeakyReLU", - "PReLU", - "ReLU", - "ReLU6", - "RReLU", - "SELU", - "CELU", - "GELU", - "Sigmoid", - "Softplus", - "Softshrink", - "Softsign", - "Tanh", - "Tanhshrink", - "Threshold", - "Softmin", - "Softmax", - "Softmax2d", - "LogSoftmax", - "AdaptiveLogSoftmaxWithLoss", -] -KERAS_ACTIVATION_FUNCS = [ - "KerasReLU", - "KerasPReLU", - "KerasLeakyReLU", - "KerasThresholdedReLU", - "KerasELU", - "KerasSoftmax", -] -DROPOUT_FUNCS = [ - "Dropout", - "Dropout2d", - "Dropout3d", - "AlphaDropout", - "FeatureAlphaDropout", -] -KERAS_DROPOUT_FUNCS = ["KerasDropout"] -CONV_BLOCK_FNS = [ - *CONV_FUNCS, - *KERAS_CONV_FUNCS, - *POOL_FUNCS, - *KERAS_POOL_FUNCS, - *PADDING_FUNCS, - *KERAS_PADDING_FUNCS, - *ACTIVATION_FUNCS, - *KERAS_ACTIVATION_FUNCS, - *NORM_FUNCS, - *KERAS_NORM_FUNCS, - *DROPOUT_FUNCS, - *KERAS_DROPOUT_FUNCS, -] -DATA_FORMAT = "PT" - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", +} - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_store_config_info(fn): @@ -460,25 +605,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -511,7 +637,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -673,26 +801,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -810,6 +920,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -916,27 +1029,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1105,6 +1212,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1361,7 +1471,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1773,7 +1885,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1941,6 +2055,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1965,11 +2082,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2203,7 +2318,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2363,11 +2480,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2407,21 +2522,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2502,6 +2602,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2558,6 +2661,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2602,6 +2708,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2656,6 +2765,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2692,6 +2820,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2714,21 +2846,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2766,6 +2894,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2817,20 +2964,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2889,7 +3022,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2915,17 +3050,6 @@ def tensorflow_add_frnt_(tensor, other, *, alpha=1): return tensorflow_add_frnt(tensor, other, alpha=alpha) -def tensorflow__handle_padding_bknd(x, strides, filters, padding): - if isinstance(padding, str) and padding.upper() == "SAME": - if x % strides == 0: - pad = max(filters - strides, 0) - else: - pad = max(filters - x % strides, 0) - else: - pad = 0 - return pad - - def tensorflow__output_ceil_shape_bknd(w, f, p, s): return math.ceil((w - f + p) / s) + 1 @@ -2950,35 +3074,6 @@ def tensorflow__padding_ceil_mode_bknd( return p -def tensorflow__handle_manual_pad_avg_pool( - x, kernel, strides, padding, ceil_mode, dims -): - if isinstance(padding, str): - pad_specific = [ - tensorflow__handle_padding_bknd( - x.shape[i + 1], strides[i], kernel[i], padding - ) - for i in range(dims) - ] - padding = [ - (pad_specific[i] // 2, pad_specific[i] - pad_specific[i] // 2) - for i in range(dims) - ] - else: - if isinstance(padding, int): - padding = [(padding,) * 2] * dims - pad_specific = [sum(padding[i]) for i in range(dims)] - c = [] - if ceil_mode: - for i in range(dims): - padding[i], c_i = tensorflow__padding_ceil_mode_bknd( - x.shape[i + 1], kernel[i], padding[i], strides[i], True - ) - c.append(c_i) - pad_specific[i] = sum(padding[i]) - return padding, pad_specific, c - - def tensorflow_map_bknd( fn: Callable, constant: Optional[Dict[str, Any]] = None, @@ -3195,79 +3290,3 @@ def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3) input = tensorflow_permute_dims(input, axes=axes) return input - - -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__BatchNorm.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__BatchNorm.py index c2b2b6bb34f4..855f7a97e77c 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__BatchNorm.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__BatchNorm.py @@ -56,16 +56,12 @@ def call(self, input): normalized, self.running_mean, self.running_var = ( tensorflow_batch_norm_frnt( input, - ( - self.running_mean - if not self.training or self.track_running_stats - else None - ), - ( - self.running_var - if not self.training or self.track_running_stats - else None - ), + self.running_mean + if not self.training or self.track_running_stats + else None, + self.running_var + if not self.training or self.track_running_stats + else None, self.weight, self.bias, bn_training, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__NormBase.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__NormBase.py index bf904536e053..b48ae673ea4b 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__NormBase.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__NormBase.py @@ -133,13 +133,11 @@ def _load_from_state_dict( state_dict = tensorflow_set_item_bknd( state_dict, num_batches_tracked_key, - ( - self.num_batches_tracked - if self.num_batches_tracked is not None - and self.num_batches_tracked.device - != tensorflow_device_frnt("meta") - else tensorflow_tensor_frnt(0, dtype=tf.int64) - ), + self.num_batches_tracked + if self.num_batches_tracked is not None + and self.num_batches_tracked.device + != tensorflow_device_frnt("meta") + else tensorflow_tensor_frnt(0, dtype=tf.int64), ) super()._load_from_state_dict( state_dict, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__helpers.py index 9dd5c8deeea0..9138b8376bd3 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_BatchNorm2d_output/run_0/tensorflow__helpers.py @@ -26,214 +26,6 @@ import tensorflow as tf -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] CONV_FUNCS = [ "Conv1d", "Conv2d", @@ -371,6 +163,101 @@ DATA_FORMAT = "PT" +def tensorflow_handle_transpose_in_input_and_output(fn): + from .tensorflow_TransposeType import tensorflow_TransposeType + + original_signature = inspect.signature(fn) + + @functools.wraps(fn) + def transpose_wrapper(self, *args, **kwargs): + global DATA_FORMAT + kwargs_call = { + key: val + for key, val in kwargs.items() + if key not in dict(original_signature.parameters) + } + fn_args_and_kwargs = { + key: val for key, val in kwargs.items() if key not in kwargs_call + } + fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) + conv_block_start = lambda f: any( + substr in f.__qualname__ + for substr in CONV_FUNCS + + NORM_FUNCS + + POOL_FUNCS + + KERAS_CONV_FUNCS + + KERAS_NORM_FUNCS + + KERAS_POOL_FUNCS + ) + next_call_in_seq = tensorflow_get_next_func(self) + name_of_next_call = ( + next_call_in_seq.__class__.__name__ + if hasattr(next_call_in_seq, "__class__") + else "" + ) + conv_block_continued = next_call_in_seq and any( + substr in name_of_next_call for substr in CONV_BLOCK_FNS + ) + if DATA_FORMAT == "PT" and conv_block_start(self.__class__): + input = fn_args_and_kwargs["input"] + if len(input.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(input.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(input.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + fn_args_and_kwargs = tensorflow_set_item_bknd( + fn_args_and_kwargs, + "input", + tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), + ) + DATA_FORMAT = "TF" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_last" + ) + res = fn(self, **fn_args_and_kwargs) + if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": + return res + if len(res.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(res.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(res.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) + DATA_FORMAT = "PT" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_first" + ) + return res + + tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature + return transpose_wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + def tensorflow_handle_array_like_without_promotion(fn: Callable): @functools.wraps(fn) def _handle_array_like_without_promotion(*args, **kwargs): @@ -412,8 +299,226 @@ def _handle_array_like_without_promotion(*args, **kwargs): ) return fn(*args, **kwargs) - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +promotion_table = { + ("bool", "bool"): "bool", + ("int8", "int8"): "int8", + ("int8", "int16"): "int16", + ("int8", "int32"): "int32", + ("int8", "int64"): "int64", + ("int16", "int16"): "int16", + ("int16", "int32"): "int32", + ("int16", "int64"): "int64", + ("int32", "int32"): "int32", + ("int32", "int64"): "int64", + ("int64", "int64"): "int64", + ("uint8", "int8"): "int16", + ("uint8", "int16"): "int16", + ("uint8", "int32"): "int32", + ("uint8", "int64"): "int64", + ("uint8", "uint8"): "uint8", + ("uint8", "uint16"): "uint16", + ("uint8", "uint32"): "uint32", + ("uint8", "uint64"): "uint64", + ("uint16", "int8"): "int32", + ("uint16", "int16"): "int32", + ("uint16", "int32"): "int32", + ("uint16", "int64"): "int64", + ("uint16", "uint16"): "uint16", + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "int8"): "int64", + ("uint32", "int16"): "int64", + ("uint32", "int32"): "int64", + ("uint32", "int64"): "int64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", + ("float16", "float16"): "float16", + ("float16", "float32"): "float32", + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", + ("bool", "int8"): "int8", + ("bool", "int16"): "int16", + ("bool", "int32"): "int32", + ("bool", "int64"): "int64", + ("bool", "uint8"): "uint8", + ("bool", "uint16"): "uint16", + ("bool", "uint32"): "uint32", + ("bool", "uint64"): "uint64", + ("bool", "float16"): "float16", + ("bool", "float32"): "float32", + ("bool", "float64"): "float64", + ("bool", "bfloat16"): "bfloat16", + ("bool", "complex64"): "complex64", + ("bool", "complex128"): "complex128", + ("int8", "float16"): "float16", + ("int8", "float32"): "float32", + ("int8", "float64"): "float64", + ("int8", "bfloat16"): "bfloat16", + ("int8", "complex64"): "complex64", + ("int8", "complex128"): "complex128", + ("int16", "float32"): "float32", + ("int16", "float64"): "float64", + ("int16", "complex64"): "complex64", + ("int16", "complex128"): "complex128", + ("int32", "float64"): "float64", + ("int32", "complex128"): "complex128", + ("int64", "float64"): "float64", + ("int64", "complex128"): "complex128", + ("uint8", "float16"): "float16", + ("uint8", "float32"): "float32", + ("uint8", "float64"): "float64", + ("uint8", "bfloat16"): "bfloat16", + ("uint8", "complex64"): "complex64", + ("uint8", "complex128"): "complex128", + ("uint16", "float32"): "float32", + ("uint16", "float64"): "float64", + ("uint16", "complex64"): "complex64", + ("uint16", "complex128"): "complex128", + ("uint32", "float64"): "float64", + ("uint32", "complex128"): "complex128", + ("uint64", "int8"): "float64", + ("uint64", "int16"): "float64", + ("uint64", "int32"): "float64", + ("uint64", "int64"): "float64", + ("uint64", "float64"): "float64", + ("uint64", "complex128"): "complex128", + ("float16", "bfloat16"): "float32", + ("float16", "complex64"): "complex64", + ("float16", "complex128"): "complex128", + ("float32", "complex64"): "complex64", + ("float32", "complex128"): "complex128", + ("float64", "complex64"): "complex128", + ("float64", "complex128"): "complex128", + ("bfloat16", "float16"): "float32", + ("bfloat16", "float32"): "float32", + ("bfloat16", "float64"): "float64", + ("bfloat16", "bfloat16"): "bfloat16", + ("bfloat16", "complex64"): "complex64", + ("bfloat16", "complex128"): "complex128", + ("complex64", "float64"): "complex128", + ("complex64", "complex64"): "complex64", + ("complex64", "complex128"): "complex128", + ("complex128", "complex128"): "complex128", + ("float16", "int16"): "float32", + ("float16", "int32"): "float64", + ("float16", "int64"): "float64", + ("float16", "uint16"): "float32", + ("float16", "uint32"): "float64", + ("float16", "uint64"): "float64", + ("float32", "int32"): "float64", + ("float32", "int64"): "float64", + ("float32", "uint32"): "float64", + ("float32", "uint64"): "float64", + ("bfloat16", "int16"): "float32", + ("bfloat16", "int32"): "float64", + ("bfloat16", "int64"): "float64", + ("bfloat16", "uint16"): "float32", + ("bfloat16", "uint32"): "float64", + ("bfloat16", "uint64"): "float64", + ("complex64", "int32"): "complex128", + ("complex64", "int64"): "complex128", + ("complex64", "uint32"): "complex128", + ("complex64", "uint64"): "complex128", +} + +array_api_promotion_table = { + ("bool", "bool"): "bool", + ("int8", "int8"): "int8", + ("int8", "int16"): "int16", + ("int8", "int32"): "int32", + ("int8", "int64"): "int64", + ("int16", "int16"): "int16", + ("int16", "int32"): "int32", + ("int16", "int64"): "int64", + ("int32", "int32"): "int32", + ("int32", "int64"): "int64", + ("int64", "int64"): "int64", + ("uint8", "int8"): "int16", + ("uint8", "int16"): "int16", + ("uint8", "int32"): "int32", + ("uint8", "int64"): "int64", + ("uint8", "uint8"): "uint8", + ("uint8", "uint16"): "uint16", + ("uint8", "uint32"): "uint32", + ("uint8", "uint64"): "uint64", + ("uint16", "int8"): "int32", + ("uint16", "int16"): "int32", + ("uint16", "int32"): "int32", + ("uint16", "int64"): "int64", + ("uint16", "uint16"): "uint16", + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "int8"): "int64", + ("uint32", "int16"): "int64", + ("uint32", "int32"): "int64", + ("uint32", "int64"): "int64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", + ("float16", "float16"): "float16", + ("float16", "float32"): "float32", + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", +} + +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -472,7 +577,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -585,6 +692,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -598,6 +706,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -623,6 +732,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -749,6 +861,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -793,6 +908,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -847,6 +965,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -883,6 +1020,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -905,21 +1046,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -957,6 +1094,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -1008,20 +1164,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -1088,26 +1230,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1225,6 +1349,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1331,27 +1458,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1520,6 +1641,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1776,7 +1900,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2188,7 +2314,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2273,6 +2401,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2297,11 +2428,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2509,7 +2638,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2669,11 +2800,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2713,21 +2842,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2999,25 +3113,6 @@ def tensorflow_device_frnt(dev): return tensorflow_default_device_bknd(dev) -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 def tensorflow_split_frnt(tensor, split_size_or_sections, dim=0): if isinstance(split_size_or_sections, int): @@ -3059,7 +3154,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -3254,82 +3351,6 @@ def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): return input -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper - - def tensorflow_ndim_bknd_(self): return len(tuple(self.shape)) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_CallVisitor.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_CallVisitor.py deleted file mode 100644 index 1e99977bd5b7..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_CallVisitor.py +++ /dev/null @@ -1,13 +0,0 @@ -import ast - -from .tensorflow__helpers import tensorflow_store_config_info - - -class tensorflow_CallVisitor(ast.NodeVisitor): - @tensorflow_store_config_info - def __init__(self): - self.func_name = None - - def visit_Call(self, node): - self.func_name = ast.unparse(node.func).strip() - return super().generic_visit(node) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_Conv2d.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_Conv2d.py deleted file mode 100644 index 7a6bbfbb0f29..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_Conv2d.py +++ /dev/null @@ -1,64 +0,0 @@ -from .tensorflow__ConvNd import tensorflow__ConvNd -from .tensorflow__helpers import tensorflow__ntuple_parse -from .tensorflow__helpers import tensorflow_conv2d_frnt -from .tensorflow__helpers import tensorflow_handle_transpose_in_input_and_output -from .tensorflow__helpers import tensorflow_pad_frnt - -_pair = tensorflow__ntuple_parse(2, "_pair") - - -class tensorflow_Conv2d(tensorflow__ConvNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode="zeros", - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - kernel_size_ = _pair(kernel_size) - stride_ = _pair(stride) - padding_ = padding if isinstance(padding, str) else _pair(padding) - dilation_ = _pair(dilation) - super().__init__( - in_channels, - out_channels, - kernel_size_, - stride_, - padding_, - dilation_, - False, - _pair(0), - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def _conv_forward(self, input, weight, bias): - if self.padding_mode != "zeros": - return tensorflow_conv2d_frnt( - tensorflow_pad_frnt( - input, self._reversed_padding_repeated_twice, mode=self.padding_mode - ), - weight, - bias, - self.stride, - _pair(0), - self.dilation, - self.groups, - ) - return tensorflow_conv2d_frnt( - input, weight, bias, self.stride, self.padding, self.dilation, self.groups - ) - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input): - return self._conv_forward(input, self.weight, self.bias) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index ac8335fe1e56..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import TypeVar -from typing import Protocol - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_TransposeType.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_TransposeType.py deleted file mode 100644 index f380aaf0d6e0..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow_TransposeType.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum - - -class tensorflow_TransposeType(Enum): - NO_TRANSPOSE = "no_transpose" - CONV1D = "conv1d" - CONV2D = "conv2d" - CONV3D = "conv3d" diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__ConvNd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__ConvNd.py deleted file mode 100644 index a866ab7b8f83..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__ConvNd.py +++ /dev/null @@ -1,541 +0,0 @@ -import tensorflow -from collections import OrderedDict - -import typing -import math -from typing import Optional - -from .tensorflow__stateful import Layer as tensorflow_keras_Layer -from .tensorflow__helpers import tensorflow__calculate_fan_in_and_fan_out -from .tensorflow__helpers import tensorflow__is_variable_bknd -from .tensorflow__helpers import tensorflow__reverse_repeat_tuple -from .tensorflow__helpers import tensorflow_add_frnt_ -from .tensorflow__helpers import tensorflow_default_bknd -from .tensorflow__helpers import tensorflow_empty_frnt -from .tensorflow__helpers import tensorflow_kaiming_uniform_ -from .tensorflow__helpers import tensorflow_set_item_bknd -from .tensorflow__helpers import tensorflow_split_frnt_ -from .tensorflow__helpers import tensorflow_store_config_info -from .tensorflow__helpers import tensorflow_uniform_ - - -class tensorflow__ConvNd(tensorflow_keras_Layer): - __constants__ = [ - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - "output_padding", - "in_channels", - "out_channels", - "kernel_size", - ] - __annotations__ = {"bias": Optional[tensorflow.Variable]} - - def _conv_forward(self, input, weight, bias): ... - - in_channels: typing.Any - _reversed_padding_repeated_twice: typing.Any - out_channels: typing.Any - kernel_size: typing.Any - stride: typing.Any - padding: typing.Any - dilation: typing.Any - transposed: typing.Any - output_padding: typing.Any - groups: typing.Any - padding_mode: typing.Any - weight: typing.Any - bias: typing.Any - - @tensorflow_store_config_info - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - self.super___init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=device, - dtype=dtype, - v=getattr(self, "_v", None), - buffers=getattr(self, "_buffers", None), - module_dict=getattr(self, "_module_dict", None), - ) - if groups <= 0: - raise ValueError("groups must be a positive integer") - if in_channels % groups != 0: - raise ValueError("in_channels must be divisible by groups") - if out_channels % groups != 0: - raise ValueError("out_channels must be divisible by groups") - valid_padding_strings = {"same", "valid"} - if isinstance(padding, str): - if padding not in valid_padding_strings: - raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" - ) - if padding == "same" and any(s != 1 for s in stride): - raise ValueError( - "padding='same' is not supported for strided convolutions" - ) - valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} - if padding_mode not in valid_padding_modes: - raise ValueError( - f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" - ) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.padding_mode = padding_mode - if isinstance(self.padding, str): - self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - with tensorflow.name_scope("_reversed_padding_repeated_twice"): - self._reversed_padding_repeated_twice = ( - tensorflow_set_item_bknd( - self._reversed_padding_repeated_twice, 2 * i, left_pad - ) - ) - with tensorflow.name_scope("_reversed_padding_repeated_twice"): - self._reversed_padding_repeated_twice = ( - tensorflow_set_item_bknd( - self._reversed_padding_repeated_twice, - 2 * i + 1, - total_padding - left_pad, - ) - ) - else: - with tensorflow.name_scope("_reversed_padding_repeated_twice"): - self._reversed_padding_repeated_twice = ( - tensorflow__reverse_repeat_tuple(self.padding, 2) - ) - if transposed: - self.weight = tensorflow.Variable( - tensorflow_empty_frnt( - (*kernel_size, out_channels // groups, in_channels), - **factory_kwargs, - ), - name="weight", - ) - else: - self.weight = tensorflow.Variable( - tensorflow_empty_frnt( - (*kernel_size, in_channels // groups, out_channels), - **factory_kwargs, - ), - name="weight", - ) - if bias: - self.bias = tensorflow.Variable( - tensorflow_empty_frnt(out_channels, **factory_kwargs), name="bias" - ) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - tensorflow_kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - with tensorflow.name_scope(""): - fan_in, _ = tensorflow__calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - tensorflow_uniform_(self.bias, -bound, bound) - - def extra_repr(self): - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" - if self.padding != (0,) * len(self.padding): - s = s + ", padding={padding}" - if self.dilation != (1,) * len(self.dilation): - s = s + ", dilation={dilation}" - if self.output_padding != (0,) * len(self.output_padding): - s = s + ", output_padding={output_padding}" - if self.groups != 1: - s = s + ", groups={groups}" - if self.bias is None: - s = s + ", bias=False" - if self.padding_mode != "zeros": - s = s + ", padding_mode={padding_mode}" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" - - def super___init__(self, *args, device=None, devices=None, **kwargs): - super().__init__( - *args, - device=device, - devices=devices, - training=True, - build_mode="explicit", - dynamic_backend=True, - **kwargs, - ) - super().__setattr__("_frontend_module", True) - super().__setattr__( - "_attr_mapping", {"_parameters": "v", "_modules": "module_dict"} - ) - - def __dir__(self): - module_attrs = dir(self.__class__) - attrs = list(self.__dict__.keys()) - parameters = list(self._v.keys()) - modules = list(self._module_dict.keys()) - buffers = list(self._buffers.keys()) - keys = module_attrs + attrs + parameters + modules + buffers - ag__result_list_0 = [] - for key in keys: - if not key[0].isdigit(): - res = key - ag__result_list_0.append(res) - keys = ag__result_list_0 - return sorted(keys) - - def __getattribute__(self, name): - if name == "__dict__": - return super().__getattribute__(name) - if "_module_dict" in self.__dict__: - modules = self.__dict__["_module_dict"] - if name in modules: - return modules[name] - if "_buffers" in self.__dict__: - buffers = self.__dict__["_buffers"] - if name in buffers: - return buffers[name] - if "_v" in self.__dict__: - v = self.__dict__["_v"] - if name in v: - return v[name] - if "_attr_mapping" in self.__dict__: - mapping = self.__dict__["_attr_mapping"] - if name in mapping: - return super().__getattribute__(mapping[name]) - return super().__getattribute__(name) - - def __getstate__(self): - state = self.__dict__.copy() - state.pop("_compiled_call_impl", None) - state.pop("_thread_local", None) - state.pop("_metrics_lock", None) - return state - - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - with tensorflow.name_scope("extra_lines"): - extra_lines = tensorflow_split_frnt_(extra_repr, "\n") - child_lines = [] - for key, module in self._module_dict.items(): - mod_str = repr(module) - mod_str = self._addindent(mod_str, 2) - child_lines.append("(" + key + "): " + mod_str) - lines = extra_lines + child_lines - main_str = self._get_name() + "(" - if lines: - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ")" - return main_str - - def __setattr__(self, name, value): - def remove_from(*dicts_or_sets): - for d in dicts_or_sets: - if name in d: - if isinstance(d, dict): - del d[name] - else: - d.discard(name) - - params = self.__dict__.get("_v") - if ( - params is not None - and name in params - and isinstance(value, tensorflow.Variable) - ): - remove_from(self.__dict__, self._buffers, self._module_dict) - self.register_parameter(name, value) - super().__setattr__(name, value) - else: - super().__setattr__(name, value) - - def _build(self, *args, **kwargs): - for module in self.__dict__.values(): - if isinstance(module, tensorflow_keras_Layer) and module is not self: - if not module._built: - module.build( - *module._args, - dynamic_backend=module._dynamic_backend, - **module._kwargs, - ) - return True - - def _call_impl(self, *args, **kwargs): - return self.call(*args, **kwargs) - - def _create_variables(self, device=None, dtype=None): - with tensorflow.name_scope("v"): - v = dict( - OrderedDict( - [ - (k.replace(".", "/"), v) - for k, v in self.__dict__.items() - if isinstance(v, tensorflow.Variable) and not k.startswith("_") - ] - ) - ) - v = ( - dict( - OrderedDict( - { - _k.replace(".", "/"): _v - for _k, _v in self._v.items() - if _k.replace(".", "/") not in v and not isinstance(_v, dict) - }, - **v, - ) - ) - if self._v - else v - ) - return v - - def _extra_repr(self): - return "" - - def _forward(self, *a, **kw): - ret = self._call_impl(*a, **kw) - return ret - - def _get_name(self): - return self.__class__.__name__ - - def _named_members( - self, get_members_fn, prefix="", recurse=True, remove_duplicate=True - ): - memo = set() - modules = ( - self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) - if recurse - else [(prefix, self)] - ) - for module_prefix, module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or id(v) in memo: - continue - if remove_duplicate: - tensorflow_add_frnt_(memo, id(v)) - name = module_prefix + ("." if module_prefix else "") + k - yield name, v - - def _replace_update_v(self, new_v, native=None): - with tensorflow.name_scope("native"): - native = tensorflow_default_bknd(native, self) - for k, v in new_v.items(): - if isinstance(v, dict): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, v) - elif tensorflow__is_variable_bknd(v): - native.__setattr__(k, tensorflow.Variable(v)) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, tensorflow.Variable(v)) - else: - raise Exception( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def _update_v(self, new_v, native=None): - with tensorflow.name_scope("native"): - native = tensorflow_default_bknd(native, self) - for k, v in new_v.items(): - if isinstance(v, dict): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, v) - elif tensorflow__is_variable_bknd(v): - native.__setattr__(k, tensorflow.Variable(v)) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, tensorflow.Variable(v)) - else: - raise Exception( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def add_module(self, name, module): - if ( - not isinstance( - module, (tensorflow_keras_Layer, tensorflow.keras.layers.Layer) - ) - and module is not None - ): - raise TypeError(f"{type(module)} is not a Module subclass") - elif not isinstance(name, str): - raise TypeError(f"module name should be a string. Got {type(name)}") - elif hasattr(self, name) and name not in self._modules: - raise KeyError(f"attribute '{name}' already exists") - elif "." in name: - raise KeyError(f'module name can\'t contain ".", got: {name}') - elif name == "": - raise KeyError('module name can\'t be empty string ""') - self._modules[name] = module - super().__setattr__(name, module) - - def apply(self, fn): - for module in self.children(): - if hasattr(module, "apply"): - module.apply(fn) - else: - fn(module) - fn(self) - return self - - def children(self): - for _, module in self.named_children(): - yield module - - def call(self, *input): - raise NotImplementedError( - f'Module [{type(self).__name__}] is missing the required "forward" function' - ) - - def get_parameter(self, target): - target = target.replace(".", "/") - return self.pt_v[target] - - def get_submodule(self, target): - if target == "": - return self - atoms: typing.Any = tensorflow_split_frnt_(target, ".") - mod: typing.Any = self - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, tensorflow_keras_Layer): - raise TypeError("`" + item + "` is not an nn.Module") - return mod - - def modules(self): - for _, module in self.named_modules(): - yield module - - def named_buffers(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.buffers.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def named_children(self): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - memo = set() - for name, module in self._module_dict.items(): - if module is not None and id(module) not in memo: - tensorflow_add_frnt_(memo, id(module)) - yield name, module - - def named_modules(self, memo=None, prefix="", remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - if memo is None: - memo = set() - if id(self) not in memo: - if remove_duplicate: - tensorflow_add_frnt_(memo, id(self)) - yield prefix, self - for name, module in self._module_dict.items(): - if module is None: - continue - submodule_prefix = prefix + ("." if prefix else "") + name - if not hasattr(module, "named_modules"): - yield submodule_prefix, self - else: - yield from module.named_modules( - memo, submodule_prefix, remove_duplicate - ) - - def named_parameters(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.v.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def parameters(self, recurse=True): - for _, param in self.named_parameters(recurse=recurse): - yield param - - def register_buffer(self, name, value, persistent=False): - super().register_buffer(name, value) - - def register_module(self, name, module): - self.add_module(name, module) - - def register_parameter(self, name, value): - super().register_parameter(name, value) - - def requires_grad_(self, requires_grad=True): - for p in self.parameters(): - p.requires_grad_(requires_grad) - return self diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 7f665606a6a6..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,3734 +0,0 @@ -from collections import UserDict -from itertools import repeat -from ivy.utils.backend import backend_stack -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Literal -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import ast -import collections -import copy -import functools -import inspect -import itertools -import math -import numpy as np -import os -import re -import tensorflow -import tensorflow as tf -import warnings - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -backend_stack = [] -CONV_FUNCS = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", -] -NORM_FUNCS = [ - "_BatchNorm", - "_InstanceNorm", - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "GroupNorm", - "SyncBatchNorm", - "InstanceNorm1d", - "InstanceNorm2d", - "InstanceNorm3d", - "LocalResponseNorm", -] -POOL_FUNCS = [ - "MaxPool1d", - "MaxPool2d", - "MaxPool3d", - "AvgPool1d", - "AvgPool2d", - "AvgPool3d", - "FractionalMaxPool2d", - "LPPool1d", - "LPPool2d", - "AdaptiveMaxPool1d", - "AdaptiveMaxPool2d", - "AdaptiveMaxPool3d", - "AdaptiveAvgPool1d", - "AdaptiveAvgPool2d", - "AdaptiveAvgPool3d", -] -KERAS_CONV_FUNCS = [ - "KerasConv1D", - "KerasConv2D", - "KerasConv3D", - "KerasDepthwiseConv2D", - "KerasConv1DTranspose", - "KerasConv2DTranspose", - "KerasConv3DTranspose", -] -KERAS_NORM_FUNCS = [ - "KerasBatchNorm1D", - "KerasBatchNorm2D", - "KerasBatchNorm3D", - "KerasLayerNormalization", - "KerasGroupNormalization", - "KerasUnitNorm1D", - "KerasUnitNorm2D", - "KerasUnitNorm3D", -] -KERAS_POOL_FUNCS = [ - "KerasAveragePooling1D", - "KerasAveragePooling2D", - "KerasAveragePooling3D", - "KerasMaxPool1D", - "KerasMaxPool2D", - "KerasMaxPool3D", -] -PADDING_FUNCS = [ - "ReflectionPad1d", - "ReflectionPad2d", - "ReplicationPad1d", - "ReplicationPad2d", - "ReplicationPad3d", - "ZeroPad2d", - "ConstantPad1d", - "ConstantPad2d", - "ConstantPad3d", -] -KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] -ACTIVATION_FUNCS = [ - "ELU", - "Hardshrink", - "Hardsigmoid", - "Hardswish", - "Hardtanh", - "LeakyReLU", - "PReLU", - "ReLU", - "ReLU6", - "RReLU", - "SELU", - "CELU", - "GELU", - "Sigmoid", - "Softplus", - "Softshrink", - "Softsign", - "Tanh", - "Tanhshrink", - "Threshold", - "Softmin", - "Softmax", - "Softmax2d", - "LogSoftmax", - "AdaptiveLogSoftmaxWithLoss", -] -KERAS_ACTIVATION_FUNCS = [ - "KerasReLU", - "KerasPReLU", - "KerasLeakyReLU", - "KerasThresholdedReLU", - "KerasELU", - "KerasSoftmax", -] -DROPOUT_FUNCS = [ - "Dropout", - "Dropout2d", - "Dropout3d", - "AlphaDropout", - "FeatureAlphaDropout", -] -KERAS_DROPOUT_FUNCS = ["KerasDropout"] -CONV_BLOCK_FNS = [ - *CONV_FUNCS, - *KERAS_CONV_FUNCS, - *POOL_FUNCS, - *KERAS_POOL_FUNCS, - *PADDING_FUNCS, - *KERAS_PADDING_FUNCS, - *ACTIVATION_FUNCS, - *KERAS_ACTIVATION_FUNCS, - *NORM_FUNCS, - *KERAS_NORM_FUNCS, - *DROPOUT_FUNCS, - *KERAS_DROPOUT_FUNCS, -] -DATA_FORMAT = "PT" - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow__ntuple_parse(n, name="parse"): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - parse.__name__ = name - return parse - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -def tensorflow__reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def tensorflow_empty_frnt( - *args, - size=None, - out=None, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, - memory_format=None, -): - if args and size: - raise TypeError("empty() got multiple values for argument 'shape'") - if size is None: - size = ( - args[0] - if isinstance(args[0], (tuple, list, tuple, tf.TensorShape)) - else args - ) - if isinstance(size, (tuple, list)): - size = tuple( - tensorflow_to_scalar_bknd_(s) if tensorflow_is_array_bknd(s) else s - for s in size - ) - return tensorflow_empty(shape=size, dtype=dtype, device=device, out=out) - - -def tensorflow_store_config_info(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - fn(self, *args, **kwargs) - if all( - [ - hasattr(self, "_args"), - hasattr(self, "_kwargs"), - hasattr(self, "_self_tracked_trackables"), - ] - ): - orig_trackables = copy.copy(self._self_tracked_trackables) - self._args = (self,) + args - self._kwargs = kwargs - self._self_tracked_trackables = orig_trackables - - return wrapper - - -def tensorflow_ndim_bknd_(self): - return len(tuple(self.shape)) - - -def tensorflow_dim_frnt_(tensor): - return tensorflow_ndim_bknd_(tensor) - - -def tensorflow_size_frnt_(tensor, dim=None): - shape = tensor.shape - if dim is None: - return shape - try: - return tensorflow_get_item(shape, dim) - except IndexError as e: - raise IndexError( - f"Dimension out of range (expected to be in range of [{len(shape)}, {len(shape) - 1}], but got {dim}" - ) from e - - -def tensorflow__calculate_fan_in_and_fan_out(tensor): - dimensions = tensorflow_dim_frnt_(tensor) - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - num_input_fmaps = tensorflow_size_frnt_(tensor, 1) - num_output_fmaps = tensorflow_size_frnt_(tensor, 0) - receptive_field_size = 1 - if tensorflow_dim_frnt_(tensor) > 2: - for s in tensor.shape[2:]: - receptive_field_size = receptive_field_size * s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - return fan_in, fan_out - - -def tensorflow__calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ["fan_in", "fan_out"] - if mode not in valid_modes: - raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") - fan_in, fan_out = tensorflow__calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == "fan_in" else fan_out - - -def tensorflow_calculate_gain(nonlinearity, param=None): - linear_fns = [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - ] - if nonlinearity in linear_fns or nonlinearity == "sigmoid": - return 1 - elif nonlinearity == "tanh": - return 5.0 / 3 - elif nonlinearity == "relu": - return math.sqrt(2.0) - elif nonlinearity == "leaky_relu": - if param is None: - negative_slope = 0.01 - elif ( - not isinstance(param, bool) - and isinstance(param, int) - or isinstance(param, float) - ): - negative_slope = param - else: - raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == "selu": - return 3.0 / 4 - else: - raise ValueError(f"Unsupported nonlinearity {nonlinearity}") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_all( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_all( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow_check_all(results, message="one of the args is False", as_array=True): - if as_array and not tensorflow_all(results) or not as_array and not all(results): - raise Exception(message) - - -def tensorflow_check_all_or_any_fn( - *args, - fn, - type="all", - limit=(0,), - message="args must exist according to type and limit given", - as_array=True, -): - if type == "all": - tensorflow_check_all([fn(arg) for arg in args], message, as_array=as_array) - elif type == "any": - count = 0 - for arg in args: - count = count + 1 if fn(arg) else count - if count not in limit: - raise Exception(message) - else: - raise Exception("type must be all or any") - - -def tensorflow__check_bounds_and_get_shape_bknd(low, high, shape): - if shape is not None: - tensorflow_check_all_or_any_fn( - low, - high, - fn=lambda x: isinstance(x, (int, float)), - type="all", - message="low and high bounds must be numerics when shape is specified", - ) - return tuple(shape) - valid_types = (tensorflow.Tensor,) - if len(backend_stack) == 0: - valid_types = valid_types + (tf.Tensor,) - else: - valid_types = valid_types + (tf.Tensor,) - if isinstance(low, valid_types): - if isinstance(high, valid_types): - tensorflow_check_equal( - tensorflow_shape(low), tensorflow_shape(high), as_array=False - ) - return tensorflow_shape(low) - if isinstance(high, valid_types): - return tensorflow_shape(high) - return tuple(()) - - -@tensorflow_infer_dtype -def tensorflow_random_uniform( - *, - low: Union[float, tensorflow.Tensor, tensorflow.Variable] = 0.0, - high: Union[float, tensorflow.Tensor, tensorflow.Variable, None] = 1.0, - shape: Optional[Union[tf.TensorShape, Sequence[int], tensorflow.Tensor]] = None, - dtype: tf.DType, - device: Optional[str] = None, - seed: Optional[int] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - shape = tensorflow__check_bounds_and_get_shape_bknd( - low, - ( - float( - tensorflow.experimental.numpy.finfo(tensorflow.float32).max - if dtype is None - else tensorflow.experimental.numpy.finfo(dtype).max - ) - if high is None - else high - ), - shape, - ) - low = tensorflow.cast(low, dtype) - if high is not None: - high = tensorflow.cast(high, dtype) - if seed: - tensorflow.random.set_seed(seed) - return tensorflow.random.uniform(shape, low, high, dtype=dtype, seed=seed) - - -def tensorflow_uniform__frnt_(tensor, from_=0, to=1, *, generator=None): - ret = tensorflow_random_uniform( - low=from_, high=to, shape=tensor.shape, dtype=tensor.dtype, seed=generator - ) - tensor = tensorflow_inplace_update(tensor, tensorflow_astype(ret, tensor.dtype)) - return tensor - - -def tensorflow_kaiming_uniform_( - tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None -): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = tensorflow__calculate_correct_fan(tensor, mode) - gain = tensorflow_calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std - return tensorflow_uniform__frnt_(tensor, -bound, bound, generator=generator) - - -def tensorflow__no_grad_uniform_(tensor, a, b, generator=None): - return tensorflow_uniform__frnt_(tensor, a, b, generator=generator) - - -def tensorflow_uniform_(tensor, a=0.0, b=1.0, generator=None): - return tensorflow__no_grad_uniform_(tensor, a, b, generator) - - -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods_1 -def tensorflow_split_frnt(tensor, split_size_or_sections, dim=0): - if isinstance(split_size_or_sections, int): - split_size = split_size_or_sections - split_size_or_sections = [split_size] * ( - tensorflow_get_item(tensor.shape, dim) // split_size - ) - if tensorflow_get_item(tensor.shape, dim) % split_size: - split_size_or_sections.append( - tensorflow_get_item(tensor.shape, dim) % split_size - ) - return tuple( - tensorflow_split( - tensor, - num_or_size_splits=split_size_or_sections, - axis=dim, - with_remainder=True, - ) - ) - - -@tensorflow_handle_methods_1 -def tensorflow_split_frnt_(tensor, split_size, dim=0): - return tensorflow_split_frnt(tensor, split_size, dim) - - -@tensorflow_handle_methods -def tensorflow_add( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - alpha: Optional[Union[int, float]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - if x1.dtype.is_bool and x2.dtype.is_bool: - return tensorflow.math.logical_or(x1, x2) - if alpha not in (1, None): - x2 = tensorflow_multiply(x2, alpha) - return tensorflow.add(x1, x2) - - -@tensorflow_handle_methods_1 -def tensorflow_add_frnt(input, other, *, alpha=1, out=None): - return tensorflow_add(input, other, alpha=alpha, out=out) - - -@tensorflow_handle_methods_1 -def tensorflow_add_frnt_(tensor, other, *, alpha=1): - return tensorflow_add_frnt(tensor, other, alpha=alpha) - - -def tensorflow__get_x_data_format_bknd( - dims: int = 2, data_format: str = "channel_first" -): - if dims == 1: - if data_format == "channel_first": - return "NCW" - else: - return "NWC" - if dims == 2: - if data_format == "channel_first": - return "NCHW" - else: - return "NHWC" - elif dims == 3: - if data_format == "channel_first": - return "NCDHW" - else: - return "NDHWC" - - -def tensorflow__x_dil_before_conv(x, dims, x_dilations, data_format): - x_dilations = [x_dilations] * dims if isinstance(x_dilations, int) else x_dilations - ag__result_list_0 = [] - for i, x_dil in enumerate(x_dilations): - if x_dil > 1: - res = i - ag__result_list_0.append(res) - x_dilations_idxs = ag__result_list_0 - if x_dilations_idxs: - if data_format[-1] == "C": - offset = 1 - else: - offset = 2 - for i in x_dilations_idxs: - h = x.shape[offset + i] - new_height = h + (h - 1) * (x_dilations[i] - 1) - h = tensorflow.eye(new_height, dtype=x.dtype)[:: x_dilations[i]] - x = tensorflow.experimental.numpy.swapaxes(x, offset + i, -1) - x = tensorflow.matmul(x, h) - x = tensorflow.experimental.numpy.swapaxes(x, -1, offset + i) - return x - - -def tensorflow__extend_2d_padding(padding, data_format): - if isinstance(padding, str): - return padding - if isinstance(padding, int): - padding = [(padding, padding)] * 2 - if data_format[-1] == "C": - padding = [(0, 0)] + padding + [(0, 0)] - else: - padding = [(0, 0), (0, 0)] + padding - return padding - - -def tensorflow_depthwise_conv2d( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int, int]], - padding: Union[str, int, Sequence[Tuple[int, int]]], - /, - *, - data_format: str = "NHWC", - dilations: Union[int, Tuple[int, int]] = 1, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - strides = [strides] * 2 if isinstance(strides, int) else strides - dilations = [dilations] * 2 if isinstance(dilations, int) else dilations - permuted_x = False - if data_format == "NCHW" and tensorflow_dev(x) == "cpu": - x = tensorflow.transpose(x, (0, 2, 3, 1)) - data_format = "NHWC" - permuted_x = True - if tensorflow.rank(filters) == 3: - filters = tensorflow.expand_dims(filters, -1) - padding = tensorflow__extend_2d_padding(padding, data_format) - strides = [1, strides[0], strides[1], 1] - res = tensorflow.nn.depthwise_conv2d( - x, filters, strides, padding, data_format, dilations - ) - if permuted_x: - res = tensorflow.transpose(res, (0, 3, 1, 2)) - return res - - -def tensorflow__pad_before_conv(x, padding, dims, data_format): - if isinstance(padding, str): - return x, padding - elif isinstance(padding, int): - pad_list = [(padding, padding)] * dims - else: - pad_list = padding - if data_format[-1] == "C": - pad_list = [(0, 0), *pad_list, (0, 0)] - else: - pad_list = [(0, 0), (0, 0), *pad_list] - return tensorflow.pad(x, pad_list, "CONSTANT"), "VALID" - - -def tensorflow__extend_3d_strides_dilations(strides, dilations, data_format): - if data_format[-1] == "C": - strides = [1, *([strides] * 3 if isinstance(strides, int) else strides), 1] - dilations = [ - 1, - *([dilations] * 3 if isinstance(dilations, int) else dilations), - 1, - ] - else: - strides = [1, 1, *([strides] * 3 if isinstance(strides, int) else strides)] - dilations = [ - 1, - 1, - *([dilations] * 3 if isinstance(dilations, int) else dilations), - ] - return strides, dilations - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_conv_general_dilated( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]], - padding: Union[str, int, Sequence[Tuple[int, int]]], - /, - *, - dims: int = 2, - data_format: str = "channel_last", - filter_format: str = "channel_last", - feature_group_count: int = 1, - x_dilations: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] = 1, - dilations: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] = 1, - bias: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if filter_format == "channel_first": - filters = tensorflow.transpose(filters, (*range(2, dims + 2), 1, 0)) - num_channels = x.shape[1] if data_format == "channel_first" else x.shape[-1] - if filters.shape[-2] != num_channels // feature_group_count: - raise Exception( - f"given feature_group_count {feature_group_count} expected input channel of the filter to be {num_channels // feature_group_count} but got {filters.shape[-2]}" - ) - if num_channels % feature_group_count != 0: - raise Exception( - f"input channel should be divisible by feature group count {feature_group_count} but got input channel {num_channels}" - ) - permuted_x = False - if data_format == "channel_first" and ( - tensorflow_dev(x) == "cpu" or feature_group_count != 1 - ): - x = tensorflow.transpose(x, (0, *range(2, dims + 2), 1)) - data_format = "channel_last" - permuted_x = True - data_format = tensorflow__get_x_data_format_bknd(dims, data_format) - x = tensorflow__x_dil_before_conv(x, dims, x_dilations, data_format) - if dims == 2: - padding = tensorflow__extend_2d_padding(padding, data_format) - if feature_group_count == 1: - res = tensorflow.nn.conv2d( - x, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - if not isinstance(padding, str): - padding = padding[1:-1] - res = tensorflow_depthwise_conv2d( - x, - tensorflow.transpose(filters, (0, 1, 3, 2)), - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - x, padding = tensorflow__pad_before_conv(x, padding, dims, data_format) - if dims == 1: - if feature_group_count == 1: - res = tensorflow.nn.conv1d( - x, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - res = tensorflow.concat( - [ - tensorflow.nn.conv1d( - x[..., i : i + filters.shape[-2]], - filters[ - ..., j : j + filters.shape[-1] // feature_group_count - ], - strides, - padding, - data_format, - dilations, - ) - for i, j in zip( - range(0, x.shape[-1], filters.shape[-2]), - range( - 0, - filters.shape[-1], - filters.shape[-1] // feature_group_count, - ), - ) - ], - axis=-1, - ) - else: - strides, dilations = tensorflow__extend_3d_strides_dilations( - strides, dilations, data_format - ) - if feature_group_count == 1: - res = tensorflow.nn.conv3d( - x, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - res = tensorflow.concat( - [ - tensorflow.nn.conv3d( - x[..., i : i + filters.shape[-2]], - filters[ - ..., j : j + filters.shape[-1] // feature_group_count - ], - strides, - padding, - data_format, - dilations, - ) - for i, j in zip( - range(0, x.shape[-1], filters.shape[-2]), - range( - 0, - filters.shape[-1], - filters.shape[-1] // feature_group_count, - ), - ) - ], - axis=-1, - ) - if bias is not None: - if data_format[1] == "C": - bias = tensorflow.reshape(bias, [1, -1, *([1] * dims)]) - res = tensorflow.math.add(res, bias) - if permuted_x: - return tensorflow.transpose(res, (0, dims + 1, *range(1, dims + 1))) - return res - - -def tensorflow__conv_frnt( - input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1 -): - dims = len(input.shape) - 2 - if isinstance(padding, str): - padding = padding.upper() - elif isinstance(padding, int): - padding = [*[(padding, padding) for _ in range(dims)]] - else: - padding = [*[(p, p) for p in padding]] - ret = tensorflow_conv_general_dilated( - input, - weight, - stride, - padding, - dims=dims, - data_format="channel_last", - filter_format="channel_last", - dilations=dilation, - feature_group_count=groups, - bias=bias, - ) - return ret - - -def tensorflow_conv2d_frnt( - input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1 -): - return tensorflow__conv_frnt( - input, - weight, - bias=bias, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - - -def tensorflow__handle_padding_shape_frnt(padding, n, mode): - ag__result_list_0 = [] - for i in range(int(len(padding) / 2) - 1, -1, -1): - res = ( - tensorflow_get_item(padding, i * 2), - tensorflow_get_item(padding, i * 2 + 1), - ) - ag__result_list_0.append(res) - padding = tuple(ag__result_list_0) - if mode == "circular": - padding = padding + ((0, 0),) * (n - len(padding)) - else: - padding = ((0, 0),) * (n - len(padding)) + padding - if mode == "circular": - padding = tuple(list(padding)[::-1]) - return padding - - -def tensorflow__to_tf_padding_bknd(pad_width, ndim): - if isinstance(pad_width, Number): - pad_width = [[pad_width] * 2] * ndim - elif len(pad_width) == 2 and isinstance(pad_width[0], Number): - pad_width = [pad_width] * ndim - elif ( - isinstance(pad_width, (list, tuple)) - and isinstance(pad_width[0], (list, tuple)) - and len(pad_width) < ndim - ): - pad_width = pad_width * ndim - return pad_width - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_pad( - input: Union[tensorflow.Tensor, tensorflow.Variable], - pad_width: Union[Iterable[Tuple[int]], int], - /, - *, - mode: Union[ - Literal[ - "constant", - "dilated", - "edge", - "linear_ramp", - "maximum", - "mean", - "median", - "minimum", - "reflect", - "symmetric", - "wrap", - "empty", - ], - Callable, - ] = "constant", - stat_length: Union[Iterable[Tuple[int]], int] = 1, - constant_values: Union[Iterable[Tuple[Number]], Number] = 0, - end_values: Union[Iterable[Tuple[Number]], Number] = 0, - reflect_type: Literal["even", "odd"] = "even", - **kwargs: Optional[Any], -): - pad_width = tensorflow__to_tf_padding_bknd(pad_width, len(input.shape)) - if not isinstance(constant_values, (tensorflow.Variable, tensorflow.Tensor)): - constant_values = tensorflow.constant(constant_values) - if constant_values.dtype != input.dtype: - constant_values = tensorflow.cast(constant_values, input.dtype) - return tensorflow.pad(input, pad_width, mode=mode, constant_values=constant_values) - - -def tensorflow_pad_frnt(input, pad, mode="constant", value=0): - if any([(pad_value < 0) for pad_value in pad]): - pad = list(pad) - slices = [] - for n in reversed(range(len(pad) // 2)): - i = n * 2 - j = i + 1 - start = None - stop = None - if tensorflow_get_item(pad, i) < 0: - start = -tensorflow_get_item(pad, i) - pad = tensorflow_set_item_bknd(pad, i, 0) - if tensorflow_get_item(pad, j) < 0: - stop = tensorflow_get_item(pad, j) - pad = tensorflow_set_item_bknd(pad, j, 0) - slices.append(slice(start, stop)) - ndim = len(input.shape) - while len(slices) < ndim: - slices.insert(0, slice(None)) - input = tensorflow_get_item(input, tuple(slices)) - value = 0 if value is None else value - mode_dict = { - "constant": "constant", - "reflect": "reflect", - "replicate": "edge", - "circular": "wrap", - } - if mode not in mode_dict: - raise ValueError(f"Unsupported padding mode: {mode}") - pad = tensorflow__handle_padding_shape_frnt(pad, len(input.shape), mode) - order = 0, 2, 3, 1 - pad = tuple(pad[i] for i in order) - return tensorflow_pad( - input, pad, mode=tensorflow_get_item(mode_dict, mode), constant_values=value - ) - - -def tensorflow_retrieve_object(frame, name): - if name is None: - return name - names = tensorflow_split_bknd_(name, ".") - obj = frame.f_locals.get(names[0]) or frame.f_globals.get(names[0]) - if obj is None: - return None - for attr in names[1:]: - try: - obj = getattr(obj, attr) - except AttributeError: - return None - return obj - - -def tensorflow_get_next_func(obj): - from .tensorflow_CallVisitor import tensorflow_CallVisitor - - stack = inspect.stack() - for frame_info in stack: - if frame_info == obj._previous_frame_info: - calling_frame = frame_info.frame - break - else: - return None - if "Sequential" in frame_info.filename: - try: - self_seq = calling_frame.f_locals["self"] - idx = calling_frame.f_locals["i"] - next_func = tensorflow_get_item(self_seq, idx + 1) - return next_func - except IndexError: - for frame_info in tensorflow_get_item( - stack, slice(stack.index(frame_info) + 1, None, None) - ): - if frame_info == self_seq._previous_frame_info: - calling_frame = frame_info.frame - break - else: - return None - lines, start_line_no = inspect.getsourcelines(calling_frame) - current_line_no = calling_frame.f_lineno - relative_line_no = current_line_no - start_line_no - try: - next_line = tensorflow_get_item(lines, relative_line_no + 1).strip() - tree = ast.parse(next_line) - visitor = tensorflow_CallVisitor() - visitor.visit(tree) - next_call_str = visitor.func_name - except Exception: - next_call_str = "" - next_func = tensorflow_retrieve_object(calling_frame, next_call_str) - return next_func - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_permute_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - axes: Tuple[int, ...], - *, - copy: Optional[bool] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.transpose(x, perm=axes) - - -def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): - from .tensorflow_TransposeType import tensorflow_TransposeType - - if transpose is tensorflow_TransposeType.NO_TRANSPOSE: - return input - if transpose is tensorflow_TransposeType.CONV1D: - axes = (0, 2, 1) if pt_to_tf else (0, 2, 1) - elif transpose is tensorflow_TransposeType.CONV2D: - axes = (0, 2, 3, 1) if pt_to_tf else (0, 3, 1, 2) - elif transpose is tensorflow_TransposeType.CONV3D: - axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3) - input = tensorflow_permute_dims(input, axes=axes) - return input - - -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__stateful_layers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__stateful_layers.py deleted file mode 100644 index ce061b0e5584..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Conv2d_output/run_0/tensorflow__stateful_layers.py +++ /dev/null @@ -1,700 +0,0 @@ -from .tensorflow__helpers import tensorflow_handle_transpose_in_input_and_output -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__stateful import store_frame_info -import tensorflow as tf -import keras -import collections -from itertools import repeat -from numbers import Number -import os -from packaging.version import parse as parse_package - - -def parse(x): - n = 2 - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - -def _reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def _handle_padding_shape(padding, n, mode): - padding = tuple( - [ - (padding[i * 2], padding[i * 2 + 1]) - for i in range(int(len(padding) / 2) - 1, -1, -1) - ] - ) - if mode == "circular": - padding = padding + ((0, 0),) * (n - len(padding)) - else: - padding = ((0, 0),) * (n - len(padding)) + padding - if mode == "circular": - padding = tuple(list(padding)[::-1]) - return padding - - -def _to_tf_padding(pad_width, ndim): - if isinstance(pad_width, Number): - pad_width = [[pad_width] * 2] * ndim - elif len(pad_width) == 2 and isinstance(pad_width[0], Number): - pad_width = [pad_width] * ndim - elif ( - isinstance(pad_width, (list, tuple)) - and isinstance(pad_width[0], (list, tuple)) - and len(pad_width) < ndim - ): - pad_width = pad_width * ndim - return pad_width - - -@tensorflow_handle_array_like_without_promotion -def _pad( - input, - pad_width, - /, - *, - mode="constant", - stat_length=1, - constant_values=0, - end_values=0, - reflect_type="even", - **kwargs, -): - pad_width = _to_tf_padding(pad_width, len(input.shape)) - if not isinstance(constant_values, (tf.Variable, tf.Tensor)): - constant_values = tf.constant(constant_values) - if constant_values.dtype != input.dtype: - constant_values = tf.cast(constant_values, input.dtype) - return tf.pad(input, pad_width, mode=mode, constant_values=constant_values) - - -def torch_pad(input, pad, mode="constant", value=0): - # deal with any negative pad values - if any([pad_value < 0 for pad_value in pad]): - pad = list(pad) - slices = [] - for n in reversed(range(len(pad) // 2)): - i = n * 2 - j = i + 1 - start = None - stop = None - if pad[i] < 0: - start = -pad[i] - pad[i] = 0 - if pad[j] < 0: - stop = pad[j] - pad[j] = 0 - slices.append(slice(start, stop)) - ndim = len(input.shape) - while len(slices) < ndim: - slices.insert(0, slice(None)) - input = input[tuple(slices)] - - value = 0 if value is None else value - mode_dict = { - "constant": "constant", - "reflect": "reflect", - "replicate": "edge", - "circular": "wrap", - } - if mode not in mode_dict: - raise ValueError(f"Unsupported padding mode: {mode}") - pad = _handle_padding_shape(pad, len(input.shape), mode) - order = 0, 2, 3, 1 - pad = tuple(pad[i] for i in order) - return _pad(input, pad, mode=mode_dict[mode], constant_values=value) - - -def resolve_convolution(*args, **kwargs): - depthwise_multiplier = kwargs["groups"] // kwargs["filters"] - if depthwise_multiplier < 1: - return KerasConv2D(*args, **kwargs) - else: - return KerasDepthwiseConv2D(*args, **kwargs) - - -class KerasDepthwiseConv2D(tf.keras.layers.DepthwiseConv2D): - def __init__(self, *args, **kwargs): - kernel_size = kwargs.pop("kernel_size") - padding = kwargs.pop("padding", 0) - stride = kwargs.pop("strides", (1, 1)) - dilation = kwargs.pop("dilation_rate", (1, 1)) - data_format = kwargs.pop("data_format", "channels_last") - - self.padding_mode = kwargs.pop("padding_mode", "zeros") - self._padding = padding - self._previous_frame_info = None - - kernel_size_ = parse(kernel_size) - stride_ = parse(stride) - padding_ = padding if isinstance(padding, str) else parse(padding) - dilation_ = parse(dilation) - - # Call the original __init__ with the remaining args and kwargs - depth_multiplier = kwargs.pop("groups") // kwargs.pop("filters") - self.depth_multiplier = depth_multiplier - - # pytorch layers attributes - self.in_channels = kwargs.pop("in_channels") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - super().__init__( - *args, - kernel_size=kernel_size_, - strides=stride_, - dilation_rate=dilation_, - padding="valid", - depth_multiplier=depth_multiplier, - data_format=data_format, - **kwargs, - ) - - # Compute self._reversed_padding_repeated_twice - if isinstance(padding_, str): - self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size) - if padding == "same": - for d, k, i in zip( - self.dilation_rate, - self.kernel_size, - range(len(self.kernel_size) - 1, -1, -1), - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = _reverse_repeat_tuple(padding_, 2) - - depthwise_shape = self.kernel_size + ( - self.in_channels, - self.depth_multiplier, - ) - - # create placeholder weights on initialization - self.weight = tf.experimental.numpy.empty( - depthwise_shape, - dtype=tf.float32, - ) - - if self.use_bias: - self.bias = tf.experimental.numpy.empty( - (self.depth_multiplier * self.in_channels,), - dtype=tf.float32, - ) - else: - self.bias = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - - os.environ["DATA_FORMAT"] = "channels_first" - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "in_channels": self.in_channels, - "padding_mode": self.padding_mode, - "kernel_size": self.kernel_size, - "padding": self._padding, - "strides": self.strides, - "dilation_rate": self.dilation_rate, - "data_format": self.data_format, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - @store_frame_info - def __call__(self, *args, **kwargs): - if not self.built: - res = super().__call__(*args, **kwargs) - # recompute build shapes based on transposed input - order = (0, 2, 3, 1) - input_shape = args[0].shape - new_shape = tuple(input_shape[i] for i in order) - self._build_shapes_dict = {"input_shape": new_shape} - return res - return self.call(args[0]) - - def __repr__(self): - return "KerasDepthWiseConv2D()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - - if built: - if parse_package(keras.__version__).major > 2: - attr_map = {"weight": "kernel"} - else: - attr_map = {"weight": "depthwise_kernel"} - else: - attr_map = {"weight": "weight"} - - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - def build(self, input_shape): - _, ch, _, _ = input_shape - if ( - not self.built - and self.data_format == "channels_last" - and os.environ.get("DATA_FORMAT", "channels_first") == "channels_first" - ): - order = (0, 2, 3, 1) - new_shape = tuple(input_shape[i] for i in order) - input_shape = tf.TensorShape(new_shape) - - super().build(input_shape) - # modify the channel axis to avoid shape assertion checks by keras - self.input_spec.axes = {1: ch} - return - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - if self._padding != 0: - padding_mode = ( - "constant" if self.padding_mode == "zeros" else self.padding_mode - ) - # handle Pytorch-style padding - input = torch_pad( - input, self._reversed_padding_repeated_twice, mode=padding_mode - ) - - return super().call(input) - - -class KerasConv2D(tf.keras.layers.Conv2D): - def __init__(self, *args, **kwargs): - kernel_size = kwargs.pop("kernel_size") - padding = kwargs.pop("padding", 0) - stride = kwargs.pop("strides", (1, 1)) - dilation = kwargs.pop("dilation_rate", (1, 1)) - data_format = kwargs.pop("data_format", "channels_last") - - self.padding_mode = kwargs.pop("padding_mode", "zeros") - self._padding = padding - self._previous_frame_info = None - - kernel_size_ = parse(kernel_size) - stride_ = parse(stride) - padding_ = padding if isinstance(padding, str) else parse(padding) - dilation_ = parse(dilation) - - # pytorch layers attributes - self.in_channels = kwargs.pop("in_channels") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - # Call the original __init__ with the remaining args and kwargs - super().__init__( - *args, - kernel_size=kernel_size_, - strides=stride_, - dilation_rate=dilation_, - padding="valid", - data_format=data_format, - **kwargs, - ) - - # Compute self._reversed_padding_repeated_twice - if isinstance(padding_, str): - self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size) - if padding == "same": - for d, k, i in zip( - self.dilation_rate, - self.kernel_size, - range(len(self.kernel_size) - 1, -1, -1), - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = _reverse_repeat_tuple(padding_, 2) - - # create placeholder weights on initialization - self.weight = tf.experimental.numpy.empty( - (*kernel_size_, self.in_channels // kwargs["groups"], self.filters), - dtype=tf.float32, - ) - if self.use_bias: - self.bias = tf.experimental.numpy.empty((self.filters,), dtype=tf.float32) - else: - self.bias = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - - os.environ["DATA_FORMAT"] = "channels_first" - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "in_channels": self.in_channels, - "padding_mode": self.padding_mode, - "kernel_size": self.kernel_size, - "padding": self._padding, - "strides": self.strides, - "dilation_rate": self.dilation_rate, - "data_format": self.data_format, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - @store_frame_info - def __call__(self, *args, **kwargs): - if not self.built: - res = super().__call__(*args, **kwargs) - # recompute build shapes based on transposed input - order = (0, 2, 3, 1) - input_shape = args[0].shape - new_shape = tuple(input_shape[i] for i in order) - self._build_shapes_dict = {"input_shape": new_shape} - return res - return self.call(args[0]) - - def __repr__(self): - return "KerasConv2D()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - if built: - attr_map = {"weight": "kernel", "out_channels": "filters"} - else: - attr_map = { - "out_channels": "filters", - } - - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - def build(self, input_shape): - _, ch, _, _ = input_shape - if ( - not self.built - and self.data_format == "channels_last" - and os.environ.get("DATA_FORMAT", "channels_first") == "channels_first" - ): - order = (0, 2, 3, 1) - new_shape = tuple(input_shape[i] for i in order) - input_shape = tf.TensorShape(new_shape) - - super().build(input_shape) - # modify the channel axis to avoid shape assertion checks by keras - self.input_spec.axes = {1: ch} - return - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - if self._padding != 0: - padding_mode = ( - "constant" if self.padding_mode == "zeros" else self.padding_mode - ) - # handle Pytorch-style padding - input = torch_pad( - input, self._reversed_padding_repeated_twice, mode=padding_mode - ) - return super().call(input) - - -class KerasDense(tf.keras.layers.Dense): - def __init__(self, *args, **kwargs): - self._previous_frame_info = None - - # pytorch layer attributes - self.in_features = kwargs.pop("in_features") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - super().__init__(*args, **kwargs) - - # create placeholder weights on initialization - self.weight = tf.experimental.numpy.empty( - (self.units, self.in_features), dtype=tf.float32 - ) - if self.use_bias: - self.bias = tf.experimental.numpy.empty((self.units,), dtype=tf.float32) - else: - self.bias = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "in_features": self.in_features, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - def __repr__(self): - return "KerasDense()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - if built: - attr_map = {"weight": "kernel", "out_features": "units"} - else: - attr_map = {"out_features": "units"} - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - def build(self, input_shape): - super().build(input_shape) - return - - def call(self, input, training=False): - return super().call(input) - - -class KerasBatchNorm2D(tf.keras.layers.BatchNormalization): - def __init__(self, *args, **kwargs): - self._previous_frame_info = None - - # pytorch layer attributes - self.num_features = kwargs.pop("num_features") - self.track_running_stats = kwargs.pop("track_running_stats") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - super().__init__(*args, **kwargs) - - # create placeholder weights on initialization - if self.scale: - self.weight = tf.experimental.numpy.empty( - (self.num_features,), dtype=tf.float32 - ) - self.bias = tf.experimental.numpy.empty( - (self.num_features,), dtype=tf.float32 - ) - else: - self.weight = None - self.bias = None - - if self.track_running_stats: - self.running_mean = tf.experimental.numpy.zeros( - (self.num_features,), dtype=tf.float32 - ) - self.running_var = tf.experimental.numpy.ones( - (self.num_features,), dtype=tf.float32 - ) - self.num_batches_tracked = tf.constant(0, dtype=tf.int64) - else: - self.running_mean = None - self.running_var = None - self.num_batches_tracked = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - self.buffers["running_mean"] = self.running_mean - self.buffers["running_var"] = self.running_var - self.buffers["num_batches_tracked"] = self.num_batches_tracked - - os.environ["DATA_FORMAT"] = "channels_first" - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "num_features": self.num_features, - "track_running_stats": self.track_running_stats, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - def __repr__(self): - return "KerasBatchNorm2D()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - if built: - attr_map = { - "weight": "gamma", - "bias": "beta", - "running_mean": "moving_mean", - "running_var": "moving_variance", - } - else: - attr_map = {} - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - @store_frame_info - def __call__(self, *args, **kwargs): - if not self.built: - res = super().__call__(*args, **kwargs) - # recompute build shapes based on transposed input - order = (0, 2, 3, 1) - input_shape = args[0].shape - new_shape = tuple(input_shape[i] for i in order) - self._build_shapes_dict = {"input_shape": new_shape} - return res - return self.call(args[0]) - - def build(self, input_shape): - _, ch, _, _ = input_shape - if ( - not self.built - and self.axis == -1 - and os.environ.get("DATA_FORMAT", "channels_first") == "channels_first" - ): - order = (0, 2, 3, 1) - new_shape = tuple(input_shape[i] for i in order) - input_shape = tf.TensorShape(new_shape) - - super().build(input_shape) - # modify the channel axis to avoid shape assertion checks by keras - self.input_spec.axes = {1: ch} - return - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - return super().call(input, training=training) - - -class KerasReLU(tf.keras.layers.ReLU): - def __init__(self, *args, **kwargs): - self._previous_frame_info = None - super().__init__(*args, **kwargs) - - def __repr__(self): - return "KerasReLU()" - - @store_frame_info - def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - return super().call(input) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_CallVisitor.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_CallVisitor.py deleted file mode 100644 index 1e99977bd5b7..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_CallVisitor.py +++ /dev/null @@ -1,13 +0,0 @@ -import ast - -from .tensorflow__helpers import tensorflow_store_config_info - - -class tensorflow_CallVisitor(ast.NodeVisitor): - @tensorflow_store_config_info - def __init__(self): - self.func_name = None - - def visit_Call(self, node): - self.func_name = ast.unparse(node.func).strip() - return super().generic_visit(node) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_ConvTranspose2d.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_ConvTranspose2d.py deleted file mode 100644 index 176e9a1d8008..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_ConvTranspose2d.py +++ /dev/null @@ -1,72 +0,0 @@ -from .tensorflow__ConvTransposeNd import tensorflow__ConvTransposeNd -from .tensorflow__helpers import tensorflow__ntuple_parse -from .tensorflow__helpers import tensorflow_conv_transpose2d_frnt -from .tensorflow__helpers import tensorflow_handle_transpose_in_input_and_output - -_pair = tensorflow__ntuple_parse(2, "_pair") - - -class tensorflow_ConvTranspose2d(tensorflow__ConvTransposeNd): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - output_padding=0, - groups=1, - bias=True, - dilation=1, - padding_mode="zeros", - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - kernel_size = _pair(kernel_size) - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - output_padding = _pair(output_padding) - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - True, - output_padding, - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, output_size=None): - if self.padding_mode != "zeros": - raise ValueError( - "Only `zeros` padding mode is supported for ConvTranspose2d" - ) - assert isinstance(self.padding, tuple) - num_spatial_dims = 2 - output_padding = self._output_padding( - input, - output_size, - self.stride, - self.padding, - self.kernel_size, - num_spatial_dims, - self.dilation, - ) - return tensorflow_conv_transpose2d_frnt( - input, - self.weight, - self.bias, - self.stride, - self.padding, - output_padding, - self.groups, - self.dilation, - ) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index 9f87b4ae29ef..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Protocol -from typing import TypeVar - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_TransposeType.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_TransposeType.py deleted file mode 100644 index f380aaf0d6e0..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow_TransposeType.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import Enum - - -class tensorflow_TransposeType(Enum): - NO_TRANSPOSE = "no_transpose" - CONV1D = "conv1d" - CONV2D = "conv2d" - CONV3D = "conv3d" diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__ConvNd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__ConvNd.py deleted file mode 100644 index a866ab7b8f83..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__ConvNd.py +++ /dev/null @@ -1,541 +0,0 @@ -import tensorflow -from collections import OrderedDict - -import typing -import math -from typing import Optional - -from .tensorflow__stateful import Layer as tensorflow_keras_Layer -from .tensorflow__helpers import tensorflow__calculate_fan_in_and_fan_out -from .tensorflow__helpers import tensorflow__is_variable_bknd -from .tensorflow__helpers import tensorflow__reverse_repeat_tuple -from .tensorflow__helpers import tensorflow_add_frnt_ -from .tensorflow__helpers import tensorflow_default_bknd -from .tensorflow__helpers import tensorflow_empty_frnt -from .tensorflow__helpers import tensorflow_kaiming_uniform_ -from .tensorflow__helpers import tensorflow_set_item_bknd -from .tensorflow__helpers import tensorflow_split_frnt_ -from .tensorflow__helpers import tensorflow_store_config_info -from .tensorflow__helpers import tensorflow_uniform_ - - -class tensorflow__ConvNd(tensorflow_keras_Layer): - __constants__ = [ - "stride", - "padding", - "dilation", - "groups", - "padding_mode", - "output_padding", - "in_channels", - "out_channels", - "kernel_size", - ] - __annotations__ = {"bias": Optional[tensorflow.Variable]} - - def _conv_forward(self, input, weight, bias): ... - - in_channels: typing.Any - _reversed_padding_repeated_twice: typing.Any - out_channels: typing.Any - kernel_size: typing.Any - stride: typing.Any - padding: typing.Any - dilation: typing.Any - transposed: typing.Any - output_padding: typing.Any - groups: typing.Any - padding_mode: typing.Any - weight: typing.Any - bias: typing.Any - - @tensorflow_store_config_info - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - self.super___init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=device, - dtype=dtype, - v=getattr(self, "_v", None), - buffers=getattr(self, "_buffers", None), - module_dict=getattr(self, "_module_dict", None), - ) - if groups <= 0: - raise ValueError("groups must be a positive integer") - if in_channels % groups != 0: - raise ValueError("in_channels must be divisible by groups") - if out_channels % groups != 0: - raise ValueError("out_channels must be divisible by groups") - valid_padding_strings = {"same", "valid"} - if isinstance(padding, str): - if padding not in valid_padding_strings: - raise ValueError( - f"Invalid padding string {padding!r}, should be one of {valid_padding_strings}" - ) - if padding == "same" and any(s != 1 for s in stride): - raise ValueError( - "padding='same' is not supported for strided convolutions" - ) - valid_padding_modes = {"zeros", "reflect", "replicate", "circular"} - if padding_mode not in valid_padding_modes: - raise ValueError( - f"padding_mode must be one of {valid_padding_modes}, but got padding_mode='{padding_mode}'" - ) - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - self.dilation = dilation - self.transposed = transposed - self.output_padding = output_padding - self.groups = groups - self.padding_mode = padding_mode - if isinstance(self.padding, str): - self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) - if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - with tensorflow.name_scope("_reversed_padding_repeated_twice"): - self._reversed_padding_repeated_twice = ( - tensorflow_set_item_bknd( - self._reversed_padding_repeated_twice, 2 * i, left_pad - ) - ) - with tensorflow.name_scope("_reversed_padding_repeated_twice"): - self._reversed_padding_repeated_twice = ( - tensorflow_set_item_bknd( - self._reversed_padding_repeated_twice, - 2 * i + 1, - total_padding - left_pad, - ) - ) - else: - with tensorflow.name_scope("_reversed_padding_repeated_twice"): - self._reversed_padding_repeated_twice = ( - tensorflow__reverse_repeat_tuple(self.padding, 2) - ) - if transposed: - self.weight = tensorflow.Variable( - tensorflow_empty_frnt( - (*kernel_size, out_channels // groups, in_channels), - **factory_kwargs, - ), - name="weight", - ) - else: - self.weight = tensorflow.Variable( - tensorflow_empty_frnt( - (*kernel_size, in_channels // groups, out_channels), - **factory_kwargs, - ), - name="weight", - ) - if bias: - self.bias = tensorflow.Variable( - tensorflow_empty_frnt(out_channels, **factory_kwargs), name="bias" - ) - else: - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - tensorflow_kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - with tensorflow.name_scope(""): - fan_in, _ = tensorflow__calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - tensorflow_uniform_(self.bias, -bound, bound) - - def extra_repr(self): - s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" - if self.padding != (0,) * len(self.padding): - s = s + ", padding={padding}" - if self.dilation != (1,) * len(self.dilation): - s = s + ", dilation={dilation}" - if self.output_padding != (0,) * len(self.output_padding): - s = s + ", output_padding={output_padding}" - if self.groups != 1: - s = s + ", groups={groups}" - if self.bias is None: - s = s + ", bias=False" - if self.padding_mode != "zeros": - s = s + ", padding_mode={padding_mode}" - return s.format(**self.__dict__) - - def __setstate__(self, state): - super().__setstate__(state) - if not hasattr(self, "padding_mode"): - self.padding_mode = "zeros" - - def super___init__(self, *args, device=None, devices=None, **kwargs): - super().__init__( - *args, - device=device, - devices=devices, - training=True, - build_mode="explicit", - dynamic_backend=True, - **kwargs, - ) - super().__setattr__("_frontend_module", True) - super().__setattr__( - "_attr_mapping", {"_parameters": "v", "_modules": "module_dict"} - ) - - def __dir__(self): - module_attrs = dir(self.__class__) - attrs = list(self.__dict__.keys()) - parameters = list(self._v.keys()) - modules = list(self._module_dict.keys()) - buffers = list(self._buffers.keys()) - keys = module_attrs + attrs + parameters + modules + buffers - ag__result_list_0 = [] - for key in keys: - if not key[0].isdigit(): - res = key - ag__result_list_0.append(res) - keys = ag__result_list_0 - return sorted(keys) - - def __getattribute__(self, name): - if name == "__dict__": - return super().__getattribute__(name) - if "_module_dict" in self.__dict__: - modules = self.__dict__["_module_dict"] - if name in modules: - return modules[name] - if "_buffers" in self.__dict__: - buffers = self.__dict__["_buffers"] - if name in buffers: - return buffers[name] - if "_v" in self.__dict__: - v = self.__dict__["_v"] - if name in v: - return v[name] - if "_attr_mapping" in self.__dict__: - mapping = self.__dict__["_attr_mapping"] - if name in mapping: - return super().__getattribute__(mapping[name]) - return super().__getattribute__(name) - - def __getstate__(self): - state = self.__dict__.copy() - state.pop("_compiled_call_impl", None) - state.pop("_thread_local", None) - state.pop("_metrics_lock", None) - return state - - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - with tensorflow.name_scope("extra_lines"): - extra_lines = tensorflow_split_frnt_(extra_repr, "\n") - child_lines = [] - for key, module in self._module_dict.items(): - mod_str = repr(module) - mod_str = self._addindent(mod_str, 2) - child_lines.append("(" + key + "): " + mod_str) - lines = extra_lines + child_lines - main_str = self._get_name() + "(" - if lines: - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - main_str += ")" - return main_str - - def __setattr__(self, name, value): - def remove_from(*dicts_or_sets): - for d in dicts_or_sets: - if name in d: - if isinstance(d, dict): - del d[name] - else: - d.discard(name) - - params = self.__dict__.get("_v") - if ( - params is not None - and name in params - and isinstance(value, tensorflow.Variable) - ): - remove_from(self.__dict__, self._buffers, self._module_dict) - self.register_parameter(name, value) - super().__setattr__(name, value) - else: - super().__setattr__(name, value) - - def _build(self, *args, **kwargs): - for module in self.__dict__.values(): - if isinstance(module, tensorflow_keras_Layer) and module is not self: - if not module._built: - module.build( - *module._args, - dynamic_backend=module._dynamic_backend, - **module._kwargs, - ) - return True - - def _call_impl(self, *args, **kwargs): - return self.call(*args, **kwargs) - - def _create_variables(self, device=None, dtype=None): - with tensorflow.name_scope("v"): - v = dict( - OrderedDict( - [ - (k.replace(".", "/"), v) - for k, v in self.__dict__.items() - if isinstance(v, tensorflow.Variable) and not k.startswith("_") - ] - ) - ) - v = ( - dict( - OrderedDict( - { - _k.replace(".", "/"): _v - for _k, _v in self._v.items() - if _k.replace(".", "/") not in v and not isinstance(_v, dict) - }, - **v, - ) - ) - if self._v - else v - ) - return v - - def _extra_repr(self): - return "" - - def _forward(self, *a, **kw): - ret = self._call_impl(*a, **kw) - return ret - - def _get_name(self): - return self.__class__.__name__ - - def _named_members( - self, get_members_fn, prefix="", recurse=True, remove_duplicate=True - ): - memo = set() - modules = ( - self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) - if recurse - else [(prefix, self)] - ) - for module_prefix, module in modules: - members = get_members_fn(module) - for k, v in members: - if v is None or id(v) in memo: - continue - if remove_duplicate: - tensorflow_add_frnt_(memo, id(v)) - name = module_prefix + ("." if module_prefix else "") + k - yield name, v - - def _replace_update_v(self, new_v, native=None): - with tensorflow.name_scope("native"): - native = tensorflow_default_bknd(native, self) - for k, v in new_v.items(): - if isinstance(v, dict): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, v) - elif tensorflow__is_variable_bknd(v): - native.__setattr__(k, tensorflow.Variable(v)) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, tensorflow.Variable(v)) - else: - raise Exception( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def _update_v(self, new_v, native=None): - with tensorflow.name_scope("native"): - native = tensorflow_default_bknd(native, self) - for k, v in new_v.items(): - if isinstance(v, dict): - native.module_dict[k] = self._replace_update_v(v, native.module_dict[k]) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, v) - elif tensorflow__is_variable_bknd(v): - native.__setattr__(k, tensorflow.Variable(v)) - elif isinstance(v, tensorflow.Variable): - native.__setattr__(k, tensorflow.Variable(v)) - else: - raise Exception( - f"found item in variable container {v} which was neither a sub ivy.Container nor a variable." - ) - return native - - def add_module(self, name, module): - if ( - not isinstance( - module, (tensorflow_keras_Layer, tensorflow.keras.layers.Layer) - ) - and module is not None - ): - raise TypeError(f"{type(module)} is not a Module subclass") - elif not isinstance(name, str): - raise TypeError(f"module name should be a string. Got {type(name)}") - elif hasattr(self, name) and name not in self._modules: - raise KeyError(f"attribute '{name}' already exists") - elif "." in name: - raise KeyError(f'module name can\'t contain ".", got: {name}') - elif name == "": - raise KeyError('module name can\'t be empty string ""') - self._modules[name] = module - super().__setattr__(name, module) - - def apply(self, fn): - for module in self.children(): - if hasattr(module, "apply"): - module.apply(fn) - else: - fn(module) - fn(self) - return self - - def children(self): - for _, module in self.named_children(): - yield module - - def call(self, *input): - raise NotImplementedError( - f'Module [{type(self).__name__}] is missing the required "forward" function' - ) - - def get_parameter(self, target): - target = target.replace(".", "/") - return self.pt_v[target] - - def get_submodule(self, target): - if target == "": - return self - atoms: typing.Any = tensorflow_split_frnt_(target, ".") - mod: typing.Any = self - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, tensorflow_keras_Layer): - raise TypeError("`" + item + "` is not an nn.Module") - return mod - - def modules(self): - for _, module in self.named_modules(): - yield module - - def named_buffers(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.buffers.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def named_children(self): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - memo = set() - for name, module in self._module_dict.items(): - if module is not None and id(module) not in memo: - tensorflow_add_frnt_(memo, id(module)) - yield name, module - - def named_modules(self, memo=None, prefix="", remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - if memo is None: - memo = set() - if id(self) not in memo: - if remove_duplicate: - tensorflow_add_frnt_(memo, id(self)) - yield prefix, self - for name, module in self._module_dict.items(): - if module is None: - continue - submodule_prefix = prefix + ("." if prefix else "") + name - if not hasattr(module, "named_modules"): - yield submodule_prefix, self - else: - yield from module.named_modules( - memo, submodule_prefix, remove_duplicate - ) - - def named_parameters(self, prefix="", recurse=True, remove_duplicate=True): - if not getattr(self, "_built", False): - self.build( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - gen = self._named_members( - lambda module: module.v.items(), - prefix=prefix, - recurse=recurse, - remove_duplicate=remove_duplicate, - ) - yield from gen - - def parameters(self, recurse=True): - for _, param in self.named_parameters(recurse=recurse): - yield param - - def register_buffer(self, name, value, persistent=False): - super().register_buffer(name, value) - - def register_module(self, name, module): - self.add_module(name, module) - - def register_parameter(self, name, value): - super().register_parameter(name, value) - - def requires_grad_(self, requires_grad=True): - for p in self.parameters(): - p.requires_grad_(requires_grad) - return self diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__ConvTransposeNd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__ConvTransposeNd.py deleted file mode 100644 index 839dece5ad95..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__ConvTransposeNd.py +++ /dev/null @@ -1,117 +0,0 @@ -import tensorflow - - -from .tensorflow__ConvNd import tensorflow__ConvNd -from .tensorflow__helpers import tensorflow__ntuple_parse -from .tensorflow__helpers import tensorflow_dim_frnt_ -from .tensorflow__helpers import tensorflow_get_item -from .tensorflow__helpers import tensorflow_size_frnt_ -from .tensorflow__helpers import tensorflow_store_config_info - -_single = tensorflow__ntuple_parse(1, "_single") - - -class tensorflow__ConvTransposeNd(tensorflow__ConvNd): - @tensorflow_store_config_info - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - device=None, - dtype=None, - ): - if padding_mode != "zeros": - raise ValueError( - f'Only "zeros" padding mode is supported for {self.__class__.__name__}' - ) - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - bias, - padding_mode, - **factory_kwargs, - ) - - def _output_padding( - self, - input, - output_size, - stride, - padding, - kernel_size, - num_spatial_dims, - dilation=None, - ): - if output_size is None: - ret = _single(self.output_padding) - else: - with tensorflow.name_scope("has_batch_dim"): - has_batch_dim = tensorflow_dim_frnt_(input) == num_spatial_dims + 2 - num_non_spatial_dims = 2 if has_batch_dim else 1 - if len(output_size) == num_non_spatial_dims + num_spatial_dims: - with tensorflow.name_scope("output_size"): - output_size = tensorflow_get_item( - output_size, slice(num_non_spatial_dims, None, None) - ) - if len(output_size) != num_spatial_dims: - raise ValueError( - f"ConvTranspose{num_spatial_dims}D: for {tensorflow_dim_frnt_(input)}D input, output_size must have {num_spatial_dims} or {num_non_spatial_dims + num_spatial_dims} elements (got {len(output_size)})" - ) - min_sizes = [] - max_sizes = [] - for d in range(num_spatial_dims): - with tensorflow.name_scope("dim_size"): - dim_size = ( - (tensorflow_size_frnt_(input, d + num_non_spatial_dims) - 1) - * tensorflow_get_item(stride, d) - - 2 * tensorflow_get_item(padding, d) - + ( - tensorflow_get_item(dilation, d) - if dilation is not None - else 1 - ) - * (tensorflow_get_item(kernel_size, d) - 1) - + 1 - ) - min_sizes.append(dim_size) - max_sizes.append( - tensorflow_get_item(min_sizes, d) - + tensorflow_get_item(stride, d) - - 1 - ) - for i in range(len(output_size)): - with tensorflow.name_scope("size"): - size = tensorflow_get_item(output_size, i) - with tensorflow.name_scope("min_size"): - min_size = tensorflow_get_item(min_sizes, i) - with tensorflow.name_scope("max_size"): - max_size = tensorflow_get_item(max_sizes, i) - if size < min_size or size > max_size: - raise ValueError( - f"requested an output size of {output_size}, but valid sizes range from {min_sizes} to {max_sizes} (for an input of {tensorflow_size_frnt_(input)[2:]})" - ) - res = [] - for d in range(num_spatial_dims): - res.append( - tensorflow_get_item(output_size, d) - - tensorflow_get_item(min_sizes, d) - ) - ret = res - return ret diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__helpers.py deleted file mode 100644 index e64d53c999d9..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,4049 +0,0 @@ -from collections import UserDict -from itertools import repeat -from ivy.utils.backend import backend_stack -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import ast -import collections -import copy -import functools -import inspect -import itertools -import math -import numpy as np -import os -import re -import tensorflow -import tensorflow as tf -import warnings - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -backend_stack = [] -CONV_FUNCS = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", -] -NORM_FUNCS = [ - "_BatchNorm", - "_InstanceNorm", - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "GroupNorm", - "SyncBatchNorm", - "InstanceNorm1d", - "InstanceNorm2d", - "InstanceNorm3d", - "LocalResponseNorm", -] -POOL_FUNCS = [ - "MaxPool1d", - "MaxPool2d", - "MaxPool3d", - "AvgPool1d", - "AvgPool2d", - "AvgPool3d", - "FractionalMaxPool2d", - "LPPool1d", - "LPPool2d", - "AdaptiveMaxPool1d", - "AdaptiveMaxPool2d", - "AdaptiveMaxPool3d", - "AdaptiveAvgPool1d", - "AdaptiveAvgPool2d", - "AdaptiveAvgPool3d", -] -KERAS_CONV_FUNCS = [ - "KerasConv1D", - "KerasConv2D", - "KerasConv3D", - "KerasDepthwiseConv2D", - "KerasConv1DTranspose", - "KerasConv2DTranspose", - "KerasConv3DTranspose", -] -KERAS_NORM_FUNCS = [ - "KerasBatchNorm1D", - "KerasBatchNorm2D", - "KerasBatchNorm3D", - "KerasLayerNormalization", - "KerasGroupNormalization", - "KerasUnitNorm1D", - "KerasUnitNorm2D", - "KerasUnitNorm3D", -] -KERAS_POOL_FUNCS = [ - "KerasAveragePooling1D", - "KerasAveragePooling2D", - "KerasAveragePooling3D", - "KerasMaxPool1D", - "KerasMaxPool2D", - "KerasMaxPool3D", -] -PADDING_FUNCS = [ - "ReflectionPad1d", - "ReflectionPad2d", - "ReplicationPad1d", - "ReplicationPad2d", - "ReplicationPad3d", - "ZeroPad2d", - "ConstantPad1d", - "ConstantPad2d", - "ConstantPad3d", -] -KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] -ACTIVATION_FUNCS = [ - "ELU", - "Hardshrink", - "Hardsigmoid", - "Hardswish", - "Hardtanh", - "LeakyReLU", - "PReLU", - "ReLU", - "ReLU6", - "RReLU", - "SELU", - "CELU", - "GELU", - "Sigmoid", - "Softplus", - "Softshrink", - "Softsign", - "Tanh", - "Tanhshrink", - "Threshold", - "Softmin", - "Softmax", - "Softmax2d", - "LogSoftmax", - "AdaptiveLogSoftmaxWithLoss", -] -KERAS_ACTIVATION_FUNCS = [ - "KerasReLU", - "KerasPReLU", - "KerasLeakyReLU", - "KerasThresholdedReLU", - "KerasELU", - "KerasSoftmax", -] -DROPOUT_FUNCS = [ - "Dropout", - "Dropout2d", - "Dropout3d", - "AlphaDropout", - "FeatureAlphaDropout", -] -KERAS_DROPOUT_FUNCS = ["KerasDropout"] -CONV_BLOCK_FNS = [ - *CONV_FUNCS, - *KERAS_CONV_FUNCS, - *POOL_FUNCS, - *KERAS_POOL_FUNCS, - *PADDING_FUNCS, - *KERAS_PADDING_FUNCS, - *ACTIVATION_FUNCS, - *KERAS_ACTIVATION_FUNCS, - *NORM_FUNCS, - *KERAS_NORM_FUNCS, - *DROPOUT_FUNCS, - *KERAS_DROPOUT_FUNCS, -] -DATA_FORMAT = "PT" - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow__ntuple_parse(n, name="parse"): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - parse.__name__ = name - return parse - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -def tensorflow__reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def tensorflow_empty_frnt( - *args, - size=None, - out=None, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, - memory_format=None, -): - if args and size: - raise TypeError("empty() got multiple values for argument 'shape'") - if size is None: - size = ( - args[0] - if isinstance(args[0], (tuple, list, tuple, tf.TensorShape)) - else args - ) - if isinstance(size, (tuple, list)): - size = tuple( - tensorflow_to_scalar_bknd_(s) if tensorflow_is_array_bknd(s) else s - for s in size - ) - return tensorflow_empty(shape=size, dtype=dtype, device=device, out=out) - - -def tensorflow_store_config_info(fn): - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - fn(self, *args, **kwargs) - if all( - [ - hasattr(self, "_args"), - hasattr(self, "_kwargs"), - hasattr(self, "_self_tracked_trackables"), - ] - ): - orig_trackables = copy.copy(self._self_tracked_trackables) - self._args = (self,) + args - self._kwargs = kwargs - self._self_tracked_trackables = orig_trackables - - return wrapper - - -def tensorflow_ndim_bknd_(self): - return len(tuple(self.shape)) - - -def tensorflow_dim_frnt_(tensor): - return tensorflow_ndim_bknd_(tensor) - - -def tensorflow_size_frnt_(tensor, dim=None): - shape = tensor.shape - if dim is None: - return shape - try: - return tensorflow_get_item(shape, dim) - except IndexError as e: - raise IndexError( - f"Dimension out of range (expected to be in range of [{len(shape)}, {len(shape) - 1}], but got {dim}" - ) from e - - -def tensorflow__calculate_fan_in_and_fan_out(tensor): - dimensions = tensorflow_dim_frnt_(tensor) - if dimensions < 2: - raise ValueError( - "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" - ) - num_input_fmaps = tensorflow_size_frnt_(tensor, 1) - num_output_fmaps = tensorflow_size_frnt_(tensor, 0) - receptive_field_size = 1 - if tensorflow_dim_frnt_(tensor) > 2: - for s in tensor.shape[2:]: - receptive_field_size = receptive_field_size * s - fan_in = num_input_fmaps * receptive_field_size - fan_out = num_output_fmaps * receptive_field_size - return fan_in, fan_out - - -def tensorflow__calculate_correct_fan(tensor, mode): - mode = mode.lower() - valid_modes = ["fan_in", "fan_out"] - if mode not in valid_modes: - raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") - fan_in, fan_out = tensorflow__calculate_fan_in_and_fan_out(tensor) - return fan_in if mode == "fan_in" else fan_out - - -def tensorflow_calculate_gain(nonlinearity, param=None): - linear_fns = [ - "linear", - "conv1d", - "conv2d", - "conv3d", - "conv_transpose1d", - "conv_transpose2d", - "conv_transpose3d", - ] - if nonlinearity in linear_fns or nonlinearity == "sigmoid": - return 1 - elif nonlinearity == "tanh": - return 5.0 / 3 - elif nonlinearity == "relu": - return math.sqrt(2.0) - elif nonlinearity == "leaky_relu": - if param is None: - negative_slope = 0.01 - elif ( - not isinstance(param, bool) - and isinstance(param, int) - or isinstance(param, float) - ): - negative_slope = param - else: - raise ValueError(f"negative_slope {param} not a valid number") - return math.sqrt(2.0 / (1 + negative_slope**2)) - elif nonlinearity == "selu": - return 3.0 / 4 - else: - raise ValueError(f"Unsupported nonlinearity {nonlinearity}") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_all( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_all( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow_check_all(results, message="one of the args is False", as_array=True): - if as_array and not tensorflow_all(results) or not as_array and not all(results): - raise Exception(message) - - -def tensorflow_check_all_or_any_fn( - *args, - fn, - type="all", - limit=(0,), - message="args must exist according to type and limit given", - as_array=True, -): - if type == "all": - tensorflow_check_all([fn(arg) for arg in args], message, as_array=as_array) - elif type == "any": - count = 0 - for arg in args: - count = count + 1 if fn(arg) else count - if count not in limit: - raise Exception(message) - else: - raise Exception("type must be all or any") - - -def tensorflow__check_bounds_and_get_shape_bknd(low, high, shape): - if shape is not None: - tensorflow_check_all_or_any_fn( - low, - high, - fn=lambda x: isinstance(x, (int, float)), - type="all", - message="low and high bounds must be numerics when shape is specified", - ) - return tuple(shape) - valid_types = (tensorflow.Tensor,) - if len(backend_stack) == 0: - valid_types = valid_types + (tf.Tensor,) - else: - valid_types = valid_types + (tf.Tensor,) - if isinstance(low, valid_types): - if isinstance(high, valid_types): - tensorflow_check_equal( - tensorflow_shape(low), tensorflow_shape(high), as_array=False - ) - return tensorflow_shape(low) - if isinstance(high, valid_types): - return tensorflow_shape(high) - return tuple(()) - - -@tensorflow_infer_dtype -def tensorflow_random_uniform( - *, - low: Union[float, tensorflow.Tensor, tensorflow.Variable] = 0.0, - high: Union[float, tensorflow.Tensor, tensorflow.Variable, None] = 1.0, - shape: Optional[Union[tf.TensorShape, Sequence[int], tensorflow.Tensor]] = None, - dtype: tf.DType, - device: Optional[str] = None, - seed: Optional[int] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - shape = tensorflow__check_bounds_and_get_shape_bknd( - low, - ( - float( - tensorflow.experimental.numpy.finfo(tensorflow.float32).max - if dtype is None - else tensorflow.experimental.numpy.finfo(dtype).max - ) - if high is None - else high - ), - shape, - ) - low = tensorflow.cast(low, dtype) - if high is not None: - high = tensorflow.cast(high, dtype) - if seed: - tensorflow.random.set_seed(seed) - return tensorflow.random.uniform(shape, low, high, dtype=dtype, seed=seed) - - -def tensorflow_uniform__frnt_(tensor, from_=0, to=1, *, generator=None): - ret = tensorflow_random_uniform( - low=from_, high=to, shape=tensor.shape, dtype=tensor.dtype, seed=generator - ) - tensor = tensorflow_inplace_update(tensor, tensorflow_astype(ret, tensor.dtype)) - return tensor - - -def tensorflow_kaiming_uniform_( - tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None -): - if 0 in tensor.shape: - warnings.warn("Initializing zero-element tensors is a no-op") - return tensor - fan = tensorflow__calculate_correct_fan(tensor, mode) - gain = tensorflow_calculate_gain(nonlinearity, a) - std = gain / math.sqrt(fan) - bound = math.sqrt(3.0) * std - return tensorflow_uniform__frnt_(tensor, -bound, bound, generator=generator) - - -def tensorflow__no_grad_uniform_(tensor, a, b, generator=None): - return tensorflow_uniform__frnt_(tensor, a, b, generator=generator) - - -def tensorflow_uniform_(tensor, a=0.0, b=1.0, generator=None): - return tensorflow__no_grad_uniform_(tensor, a, b, generator) - - -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods_1 -def tensorflow_split_frnt(tensor, split_size_or_sections, dim=0): - if isinstance(split_size_or_sections, int): - split_size = split_size_or_sections - split_size_or_sections = [split_size] * ( - tensorflow_get_item(tensor.shape, dim) // split_size - ) - if tensorflow_get_item(tensor.shape, dim) % split_size: - split_size_or_sections.append( - tensorflow_get_item(tensor.shape, dim) % split_size - ) - return tuple( - tensorflow_split( - tensor, - num_or_size_splits=split_size_or_sections, - axis=dim, - with_remainder=True, - ) - ) - - -@tensorflow_handle_methods_1 -def tensorflow_split_frnt_(tensor, split_size, dim=0): - return tensorflow_split_frnt(tensor, split_size, dim) - - -@tensorflow_handle_methods -def tensorflow_add( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - alpha: Optional[Union[int, float]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - if x1.dtype.is_bool and x2.dtype.is_bool: - return tensorflow.math.logical_or(x1, x2) - if alpha not in (1, None): - x2 = tensorflow_multiply(x2, alpha) - return tensorflow.add(x1, x2) - - -@tensorflow_handle_methods_1 -def tensorflow_add_frnt(input, other, *, alpha=1, out=None): - return tensorflow_add(input, other, alpha=alpha, out=out) - - -@tensorflow_handle_methods_1 -def tensorflow_add_frnt_(tensor, other, *, alpha=1): - return tensorflow_add_frnt(tensor, other, alpha=alpha) - - -def tensorflow_current_backend_str(): - return "tensorflow" - - -def tensorflow__deconv_length_bknd( - dim_size, stride_size, kernel_size, padding, dilation=1 -): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - if padding == "SAME": - dim_size = dim_size * stride_size - else: - dim_size = dim_size * stride_size + max(kernel_size - stride_size, 0) - return dim_size - - -def tensorflow__transpose_out_pad( - x_shape, filter_shape, strides, padding, dims, dilations, data_format -): - if data_format[-1] == "C": - offset = 1 - else: - offset = 2 - dilations = [dilations] * dims if isinstance(dilations, int) else dilations - strides = [strides] * dims if isinstance(strides, int) else strides - if isinstance(padding, str): - out_shape = [ - tensorflow__deconv_length_bknd( - x_shape[offset + i], strides[i], filter_shape[i], padding, dilations[i] - ) - for i in range(dims) - ] - else: - if isinstance(padding, int): - padding = [[padding, padding]] * dims - out_shape = [ - ( - (x_shape[offset + i] - 1) * strides[i] - - padding[i][0] - - padding[i][1] - + dilations[i] * (filter_shape[i] - 1) - + 1 - ) - for i in range(dims) - ] - if data_format[-1] == "C": - padding = [[0, 0], *padding, [0, 0]] - else: - padding = [[0, 0], [0, 0], *padding] - if data_format[-1] == "C": - out_shape = [x_shape[0], *out_shape, filter_shape[-2]] - else: - out_shape = [x_shape[0], filter_shape[-2], *out_shape] - return out_shape, padding - - -def tensorflow_conv1d_transpose( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int]], - padding: str, - /, - *, - output_shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - filter_format: str = "channel_last", - data_format: str = "NWC", - dilations: Union[int, Tuple[int]] = 1, - bias: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if tensorflow_dev(x) == "cpu" and ( - dilations > 1 if isinstance(dilations, int) else any(d > 1 for d in dilations) - ): - raise Exception( - "Tensorflow does not support dilations greater than 1 when device is cpu" - ) - permuted_x = False - if data_format == "NCW" and tensorflow_dev(x) == "cpu": - x = tensorflow.transpose(x, (0, 2, 1)) - data_format = "NWC" - permuted_x = True - if filter_format == "channel_first": - filters = tensorflow.transpose(filters, (2, 1, 0)) - output_shape, padding = tensorflow__transpose_out_pad( - x.shape, filters.shape, strides, padding, 1, dilations, data_format - ) - res = tensorflow.nn.conv1d_transpose( - x, filters, output_shape, strides, padding, data_format, dilations - ) - if bias is not None: - if data_format[1] == "C": - bias = tensorflow.reshape(bias, [1, -1, 1]) - res = tensorflow.math.add(res, bias) - if permuted_x: - res = tensorflow.transpose(res, (0, 2, 1)) - return res - - -def tensorflow_conv2d_transpose( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int, int]], - padding: Union[str, int, Sequence[Tuple[int, int]]], - /, - *, - output_shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - filter_format: str = "channel_last", - data_format: str = "NHWC", - dilations: Union[int, Tuple[int, int]] = 1, - bias: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if tensorflow_dev(x) == "cpu" and ( - dilations > 1 if isinstance(dilations, int) else any(d > 1 for d in dilations) - ): - raise Exception( - "Tensorflow does not support dilations greater than 1 when device is cpu" - ) - permuted_x = False - if data_format == "NCHW" and tensorflow_dev(x) == "cpu": - x = tensorflow.transpose(x, (0, 2, 3, 1)) - data_format = "NHWC" - permuted_x = True - if filter_format == "channel_first": - filters = tensorflow.transpose(filters, (2, 3, 1, 0)) - output_shape, padding = tensorflow__transpose_out_pad( - x.shape, filters.shape, strides, padding, 2, dilations, data_format - ) - res = tensorflow.nn.conv2d_transpose( - x, filters, output_shape, strides, padding, data_format, dilations - ) - if bias is not None: - if data_format[1] == "C": - bias = tensorflow.reshape(bias, [1, -1, 1, 1]) - res = tensorflow.math.add(res, bias) - if permuted_x: - return tensorflow.transpose(res, (0, 3, 1, 2)) - return res - - -def tensorflow__extend_3d_strides_dilations(strides, dilations, data_format): - if data_format[-1] == "C": - strides = [1, *([strides] * 3 if isinstance(strides, int) else strides), 1] - dilations = [ - 1, - *([dilations] * 3 if isinstance(dilations, int) else dilations), - 1, - ] - else: - strides = [1, 1, *([strides] * 3 if isinstance(strides, int) else strides)] - dilations = [ - 1, - 1, - *([dilations] * 3 if isinstance(dilations, int) else dilations), - ] - return strides, dilations - - -def tensorflow_conv3d_transpose( - x: tf.Tensor, - filters: tf.Tensor, - strides: Union[int, Tuple[int, int, int]], - padding: str, - /, - *, - output_shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - filter_format: str = "channel_last", - data_format: str = "NDHWC", - dilations: Union[int, Tuple[int, int, int]] = 1, - bias: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if tensorflow_dev(x) == "cpu" and ( - dilations > 1 if isinstance(dilations, int) else any(d > 1 for d in dilations) - ): - raise Exception( - "Tensorflow does not support dilations greater than 1 when device is cpu" - ) - permuted_x = False - if data_format == "NCDHW" and tensorflow_dev(x) == "cpu": - x = tensorflow.transpose(x, (0, 2, 3, 4, 1)) - data_format = "NDHWC" - permuted_x = True - if filter_format == "channel_first": - filters = tensorflow.transpose(filters, (2, 3, 4, 1, 0)) - output_shape, padding = tensorflow__transpose_out_pad( - x.shape, filters.shape, strides, padding, 3, dilations, data_format - ) - strides, dilations = tensorflow__extend_3d_strides_dilations( - strides, dilations, data_format - ) - res = tensorflow.nn.conv3d_transpose( - x, filters, output_shape, strides, padding, data_format, dilations - ) - if bias is not None: - if data_format[1] == "C": - bias = tensorflow.reshape(bias, [1, -1, 1, 1, 1]) - res = tensorflow.math.add(res, bias) - if permuted_x: - return tensorflow.transpose(res, (0, 4, 1, 2, 3)) - return res - - -def tensorflow__get_x_data_format_bknd( - dims: int = 2, data_format: str = "channel_first" -): - if dims == 1: - if data_format == "channel_first": - return "NCW" - else: - return "NWC" - if dims == 2: - if data_format == "channel_first": - return "NCHW" - else: - return "NHWC" - elif dims == 3: - if data_format == "channel_first": - return "NCDHW" - else: - return "NDHWC" - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_conv_general_transpose( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int, int]], - padding: str, - /, - *, - dims: int = 2, - filter_format: str = "channel_last", - data_format: str = "channel_last", - output_shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - dilations: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] = 1, - feature_group_count: int = 1, - bias: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if feature_group_count == 1: - if dims == 1: - return tensorflow_conv1d_transpose( - x, - filters, - strides, - padding, - output_shape=output_shape, - filter_format=filter_format, - data_format="NWC" if data_format == "channel_last" else "NCW", - dilations=dilations, - bias=bias, - ) - elif dims == 2: - return tensorflow_conv2d_transpose( - x, - filters, - strides, - padding, - output_shape=output_shape, - filter_format=filter_format, - data_format="NHWC" if data_format == "channel_last" else "NCHW", - dilations=dilations, - bias=bias, - ) - else: - return tensorflow_conv3d_transpose( - x, - filters, - strides, - padding, - output_shape=output_shape, - filter_format=filter_format, - data_format="NDHWC" if data_format == "channel_last" else "NCDHW", - dilations=dilations, - bias=bias, - ) - else: - if filter_format == "channel_first": - filters = tensorflow.transpose(filters, (*range(2, dims + 2), 1, 0)) - permuted_x = False - if data_format == "channel_first" and tensorflow_dev(x) == "cpu": - x = tensorflow.transpose(x, (0, *range(2, dims + 2), 1)) - data_format = "channel_last" - permuted_x = True - data_format = tensorflow__get_x_data_format_bknd(dims, data_format) - output_shape, padding = tensorflow__transpose_out_pad( - x.shape, filters.shape, strides, padding, dims, dilations, data_format - ) - if dims == 1: - res = tensorflow.concat( - [ - tensorflow.nn.conv1d_transpose( - x[..., j : j + filters.shape[-2] // feature_group_count], - filters[ - ..., j : j + filters.shape[-2] // feature_group_count, : - ], - output_shape, - strides, - padding=padding, - data_format=data_format, - dilations=dilations, - ) - for j in range( - 0, filters.shape[-2], filters.shape[-2] // feature_group_count - ) - ], - axis=-1, - ) - elif dims == 2: - res = tensorflow.concat( - [ - tensorflow.nn.conv2d_transpose( - x[..., j : j + filters.shape[-2] // feature_group_count], - filters[ - ..., j : j + filters.shape[-2] // feature_group_count, : - ], - output_shape, - strides, - padding=padding, - data_format=data_format, - dilations=dilations, - ) - for j in range( - 0, filters.shape[-2], filters.shape[-2] // feature_group_count - ) - ], - axis=-1, - ) - else: - strides, dilations = tensorflow__extend_3d_strides_dilations( - strides, dilations, data_format - ) - res = tensorflow.concat( - [ - tensorflow.nn.conv3d_transpose( - x[..., j : j + filters.shape[-2] // feature_group_count], - filters[ - ..., j : j + filters.shape[-2] // feature_group_count, : - ], - output_shape, - strides, - padding=padding, - data_format=data_format, - dilations=dilations, - ) - for j in range( - 0, filters.shape[-2], filters.shape[-2] // feature_group_count - ) - ], - axis=-1, - ) - res = tensorflow.math.add(res, bias) if bias is not None else res - if permuted_x: - return tensorflow.transpose(res, (0, dims + 1, *range(1, dims + 1))) - return res - - -def tensorflow__get_transpose_pad_frnt(padding, output_padding, dims): - padding, output_padding = map( - lambda x: [x] * dims if isinstance(x, int) else x, [padding, output_padding] - ) - ag__result_list_0 = [] - for pad, output_pad in zip(padding, output_padding): - res = [pad, pad - output_pad] - ag__result_list_0.append(res) - asymmetric_padding = ag__result_list_0 - return asymmetric_padding - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_permute_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - axes: Tuple[int, ...], - *, - copy: Optional[bool] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.transpose(x, perm=axes) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_flip( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Optional[Union[int, Sequence[int]]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - num_dims = len(x.shape) - if not num_dims: - ret = x - else: - if axis is None: - new_axis = list(range(num_dims)) - else: - new_axis = axis - if isinstance(new_axis, int): - new_axis = [new_axis] - else: - new_axis = new_axis - new_axis = [(item + num_dims if item < 0 else item) for item in new_axis] - ret = tensorflow.reverse(x, new_axis) - return ret - - -def tensorflow__x_dil_before_conv(x, dims, x_dilations, data_format): - x_dilations = [x_dilations] * dims if isinstance(x_dilations, int) else x_dilations - ag__result_list_0 = [] - for i, x_dil in enumerate(x_dilations): - if x_dil > 1: - res = i - ag__result_list_0.append(res) - x_dilations_idxs = ag__result_list_0 - if x_dilations_idxs: - if data_format[-1] == "C": - offset = 1 - else: - offset = 2 - for i in x_dilations_idxs: - h = x.shape[offset + i] - new_height = h + (h - 1) * (x_dilations[i] - 1) - h = tensorflow.eye(new_height, dtype=x.dtype)[:: x_dilations[i]] - x = tensorflow.experimental.numpy.swapaxes(x, offset + i, -1) - x = tensorflow.matmul(x, h) - x = tensorflow.experimental.numpy.swapaxes(x, -1, offset + i) - return x - - -def tensorflow__extend_2d_padding(padding, data_format): - if isinstance(padding, str): - return padding - if isinstance(padding, int): - padding = [(padding, padding)] * 2 - if data_format[-1] == "C": - padding = [(0, 0)] + padding + [(0, 0)] - else: - padding = [(0, 0), (0, 0)] + padding - return padding - - -def tensorflow_depthwise_conv2d( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int, int]], - padding: Union[str, int, Sequence[Tuple[int, int]]], - /, - *, - data_format: str = "NHWC", - dilations: Union[int, Tuple[int, int]] = 1, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - strides = [strides] * 2 if isinstance(strides, int) else strides - dilations = [dilations] * 2 if isinstance(dilations, int) else dilations - permuted_x = False - if data_format == "NCHW" and tensorflow_dev(x) == "cpu": - x = tensorflow.transpose(x, (0, 2, 3, 1)) - data_format = "NHWC" - permuted_x = True - if tensorflow.rank(filters) == 3: - filters = tensorflow.expand_dims(filters, -1) - padding = tensorflow__extend_2d_padding(padding, data_format) - strides = [1, strides[0], strides[1], 1] - res = tensorflow.nn.depthwise_conv2d( - x, filters, strides, padding, data_format, dilations - ) - if permuted_x: - res = tensorflow.transpose(res, (0, 3, 1, 2)) - return res - - -def tensorflow__pad_before_conv(x, padding, dims, data_format): - if isinstance(padding, str): - return x, padding - elif isinstance(padding, int): - pad_list = [(padding, padding)] * dims - else: - pad_list = padding - if data_format[-1] == "C": - pad_list = [(0, 0), *pad_list, (0, 0)] - else: - pad_list = [(0, 0), (0, 0), *pad_list] - return tensorflow.pad(x, pad_list, "CONSTANT"), "VALID" - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_conv_general_dilated( - x: Union[tensorflow.Tensor, tensorflow.Variable], - filters: Union[tensorflow.Tensor, tensorflow.Variable], - strides: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]], - padding: Union[str, int, Sequence[Tuple[int, int]]], - /, - *, - dims: int = 2, - data_format: str = "channel_last", - filter_format: str = "channel_last", - feature_group_count: int = 1, - x_dilations: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] = 1, - dilations: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]] = 1, - bias: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if filter_format == "channel_first": - filters = tensorflow.transpose(filters, (*range(2, dims + 2), 1, 0)) - num_channels = x.shape[1] if data_format == "channel_first" else x.shape[-1] - if filters.shape[-2] != num_channels // feature_group_count: - raise Exception( - f"given feature_group_count {feature_group_count} expected input channel of the filter to be {num_channels // feature_group_count} but got {filters.shape[-2]}" - ) - if num_channels % feature_group_count != 0: - raise Exception( - f"input channel should be divisible by feature group count {feature_group_count} but got input channel {num_channels}" - ) - permuted_x = False - if data_format == "channel_first" and ( - tensorflow_dev(x) == "cpu" or feature_group_count != 1 - ): - x = tensorflow.transpose(x, (0, *range(2, dims + 2), 1)) - data_format = "channel_last" - permuted_x = True - data_format = tensorflow__get_x_data_format_bknd(dims, data_format) - x = tensorflow__x_dil_before_conv(x, dims, x_dilations, data_format) - if dims == 2: - padding = tensorflow__extend_2d_padding(padding, data_format) - if feature_group_count == 1: - res = tensorflow.nn.conv2d( - x, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - if not isinstance(padding, str): - padding = padding[1:-1] - res = tensorflow_depthwise_conv2d( - x, - tensorflow.transpose(filters, (0, 1, 3, 2)), - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - x, padding = tensorflow__pad_before_conv(x, padding, dims, data_format) - if dims == 1: - if feature_group_count == 1: - res = tensorflow.nn.conv1d( - x, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - res = tensorflow.concat( - [ - tensorflow.nn.conv1d( - x[..., i : i + filters.shape[-2]], - filters[ - ..., j : j + filters.shape[-1] // feature_group_count - ], - strides, - padding, - data_format, - dilations, - ) - for i, j in zip( - range(0, x.shape[-1], filters.shape[-2]), - range( - 0, - filters.shape[-1], - filters.shape[-1] // feature_group_count, - ), - ) - ], - axis=-1, - ) - else: - strides, dilations = tensorflow__extend_3d_strides_dilations( - strides, dilations, data_format - ) - if feature_group_count == 1: - res = tensorflow.nn.conv3d( - x, - filters, - strides, - padding, - data_format=data_format, - dilations=dilations, - ) - else: - res = tensorflow.concat( - [ - tensorflow.nn.conv3d( - x[..., i : i + filters.shape[-2]], - filters[ - ..., j : j + filters.shape[-1] // feature_group_count - ], - strides, - padding, - data_format, - dilations, - ) - for i, j in zip( - range(0, x.shape[-1], filters.shape[-2]), - range( - 0, - filters.shape[-1], - filters.shape[-1] // feature_group_count, - ), - ) - ], - axis=-1, - ) - if bias is not None: - if data_format[1] == "C": - bias = tensorflow.reshape(bias, [1, -1, *([1] * dims)]) - res = tensorflow.math.add(res, bias) - if permuted_x: - return tensorflow.transpose(res, (0, dims + 1, *range(1, dims + 1))) - return res - - -def tensorflow__conv_transpose_frnt( - input, - weight, - bias=None, - stride=1, - padding=0, - output_padding=0, - groups=1, - dilation=1, -): - dims = len(input.shape) - 2 - weight = tensorflow_permute_dims(weight, axes=(*range(2, dims + 2), 0, 1)) - for i in range(dims): - weight = tensorflow_flip(weight, axis=i) - padding, output_padding, stride, dilation = map( - lambda x: [x] * dims if isinstance(x, int) else x, - [padding, output_padding, stride, dilation], - ) - ag__result_list_0 = [] - for i in range(dims): - res = ( - (tensorflow_get_item(weight.shape, i) - 1) - * tensorflow_get_item(dilation, i) - + max( - [ - tensorflow_get_item(output_padding, i) - - tensorflow_get_item(padding, i), - 0, - ] - ), - ) * 2 - ag__result_list_0.append(res) - pad_widths = ag__result_list_0 - ret = tensorflow_conv_general_dilated( - input, - weight, - 1, - pad_widths, - dims=dims, - data_format="channel_last", - feature_group_count=groups, - x_dilations=stride, - dilations=dilation, - bias=bias, - ) - unpad_slice = (slice(None),) * 2 - for i in range(dims): - unpad_slice = unpad_slice + ( - slice( - max( - [ - tensorflow_get_item(padding, i) - - tensorflow_get_item(dilation, i) // 2, - tensorflow_get_item(padding, i), - tensorflow_get_item(output_padding, i), - ] - ), - tensorflow_get_item(ret.shape, 2 + i) - - tensorflow_get_item(padding, i) - + tensorflow_get_item(output_padding, i) - + tensorflow_get_item(dilation, i) // 2, - 1, - ), - ) - ret = tensorflow_get_item(ret, unpad_slice) - return ret - - -def tensorflow_conv_transpose2d_frnt( - input, - weight, - bias=None, - stride=1, - padding=0, - output_padding=0, - groups=1, - dilation=1, -): - if tensorflow_current_backend_str() in ["torch", "tensorflow"]: - return tensorflow_conv_general_transpose( - input, - weight, - stride, - tensorflow__get_transpose_pad_frnt(padding, output_padding, 2), - dims=2, - filter_format="channel_last", - data_format="channel_last", - dilations=dilation, - feature_group_count=groups, - bias=bias, - ) - else: - return tensorflow__conv_transpose_frnt( - input, - weight, - bias=bias, - stride=stride, - padding=padding, - output_padding=output_padding, - groups=groups, - dilation=dilation, - ) - - -def tensorflow_retrieve_object(frame, name): - if name is None: - return name - names = tensorflow_split_bknd_(name, ".") - obj = frame.f_locals.get(names[0]) or frame.f_globals.get(names[0]) - if obj is None: - return None - for attr in names[1:]: - try: - obj = getattr(obj, attr) - except AttributeError: - return None - return obj - - -def tensorflow_get_next_func(obj): - from .tensorflow_CallVisitor import tensorflow_CallVisitor - - stack = inspect.stack() - for frame_info in stack: - if frame_info == obj._previous_frame_info: - calling_frame = frame_info.frame - break - else: - return None - if "Sequential" in frame_info.filename: - try: - self_seq = calling_frame.f_locals["self"] - idx = calling_frame.f_locals["i"] - next_func = tensorflow_get_item(self_seq, idx + 1) - return next_func - except IndexError: - for frame_info in tensorflow_get_item( - stack, slice(stack.index(frame_info) + 1, None, None) - ): - if frame_info == self_seq._previous_frame_info: - calling_frame = frame_info.frame - break - else: - return None - lines, start_line_no = inspect.getsourcelines(calling_frame) - current_line_no = calling_frame.f_lineno - relative_line_no = current_line_no - start_line_no - try: - next_line = tensorflow_get_item(lines, relative_line_no + 1).strip() - tree = ast.parse(next_line) - visitor = tensorflow_CallVisitor() - visitor.visit(tree) - next_call_str = visitor.func_name - except Exception: - next_call_str = "" - next_func = tensorflow_retrieve_object(calling_frame, next_call_str) - return next_func - - -def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): - from .tensorflow_TransposeType import tensorflow_TransposeType - - if transpose is tensorflow_TransposeType.NO_TRANSPOSE: - return input - if transpose is tensorflow_TransposeType.CONV1D: - axes = (0, 2, 1) if pt_to_tf else (0, 2, 1) - elif transpose is tensorflow_TransposeType.CONV2D: - axes = (0, 2, 3, 1) if pt_to_tf else (0, 3, 1, 2) - elif transpose is tensorflow_TransposeType.CONV3D: - axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3) - input = tensorflow_permute_dims(input, axes=axes) - return input - - -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__stateful_layers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__stateful_layers.py deleted file mode 100644 index ce061b0e5584..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ConvTranspose2d_output/run_0/tensorflow__stateful_layers.py +++ /dev/null @@ -1,700 +0,0 @@ -from .tensorflow__helpers import tensorflow_handle_transpose_in_input_and_output -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__stateful import store_frame_info -import tensorflow as tf -import keras -import collections -from itertools import repeat -from numbers import Number -import os -from packaging.version import parse as parse_package - - -def parse(x): - n = 2 - if isinstance(x, collections.abc.Iterable): - return tuple(x) - return tuple(repeat(x, n)) - - -def _reverse_repeat_tuple(t, n): - return tuple(x for x in reversed(t) for _ in range(n)) - - -def _handle_padding_shape(padding, n, mode): - padding = tuple( - [ - (padding[i * 2], padding[i * 2 + 1]) - for i in range(int(len(padding) / 2) - 1, -1, -1) - ] - ) - if mode == "circular": - padding = padding + ((0, 0),) * (n - len(padding)) - else: - padding = ((0, 0),) * (n - len(padding)) + padding - if mode == "circular": - padding = tuple(list(padding)[::-1]) - return padding - - -def _to_tf_padding(pad_width, ndim): - if isinstance(pad_width, Number): - pad_width = [[pad_width] * 2] * ndim - elif len(pad_width) == 2 and isinstance(pad_width[0], Number): - pad_width = [pad_width] * ndim - elif ( - isinstance(pad_width, (list, tuple)) - and isinstance(pad_width[0], (list, tuple)) - and len(pad_width) < ndim - ): - pad_width = pad_width * ndim - return pad_width - - -@tensorflow_handle_array_like_without_promotion -def _pad( - input, - pad_width, - /, - *, - mode="constant", - stat_length=1, - constant_values=0, - end_values=0, - reflect_type="even", - **kwargs, -): - pad_width = _to_tf_padding(pad_width, len(input.shape)) - if not isinstance(constant_values, (tf.Variable, tf.Tensor)): - constant_values = tf.constant(constant_values) - if constant_values.dtype != input.dtype: - constant_values = tf.cast(constant_values, input.dtype) - return tf.pad(input, pad_width, mode=mode, constant_values=constant_values) - - -def torch_pad(input, pad, mode="constant", value=0): - # deal with any negative pad values - if any([pad_value < 0 for pad_value in pad]): - pad = list(pad) - slices = [] - for n in reversed(range(len(pad) // 2)): - i = n * 2 - j = i + 1 - start = None - stop = None - if pad[i] < 0: - start = -pad[i] - pad[i] = 0 - if pad[j] < 0: - stop = pad[j] - pad[j] = 0 - slices.append(slice(start, stop)) - ndim = len(input.shape) - while len(slices) < ndim: - slices.insert(0, slice(None)) - input = input[tuple(slices)] - - value = 0 if value is None else value - mode_dict = { - "constant": "constant", - "reflect": "reflect", - "replicate": "edge", - "circular": "wrap", - } - if mode not in mode_dict: - raise ValueError(f"Unsupported padding mode: {mode}") - pad = _handle_padding_shape(pad, len(input.shape), mode) - order = 0, 2, 3, 1 - pad = tuple(pad[i] for i in order) - return _pad(input, pad, mode=mode_dict[mode], constant_values=value) - - -def resolve_convolution(*args, **kwargs): - depthwise_multiplier = kwargs["groups"] // kwargs["filters"] - if depthwise_multiplier < 1: - return KerasConv2D(*args, **kwargs) - else: - return KerasDepthwiseConv2D(*args, **kwargs) - - -class KerasDepthwiseConv2D(tf.keras.layers.DepthwiseConv2D): - def __init__(self, *args, **kwargs): - kernel_size = kwargs.pop("kernel_size") - padding = kwargs.pop("padding", 0) - stride = kwargs.pop("strides", (1, 1)) - dilation = kwargs.pop("dilation_rate", (1, 1)) - data_format = kwargs.pop("data_format", "channels_last") - - self.padding_mode = kwargs.pop("padding_mode", "zeros") - self._padding = padding - self._previous_frame_info = None - - kernel_size_ = parse(kernel_size) - stride_ = parse(stride) - padding_ = padding if isinstance(padding, str) else parse(padding) - dilation_ = parse(dilation) - - # Call the original __init__ with the remaining args and kwargs - depth_multiplier = kwargs.pop("groups") // kwargs.pop("filters") - self.depth_multiplier = depth_multiplier - - # pytorch layers attributes - self.in_channels = kwargs.pop("in_channels") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - super().__init__( - *args, - kernel_size=kernel_size_, - strides=stride_, - dilation_rate=dilation_, - padding="valid", - depth_multiplier=depth_multiplier, - data_format=data_format, - **kwargs, - ) - - # Compute self._reversed_padding_repeated_twice - if isinstance(padding_, str): - self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size) - if padding == "same": - for d, k, i in zip( - self.dilation_rate, - self.kernel_size, - range(len(self.kernel_size) - 1, -1, -1), - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = _reverse_repeat_tuple(padding_, 2) - - depthwise_shape = self.kernel_size + ( - self.in_channels, - self.depth_multiplier, - ) - - # create placeholder weights on initialization - self.weight = tf.experimental.numpy.empty( - depthwise_shape, - dtype=tf.float32, - ) - - if self.use_bias: - self.bias = tf.experimental.numpy.empty( - (self.depth_multiplier * self.in_channels,), - dtype=tf.float32, - ) - else: - self.bias = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - - os.environ["DATA_FORMAT"] = "channels_first" - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "in_channels": self.in_channels, - "padding_mode": self.padding_mode, - "kernel_size": self.kernel_size, - "padding": self._padding, - "strides": self.strides, - "dilation_rate": self.dilation_rate, - "data_format": self.data_format, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - @store_frame_info - def __call__(self, *args, **kwargs): - if not self.built: - res = super().__call__(*args, **kwargs) - # recompute build shapes based on transposed input - order = (0, 2, 3, 1) - input_shape = args[0].shape - new_shape = tuple(input_shape[i] for i in order) - self._build_shapes_dict = {"input_shape": new_shape} - return res - return self.call(args[0]) - - def __repr__(self): - return "KerasDepthWiseConv2D()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - - if built: - if parse_package(keras.__version__).major > 2: - attr_map = {"weight": "kernel"} - else: - attr_map = {"weight": "depthwise_kernel"} - else: - attr_map = {"weight": "weight"} - - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - def build(self, input_shape): - _, ch, _, _ = input_shape - if ( - not self.built - and self.data_format == "channels_last" - and os.environ.get("DATA_FORMAT", "channels_first") == "channels_first" - ): - order = (0, 2, 3, 1) - new_shape = tuple(input_shape[i] for i in order) - input_shape = tf.TensorShape(new_shape) - - super().build(input_shape) - # modify the channel axis to avoid shape assertion checks by keras - self.input_spec.axes = {1: ch} - return - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - if self._padding != 0: - padding_mode = ( - "constant" if self.padding_mode == "zeros" else self.padding_mode - ) - # handle Pytorch-style padding - input = torch_pad( - input, self._reversed_padding_repeated_twice, mode=padding_mode - ) - - return super().call(input) - - -class KerasConv2D(tf.keras.layers.Conv2D): - def __init__(self, *args, **kwargs): - kernel_size = kwargs.pop("kernel_size") - padding = kwargs.pop("padding", 0) - stride = kwargs.pop("strides", (1, 1)) - dilation = kwargs.pop("dilation_rate", (1, 1)) - data_format = kwargs.pop("data_format", "channels_last") - - self.padding_mode = kwargs.pop("padding_mode", "zeros") - self._padding = padding - self._previous_frame_info = None - - kernel_size_ = parse(kernel_size) - stride_ = parse(stride) - padding_ = padding if isinstance(padding, str) else parse(padding) - dilation_ = parse(dilation) - - # pytorch layers attributes - self.in_channels = kwargs.pop("in_channels") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - # Call the original __init__ with the remaining args and kwargs - super().__init__( - *args, - kernel_size=kernel_size_, - strides=stride_, - dilation_rate=dilation_, - padding="valid", - data_format=data_format, - **kwargs, - ) - - # Compute self._reversed_padding_repeated_twice - if isinstance(padding_, str): - self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size) - if padding == "same": - for d, k, i in zip( - self.dilation_rate, - self.kernel_size, - range(len(self.kernel_size) - 1, -1, -1), - ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - self._reversed_padding_repeated_twice[2 * i] = left_pad - self._reversed_padding_repeated_twice[2 * i + 1] = ( - total_padding - left_pad - ) - else: - self._reversed_padding_repeated_twice = _reverse_repeat_tuple(padding_, 2) - - # create placeholder weights on initialization - self.weight = tf.experimental.numpy.empty( - (*kernel_size_, self.in_channels // kwargs["groups"], self.filters), - dtype=tf.float32, - ) - if self.use_bias: - self.bias = tf.experimental.numpy.empty((self.filters,), dtype=tf.float32) - else: - self.bias = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - - os.environ["DATA_FORMAT"] = "channels_first" - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "in_channels": self.in_channels, - "padding_mode": self.padding_mode, - "kernel_size": self.kernel_size, - "padding": self._padding, - "strides": self.strides, - "dilation_rate": self.dilation_rate, - "data_format": self.data_format, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - @store_frame_info - def __call__(self, *args, **kwargs): - if not self.built: - res = super().__call__(*args, **kwargs) - # recompute build shapes based on transposed input - order = (0, 2, 3, 1) - input_shape = args[0].shape - new_shape = tuple(input_shape[i] for i in order) - self._build_shapes_dict = {"input_shape": new_shape} - return res - return self.call(args[0]) - - def __repr__(self): - return "KerasConv2D()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - if built: - attr_map = {"weight": "kernel", "out_channels": "filters"} - else: - attr_map = { - "out_channels": "filters", - } - - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - def build(self, input_shape): - _, ch, _, _ = input_shape - if ( - not self.built - and self.data_format == "channels_last" - and os.environ.get("DATA_FORMAT", "channels_first") == "channels_first" - ): - order = (0, 2, 3, 1) - new_shape = tuple(input_shape[i] for i in order) - input_shape = tf.TensorShape(new_shape) - - super().build(input_shape) - # modify the channel axis to avoid shape assertion checks by keras - self.input_spec.axes = {1: ch} - return - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - if self._padding != 0: - padding_mode = ( - "constant" if self.padding_mode == "zeros" else self.padding_mode - ) - # handle Pytorch-style padding - input = torch_pad( - input, self._reversed_padding_repeated_twice, mode=padding_mode - ) - return super().call(input) - - -class KerasDense(tf.keras.layers.Dense): - def __init__(self, *args, **kwargs): - self._previous_frame_info = None - - # pytorch layer attributes - self.in_features = kwargs.pop("in_features") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - super().__init__(*args, **kwargs) - - # create placeholder weights on initialization - self.weight = tf.experimental.numpy.empty( - (self.units, self.in_features), dtype=tf.float32 - ) - if self.use_bias: - self.bias = tf.experimental.numpy.empty((self.units,), dtype=tf.float32) - else: - self.bias = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "in_features": self.in_features, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - def __repr__(self): - return "KerasDense()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - if built: - attr_map = {"weight": "kernel", "out_features": "units"} - else: - attr_map = {"out_features": "units"} - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - def build(self, input_shape): - super().build(input_shape) - return - - def call(self, input, training=False): - return super().call(input) - - -class KerasBatchNorm2D(tf.keras.layers.BatchNormalization): - def __init__(self, *args, **kwargs): - self._previous_frame_info = None - - # pytorch layer attributes - self.num_features = kwargs.pop("num_features") - self.track_running_stats = kwargs.pop("track_running_stats") - - # ivy.Module attributes - self._v = dict() - self._buffers = dict() - - super().__init__(*args, **kwargs) - - # create placeholder weights on initialization - if self.scale: - self.weight = tf.experimental.numpy.empty( - (self.num_features,), dtype=tf.float32 - ) - self.bias = tf.experimental.numpy.empty( - (self.num_features,), dtype=tf.float32 - ) - else: - self.weight = None - self.bias = None - - if self.track_running_stats: - self.running_mean = tf.experimental.numpy.zeros( - (self.num_features,), dtype=tf.float32 - ) - self.running_var = tf.experimental.numpy.ones( - (self.num_features,), dtype=tf.float32 - ) - self.num_batches_tracked = tf.constant(0, dtype=tf.int64) - else: - self.running_mean = None - self.running_var = None - self.num_batches_tracked = None - - self.v["weight"] = self.weight - self.v["bias"] = self.bias - self.buffers["running_mean"] = self.running_mean - self.buffers["running_var"] = self.running_var - self.buffers["num_batches_tracked"] = self.num_batches_tracked - - os.environ["DATA_FORMAT"] = "channels_first" - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - def named_parameters(self): - return {k: v for k, v in self.v.items() if v is not None} - - def named_buffers(self): - return {k: v for k, v in self.buffers.items() if v is not None} - - def eval(self): - self.trainable = False - - def get_config(self): - config = super().get_config() - config.update( - { - "num_features": self.num_features, - "track_running_stats": self.track_running_stats, - } - ) - return config - - @classmethod - def from_config(cls, config): - return cls(**config) - - def __repr__(self): - return "KerasBatchNorm2D()" - - def __setattr__(self, name, value): - if name in ["_v", "_buffers"]: - self.__dict__[name] = value - return - super().__setattr__(name, value) - - def __getattribute__(self, name): - built = object.__getattribute__(self, "__dict__").get("built", False) - if built: - attr_map = { - "weight": "gamma", - "bias": "beta", - "running_mean": "moving_mean", - "running_var": "moving_variance", - } - else: - attr_map = {} - new_name = attr_map[name] if name in attr_map else name - return super().__getattribute__(new_name) - - @store_frame_info - def __call__(self, *args, **kwargs): - if not self.built: - res = super().__call__(*args, **kwargs) - # recompute build shapes based on transposed input - order = (0, 2, 3, 1) - input_shape = args[0].shape - new_shape = tuple(input_shape[i] for i in order) - self._build_shapes_dict = {"input_shape": new_shape} - return res - return self.call(args[0]) - - def build(self, input_shape): - _, ch, _, _ = input_shape - if ( - not self.built - and self.axis == -1 - and os.environ.get("DATA_FORMAT", "channels_first") == "channels_first" - ): - order = (0, 2, 3, 1) - new_shape = tuple(input_shape[i] for i in order) - input_shape = tf.TensorShape(new_shape) - - super().build(input_shape) - # modify the channel axis to avoid shape assertion checks by keras - self.input_spec.axes = {1: ch} - return - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - return super().call(input, training=training) - - -class KerasReLU(tf.keras.layers.ReLU): - def __init__(self, *args, **kwargs): - self._previous_frame_info = None - super().__init__(*args, **kwargs) - - def __repr__(self): - return "KerasReLU()" - - @store_frame_info - def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) - - @tensorflow_handle_transpose_in_input_and_output - def call(self, input, training=False): - return super().call(input) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Dropout2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Dropout2d_output/run_0/tensorflow__helpers.py index a80a84f6e2be..d4f0cdf3c131 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Dropout2d_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Dropout2d_output/run_0/tensorflow__helpers.py @@ -26,214 +26,6 @@ import tensorflow as tf -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} CONV_FUNCS = [ "Conv1d", "Conv2d", @@ -371,6 +163,96 @@ DATA_FORMAT = "PT" +def tensorflow_handle_transpose_in_input_and_output(fn): + from .tensorflow_TransposeType import tensorflow_TransposeType + + original_signature = inspect.signature(fn) + + @functools.wraps(fn) + def transpose_wrapper(self, *args, **kwargs): + global DATA_FORMAT + kwargs_call = { + key: val + for key, val in kwargs.items() + if key not in dict(original_signature.parameters) + } + fn_args_and_kwargs = { + key: val for key, val in kwargs.items() if key not in kwargs_call + } + fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) + conv_block_start = lambda f: any( + substr in f.__qualname__ + for substr in CONV_FUNCS + + NORM_FUNCS + + POOL_FUNCS + + KERAS_CONV_FUNCS + + KERAS_NORM_FUNCS + + KERAS_POOL_FUNCS + ) + next_call_in_seq = tensorflow_get_next_func(self) + name_of_next_call = ( + next_call_in_seq.__class__.__name__ + if hasattr(next_call_in_seq, "__class__") + else "" + ) + conv_block_continued = next_call_in_seq and any( + substr in name_of_next_call for substr in CONV_BLOCK_FNS + ) + if DATA_FORMAT == "PT" and conv_block_start(self.__class__): + input = fn_args_and_kwargs["input"] + if len(input.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(input.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(input.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + fn_args_and_kwargs = tensorflow_set_item_bknd( + fn_args_and_kwargs, + "input", + tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), + ) + DATA_FORMAT = "TF" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_last" + ) + res = fn(self, **fn_args_and_kwargs) + if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": + return res + if len(res.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(res.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(res.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) + DATA_FORMAT = "PT" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_first" + ) + return res + + tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature + return transpose_wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + def tensorflow_handle_array_like_without_promotion(fn: Callable): @functools.wraps(fn) def _handle_array_like_without_promotion(*args, **kwargs): @@ -412,8 +294,231 @@ def _handle_array_like_without_promotion(*args, **kwargs): ) return fn(*args, **kwargs) - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +promotion_table = { + ("bool", "bool"): "bool", + ("int8", "int8"): "int8", + ("int8", "int16"): "int16", + ("int8", "int32"): "int32", + ("int8", "int64"): "int64", + ("int16", "int16"): "int16", + ("int16", "int32"): "int32", + ("int16", "int64"): "int64", + ("int32", "int32"): "int32", + ("int32", "int64"): "int64", + ("int64", "int64"): "int64", + ("uint8", "int8"): "int16", + ("uint8", "int16"): "int16", + ("uint8", "int32"): "int32", + ("uint8", "int64"): "int64", + ("uint8", "uint8"): "uint8", + ("uint8", "uint16"): "uint16", + ("uint8", "uint32"): "uint32", + ("uint8", "uint64"): "uint64", + ("uint16", "int8"): "int32", + ("uint16", "int16"): "int32", + ("uint16", "int32"): "int32", + ("uint16", "int64"): "int64", + ("uint16", "uint16"): "uint16", + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "int8"): "int64", + ("uint32", "int16"): "int64", + ("uint32", "int32"): "int64", + ("uint32", "int64"): "int64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", + ("float16", "float16"): "float16", + ("float16", "float32"): "float32", + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", + ("bool", "int8"): "int8", + ("bool", "int16"): "int16", + ("bool", "int32"): "int32", + ("bool", "int64"): "int64", + ("bool", "uint8"): "uint8", + ("bool", "uint16"): "uint16", + ("bool", "uint32"): "uint32", + ("bool", "uint64"): "uint64", + ("bool", "float16"): "float16", + ("bool", "float32"): "float32", + ("bool", "float64"): "float64", + ("bool", "bfloat16"): "bfloat16", + ("bool", "complex64"): "complex64", + ("bool", "complex128"): "complex128", + ("int8", "float16"): "float16", + ("int8", "float32"): "float32", + ("int8", "float64"): "float64", + ("int8", "bfloat16"): "bfloat16", + ("int8", "complex64"): "complex64", + ("int8", "complex128"): "complex128", + ("int16", "float32"): "float32", + ("int16", "float64"): "float64", + ("int16", "complex64"): "complex64", + ("int16", "complex128"): "complex128", + ("int32", "float64"): "float64", + ("int32", "complex128"): "complex128", + ("int64", "float64"): "float64", + ("int64", "complex128"): "complex128", + ("uint8", "float16"): "float16", + ("uint8", "float32"): "float32", + ("uint8", "float64"): "float64", + ("uint8", "bfloat16"): "bfloat16", + ("uint8", "complex64"): "complex64", + ("uint8", "complex128"): "complex128", + ("uint16", "float32"): "float32", + ("uint16", "float64"): "float64", + ("uint16", "complex64"): "complex64", + ("uint16", "complex128"): "complex128", + ("uint32", "float64"): "float64", + ("uint32", "complex128"): "complex128", + ("uint64", "int8"): "float64", + ("uint64", "int16"): "float64", + ("uint64", "int32"): "float64", + ("uint64", "int64"): "float64", + ("uint64", "float64"): "float64", + ("uint64", "complex128"): "complex128", + ("float16", "bfloat16"): "float32", + ("float16", "complex64"): "complex64", + ("float16", "complex128"): "complex128", + ("float32", "complex64"): "complex64", + ("float32", "complex128"): "complex128", + ("float64", "complex64"): "complex128", + ("float64", "complex128"): "complex128", + ("bfloat16", "float16"): "float32", + ("bfloat16", "float32"): "float32", + ("bfloat16", "float64"): "float64", + ("bfloat16", "bfloat16"): "bfloat16", + ("bfloat16", "complex64"): "complex64", + ("bfloat16", "complex128"): "complex128", + ("complex64", "float64"): "complex128", + ("complex64", "complex64"): "complex64", + ("complex64", "complex128"): "complex128", + ("complex128", "complex128"): "complex128", + ("float16", "int16"): "float32", + ("float16", "int32"): "float64", + ("float16", "int64"): "float64", + ("float16", "uint16"): "float32", + ("float16", "uint32"): "float64", + ("float16", "uint64"): "float64", + ("float32", "int32"): "float64", + ("float32", "int64"): "float64", + ("float32", "uint32"): "float64", + ("float32", "uint64"): "float64", + ("bfloat16", "int16"): "float32", + ("bfloat16", "int32"): "float64", + ("bfloat16", "int64"): "float64", + ("bfloat16", "uint16"): "float32", + ("bfloat16", "uint32"): "float64", + ("bfloat16", "uint64"): "float64", + ("complex64", "int32"): "complex128", + ("complex64", "int64"): "complex128", + ("complex64", "uint32"): "complex128", + ("complex64", "uint64"): "complex128", +} + +array_api_promotion_table = { + ("bool", "bool"): "bool", + ("int8", "int8"): "int8", + ("int8", "int16"): "int16", + ("int8", "int32"): "int32", + ("int8", "int64"): "int64", + ("int16", "int16"): "int16", + ("int16", "int32"): "int32", + ("int16", "int64"): "int64", + ("int32", "int32"): "int32", + ("int32", "int64"): "int64", + ("int64", "int64"): "int64", + ("uint8", "int8"): "int16", + ("uint8", "int16"): "int16", + ("uint8", "int32"): "int32", + ("uint8", "int64"): "int64", + ("uint8", "uint8"): "uint8", + ("uint8", "uint16"): "uint16", + ("uint8", "uint32"): "uint32", + ("uint8", "uint64"): "uint64", + ("uint16", "int8"): "int32", + ("uint16", "int16"): "int32", + ("uint16", "int32"): "int32", + ("uint16", "int64"): "int64", + ("uint16", "uint16"): "uint16", + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "int8"): "int64", + ("uint32", "int16"): "int64", + ("uint32", "int32"): "int64", + ("uint32", "int64"): "int64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", + ("float16", "float16"): "float16", + ("float16", "float32"): "float32", + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", +} + +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_store_config_info(fn): @@ -459,25 +564,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -510,7 +596,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -672,26 +760,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -809,6 +879,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -915,27 +988,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1104,6 +1171,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1360,7 +1430,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1772,7 +1844,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1940,6 +2014,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1964,11 +2041,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2202,7 +2277,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2362,11 +2439,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2406,21 +2481,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2501,6 +2561,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2557,6 +2620,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2601,6 +2667,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2655,6 +2724,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2691,6 +2779,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2713,21 +2805,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2765,6 +2853,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2816,20 +2923,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2888,7 +2981,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2996,82 +3091,6 @@ def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): return input -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper - - def tensorflow_ndim_bknd_(self): return len(tuple(self.shape)) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow__helpers.py index fa4402c245a2..b8ad4d1c9813 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_LayerNorm_output/run_0/tensorflow__helpers.py @@ -23,6 +23,118 @@ import tensorflow as tf +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +259,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +301,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +361,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +476,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +490,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +516,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +645,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +692,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +749,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +804,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +830,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +878,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +948,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +1014,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1133,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1242,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1425,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1684,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2098,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2135,6 +2185,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2159,11 +2212,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2371,7 +2422,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2531,11 +2584,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2575,21 +2626,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2815,7 +2851,9 @@ def tensorflow_equal( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2887,7 +2925,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2944,25 +2984,6 @@ def tensorflow_layer_norm_frnt( return tensorflow_layer_norm_bknd(input, axis, scale=weight, offset=bias, eps=eps) -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 def tensorflow_split_frnt(tensor, split_size_or_sections, dim=0): if isinstance(split_size_or_sections, int): diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Linear_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Linear_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Linear_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Linear_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_MaxPool2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_MaxPool2d_output/run_0/tensorflow__helpers.py index 04e6ced465b8..c7f895779aba 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_MaxPool2d_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_MaxPool2d_output/run_0/tensorflow__helpers.py @@ -27,6 +27,360 @@ import tensorflow as tf +CONV_FUNCS = [ + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", +] +NORM_FUNCS = [ + "_BatchNorm", + "_InstanceNorm", + "BatchNorm1d", + "BatchNorm2d", + "BatchNorm3d", + "GroupNorm", + "SyncBatchNorm", + "InstanceNorm1d", + "InstanceNorm2d", + "InstanceNorm3d", + "LocalResponseNorm", +] +POOL_FUNCS = [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "AvgPool1d", + "AvgPool2d", + "AvgPool3d", + "FractionalMaxPool2d", + "LPPool1d", + "LPPool2d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + "AdaptiveAvgPool1d", + "AdaptiveAvgPool2d", + "AdaptiveAvgPool3d", +] +KERAS_CONV_FUNCS = [ + "KerasConv1D", + "KerasConv2D", + "KerasConv3D", + "KerasDepthwiseConv2D", + "KerasConv1DTranspose", + "KerasConv2DTranspose", + "KerasConv3DTranspose", +] +KERAS_NORM_FUNCS = [ + "KerasBatchNorm1D", + "KerasBatchNorm2D", + "KerasBatchNorm3D", + "KerasLayerNormalization", + "KerasGroupNormalization", + "KerasUnitNorm1D", + "KerasUnitNorm2D", + "KerasUnitNorm3D", +] +KERAS_POOL_FUNCS = [ + "KerasAveragePooling1D", + "KerasAveragePooling2D", + "KerasAveragePooling3D", + "KerasMaxPool1D", + "KerasMaxPool2D", + "KerasMaxPool3D", +] +PADDING_FUNCS = [ + "ReflectionPad1d", + "ReflectionPad2d", + "ReplicationPad1d", + "ReplicationPad2d", + "ReplicationPad3d", + "ZeroPad2d", + "ConstantPad1d", + "ConstantPad2d", + "ConstantPad3d", +] +KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] +ACTIVATION_FUNCS = [ + "ELU", + "Hardshrink", + "Hardsigmoid", + "Hardswish", + "Hardtanh", + "LeakyReLU", + "PReLU", + "ReLU", + "ReLU6", + "RReLU", + "SELU", + "CELU", + "GELU", + "Sigmoid", + "Softplus", + "Softshrink", + "Softsign", + "Tanh", + "Tanhshrink", + "Threshold", + "Softmin", + "Softmax", + "Softmax2d", + "LogSoftmax", + "AdaptiveLogSoftmaxWithLoss", +] +KERAS_ACTIVATION_FUNCS = [ + "KerasReLU", + "KerasPReLU", + "KerasLeakyReLU", + "KerasThresholdedReLU", + "KerasELU", + "KerasSoftmax", +] +DROPOUT_FUNCS = [ + "Dropout", + "Dropout2d", + "Dropout3d", + "AlphaDropout", + "FeatureAlphaDropout", +] +KERAS_DROPOUT_FUNCS = ["KerasDropout"] +CONV_BLOCK_FNS = [ + *CONV_FUNCS, + *KERAS_CONV_FUNCS, + *POOL_FUNCS, + *KERAS_POOL_FUNCS, + *PADDING_FUNCS, + *KERAS_PADDING_FUNCS, + *ACTIVATION_FUNCS, + *KERAS_ACTIVATION_FUNCS, + *NORM_FUNCS, + *KERAS_NORM_FUNCS, + *DROPOUT_FUNCS, + *KERAS_DROPOUT_FUNCS, +] +DATA_FORMAT = "PT" + + +def tensorflow_handle_transpose_in_input_and_output(fn): + from .tensorflow_TransposeType import tensorflow_TransposeType + + original_signature = inspect.signature(fn) + + @functools.wraps(fn) + def transpose_wrapper(self, *args, **kwargs): + global DATA_FORMAT + kwargs_call = { + key: val + for key, val in kwargs.items() + if key not in dict(original_signature.parameters) + } + fn_args_and_kwargs = { + key: val for key, val in kwargs.items() if key not in kwargs_call + } + fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) + conv_block_start = lambda f: any( + substr in f.__qualname__ + for substr in CONV_FUNCS + + NORM_FUNCS + + POOL_FUNCS + + KERAS_CONV_FUNCS + + KERAS_NORM_FUNCS + + KERAS_POOL_FUNCS + ) + next_call_in_seq = tensorflow_get_next_func(self) + name_of_next_call = ( + next_call_in_seq.__class__.__name__ + if hasattr(next_call_in_seq, "__class__") + else "" + ) + conv_block_continued = next_call_in_seq and any( + substr in name_of_next_call for substr in CONV_BLOCK_FNS + ) + if DATA_FORMAT == "PT" and conv_block_start(self.__class__): + input = fn_args_and_kwargs["input"] + if len(input.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(input.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(input.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + fn_args_and_kwargs = tensorflow_set_item_bknd( + fn_args_and_kwargs, + "input", + tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), + ) + DATA_FORMAT = "TF" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_last" + ) + res = fn(self, **fn_args_and_kwargs) + if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": + return res + if len(res.shape) > 4: + transpose = tensorflow_TransposeType.CONV3D + elif len(res.shape) > 3: + transpose = tensorflow_TransposeType.CONV2D + elif len(res.shape) > 2: + transpose = tensorflow_TransposeType.CONV1D + else: + transpose = tensorflow_TransposeType.NO_TRANSPOSE + res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) + DATA_FORMAT = "PT" + os.environ = tensorflow_set_item_bknd( + os.environ, "DATA_FORMAT", "channels_first" + ) + return res + + tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature + return transpose_wrapper + + +def tensorflow__handle_padding_shape_frnt(padding, n, mode): + ag__result_list_0 = [] + for i in range(int(len(padding) / 2) - 1, -1, -1): + res = ( + tensorflow_get_item(padding, i * 2), + tensorflow_get_item(padding, i * 2 + 1), + ) + ag__result_list_0.append(res) + padding = tuple(ag__result_list_0) + if mode == "circular": + padding = padding + ((0, 0),) * (n - len(padding)) + else: + padding = ((0, 0),) * (n - len(padding)) + padding + if mode == "circular": + padding = tuple(list(padding)[::-1]) + return padding + + +def tensorflow__handle_padding_bknd(x, strides, filters, padding): + if isinstance(padding, str) and padding.upper() == "SAME": + if x % strides == 0: + pad = max(filters - strides, 0) + else: + pad = max(filters - x % strides, 0) + else: + pad = 0 + return pad + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -151,6 +505,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -176,245 +531,24 @@ ("uint16", "int32"): "int32", ("uint16", "int64"): "int64", ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -CONV_FUNCS = [ - "Conv1d", - "Conv2d", - "Conv3d", - "ConvTranspose1d", - "ConvTranspose2d", - "ConvTranspose3d", -] -NORM_FUNCS = [ - "_BatchNorm", - "_InstanceNorm", - "BatchNorm1d", - "BatchNorm2d", - "BatchNorm3d", - "GroupNorm", - "SyncBatchNorm", - "InstanceNorm1d", - "InstanceNorm2d", - "InstanceNorm3d", - "LocalResponseNorm", -] -POOL_FUNCS = [ - "MaxPool1d", - "MaxPool2d", - "MaxPool3d", - "AvgPool1d", - "AvgPool2d", - "AvgPool3d", - "FractionalMaxPool2d", - "LPPool1d", - "LPPool2d", - "AdaptiveMaxPool1d", - "AdaptiveMaxPool2d", - "AdaptiveMaxPool3d", - "AdaptiveAvgPool1d", - "AdaptiveAvgPool2d", - "AdaptiveAvgPool3d", -] -KERAS_CONV_FUNCS = [ - "KerasConv1D", - "KerasConv2D", - "KerasConv3D", - "KerasDepthwiseConv2D", - "KerasConv1DTranspose", - "KerasConv2DTranspose", - "KerasConv3DTranspose", -] -KERAS_NORM_FUNCS = [ - "KerasBatchNorm1D", - "KerasBatchNorm2D", - "KerasBatchNorm3D", - "KerasLayerNormalization", - "KerasGroupNormalization", - "KerasUnitNorm1D", - "KerasUnitNorm2D", - "KerasUnitNorm3D", -] -KERAS_POOL_FUNCS = [ - "KerasAveragePooling1D", - "KerasAveragePooling2D", - "KerasAveragePooling3D", - "KerasMaxPool1D", - "KerasMaxPool2D", - "KerasMaxPool3D", -] -PADDING_FUNCS = [ - "ReflectionPad1d", - "ReflectionPad2d", - "ReplicationPad1d", - "ReplicationPad2d", - "ReplicationPad3d", - "ZeroPad2d", - "ConstantPad1d", - "ConstantPad2d", - "ConstantPad3d", -] -KERAS_PADDING_FUNCS = ["KerasZeroPadding1D", "KerasZeroPadding2D", "KerasZeroPadding3D"] -ACTIVATION_FUNCS = [ - "ELU", - "Hardshrink", - "Hardsigmoid", - "Hardswish", - "Hardtanh", - "LeakyReLU", - "PReLU", - "ReLU", - "ReLU6", - "RReLU", - "SELU", - "CELU", - "GELU", - "Sigmoid", - "Softplus", - "Softshrink", - "Softsign", - "Tanh", - "Tanhshrink", - "Threshold", - "Softmin", - "Softmax", - "Softmax2d", - "LogSoftmax", - "AdaptiveLogSoftmaxWithLoss", -] -KERAS_ACTIVATION_FUNCS = [ - "KerasReLU", - "KerasPReLU", - "KerasLeakyReLU", - "KerasThresholdedReLU", - "KerasELU", - "KerasSoftmax", -] -DROPOUT_FUNCS = [ - "Dropout", - "Dropout2d", - "Dropout3d", - "AlphaDropout", - "FeatureAlphaDropout", -] -KERAS_DROPOUT_FUNCS = ["KerasDropout"] -CONV_BLOCK_FNS = [ - *CONV_FUNCS, - *KERAS_CONV_FUNCS, - *POOL_FUNCS, - *KERAS_POOL_FUNCS, - *PADDING_FUNCS, - *KERAS_PADDING_FUNCS, - *ACTIVATION_FUNCS, - *KERAS_ACTIVATION_FUNCS, - *NORM_FUNCS, - *KERAS_NORM_FUNCS, - *DROPOUT_FUNCS, - *KERAS_DROPOUT_FUNCS, -] -DATA_FORMAT = "PT" - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "int8"): "int64", + ("uint32", "int16"): "int64", + ("uint32", "int32"): "int64", + ("uint32", "int64"): "int64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", + ("float16", "float16"): "float16", + ("float16", "float32"): "float32", + ("float16", "float64"): "float64", + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", +} - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_store_config_info(fn): @@ -460,25 +594,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -511,7 +626,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -673,26 +790,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -810,6 +909,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -916,27 +1018,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1105,6 +1201,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1361,7 +1460,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1773,7 +1874,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1941,6 +2044,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1965,11 +2071,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2203,7 +2307,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2363,11 +2469,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2407,21 +2511,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2502,6 +2591,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2558,6 +2650,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2602,6 +2697,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2656,6 +2754,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2692,6 +2809,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2714,21 +2835,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2766,6 +2883,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2817,20 +2953,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2889,7 +3011,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -3053,17 +3177,6 @@ def tensorflow__determine_depth_max_pooling( return x, kernel, strides, depth_pooling -def tensorflow__handle_padding_bknd(x, strides, filters, padding): - if isinstance(padding, str) and padding.upper() == "SAME": - if x % strides == 0: - pad = max(filters - strides, 0) - else: - pad = max(filters - x % strides, 0) - else: - pad = 0 - return pad - - def tensorflow__output_ceil_shape_bknd(w, f, p, s): return math.ceil((w - f + p) / s) + 1 @@ -3254,24 +3367,6 @@ def tensorflow_reshape_frnt_(tensor, *args, shape=None): raise ValueError("reshape() got no values for argument 'shape'") -def tensorflow__handle_padding_shape_frnt(padding, n, mode): - ag__result_list_0 = [] - for i in range(int(len(padding) / 2) - 1, -1, -1): - res = ( - tensorflow_get_item(padding, i * 2), - tensorflow_get_item(padding, i * 2 + 1), - ) - ag__result_list_0.append(res) - padding = tuple(ag__result_list_0) - if mode == "circular": - padding = padding + ((0, 0),) * (n - len(padding)) - else: - padding = ((0, 0),) * (n - len(padding)) + padding - if mode == "circular": - padding = tuple(list(padding)[::-1]) - return padding - - def tensorflow__to_tf_padding_bknd(pad_width, ndim): if isinstance(pad_width, Number): pad_width = [[pad_width] * 2] * ndim @@ -3772,79 +3867,3 @@ def tensorflow_apply_transpose(input, transpose, pt_to_tf=True): axes = (0, 2, 3, 4, 1) if pt_to_tf else (0, 4, 1, 2, 3) input = tensorflow_permute_dims(input, axes=axes) return input - - -def tensorflow_handle_transpose_in_input_and_output(fn): - from .tensorflow_TransposeType import tensorflow_TransposeType - - original_signature = inspect.signature(fn) - - @functools.wraps(fn) - def transpose_wrapper(self, *args, **kwargs): - global DATA_FORMAT - kwargs_call = { - key: val - for key, val in kwargs.items() - if key not in dict(original_signature.parameters) - } - fn_args_and_kwargs = { - key: val for key, val in kwargs.items() if key not in kwargs_call - } - fn_args_and_kwargs.update(dict(zip(fn.__code__.co_varnames[1:], args))) - conv_block_start = lambda f: any( - substr in f.__qualname__ - for substr in CONV_FUNCS - + NORM_FUNCS - + POOL_FUNCS - + KERAS_CONV_FUNCS - + KERAS_NORM_FUNCS - + KERAS_POOL_FUNCS - ) - next_call_in_seq = tensorflow_get_next_func(self) - name_of_next_call = ( - next_call_in_seq.__class__.__name__ - if hasattr(next_call_in_seq, "__class__") - else "" - ) - conv_block_continued = next_call_in_seq and any( - substr in name_of_next_call for substr in CONV_BLOCK_FNS - ) - if DATA_FORMAT == "PT" and conv_block_start(self.__class__): - input = fn_args_and_kwargs["input"] - if len(input.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(input.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(input.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - fn_args_and_kwargs = tensorflow_set_item_bknd( - fn_args_and_kwargs, - "input", - tensorflow_apply_transpose(input, transpose=transpose, pt_to_tf=True), - ) - DATA_FORMAT = "TF" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_last" - ) - res = fn(self, **fn_args_and_kwargs) - if DATA_FORMAT == "TF" and conv_block_continued or DATA_FORMAT == "PT": - return res - if len(res.shape) > 4: - transpose = tensorflow_TransposeType.CONV3D - elif len(res.shape) > 3: - transpose = tensorflow_TransposeType.CONV2D - elif len(res.shape) > 2: - transpose = tensorflow_TransposeType.CONV1D - else: - transpose = tensorflow_TransposeType.NO_TRANSPOSE - res = tensorflow_apply_transpose(res, transpose=transpose, pt_to_tf=False) - DATA_FORMAT = "PT" - os.environ = tensorflow_set_item_bknd( - os.environ, "DATA_FORMAT", "channels_first" - ) - return res - - tensorflow_handle_transpose_in_input_and_output.__signature__ = original_signature - return transpose_wrapper diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleDict_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleDict_output/run_0/tensorflow__helpers.py index f6b0374e0786..ed9e67ed9d29 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleDict_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleDict_output/run_0/tensorflow__helpers.py @@ -23,6 +23,118 @@ import tensorflow as tf +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +259,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +301,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -302,25 +329,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -353,7 +361,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -515,26 +525,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -652,6 +644,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -758,27 +753,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -947,6 +936,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1203,7 +1195,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1615,7 +1609,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1783,6 +1779,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1807,11 +1806,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2045,7 +2042,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2205,11 +2204,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2249,21 +2246,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2344,6 +2326,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2400,6 +2385,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2444,6 +2432,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2498,6 +2489,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2534,6 +2544,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2556,21 +2570,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2608,6 +2618,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2659,20 +2688,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2731,7 +2746,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow_ModuleList.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow_ModuleList.py index 221e8ae19387..dd4cac3df3b9 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow_ModuleList.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow_ModuleList.py @@ -4,8 +4,8 @@ import typing import operator -from collections import abc as container_abcs from itertools import chain +from collections import abc as container_abcs from .tensorflow__stateful import Model as tensorflow_keras_Model from .tensorflow__helpers import tensorflow__addindent diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow__helpers.py index 597da46b91f7..10d301dfbf21 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ModuleList_output/run_0/tensorflow__helpers.py @@ -23,6 +23,118 @@ import tensorflow as tf +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +259,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +301,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -302,25 +329,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -353,7 +361,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -515,26 +525,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -652,6 +644,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -758,27 +753,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -947,6 +936,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1203,7 +1195,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1615,7 +1609,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1783,6 +1779,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1807,11 +1806,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2045,7 +2042,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2205,11 +2204,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2249,21 +2246,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2344,6 +2326,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2400,6 +2385,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2444,6 +2432,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2498,6 +2489,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2534,6 +2544,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2556,21 +2570,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2608,6 +2618,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2659,20 +2688,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2742,7 +2757,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_Sequential.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_Sequential.py index bef14edf7dc8..3c7a5be74858 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_Sequential.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow_Sequential.py @@ -2,8 +2,8 @@ from collections import OrderedDict import threading -import typing import operator +import typing from typing import overload from itertools import islice diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow__helpers.py index f6b0374e0786..ed9e67ed9d29 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_Sequential_output/run_0/tensorflow__helpers.py @@ -23,6 +23,118 @@ import tensorflow as tf +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods_1(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +259,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +301,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -302,25 +329,6 @@ def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): ) or tensorflow_is_native_array(x, exclusive=exclusive) -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - def tensorflow_exists_bknd(x: Any, /): return x is not None @@ -353,7 +361,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -515,26 +525,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods_1(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods_1 +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -652,6 +644,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -758,27 +753,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -947,6 +936,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1203,7 +1195,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1615,7 +1609,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1783,6 +1779,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1807,11 +1806,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2045,7 +2042,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2205,11 +2204,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2249,21 +2246,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2344,6 +2326,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2400,6 +2385,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2444,6 +2432,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2498,6 +2489,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2534,6 +2544,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2556,21 +2570,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2608,6 +2618,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2659,20 +2688,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2731,7 +2746,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_abs_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_abs_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_abs_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_abs_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_add.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_add.py index 9439d9a57019..3e5f36e70fa9 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_add.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_add_output/run_0/tensorflow_add.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Optional from typing import Union +from typing import Optional from .tensorflow__helpers import tensorflow_asarray from .tensorflow__helpers import tensorflow_default_dtype_bknd @@ -23,7 +23,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow_all.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow_all.py index d1b70765780c..12337bcdcaf2 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow_all.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_all_output/run_0/tensorflow_all.py @@ -1,8 +1,8 @@ import tensorflow from typing import Optional -from typing import Union from typing import Sequence +from typing import Union from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_any.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_any.py index 6bbf37e24b80..8fe92486b3d4 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_any.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_any_output/run_0/tensorflow_any.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Sequence from typing import Union +from typing import Sequence from typing import Optional from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow_arange.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow_arange.py index fcd19b3ccc5b..c7d7db255805 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow_arange.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_arange_output/run_0/tensorflow_arange.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_as_native_dtype from .tensorflow__helpers import tensorflow_default_dtype_bknd diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow__helpers.py index b76b3e2a199d..b3c4002471f5 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_argwhere_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_ivy_dtype_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_ivy_dtype_output/run_0/tensorflow__helpers.py index 0f8e6430e8ff..e65cc400f346 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_ivy_dtype_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_ivy_dtype_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -default_dtype_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_float_dtype_stack = [] -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_int_dtype_stack = [] -backend = "" - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -450,20 +460,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -530,26 +526,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -667,6 +645,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -839,6 +820,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -912,27 +896,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1101,6 +1079,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1357,7 +1338,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1769,7 +1752,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1893,6 +1878,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1917,11 +1905,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2129,7 +2115,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2289,11 +2277,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2333,21 +2319,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2428,6 +2399,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2484,6 +2458,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2520,6 +2513,9 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2574,6 +2570,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2597,6 +2612,10 @@ def tensorflow_as_native_dtype( ) +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2619,21 +2638,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_native_dtype_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_native_dtype_output/run_0/tensorflow__helpers.py index 0f8e6430e8ff..e65cc400f346 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_native_dtype_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_as_native_dtype_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -default_dtype_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_float_dtype_stack = [] -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_int_dtype_stack = [] -backend = "" - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -450,20 +460,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -530,26 +526,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -667,6 +645,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -839,6 +820,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -912,27 +896,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1101,6 +1079,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1357,7 +1338,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1769,7 +1752,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1893,6 +1878,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1917,11 +1905,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2129,7 +2115,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2289,11 +2277,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2333,21 +2319,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2428,6 +2399,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2484,6 +2458,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2520,6 +2513,9 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2574,6 +2570,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2597,6 +2612,10 @@ def tensorflow_as_native_dtype( ) +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2619,21 +2638,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow_asarray.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow_asarray.py index 34ee7995c305..a65c00fe0b7e 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow_asarray.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_asarray_output/run_0/tensorflow_asarray.py @@ -1,9 +1,9 @@ import tensorflow import numpy as np +from typing import Union from typing import TypeVar from typing import Optional -from typing import Union from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd from .tensorflow__helpers import tensorflow__asarray_infer_dtype_bknd diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_astype.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_astype.py index a0614a6222b8..a12360f716c3 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_astype.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_astype_output/run_0/tensorflow_astype.py @@ -1,8 +1,8 @@ import tensorflow import tensorflow as tf -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_as_native_dtype from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow_batch_norm.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow_batch_norm.py index 6bad1a0d5c0c..450862864476 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow_batch_norm.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_batch_norm_output/run_0/tensorflow_batch_norm.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Tuple from typing import Union +from typing import Tuple from typing import Optional from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index 9f87b4ae29ef..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Protocol -from typing import TypeVar - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow__helpers.py deleted file mode 100644 index d85df9b19d15..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2676 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr - - -def tensorflow__check_shapes_broadcastable_bknd(out, inp): - if out is not None: - tensorflow_check_shapes_broadcastable(out, inp) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow_bernoulli.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow_bernoulli.py deleted file mode 100644 index 429244328aee..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_bernoulli_output/run_0/tensorflow_bernoulli.py +++ /dev/null @@ -1,28 +0,0 @@ -import tensorflow -import tensorflow as tf - -from typing import Union -from typing import Sequence -from typing import Optional - -from .tensorflow__helpers import tensorflow__check_shapes_broadcastable_bknd -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -def tensorflow_bernoulli( - probs: Union[float, tensorflow.Tensor, tensorflow.Variable], - *, - logits: Union[float, tensorflow.Tensor, tensorflow.Variable] = None, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - device: Optional[str] = None, - dtype: Optional[str] = None, - seed: Optional[int] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = dtype if dtype is not None else probs.dtype - if logits is not None: - probs = tensorflow.nn.softmax(logits, -1) - if not tensorflow__check_shapes_broadcastable_bknd(shape, probs.shape): - shape = probs.shape - return tensorflow.keras.backend.random_bernoulli(shape, probs, dtype, seed) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_or_any_fn_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_or_any_fn_output/run_0/tensorflow__helpers.py index f80670f5bb39..678addf0745b 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_or_any_fn_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_or_any_fn_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow__helpers.py index 2bab9cba2ad3..6c6aa45494fd 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_all_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_any_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_any_output/run_0/tensorflow__helpers.py index e20a3f841c03..5bc8d53b9ea1 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_any_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_any_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2388,7 +2420,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2548,11 +2582,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2592,21 +2624,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow__helpers.py index e20a3f841c03..5bc8d53b9ea1 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_check_equal_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2388,7 +2420,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2548,11 +2582,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2592,21 +2624,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow__helpers.py index 0d123456c924..3829cdba671f 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow__helpers.py @@ -24,6 +24,99 @@ from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -148,6 +241,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -189,94 +283,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -default_uint_dtype_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -335,7 +343,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -448,6 +458,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -461,6 +472,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -486,6 +498,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -612,6 +627,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -656,6 +674,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -710,6 +731,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -746,6 +786,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -768,21 +812,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -820,6 +860,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -871,20 +930,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -951,26 +996,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1088,6 +1115,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1231,27 +1261,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1522,7 +1546,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1934,7 +1960,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2080,6 +2108,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2104,11 +2135,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2316,7 +2345,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2476,11 +2507,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2520,21 +2549,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2629,6 +2643,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow_clip.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow_clip.py index 011ecfbe0b04..cca98ca6e9c2 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow_clip.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_clip_output/run_0/tensorflow_clip.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from numbers import Number +from typing import Union from typing import Optional from .tensorflow__helpers import tensorflow_as_native_dtype diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow__helpers.py index e2f11cae1bac..e5130e4363f5 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow__helpers.py @@ -1,9 +1,9 @@ import tensorflow -from typing import Optional from typing import Tuple -from typing import Union from typing import List +from typing import Optional +from typing import Union def tensorflow_concat( diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow_concat.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow_concat.py index 424cdd95dcae..95527c04f1f3 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow_concat.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_concat_output/run_0/tensorflow_concat.py @@ -1,9 +1,9 @@ import tensorflow -from typing import Optional from typing import Tuple -from typing import Union from typing import List +from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_concat diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow__helpers.py index c1e6b1511d21..2f8e4e2407b7 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_conv_general_dilated.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_conv_general_dilated.py index 1fe1c4972fe4..c8d6050ece56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_conv_general_dilated.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_conv_general_dilated_output/run_0/tensorflow_conv_general_dilated.py @@ -1,9 +1,9 @@ import tensorflow -from typing import Sequence -from typing import Optional from typing import Union +from typing import Optional from typing import Tuple +from typing import Sequence from .tensorflow__helpers import tensorflow__extend_2d_padding from .tensorflow__helpers import tensorflow__extend_3d_strides_dilations diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow__helpers.py index 6787eeb19f46..1f0f14a15070 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -343,7 +351,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -456,6 +466,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -469,6 +480,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -494,6 +506,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -620,6 +635,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -664,6 +682,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -718,6 +739,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -754,6 +794,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -776,21 +820,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -828,6 +868,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -879,20 +938,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -1000,6 +1045,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1106,27 +1154,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1295,6 +1337,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1551,7 +1596,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1963,7 +2010,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2087,6 +2136,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2111,11 +2163,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2323,7 +2373,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2483,11 +2535,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2527,21 +2577,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2582,25 +2617,6 @@ def tensorflow_set_item_bknd( return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods @tensorflow_handle_array_like_without_promotion def tensorflow_split( diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow_default_device_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow_default_device_bknd.py index b8aa4cf37897..3eac8e949361 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow_default_device_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_device_bknd_output/run_0/tensorflow_default_device_bknd.py @@ -1,8 +1,8 @@ import tensorflow import tensorflow as tf -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_as_ivy_dev from .tensorflow__helpers import tensorflow_as_native_dev diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow__helpers.py index c0ae7ff73b08..22d81179ec13 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -475,20 +485,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -555,26 +551,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -692,6 +670,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -798,27 +779,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -987,6 +962,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1243,7 +1221,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1655,7 +1635,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1823,6 +1805,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1847,11 +1832,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2085,7 +2068,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2245,11 +2230,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2289,21 +2272,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2384,6 +2352,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2440,6 +2411,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2484,6 +2458,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2538,6 +2515,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2574,6 +2570,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2596,21 +2596,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2648,6 +2644,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow_default_dtype_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow_default_dtype_bknd.py index ead8121256dd..9aab49dd98ed 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow_default_dtype_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_default_dtype_bknd_output/run_0/tensorflow_default_dtype_bknd.py @@ -1,8 +1,8 @@ import tensorflow import tensorflow as tf -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_as_ivy_dtype from .tensorflow__helpers import tensorflow_as_native_dtype diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow__helpers.py index 21c38366400c..4b1a31634459 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dev_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_device_stack = [] - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_stack( @@ -376,7 +384,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -489,6 +499,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -502,6 +513,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -527,6 +539,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -653,6 +668,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -697,6 +715,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -751,6 +772,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -787,6 +827,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -809,21 +853,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -861,6 +901,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -912,20 +971,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -1077,27 +1122,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1266,6 +1305,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1522,7 +1564,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1934,7 +1978,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2058,6 +2104,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2082,11 +2131,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2294,7 +2341,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2454,11 +2503,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2498,21 +2545,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2553,25 +2585,6 @@ def tensorflow_set_item_bknd( return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods @tensorflow_handle_array_like_without_promotion def tensorflow_split( @@ -2641,6 +2654,9 @@ def tensorflow_as_ivy_dev(device: str, /): return str(f"{dev_type}:{dev_idx}") +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow_divide.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow_divide.py index 5a6d0ec4362f..01f2e8dbc5b7 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow_divide.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_divide_output/run_0/tensorflow_divide.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_asarray from .tensorflow__helpers import tensorflow_default_dtype_bknd @@ -24,7 +24,9 @@ def tensorflow_divide( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dropout_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dtype_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dtype_output/run_0/tensorflow__helpers.py index c0ae7ff73b08..22d81179ec13 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_dtype_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_dtype_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -475,20 +485,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -555,26 +551,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -692,6 +670,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -798,27 +779,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -987,6 +962,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1243,7 +1221,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1655,7 +1635,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1823,6 +1805,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1847,11 +1832,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2085,7 +2068,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2245,11 +2230,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2289,21 +2272,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2384,6 +2352,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2440,6 +2411,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2484,6 +2458,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2538,6 +2515,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2574,6 +2570,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2596,21 +2596,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2648,6 +2644,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index 9f87b4ae29ef..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Protocol -from typing import TypeVar - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 06e137cf3452..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2671 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow_empty.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow_empty.py deleted file mode 100644 index 73fb31fa90b7..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_empty_output/run_0/tensorflow_empty.py +++ /dev/null @@ -1,21 +0,0 @@ -import tensorflow -import tensorflow as tf - -from typing import Union -from typing import Sequence -from typing import Optional - -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow_equal.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow_equal.py index 728f68a913f6..a9f7a098de5e 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow_equal.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_equal_output/run_0/tensorflow_equal.py @@ -21,7 +21,9 @@ def tensorflow_equal( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow_erfinv.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow_erfinv.py index 9bb43f0d8328..87f2e1c9b719 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow_erfinv.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_erfinv_output/run_0/tensorflow_erfinv.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Optional from typing import Union +from typing import Optional from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_exp.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_exp.py index 2cd2d3e62048..b2a02898870e 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_exp.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_exp_output/run_0/tensorflow_exp.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_expand_dims.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_expand_dims.py index 73a2060a66cc..a23155c2520e 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_expand_dims.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_expand_dims_output/run_0/tensorflow_expand_dims.py @@ -2,8 +2,8 @@ import numpy as np from typing import Optional -from typing import Union from typing import Sequence +from typing import Union from .tensorflow__helpers import tensorflow__calculate_out_shape_bknd from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_flatten_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_flatten_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_flatten_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_flatten_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow_floor_divide.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow_floor_divide.py index 22e215c67b61..72894af8b44f 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow_floor_divide.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_floor_divide_output/run_0/tensorflow_floor_divide.py @@ -21,7 +21,9 @@ def tensorflow_floor_divide( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index 9f87b4ae29ef..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Protocol -from typing import TypeVar - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 06e137cf3452..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2671 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow_full_like.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow_full_like.py deleted file mode 100644 index 8d236571ecce..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_full_like_output/run_0/tensorflow_full_like.py +++ /dev/null @@ -1,22 +0,0 @@ -import tensorflow - -from typing import Union -from typing import Optional -from numbers import Number - -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_full_like( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - fill_value: Number, - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.full_like(x, fill_value, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_gather_nd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_gather_nd.py index ca479018e633..34ace526faa0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_gather_nd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gather_nd_output/run_0/tensorflow_gather_nd.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Optional from typing import Union +from typing import Optional from .tensorflow__helpers import tensorflow_check_gather_nd_input_valid from .tensorflow__helpers import tensorflow_gather_nd_helper diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_gelu_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_get_item_bknd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_get_item_bknd_output/run_0/tensorflow__helpers.py index df16104cad4b..953e5a910cd1 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_get_item_bknd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_get_item_bknd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,20 +457,6 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -527,26 +523,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -664,6 +642,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -770,27 +751,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -959,6 +934,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1215,7 +1193,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1627,7 +1607,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1795,6 +1777,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1819,11 +1804,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2057,7 +2040,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2217,11 +2202,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2261,21 +2244,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2356,6 +2324,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2412,6 +2383,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2456,6 +2430,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2510,6 +2487,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2546,6 +2542,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2568,21 +2568,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2620,6 +2616,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_inplace_update_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_inplace_update_output/run_0/tensorflow__helpers.py index 58222341c5f5..d1e89024eb25 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_inplace_update_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_inplace_update_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1623,7 +1650,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2035,7 +2064,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2159,6 +2190,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2183,11 +2217,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2395,7 +2427,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2555,11 +2589,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2599,21 +2631,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow__helpers.py index 79061ec850f0..ab1fefb864d2 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2038,7 +2067,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2162,6 +2193,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2186,11 +2220,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2398,7 +2430,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2558,11 +2592,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2602,21 +2634,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2682,7 +2699,9 @@ def tensorflow_divide( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_interpolate.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_interpolate.py index fc92079548c5..e5c05693ff2c 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_interpolate.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_interpolate_output/run_0/tensorflow_interpolate.py @@ -1,8 +1,8 @@ import tensorflow from typing import Sequence -from typing import Literal from typing import Optional +from typing import Literal from typing import Union from .tensorflow__helpers import tensorflow__get_size_bknd @@ -54,11 +54,11 @@ def tensorflow_interpolate( mode = ( "bilinear" if mode == "linear" - else ( - "area" - if mode == "tf_area" - else "nearest" if mode == "nearest-exact" else mode - ) + else "area" + if mode == "tf_area" + else "nearest" + if mode == "nearest-exact" + else mode ) if mode == "tf_bicubic": mode = "bicubic" diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_is_array_bknd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_is_array_bknd_output/run_0/tensorflow__helpers.py index ef32b68640a5..ea0fb33b6951 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_is_array_bknd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_is_array_bknd_output/run_0/tensorflow__helpers.py @@ -128,6 +128,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -169,6 +170,7 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } + tf.experimental.numpy.experimental_enable_numpy_behavior(True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow__helpers.py index 751b26a6de35..460e74b1df87 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2728,7 +2745,9 @@ def tensorflow_add( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow_layer_norm_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow_layer_norm_bknd.py index 01f3f5899901..39a88a9686a4 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow_layer_norm_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_layer_norm_bknd_output/run_0/tensorflow_layer_norm_bknd.py @@ -1,9 +1,9 @@ import tensorflow import tensorflow as tf -from typing import Union from typing import Optional from typing import List +from typing import Union from .tensorflow__helpers import tensorflow_add from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_leaky_relu_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_leaky_relu_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_leaky_relu_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_leaky_relu_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index ac8335fe1e56..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import TypeVar -from typing import Protocol - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow__helpers.py deleted file mode 100644 index eaee8c852649..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2675 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr - - -def tensorflow__slice_at_axis(sl, axis): - return (slice(None),) * axis + (sl,) + (...,) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow_linspace.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow_linspace.py deleted file mode 100644 index 8692fe07d4cc..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_linspace_output/run_0/tensorflow_linspace.py +++ /dev/null @@ -1,40 +0,0 @@ -import tensorflow - -from typing import Optional -from typing import Union - -from .tensorflow__helpers import tensorflow__slice_at_axis -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_linspace( - start: Union[tensorflow.Tensor, tensorflow.Variable, float], - stop: Union[tensorflow.Tensor, tensorflow.Variable, float], - /, - num: int, - *, - axis: Optional[int] = None, - endpoint: bool = True, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - axis = -1 - start = tensorflow.cast(tensorflow.constant(start), dtype=dtype) - stop = tensorflow.cast(tensorflow.constant(stop), dtype=dtype) - if not endpoint: - ans = tensorflow.linspace(start, stop, num + 1, axis=axis) - if axis < 0: - axis += len(ans.shape) - ans = tensorflow.convert_to_tensor( - ans.numpy()[tensorflow__slice_at_axis(slice(None, -1), axis)] - ) - else: - ans = tensorflow.linspace(start, stop, num, axis=axis) - if dtype.is_integer and ans.dtype.is_floating: - ans = tensorflow.math.floor(ans) - return tensorflow.cast(ans, dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_log2.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_log2.py index e8a92632748c..6935fe667d60 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_log2.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log2_output/run_0/tensorflow_log2.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Optional from typing import Union +from typing import Optional from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_log_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow__helpers.py index 55023176b74b..1fce5432ddff 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow__helpers.py @@ -23,6 +23,110 @@ import tensorflow as tf +def tensorflow__handle_padding_bknd(x, strides, filters, padding): + if isinstance(padding, str) and padding.upper() == "SAME": + if x % strides == 0: + pad = max(filters - strides, 0) + else: + pad = max(filters - x % strides, 0) + else: + pad = 0 + return pad + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +251,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +293,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +353,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -496,26 +517,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -633,6 +636,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -739,27 +745,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -928,6 +928,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1184,7 +1187,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1596,7 +1601,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1764,6 +1771,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1788,11 +1798,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2026,7 +2034,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2186,11 +2196,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2230,21 +2238,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2325,6 +2318,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2381,6 +2377,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2425,6 +2424,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2479,6 +2481,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2515,6 +2536,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2537,21 +2562,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2589,6 +2610,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2640,20 +2680,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -2797,17 +2823,6 @@ def tensorflow__determine_depth_max_pooling( return x, kernel, strides, depth_pooling -def tensorflow__handle_padding_bknd(x, strides, filters, padding): - if isinstance(padding, str) and padding.upper() == "SAME": - if x % strides == 0: - pad = max(filters - strides, 0) - else: - pad = max(filters - x % strides, 0) - else: - pad = 0 - return pad - - def tensorflow__output_ceil_shape_bknd(w, f, p, s): return math.ceil((w - f + p) / s) + 1 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow_max_pool2d.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow_max_pool2d.py index eb20f95351ce..97e5ddc86675 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow_max_pool2d.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_max_pool2d_output/run_0/tensorflow_max_pool2d.py @@ -1,9 +1,9 @@ import tensorflow -from typing import List -from typing import Union -from typing import Tuple from typing import Optional +from typing import Tuple +from typing import Union +from typing import List from .tensorflow__helpers import tensorflow__determine_depth_max_pooling from .tensorflow__helpers import tensorflow__handle_padding_bknd diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow_maximum.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow_maximum.py index c6712c3b1882..3148955e052c 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow_maximum.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_maximum_output/run_0/tensorflow_maximum.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_asarray from .tensorflow__helpers import tensorflow_default_dtype_bknd @@ -22,7 +22,9 @@ def tensorflow_maximum( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow_meshgrid.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow_meshgrid.py index b9d46a3a9338..febe54725d45 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow_meshgrid.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_meshgrid_output/run_0/tensorflow_meshgrid.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow_minimum.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow_minimum.py index 435339bedc1c..8b9ecc694ee4 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow_minimum.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_minimum_output/run_0/tensorflow_minimum.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_asarray from .tensorflow__helpers import tensorflow_default_dtype_bknd @@ -22,7 +22,9 @@ def tensorflow_minimum( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow__helpers.py index 1b64cf5d5694..bde7b8c8d8d0 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_dtype_stack = [] -default_float_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -531,20 +541,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -611,26 +607,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -748,6 +726,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -854,27 +835,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1043,6 +1018,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1299,7 +1277,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1711,7 +1691,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1877,6 +1859,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1901,11 +1886,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2139,7 +2122,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2299,11 +2284,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2343,21 +2326,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2438,6 +2406,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2494,6 +2465,25 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2530,6 +2520,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2552,21 +2546,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2604,6 +2594,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -2627,6 +2636,10 @@ def tensorflow_as_native_dtype( ) +default_dtype_stack = [] +default_float_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_multiply.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_multiply.py index 81dc60ebf373..604c30be52df 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_multiply.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_multiply_output/run_0/tensorflow_multiply.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Optional from typing import Union +from typing import Optional from .tensorflow__helpers import tensorflow_asarray from .tensorflow__helpers import tensorflow_default_dtype_bknd @@ -21,7 +21,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow_nonzero.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow_nonzero.py index a307b8b1a3ff..edc8a537a4b8 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow_nonzero.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_nonzero_output/run_0/tensorflow_nonzero.py @@ -1,7 +1,7 @@ import tensorflow -from numbers import Number from typing import Union +from numbers import Number from typing import Optional from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index 9f87b4ae29ef..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Protocol -from typing import TypeVar - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 06e137cf3452..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2671 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow_ones.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow_ones.py deleted file mode 100644 index 8549c443e681..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_ones_output/run_0/tensorflow_ones.py +++ /dev/null @@ -1,21 +0,0 @@ -import tensorflow -import tensorflow as tf - -from typing import Sequence -from typing import Union -from typing import Optional - -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_ones( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.ones(shape, dtype=tensorflow.float32) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow__helpers.py index 37fa10ab2117..c7f7ab515b51 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_pad.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_pad.py index 3cf8e5e78c65..e3b9654b0ed7 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_pad.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pad_output/run_0/tensorflow_pad.py @@ -1,13 +1,13 @@ import tensorflow -from numbers import Number -from typing import Literal -from typing import Tuple -from typing import Iterable -from typing import Any -from typing import Callable from typing import Optional from typing import Union +from typing import Literal +from numbers import Number +from typing import Callable +from typing import Any +from typing import Iterable +from typing import Tuple from .tensorflow__helpers import tensorflow__to_tf_padding_bknd from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_permute_dims.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_permute_dims.py index de4729f5c9f9..fb1ae482b707 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_permute_dims.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_permute_dims_output/run_0/tensorflow_permute_dims.py @@ -1,8 +1,8 @@ import tensorflow +from typing import Tuple from typing import Union from typing import Optional -from typing import Tuple from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow__helpers.py index 81493dd7f88d..ef833e3ab065 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -440,20 +450,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -520,26 +516,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -657,6 +635,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -763,27 +744,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -952,6 +927,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1208,7 +1186,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1620,7 +1600,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1792,6 +1774,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1816,11 +1801,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2054,7 +2037,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2214,11 +2199,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2258,21 +2241,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2353,6 +2321,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2409,6 +2380,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2453,6 +2427,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2507,6 +2484,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -2543,6 +2539,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2565,21 +2565,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2617,6 +2613,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow_pow.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow_pow.py index a17fb05b6ae3..e5ace5d24038 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow_pow.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_pow_output/run_0/tensorflow_pow.py @@ -36,7 +36,9 @@ def tensorflow_pow( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index ac8335fe1e56..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import TypeVar -from typing import Protocol - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 394932ff412f..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2746 +0,0 @@ -from collections import UserDict -from ivy.utils.backend import backend_stack -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -backend_stack = [] - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_all( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_all( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow_check_all(results, message="one of the args is False", as_array=True): - if as_array and not tensorflow_all(results) or not as_array and not all(results): - raise Exception(message) - - -def tensorflow_check_all_or_any_fn( - *args, - fn, - type="all", - limit=(0,), - message="args must exist according to type and limit given", - as_array=True, -): - if type == "all": - tensorflow_check_all([fn(arg) for arg in args], message, as_array=as_array) - elif type == "any": - count = 0 - for arg in args: - count = count + 1 if fn(arg) else count - if count not in limit: - raise Exception(message) - else: - raise Exception("type must be all or any") - - -def tensorflow__check_bounds_and_get_shape_bknd(low, high, shape): - if shape is not None: - tensorflow_check_all_or_any_fn( - low, - high, - fn=lambda x: isinstance(x, (int, float)), - type="all", - message="low and high bounds must be numerics when shape is specified", - ) - return tuple(shape) - valid_types = (tensorflow.Tensor,) - if len(backend_stack) == 0: - valid_types = valid_types + (tf.Tensor,) - else: - valid_types = valid_types + (tf.Tensor,) - if isinstance(low, valid_types): - if isinstance(high, valid_types): - tensorflow_check_equal( - tensorflow_shape(low), tensorflow_shape(high), as_array=False - ) - return tensorflow_shape(low) - if isinstance(high, valid_types): - return tensorflow_shape(high) - return tuple(()) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow_random_uniform.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow_random_uniform.py deleted file mode 100644 index 733e0643f54f..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_random_uniform_output/run_0/tensorflow_random_uniform.py +++ /dev/null @@ -1,41 +0,0 @@ -import tensorflow -import tensorflow as tf - -from typing import Union -from typing import Sequence -from typing import Optional - -from .tensorflow__helpers import tensorflow__check_bounds_and_get_shape_bknd -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -def tensorflow_random_uniform( - *, - low: Union[float, tensorflow.Tensor, tensorflow.Variable] = 0.0, - high: Union[float, tensorflow.Tensor, tensorflow.Variable, None] = 1.0, - shape: Optional[Union[tf.TensorShape, Sequence[int], tensorflow.Tensor]] = None, - dtype: tf.DType, - device: Optional[str] = None, - seed: Optional[int] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - shape = tensorflow__check_bounds_and_get_shape_bknd( - low, - ( - float( - tensorflow.experimental.numpy.finfo(tensorflow.float32).max - if dtype is None - else tensorflow.experimental.numpy.finfo(dtype).max - ) - if high is None - else high - ), - shape, - ) - low = tensorflow.cast(low, dtype) - if high is not None: - high = tensorflow.cast(high, dtype) - if seed: - tensorflow.random.set_seed(seed) - return tensorflow.random.uniform(shape, low, high, dtype=dtype, seed=seed) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_relu_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow_reshape.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow_reshape.py index d65dd7e1ec26..ebcbf0d55c36 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow_reshape.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_reshape_output/run_0/tensorflow_reshape.py @@ -1,8 +1,8 @@ import tensorflow import tensorflow as tf -from typing import Sequence from typing import Optional +from typing import Sequence from typing import Union from .tensorflow__helpers import tensorflow__reshape_fortran_tf diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow_roll.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow_roll.py index e18e6d210b48..6c7ec58634c2 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow_roll.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_roll_output/run_0/tensorflow_roll.py @@ -1,8 +1,8 @@ import tensorflow +from typing import Union from typing import Optional from typing import Sequence -from typing import Union from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow__helpers.py index 9605f2ec5525..ad6f03385ba3 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_exists_bknd(x: Any, /): @@ -310,7 +318,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -475,20 +485,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -555,26 +551,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -692,6 +670,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -798,27 +779,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -965,6 +940,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1221,7 +1199,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -1633,7 +1613,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -1801,6 +1783,9 @@ def tensorflow_is_uint_dtype_bknd( return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -1825,11 +1810,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2063,7 +2046,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2223,11 +2208,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2267,21 +2250,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], @@ -2362,6 +2330,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2418,6 +2389,9 @@ def tensorflow_default_complex_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -2462,6 +2436,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2539,6 +2516,10 @@ def tensorflow_as_native_dtype( ) +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2561,21 +2542,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -2613,6 +2590,42 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow_scatter_nd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow_scatter_nd.py index 402742e3fe9c..09d5a9e844af 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow_scatter_nd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_scatter_nd_output/run_0/tensorflow_scatter_nd.py @@ -29,11 +29,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow__helpers.py index 283529150f44..41a7f46379d9 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1169,27 +1199,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1358,6 +1382,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1614,7 +1641,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2026,7 +2055,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2150,6 +2181,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2174,11 +2208,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2386,7 +2418,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2550,11 +2584,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2594,21 +2626,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_set_item_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_set_item_bknd.py index cba8bbc72209..688205fd1e70 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_set_item_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_set_item_bknd_output/run_0/tensorflow_set_item_bknd.py @@ -2,9 +2,9 @@ import tensorflow as tf import numpy as np +from typing import Optional from typing import Union from typing import Tuple -from typing import Optional from .tensorflow__helpers import tensorflow__parse_query_bknd from .tensorflow__helpers import tensorflow_asarray diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sigmoid_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sigmoid_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sigmoid_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sigmoid_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_sign.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_sign.py index b784bd121eda..ed6c78f2c060 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_sign.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_sign_output/run_0/tensorflow_sign.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow__helpers.py index d8608d8ffe3e..eaf3294965aa 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_softmax_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_split_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_split_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_split_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_split_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_stack_output/run_0/tensorflow_stack.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_stack_output/run_0/tensorflow_stack.py index 9883d6b3d5cb..626bbcf3a19a 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_stack_output/run_0/tensorflow_stack.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_stack_output/run_0/tensorflow_stack.py @@ -1,9 +1,9 @@ import tensorflow -from typing import Optional +from typing import List from typing import Tuple from typing import Union -from typing import List +from typing import Optional def tensorflow_stack( diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_NestedSequence_bknd.py index ac8335fe1e56..9f87b4ae29ef 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import TypeVar from typing import Protocol +from typing import TypeVar _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_swapaxes.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_swapaxes.py index 9d273d75721d..57f12ef3d0a7 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_swapaxes.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_swapaxes_output/run_0/tensorflow_swapaxes.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Optional from typing import Union +from typing import Optional from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_NestedSequence_bknd.py index 9f87b4ae29ef..ac8335fe1e56 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_NestedSequence_bknd.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_NestedSequence_bknd.py @@ -1,5 +1,5 @@ -from typing import Protocol from typing import TypeVar +from typing import Protocol _T_co = TypeVar("_T_co", covariant=True) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow__helpers.py index 3687fc1df4db..b0ccff1b3ca8 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_to_device.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_to_device.py index 208ac581376a..d6054943fde8 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_to_device.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_device_output/run_0/tensorflow_to_device.py @@ -1,7 +1,7 @@ import tensorflow -from typing import Union from typing import Optional +from typing import Union from .tensorflow__helpers import tensorflow__same_device from .tensorflow__helpers import tensorflow_as_native_dev diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_numpy_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_numpy_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_numpy_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_numpy_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_scalar_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_scalar_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_scalar_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_to_scalar_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow__helpers.py index d50687f412e5..3aaa2aa83879 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow__helpers.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow__helpers.py @@ -23,6 +23,99 @@ import tensorflow as tf +def tensorflow_handle_array_like_without_promotion(fn: Callable): + @functools.wraps(fn) + def _handle_array_like_without_promotion(*args, **kwargs): + args = list(args) + num_args = len(args) + try: + type_hints = inspect.signature(fn).parameters + except (TypeError, ValueError): + return fn(*args, **kwargs) + parameters = list(type_hints.keys()) + annotations = [param.annotation for param in type_hints.values()] + device = tensorflow__get_preferred_device(args, kwargs) + for i, (annotation, parameter, arg) in enumerate( + zip(annotations, parameters, args) + ): + annotation_str = str(annotation) + if ( + ("rray" in annotation_str or "Tensor" in annotation_str) + and parameter != "out" + and all( + sq not in annotation_str + for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] + ) + ): + if i < num_args: + if tensorflow__check_in_nested_sequence( + arg, value=Ellipsis, _type=slice + ): + continue + if not tensorflow_is_array_bknd(arg): + args = tensorflow_set_item_bknd( + args, i, tensorflow_asarray(arg, device=device) + ) + elif parameters in kwargs: + kwarg = tensorflow_get_item(kwargs, parameter) + if not tensorflow_is_array_bknd(kwarg): + kwargs = tensorflow_set_item_bknd( + kwargs, parameter, tensorflow_asarray(kwarg, device=device) + ) + return fn(*args, **kwargs) + + _handle_array_like_without_promotion.handle_array_like_without_promotion = True + return _handle_array_like_without_promotion + + +def tensorflow_handle_set_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, val, **kwargs): + try: + inp.__setitem__(query, val) + res = inp + except IndexError: + raise + except Exception: + res = fn(inp, query, val, **kwargs) + return res + + return wrapper + + +def tensorflow_handle_methods(fn): + def extract_function_name(s): + match = re.search("_(.+?)(?:_\\d+)?$", s) + if match: + return match.group(1) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if tensorflow_is_array_bknd(args[0]): + return fn(*args, **kwargs) + else: + pattern = "_bknd_|_bknd|_frnt_|_frnt" + fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) + new_fn = getattr(args[0], fn_name) + return new_fn(*args[1:], **kwargs) + + return wrapper + + +def tensorflow_handle_get_item(fn): + @functools.wraps(fn) + def wrapper(inp, query, **kwargs): + try: + res = inp.__getitem__(query) + except IndexError: + raise + except Exception: + res = fn(inp, query, **kwargs) + return res + + return wrapper + + promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -147,6 +240,7 @@ ("complex64", "uint32"): "complex128", ("complex64", "uint64"): "complex128", } + array_api_promotion_table = { ("bool", "bool"): "bool", ("int8", "int8"): "int8", @@ -188,94 +282,8 @@ ("float32", "float64"): "float64", ("float64", "float64"): "float64", } -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion +tf.experimental.numpy.experimental_enable_numpy_behavior(True) def tensorflow_is_native_array(x, /, *, exclusive=False): @@ -334,7 +342,9 @@ def tensorflow_default_bknd( return ( x if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val + else default_val() + if default_callable + else default_val ) @@ -447,6 +457,7 @@ def tensorflow_is_complex_dtype_bknd( return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) +@tensorflow_handle_array_like_without_promotion def tensorflow_real( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -460,6 +471,7 @@ def tensorflow_real_bknd_(self): return tensorflow_real(self) +@tensorflow_handle_array_like_without_promotion def tensorflow_imag( val: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -485,6 +497,9 @@ def tensorflow__check_complex128_bknd(input): return False +default_complex_dtype_stack = [] + + def tensorflow_default_complex_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -611,6 +626,9 @@ def nested_fun(x): return "int" in tensorflow_as_ivy_dtype(dtype_in) +default_dtype_stack = [] + + def tensorflow_default_dtype_bknd( *, dtype: Optional[Union[str, str]] = None, @@ -655,6 +673,9 @@ def tensorflow_default_dtype_bknd( return tensorflow_as_ivy_dtype(ret) +default_float_dtype_stack = [] + + def tensorflow_default_float_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -709,6 +730,25 @@ def tensorflow_default_float_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +ivy_dtype_dict = { + tensorflow.int8: "int8", + tensorflow.int16: "int16", + tensorflow.int32: "int32", + tensorflow.int64: "int64", + tensorflow.uint8: "uint8", + tensorflow.uint16: "uint16", + tensorflow.uint32: "uint32", + tensorflow.uint64: "uint64", + tensorflow.bfloat16: "bfloat16", + tensorflow.float16: "float16", + tensorflow.float32: "float32", + tensorflow.float64: "float64", + tensorflow.complex64: "complex64", + tensorflow.complex128: "complex128", + tensorflow.bool: "bool", +} + + def tensorflow_as_ivy_dtype( dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / ): @@ -745,6 +785,10 @@ def tensorflow_as_ivy_dtype( raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") +default_int_dtype_stack = [] +backend = "" + + def tensorflow_default_int_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -767,21 +811,17 @@ def tensorflow_default_int_dtype_bknd( elif isinstance(input, (list, tuple, dict)): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if tensorflow_is_array_bknd(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 elif tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "int64" + if tensorflow_is_array_bknd(x) + else x > 2147483647 and x != math.inf, stop_after_n_found=1, ): ret = tf.int64 @@ -819,6 +859,25 @@ def tensorflow_default_int_dtype_bknd( return str(tensorflow_as_ivy_dtype(ret)) +native_dtype_dict = { + "int8": tensorflow.int8, + "int16": tensorflow.int16, + "int32": tensorflow.int32, + "int64": tensorflow.int64, + "uint8": tensorflow.uint8, + "uint16": tensorflow.uint16, + "uint32": tensorflow.uint32, + "uint64": tensorflow.uint64, + "bfloat16": tensorflow.bfloat16, + "float16": tensorflow.float16, + "float32": tensorflow.float32, + "float64": tensorflow.float64, + "complex64": tensorflow.complex64, + "complex128": tensorflow.complex128, + "bool": tensorflow.bool, +} + + def tensorflow_as_native_dtype( dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], ): @@ -870,20 +929,6 @@ def tensorflow_is_bool_dtype_bknd( return "bool" in tensorflow_as_ivy_dtype(dtype_in) -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - @tensorflow_handle_get_item def tensorflow_get_item( x: Union[tensorflow.Tensor, tensorflow.Variable], @@ -950,26 +995,8 @@ def tensorflow_as_native_dev(device: str, /): return ret -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - @tensorflow_handle_methods +@tensorflow_handle_array_like_without_promotion def tensorflow_split( x: Union[tensorflow.Tensor, tensorflow.Variable], /, @@ -1087,6 +1114,9 @@ def tensorflow_dev( return tensorflow_as_ivy_dev(dv) +default_device_stack = [] + + def tensorflow_default_device_bknd( device: Optional[Union[str, str]] = None, /, @@ -1193,27 +1223,21 @@ def tensorflow_nested_map_bknd( to_ignore = to_ignore + (class_instance,) tuple_check_fn = tensorflow_default_bknd( _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["tuple"] + else lambda x_, t_: type(x_) is t_, ) list_check_fn = tensorflow_default_bknd( _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["list"] + else lambda x_, t_: type(x_) is t_, ) dict_check_fn = tensorflow_default_bknd( _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), + (lambda x_, t_: isinstance(x_, t_)) + if include_derived["dict"] + else lambda x_, t_: type(x_) is t_, ) if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): ret_list = [ @@ -1382,6 +1406,9 @@ def _infer_dtype(obj): return _asarray_infer_dtype_wrapper +SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") + + @tensorflow_handle_array_like_without_promotion @tensorflow__asarray_to_native_arrays_and_back_bknd @tensorflow__asarray_infer_dtype_bknd @@ -1638,7 +1665,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2050,7 +2079,9 @@ def tensorflow__parse_query_bknd(query, x_shape, scatter=False): ( tensorflow_reshape_bknd_(arr, (-1,)) if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr + else tensorflow_expand_dims(arr) + if not len(arr.shape) + else arr ) for arr in array_queries ] @@ -2174,6 +2205,9 @@ def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): return tensorflow_to_scalar(self) +default_uint_dtype_stack = [] + + def tensorflow_default_uint_dtype_bknd( *, input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, @@ -2198,11 +2232,9 @@ def is_native(x): if tensorflow_nested_argwhere_bknd( input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), + lambda x: tensorflow_dtype(x) == "uint64" + if is_native(x) + else x > 9223372036854775807 and x != math.inf, stop_after_n_found=1, ): ret = tf.uint64 @@ -2410,7 +2442,9 @@ def tensorflow_multiply( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) @@ -2570,11 +2604,9 @@ def tensorflow_scatter_nd( dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) updates = tensorflow.cast( updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), + tensorflow_as_native_dtype(dtype) + if tensorflow_exists_bknd(out) + else updates_dtype, ) expected_shape = ( list(tensorflow.shape(indices)[:-1]) @@ -2614,21 +2646,6 @@ def tensorflow_scatter_nd( return res -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - @tensorflow_handle_set_item def tensorflow_set_item_bknd( x: Union[tensorflow.Tensor, tf.Tensor], diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow_where.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow_where.py index b1c272950179..3c1301a63194 100644 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow_where.py +++ b/ivy/compiler/_cache/Translated_Outputs/tensorflow_where_output/run_0/tensorflow_where.py @@ -24,7 +24,9 @@ def tensorflow_where( dtype = ( x1.dtype if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() + else x2.dtype + if hasattr(x2, "dtype") + else tensorflow_default_dtype_bknd() ) if not tensorflow_is_array_bknd(x1): x1 = tensorflow_asarray(x1, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index ac8335fe1e56..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import TypeVar -from typing import Protocol - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 06e137cf3452..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2671 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow_zeros_like.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow_zeros_like.py deleted file mode 100644 index 9588228d3656..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_like_output/run_0/tensorflow_zeros_like.py +++ /dev/null @@ -1,20 +0,0 @@ -import tensorflow - -from typing import Optional -from typing import Union - -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_zeros_like( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.zeros_like(x, dtype=dtype) diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/__init__.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow_NestedSequence_bknd.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow_NestedSequence_bknd.py deleted file mode 100644 index ac8335fe1e56..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow_NestedSequence_bknd.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import TypeVar -from typing import Protocol - -_T_co = TypeVar("_T_co", covariant=True) - - -class tensorflow_NestedSequence_bknd(Protocol[_T_co]): - def __getitem__(self, key: int, /): ... - - def __len__(self, /): ... diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow__helpers.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow__helpers.py deleted file mode 100644 index 06e137cf3452..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow__helpers.py +++ /dev/null @@ -1,2671 +0,0 @@ -from collections import UserDict -from numbers import Number -from numpy.core.numeric import normalize_axis_tuple -from operator import mul -from .tensorflow_NestedSequence_bknd import tensorflow_NestedSequence_bknd -from typing import Any -from typing import Callable -from typing import Dict -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TypeVar -from typing import Union -import functools -import inspect -import itertools -import math -import numpy as np -import re -import tensorflow -import tensorflow as tf - - -promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", - ("bool", "int8"): "int8", - ("bool", "int16"): "int16", - ("bool", "int32"): "int32", - ("bool", "int64"): "int64", - ("bool", "uint8"): "uint8", - ("bool", "uint16"): "uint16", - ("bool", "uint32"): "uint32", - ("bool", "uint64"): "uint64", - ("bool", "float16"): "float16", - ("bool", "float32"): "float32", - ("bool", "float64"): "float64", - ("bool", "bfloat16"): "bfloat16", - ("bool", "complex64"): "complex64", - ("bool", "complex128"): "complex128", - ("int8", "float16"): "float16", - ("int8", "float32"): "float32", - ("int8", "float64"): "float64", - ("int8", "bfloat16"): "bfloat16", - ("int8", "complex64"): "complex64", - ("int8", "complex128"): "complex128", - ("int16", "float32"): "float32", - ("int16", "float64"): "float64", - ("int16", "complex64"): "complex64", - ("int16", "complex128"): "complex128", - ("int32", "float64"): "float64", - ("int32", "complex128"): "complex128", - ("int64", "float64"): "float64", - ("int64", "complex128"): "complex128", - ("uint8", "float16"): "float16", - ("uint8", "float32"): "float32", - ("uint8", "float64"): "float64", - ("uint8", "bfloat16"): "bfloat16", - ("uint8", "complex64"): "complex64", - ("uint8", "complex128"): "complex128", - ("uint16", "float32"): "float32", - ("uint16", "float64"): "float64", - ("uint16", "complex64"): "complex64", - ("uint16", "complex128"): "complex128", - ("uint32", "float64"): "float64", - ("uint32", "complex128"): "complex128", - ("uint64", "int8"): "float64", - ("uint64", "int16"): "float64", - ("uint64", "int32"): "float64", - ("uint64", "int64"): "float64", - ("uint64", "float64"): "float64", - ("uint64", "complex128"): "complex128", - ("float16", "bfloat16"): "float32", - ("float16", "complex64"): "complex64", - ("float16", "complex128"): "complex128", - ("float32", "complex64"): "complex64", - ("float32", "complex128"): "complex128", - ("float64", "complex64"): "complex128", - ("float64", "complex128"): "complex128", - ("bfloat16", "float16"): "float32", - ("bfloat16", "float32"): "float32", - ("bfloat16", "float64"): "float64", - ("bfloat16", "bfloat16"): "bfloat16", - ("bfloat16", "complex64"): "complex64", - ("bfloat16", "complex128"): "complex128", - ("complex64", "float64"): "complex128", - ("complex64", "complex64"): "complex64", - ("complex64", "complex128"): "complex128", - ("complex128", "complex128"): "complex128", - ("float16", "int16"): "float32", - ("float16", "int32"): "float64", - ("float16", "int64"): "float64", - ("float16", "uint16"): "float32", - ("float16", "uint32"): "float64", - ("float16", "uint64"): "float64", - ("float32", "int32"): "float64", - ("float32", "int64"): "float64", - ("float32", "uint32"): "float64", - ("float32", "uint64"): "float64", - ("bfloat16", "int16"): "float32", - ("bfloat16", "int32"): "float64", - ("bfloat16", "int64"): "float64", - ("bfloat16", "uint16"): "float32", - ("bfloat16", "uint32"): "float64", - ("bfloat16", "uint64"): "float64", - ("complex64", "int32"): "complex128", - ("complex64", "int64"): "complex128", - ("complex64", "uint32"): "complex128", - ("complex64", "uint64"): "complex128", -} -array_api_promotion_table = { - ("bool", "bool"): "bool", - ("int8", "int8"): "int8", - ("int8", "int16"): "int16", - ("int8", "int32"): "int32", - ("int8", "int64"): "int64", - ("int16", "int16"): "int16", - ("int16", "int32"): "int32", - ("int16", "int64"): "int64", - ("int32", "int32"): "int32", - ("int32", "int64"): "int64", - ("int64", "int64"): "int64", - ("uint8", "int8"): "int16", - ("uint8", "int16"): "int16", - ("uint8", "int32"): "int32", - ("uint8", "int64"): "int64", - ("uint8", "uint8"): "uint8", - ("uint8", "uint16"): "uint16", - ("uint8", "uint32"): "uint32", - ("uint8", "uint64"): "uint64", - ("uint16", "int8"): "int32", - ("uint16", "int16"): "int32", - ("uint16", "int32"): "int32", - ("uint16", "int64"): "int64", - ("uint16", "uint16"): "uint16", - ("uint16", "uint32"): "uint32", - ("uint16", "uint64"): "uint64", - ("uint32", "int8"): "int64", - ("uint32", "int16"): "int64", - ("uint32", "int32"): "int64", - ("uint32", "int64"): "int64", - ("uint32", "uint32"): "uint32", - ("uint32", "uint64"): "uint64", - ("uint64", "uint64"): "uint64", - ("float16", "float16"): "float16", - ("float16", "float32"): "float32", - ("float16", "float64"): "float64", - ("float32", "float32"): "float32", - ("float32", "float64"): "float64", - ("float64", "float64"): "float64", -} -tf.experimental.numpy.experimental_enable_numpy_behavior(True) -default_device_stack = [] -SupportsBufferProtocol = TypeVar("SupportsBufferProtocol") -default_uint_dtype_stack = [] -default_complex_dtype_stack = [] -default_dtype_stack = [] -default_float_dtype_stack = [] -ivy_dtype_dict = { - tensorflow.int8: "int8", - tensorflow.int16: "int16", - tensorflow.int32: "int32", - tensorflow.int64: "int64", - tensorflow.uint8: "uint8", - tensorflow.uint16: "uint16", - tensorflow.uint32: "uint32", - tensorflow.uint64: "uint64", - tensorflow.bfloat16: "bfloat16", - tensorflow.float16: "float16", - tensorflow.float32: "float32", - tensorflow.float64: "float64", - tensorflow.complex64: "complex64", - tensorflow.complex128: "complex128", - tensorflow.bool: "bool", -} -default_int_dtype_stack = [] -backend = "" -native_dtype_dict = { - "int8": tensorflow.int8, - "int16": tensorflow.int16, - "int32": tensorflow.int32, - "int64": tensorflow.int64, - "uint8": tensorflow.uint8, - "uint16": tensorflow.uint16, - "uint32": tensorflow.uint32, - "uint64": tensorflow.uint64, - "bfloat16": tensorflow.bfloat16, - "float16": tensorflow.float16, - "float32": tensorflow.float32, - "float64": tensorflow.float64, - "complex64": tensorflow.complex64, - "complex128": tensorflow.complex128, - "bool": tensorflow.bool, -} - - -def tensorflow_infer_dtype(fn: Callable): - @functools.wraps(fn) - def _infer_dtype(*args, dtype=None, **kwargs): - arr = ( - None - if tensorflow_exists_bknd(dtype) - else tensorflow__get_first_array(*args, **kwargs) - ) - dtype = tensorflow_default_dtype_bknd(dtype=dtype, item=arr, as_native=True) - return fn(*args, dtype=dtype, **kwargs) - - _infer_dtype.infer_dtype = True - return _infer_dtype - - -def tensorflow_handle_array_like_without_promotion(fn: Callable): - @functools.wraps(fn) - def _handle_array_like_without_promotion(*args, **kwargs): - args = list(args) - num_args = len(args) - try: - type_hints = inspect.signature(fn).parameters - except (TypeError, ValueError): - return fn(*args, **kwargs) - parameters = list(type_hints.keys()) - annotations = [param.annotation for param in type_hints.values()] - device = tensorflow__get_preferred_device(args, kwargs) - for i, (annotation, parameter, arg) in enumerate( - zip(annotations, parameters, args) - ): - annotation_str = str(annotation) - if ( - ("rray" in annotation_str or "Tensor" in annotation_str) - and parameter != "out" - and all( - sq not in annotation_str - for sq in ["Sequence", "List", "Tuple", "float", "int", "bool"] - ) - ): - if i < num_args: - if tensorflow__check_in_nested_sequence( - arg, value=Ellipsis, _type=slice - ): - continue - if not tensorflow_is_array_bknd(arg): - args = tensorflow_set_item_bknd( - args, i, tensorflow_asarray(arg, device=device) - ) - elif parameters in kwargs: - kwarg = tensorflow_get_item(kwargs, parameter) - if not tensorflow_is_array_bknd(kwarg): - kwargs = tensorflow_set_item_bknd( - kwargs, parameter, tensorflow_asarray(kwarg, device=device) - ) - return fn(*args, **kwargs) - - _handle_array_like_without_promotion.handle_array_like_without_promotion = True - return _handle_array_like_without_promotion - - -def tensorflow_exists_bknd(x: Any, /): - return x is not None - - -def tensorflow_is_native_array(x, /, *, exclusive=False): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - return not exclusive - if isinstance(x, (tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray)): - if exclusive and isinstance(x, tensorflow.Variable): - return False - return True - return False - - -def tensorflow_is_ivy_array_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], /, *, exclusive: Optional[bool] = False -): - return isinstance(x, tensorflow.Tensor) and tensorflow_is_native_array( - x, exclusive=exclusive - ) - - -def tensorflow_is_array_bknd(x: Any, /, *, exclusive: bool = False): - return tensorflow_is_ivy_array_bknd( - x, exclusive=exclusive - ) or tensorflow_is_native_array(x, exclusive=exclusive) - - -def tensorflow_default_bknd( - x: Any, - /, - default_val: Any, - *, - catch_exceptions: bool = False, - rev: bool = False, - with_callable: bool = False, -): - with_callable = catch_exceptions or with_callable - if rev: - x, default_val = default_val, x - if with_callable: - x_callable = callable(x) - default_callable = callable(default_val) - else: - x_callable = False - default_callable = False - if catch_exceptions: - try: - x = x() if x_callable else x - except Exception: - return default_val() if default_callable else default_val - else: - x = x() if x_callable else x - return ( - x - if tensorflow_exists_bknd(x) - else default_val() if default_callable else default_val - ) - - -def tensorflow_nested_argwhere_bknd( - nest: Iterable, - fn: Callable, - check_nests: bool = False, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - _index: Optional[List] = None, - _base: bool = True, - stop_after_n_found: Optional[int] = None, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - _index = [] if _index is None else _index - if isinstance(nest, (tuple, list)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for i, item in enumerate(nest): - ind = ( - tensorflow_nested_argwhere_bknd( - item, - fn, - check_nests, - to_ignore, - _index + [i], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - item, fn, check_nests, to_ignore, _index + [i], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - if stop_after_n_found is not None and n >= stop_after_n_found: - break - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - elif isinstance(nest, (dict, UserDict)) and not isinstance(nest, to_ignore): - n = 0 - _indices = [] - for k, v in nest.items(): - ind = ( - tensorflow_nested_argwhere_bknd( - v, - fn, - check_nests, - to_ignore, - _index + [k], - False, - stop_after_n_found - n, - ) - if stop_after_n_found is not None - else tensorflow_nested_argwhere_bknd( - v, fn, check_nests, to_ignore, _index + [k], False, None - ) - ) - if stop_after_n_found is not None and ind: - if n >= stop_after_n_found: - break - n = n + len(ind) - _indices = _indices + [ind] - _indices = [idx for idxs in _indices if idxs for idx in idxs] - if check_nests and fn(nest): - _indices.append(_index) - else: - cond_met = fn(nest) - if cond_met: - return [_index] - return False - return [index for index in _indices if index] - - -def tensorflow__check_float64_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "float64" - if math.isfinite(input): - m, e = math.frexp(input) - return abs(input) > 3.4028235e38 or e < -126 or e > 128 - return False - - -def tensorflow_as_ivy_dtype_bknd(dtype_in: Union[str, str], /): - return tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_is_complex_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "complex" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (complex, np.complexfloating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (complex, np.complexfloating)) - or tensorflow_is_array_bknd(x) - and "complex" in tensorflow_dtype(x), - ) - return "complex" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_as_native_dev(device: str, /): - if isinstance(device, str) and "/" in device: - return device - ret = f"/{str(device).upper()}" - if not ret[-1].isnumeric(): - ret += ":0" - return ret - - -def tensorflow_handle_methods(fn): - def extract_function_name(s): - match = re.search("_(.+?)(?:_\\d+)?$", s) - if match: - return match.group(1) - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - if tensorflow_is_array_bknd(args[0]): - return fn(*args, **kwargs) - else: - pattern = "_bknd_|_bknd|_frnt_|_frnt" - fn_name = extract_function_name(re.sub(pattern, "", fn.__name__)) - new_fn = getattr(args[0], fn_name) - return new_fn(*args[1:], **kwargs) - - return wrapper - - -@tensorflow_handle_methods -def tensorflow_split( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], Union[tensorflow.Tensor, tensorflow.Variable]] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - if x.shape == (): - if num_or_size_splits is not None and num_or_size_splits != 1: - raise Exception( - f"input array had no shape, but num_sections specified was {num_or_size_splits}" - ) - return [x] - if num_or_size_splits is None: - dim_size = tensorflow.shape(x)[axis] - num_or_size_splits = int(dim_size) - if isinstance(num_or_size_splits, (tensorflow.Tensor, tensorflow.Variable)): - num_or_size_splits = tensorflow.cast(num_or_size_splits, tensorflow.int32) - elif isinstance(num_or_size_splits, int) and with_remainder: - num_chunks = x.shape[axis] / num_or_size_splits - num_chunks_int = math.floor(num_chunks) - remainder = num_chunks - num_chunks_int - if remainder != 0: - num_or_size_splits = [num_or_size_splits] * num_chunks_int + [ - int(remainder * num_or_size_splits) - ] - return tensorflow.split(x, num_or_size_splits, axis) - - -@tensorflow_handle_methods -def tensorflow_split_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - num_or_size_splits: Optional[ - Union[int, Sequence[int], tensorflow.Tensor, tf.Tensor] - ] = None, - axis: int = 0, - with_remainder: bool = False, -): - return tensorflow_split( - self, - copy=copy, - num_or_size_splits=num_or_size_splits, - axis=axis, - with_remainder=with_remainder, - ) - - -def tensorflow_as_ivy_dev(device: str, /): - if isinstance(device, str) and "/" not in device: - return str(device) - dev_in_split = tensorflow_split_bknd_(device[1:], ":")[-2:] - if len(dev_in_split) == 1: - return str(dev_in_split[0]) - dev_type, dev_idx = dev_in_split[0], dev_in_split[1] - dev_type = dev_type.lower() - if dev_type == "cpu": - return str(dev_type) - return str(f"{dev_type}:{dev_idx}") - - -def tensorflow_stack( - arrays: Union[Tuple[tensorflow.Tensor], List[tensorflow.Tensor]], - /, - *, - axis: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - return tensorflow.experimental.numpy.stack(arrays, axis) - except ValueError as e: - raise Exception(e) from e - - -def tensorflow_stack_bknd_( - self: tensorflow.Tensor, - /, - arrays: Union[ - Tuple[Union[tensorflow.Tensor, tf.Tensor]], - List[Union[tensorflow.Tensor, tf.Tensor]], - ], - *, - axis: int = 0, - out: Optional[tensorflow.Tensor] = None, -): - if not isinstance(arrays, (tuple, list)): - arrays = [arrays] - if isinstance(arrays, tuple): - x = (self,) + arrays - else: - x = [self] + arrays - return tensorflow_stack(x, axis=axis, out=out) - - -def tensorflow_dev( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - /, - *, - as_native: bool = False, -): - if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__): - x = x.value - if isinstance(x, tensorflow.TensorArray): - x = tensorflow_stack_bknd_(x) - dv = x.device - if as_native: - return dv - dv = dv if dv else tensorflow_default_device_bknd(as_native=False) - return tensorflow_as_ivy_dev(dv) - - -def tensorflow_default_device_bknd( - device: Optional[Union[str, str]] = None, - /, - *, - item: Optional[Union[list, tuple, dict, tensorflow.Tensor, tf.Tensor]] = None, - as_native: Optional[bool] = None, -): - if tensorflow_exists_bknd(device): - if as_native is True: - return tensorflow_as_native_dev(device) - elif as_native is False: - return tensorflow_as_ivy_dev(device) - return device - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_array_bknd(item): - return tensorflow_dev(item, as_native=as_native) - global default_device_stack - if not default_device_stack: - ret = "cpu" - else: - ret = default_device_stack[-1] - if as_native: - return tensorflow_as_native_dev(ret) - return tensorflow_as_ivy_dev(ret) - - -def tensorflow__get_preferred_device(args, kwargs): - device = None - if "device" in kwargs and kwargs["device"] is not None: - return device - if not False: - arr_arg = tensorflow__get_first_array(*args, **kwargs) - return tensorflow_default_device_bknd(item=arr_arg, as_native=True) - return tensorflow_default_device_bknd(as_native=True) - - -def tensorflow__check_in_nested_sequence(sequence, value=None, _type=None): - if sequence is value or isinstance(sequence, _type): - return True - elif isinstance(sequence, (tuple, list)): - if any(isinstance(_val, _type) or _val is value for _val in sequence): - return True - else: - return any( - tensorflow__check_in_nested_sequence(sub_sequence, value, _type) - for sub_sequence in sequence - if isinstance(sub_sequence, (tuple, list)) - ) - - -def tensorflow_is_variable(x, /, *, exclusive=False): - return isinstance(x, tensorflow.Variable) - - -def tensorflow_variable(x, /): - with tensorflow.device(tensorflow_dev(x, as_native=True)): - return tensorflow.Variable(x, trainable=True) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_stop_gradient( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - preserve_type: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - is_var = tensorflow_is_variable(x) - x = tensorflow.stop_gradient(x) - if is_var and preserve_type: - return tensorflow_variable(x) - return x - - -def tensorflow_nested_map_bknd( - fn: Callable, - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - /, - include_derived: Optional[Union[Dict[str, bool], bool]] = None, - to_ignore: Optional[Union[type, Tuple[type]]] = None, - to_mutable: bool = False, - _tuple_check_fn: Optional[Callable] = None, - _list_check_fn: Optional[Callable] = None, - _dict_check_fn: Optional[Callable] = None, - shallow: bool = True, -): - to_ignore = tensorflow_default_bknd(to_ignore, ()) - if include_derived is True: - include_derived = {"tuple": True, "list": True, "dict": True} - elif not include_derived: - include_derived = {} - for t in ("tuple", "list", "dict"): - if t not in include_derived: - include_derived = tensorflow_set_item_bknd(include_derived, t, False) - class_instance = type(x) - if ( - hasattr(x, "is_tracked_proxy") - and hasattr(class_instance, "__bases__") - and not set(class_instance.__bases__).intersection(set(to_ignore)) - ): - to_ignore = to_ignore + (class_instance,) - tuple_check_fn = tensorflow_default_bknd( - _tuple_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["tuple"] - else lambda x_, t_: type(x_) is t_ - ), - ) - list_check_fn = tensorflow_default_bknd( - _list_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["list"] - else lambda x_, t_: type(x_) is t_ - ), - ) - dict_check_fn = tensorflow_default_bknd( - _dict_check_fn, - ( - (lambda x_, t_: isinstance(x_, t_)) - if include_derived["dict"] - else lambda x_, t_: type(x_) is t_ - ), - ) - if tuple_check_fn(x, tuple) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if to_mutable: - return ret_list - elif hasattr(x, "_fields"): - return class_instance(**dict(zip(x._fields, ret_list))) - else: - return class_instance(ret_list) - elif list_check_fn(x, list) and not isinstance(x, to_ignore): - ret_list = [ - tensorflow_nested_map_bknd( - fn, - i, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for i in x - ] - if shallow: - x = tensorflow_set_item_bknd(x, slice(None, None, None), ret_list[:]) - return x - return class_instance(ret_list) - elif (dict_check_fn(x, dict) or isinstance(x, UserDict)) and not isinstance( - x, to_ignore - ): - class_instance = type(x) - ret = { - k: tensorflow_nested_map_bknd( - fn, - v, - include_derived, - to_ignore, - to_mutable, - tuple_check_fn, - list_check_fn, - dict_check_fn, - shallow, - ) - for k, v in x.items() - } - if shallow: - x.update(ret) - return x - return class_instance(ret) - elif isinstance(x, slice): - return slice(*tensorflow_nested_map_bknd(fn, [x.start, x.stop, x.step])) - return fn(x) - - -def tensorflow__to_ivy_bknd_(x: Any): - if isinstance(x, tensorflow.Tensor): - return x - elif isinstance(x, tf.TensorShape): - return tuple(x) - elif isinstance(x, dict): - return x.to_ivy() - if tensorflow_is_native_array(x) or isinstance(x, np.ndarray): - return tensorflow.convert_to_tensor(x) - return x - - -def tensorflow_to_ivy_bknd_( - x: Union[tensorflow.Tensor, tf.Tensor, Iterable], - nested: bool = False, - include_derived: Optional[Dict[str, bool]] = None, -): - if nested: - return tensorflow_nested_map_bknd( - tensorflow__to_ivy_bknd_, x, include_derived, shallow=False - ) - return tensorflow__to_ivy_bknd_(x) - - -def tensorflow__asarray_to_native_arrays_and_back_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_to_native_arrays_and_back_wrapper(*args, dtype=None, **kwargs): - new_arg = args[0] - new_args = (new_arg,) + args[1:] - if dtype is not None: - dtype = tensorflow_default_dtype_bknd(dtype=dtype, as_native=True) - return tensorflow_to_ivy_bknd_(fn(*new_args, dtype=dtype, **kwargs)) - - _asarray_to_native_arrays_and_back_wrapper._asarray_to_native_arrays_and_back = True - return _asarray_to_native_arrays_and_back_wrapper - - -def tensorflow__flatten_nest_bknd(xs): - for x in xs: - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - yield from tensorflow__flatten_nest_bknd(x) - else: - yield x - - -def tensorflow_promote_types_bknd( - type1: Union[str, tf.DType], - type2: Union[str, tf.DType], - /, - *, - array_api_promotion: bool = False, -): - if not (type1 and type2): - return type1 if type1 else type2 - query = [tensorflow_as_ivy_dtype(type1), tensorflow_as_ivy_dtype(type2)] - query = tuple(query) - if query not in promotion_table: - query = query[1], query[0] - - def _promote(query): - if array_api_promotion: - return tensorflow_get_item(array_api_promotion_table, query) - return tensorflow_get_item(promotion_table, query) - - return _promote(query) - - -def tensorflow__asarray_infer_dtype_bknd(fn: Callable): - @functools.wraps(fn) - def _asarray_infer_dtype_wrapper(*args, dtype=None, **kwargs): - def _infer_dtype(obj): - if isinstance(obj, tf.TensorShape): - obj = list(obj) - if hasattr(obj, "dtype"): - return obj.dtype.name if isinstance(obj, np.ndarray) else obj.dtype - else: - return tensorflow_default_dtype_bknd(item=obj) - - if not tensorflow_exists_bknd(dtype): - arr = args[0] - dtype_list = [ - tensorflow_nested_map_bknd( - lambda x: _infer_dtype(x), arr, shallow=False - ) - ] - dtype_list = tensorflow__flatten_nest_bknd(dtype_list) - dtype_list = list(set(dtype_list)) - if len(dtype_list) != 0: - dtype = dtype_list[0] - for dt in dtype_list[1:]: - dtype = tensorflow_promote_types_bknd(dtype, dt) - else: - dtype = tensorflow_default_float_dtype_bknd() - dtype = tensorflow_as_native_dtype(dtype) - return fn(*args, dtype=dtype, **kwargs) - - _asarray_infer_dtype_wrapper.infer_dtype = True - return _asarray_infer_dtype_wrapper - - -@tensorflow_handle_array_like_without_promotion -@tensorflow__asarray_to_native_arrays_and_back_bknd -@tensorflow__asarray_infer_dtype_bknd -def tensorflow_asarray( - obj: Union[ - tensorflow.Tensor, - tensorflow.Variable, - tensorflow.TensorShape, - bool, - int, - float, - tensorflow_NestedSequence_bknd, - SupportsBufferProtocol, - np.ndarray, - ], - /, - *, - copy: Optional[bool] = None, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - with tensorflow.device(device): - if tensorflow.is_tensor(obj): - ret = tensorflow.cast(obj, dtype) if obj.dtype != dtype else obj - elif ( - dtype is not None - and dtype.is_integer - and np.issubdtype(np.array(obj).dtype, np.floating) - ): - obj_np = np.array(obj) - ret = tensorflow.convert_to_tensor(obj_np, dtype) - else: - ret = tensorflow.convert_to_tensor(obj, dtype) - return ( - tensorflow.identity(ret) - if copy or tensorflow_as_native_dev(tensorflow_dev(ret)) != device - else ret - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_size(x: tensorflow.Tensor, /): - return functools.reduce(mul, x.shape) if len(x.shape) > 0 else 1 - - -def tensorflow_size_bknd_(self): - return tensorflow_size(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_unstack( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - if x.shape == (): - return [x] - ret = tensorflow.unstack(x, axis=axis) - if keepdims: - return [tensorflow.expand_dims(r, axis) for r in ret] - return ret - - -def tensorflow_unstack_bknd_( - self: tensorflow.Tensor, - /, - *, - copy: Optional[bool] = None, - axis: int = 0, - keepdims: bool = False, -): - return tensorflow_unstack(self, copy=copy, axis=axis, keepdims=keepdims) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_copy_array( - x: Union[tensorflow.Tensor, tensorflow.Variable, tensorflow.TensorArray], - *, - to_ivy_array: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if isinstance(x, tensorflow.TensorArray): - x_wrapped = tensorflow_stack_bknd_(x) - y = tensorflow.TensorArray(x.dtype, tensorflow_size_bknd_(x)()) - x = tensorflow_unstack_bknd_(y, tensorflow_copy_array(x_wrapped)) - else: - x = tensorflow.identity(x) - if to_ivy_array: - return tensorflow_to_ivy_bknd_(x) - return x - - -def tensorflow_tile( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - repeats: Sequence[int], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if x.shape == (): - x = tensorflow.reshape(x, (-1,)) - if isinstance(repeats, Number): - repeats = [repeats] - if isinstance(repeats, tensorflow.Tensor) and repeats.shape == (): - repeats = tensorflow.reshape(repeats, (-1,)) - if len(x.shape) < len(repeats): - while len(x.shape) != len(repeats): - x = tensorflow.expand_dims(x, 0) - elif len(x.shape) > len(repeats): - repeats = list(repeats) - while len(x.shape) != len(repeats): - repeats = [1] + repeats - return tensorflow.tile(x, repeats) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_nonzero( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - as_tuple: bool = True, - size: Optional[int] = None, - fill_value: Number = 0, -): - res = tensorflow.experimental.numpy.nonzero(x) - if size is not None: - dtype = tensorflow.int64 - if isinstance(fill_value, float): - dtype = tensorflow.float64 - res = tensorflow.cast(res, dtype) - diff = size - res[0].shape[0] - if diff > 0: - res = tensorflow.pad(res, [[0, 0], [0, diff]], constant_values=fill_value) - elif diff < 0: - res = tensorflow.slice(res, [0, 0], [-1, size]) - if as_tuple: - return tuple(res) - return tensorflow.stack(res, axis=1) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_diff( - x: Union[tensorflow.Tensor, tensorflow.Variable, list, tuple], - /, - *, - n: int = 1, - axis: int = -1, - prepend: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - append: Optional[ - Union[tensorflow.Tensor, tensorflow.Variable, int, float, list, tuple] - ] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if n == 0: - return x - if prepend is not None: - x = tensorflow.experimental.numpy.append( - prepend, x, axis=axis if axis != -1 else None - ) - if append is not None: - x = tensorflow.experimental.numpy.append( - x, append, axis=axis if axis != -1 else None - ) - return tensorflow.experimental.numpy.diff(x, n=n, axis=axis) - - -def tensorflow__parse_ellipsis_bknd(so, ndims): - pre = list() - for s in so: - if s is Ellipsis: - break - pre.append(s) - post = list() - for s in reversed(so): - if s is Ellipsis: - break - post.append(s) - ret = list( - pre - + [slice(None, None, None) for _ in range(ndims - len(pre) - len(post))] - + list(reversed(post)) - ) - return ret, (len(pre), ndims - len(post)) - - -def tensorflow_broadcast_arrays(*arrays: Union[tensorflow.Tensor, tensorflow.Variable]): - if len(arrays) > 1: - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - tensorflow.shape(arrays[0]), tensorflow.shape(arrays[1]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - if len(arrays) > 2: - for i in range(2, len(arrays)): - try: - desired_shape = tensorflow.broadcast_dynamic_shape( - desired_shape, tensorflow.shape(arrays[i]) - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - else: - return [arrays[0]] - result = [] - for tensor in arrays: - result.append(tensorflow.broadcast_to(tensor, desired_shape)) - return result - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_astype( - x: Union[tensorflow.Tensor, tensorflow.Variable], - dtype: Union[tf.DType, str], - /, - *, - copy: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if x.dtype == dtype: - return tensorflow.experimental.numpy.copy(x) if copy else x - return tensorflow.cast(x, dtype) - - -def tensorflow_astype_bknd_( - self: tensorflow.Tensor, - dtype: str, - /, - *, - copy: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_astype(self, dtype, copy=copy, out=out) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_where( - condition: Union[tensorflow.Tensor, tensorflow.Variable], - x1: Union[tensorflow.Tensor, tensorflow.Variable], - x2: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.cast( - tensorflow.experimental.numpy.where(condition, x1, x2), x1.dtype - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_arange( - start: float, - /, - stop: Optional[float] = None, - step: float = 1, - *, - dtype: Optional[tensorflow.DType] = None, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if stop is None: - stop = start - start = 0 - if step > 0 and start > stop or step < 0 and start < stop: - if isinstance(stop, float): - stop = float(start) - else: - stop = start - if isinstance(start, (float, int)): - start = tensorflow.convert_to_tensor(start) - if isinstance(stop, (float, int)): - stop = tensorflow.convert_to_tensor(stop) - if isinstance(step, (float, int)): - step = tensorflow.convert_to_tensor(step) - if dtype is None: - if isinstance(start, int) and isinstance(stop, int) and isinstance(step, int): - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), - tensorflow.int32, - ) - else: - return tensorflow.range(start, stop, delta=step) - else: - dtype = tensorflow_as_native_dtype(tensorflow_default_dtype_bknd(dtype=dtype)) - if dtype in [ - tensorflow.int8, - tensorflow.uint8, - tensorflow.int16, - tensorflow.uint16, - tensorflow.uint32, - tensorflow.uint64, - ]: - return tensorflow.cast( - tensorflow.range(start, stop, delta=step, dtype=tensorflow.int64), dtype - ) - else: - return tensorflow.range(start, stop, delta=step, dtype=dtype) - - -def tensorflow__parse_slice_bknd(idx, s): - step = 1 if idx.step is None else idx.step - if step > 0: - start = 0 if idx.start is None else idx.start - if start >= s: - stop = start - else: - if start <= -s: - start = 0 - elif start < 0: - start = start + s - stop = s if idx.stop is None else idx.stop - if stop > s: - stop = s - elif start <= -s: - stop = 0 - elif stop < 0: - stop = stop + s - else: - start = s - 1 if idx.start is None else idx.start - if start < -s: - stop = start - else: - if start >= s: - start = s - 1 - elif start < 0: - start = start + s - if idx.stop is None: - stop = -1 - else: - stop = idx.stop - if stop > s: - stop = s - elif stop < -s: - stop = -1 - elif stop == -s: - stop = 0 - elif stop < 0: - stop = stop + s - q_i = tensorflow_arange(start, stop, step) - ag__result_list_0 = [] - for q in q_i: - if 0 <= q < s: - res = q - ag__result_list_0.append(res) - q_i = ag__result_list_0 - q_i = ( - tensorflow_asarray(q_i) - if len(q_i) or start == stop or idx.stop is not None - else tensorflow_arange(0, s, 1) - ) - return q_i - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_shape( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, as_array: bool = False -): - if as_array: - return tensorflow_asarray( - tensorflow.shape(x), dtype=tensorflow_default_int_dtype_bknd() - ) - else: - return tuple(x.shape) - - -def tensorflow__deep_flatten_bknd(iterable): - def _flatten_gen(iterable): - for item in iterable: - if isinstance(item, list): - yield from _flatten_gen(item) - else: - yield item - - return list(_flatten_gen(iterable)) - - -def tensorflow__calculate_out_shape_bknd(axis, array_shape): - if type(axis) not in (tuple, list): - axis = (axis,) - out_dims = len(axis) + len(array_shape) - norm_axis = normalize_axis_tuple(axis, out_dims) - shape_iter = iter(array_shape) - ag__result_list_0 = [] - for current_ax in range(out_dims): - res = 1 if current_ax in norm_axis else next(shape_iter) - ag__result_list_0.append(res) - out_shape = ag__result_list_0 - return out_shape - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_expand_dims( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - copy: Optional[bool] = None, - axis: Union[int, Sequence[int]] = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - try: - out_shape = tensorflow__calculate_out_shape_bknd(axis, tensorflow.shape(x)) - ret = tensorflow.reshape(x, shape=out_shape) - return ret - except (tensorflow.errors.InvalidArgumentError, np.AxisError) as error: - raise Exception(error) from error - - -def tensorflow_check_elem_in_list(elem, list, inverse=False, message=""): - if inverse and elem in list: - raise Exception( - message if message != "" else f"{elem} must not be one of {list}" - ) - elif not inverse and elem not in list: - raise Exception(message if message != "" else f"{elem} must be one of {list}") - - -def tensorflow__reshape_fortran_tf(x, shape): - if len(x.shape) > 0: - x = tensorflow.transpose(x) - return tensorflow.transpose(tensorflow.reshape(x, shape[::-1])) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_reshape( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_elem_in_list(order, ["C", "F"]) - if not allowzero: - shape = [ - (new_s if con else old_s) - for new_s, con, old_s in zip( - shape, tensorflow.constant(shape) != 0, x.shape - ) - ] - if order == "F": - return tensorflow__reshape_fortran_tf(x, shape) - return tensorflow.reshape(x, shape) - - -def tensorflow_reshape_bknd_( - self: tensorflow.Tensor, - /, - shape: Union[tuple, tf.TensorShape, Sequence[int]], - *, - copy: Optional[bool] = None, - order: str = "C", - allowzero: bool = True, - out: Optional[tensorflow.Tensor] = None, -): - return tensorflow_reshape( - self, shape, copy=copy, allowzero=allowzero, out=out, order=order - ) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_meshgrid( - *arrays: Union[tensorflow.Tensor, tensorflow.Variable], - sparse: bool = False, - indexing: str = "xy", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if not sparse: - return tensorflow.meshgrid(*arrays, indexing=indexing) - sd = (1,) * len(arrays) - ag__result_list_0 = [] - for i, a in enumerate(arrays): - res = tensorflow.reshape( - tensorflow.convert_to_tensor(a), sd[:i] + (-1,) + sd[i + 1 :] - ) - ag__result_list_0.append(res) - res = ag__result_list_0 - if indexing == "xy" and len(arrays) > 1: - res[0] = tensorflow.reshape(res[0], (1, -1) + sd[2:]) - res[1] = tensorflow.reshape(res[1], (-1, 1) + sd[2:]) - return res - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_empty( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.experimental.numpy.empty(shape, dtype=tensorflow.float32) - - -def tensorflow__parse_query_bknd(query, x_shape, scatter=False): - query = (query,) if not isinstance(query, tuple) else query - ag__result_list_0 = [] - for q in query: - res = tensorflow_asarray(q) if isinstance(q, (tuple, list, int)) else q - ag__result_list_0.append(res) - query = ag__result_list_0 - ag__result_list_1 = [] - for i, q in enumerate(query): - if tensorflow_is_array_bknd(q): - res = i - ag__result_list_1.append(res) - non_slice_q_idxs = ag__result_list_1 - to_front = ( - len(non_slice_q_idxs) > 1 - and any(tensorflow_diff(non_slice_q_idxs) != 1) - and non_slice_q_idxs[-1] < len(x_shape) - ) - ag__result_list_2 = [] - for i, q in enumerate(query): - if q is None: - res = i - ag__result_list_2.append(res) - new_axes = ag__result_list_2 - ag__result_list_3 = [] - for q in query: - if q is not None: - res = q - ag__result_list_3.append(res) - query = ag__result_list_3 - query = [Ellipsis] if query == [] else query - ellipsis_inds = None - if any(q is Ellipsis for q in query): - query, ellipsis_inds = tensorflow__parse_ellipsis_bknd(query, len(x_shape)) - ag__result_list_4 = [] - for i, v in enumerate(query): - if tensorflow_is_array_bknd(v): - res = i - ag__result_list_4.append(res) - array_inds = ag__result_list_4 - if array_inds: - array_queries = tensorflow_broadcast_arrays( - *[v for i, v in enumerate(query) if i in array_inds] - ) - array_queries = [ - ( - tensorflow_nonzero(q, as_tuple=False)[0] - if tensorflow_is_bool_dtype_bknd(q) - else q - ) - for q in array_queries - ] - array_queries = [ - ( - tensorflow_astype_bknd_( - tensorflow_where( - arr < 0, arr + tensorflow_get_item(x_shape, i), arr - ), - tf.int64, - ) - if tensorflow_size_bknd_(arr) - else tensorflow_astype_bknd_(arr, tf.int64) - ) - for arr, i in zip(array_queries, array_inds) - ] - for idx, arr in zip(array_inds, array_queries): - query = tensorflow_set_item_bknd(query, idx, arr) - ag__result_list_5 = [] - for i, q in enumerate(query): - res = ( - tensorflow_astype_bknd_( - tensorflow__parse_slice_bknd(q, tensorflow_get_item(x_shape, i)), - tf.int64, - ) - if isinstance(q, slice) - else q - ) - ag__result_list_5.append(res) - query = ag__result_list_5 - if len(query) < len(x_shape): - query = query + [ - tensorflow_astype_bknd_(tensorflow_arange(0, s, 1), tf.int64) - for s in tensorflow_get_item(x_shape, slice(len(query), None, None)) - ] - if len(array_inds) and to_front: - target_shape = ( - [list(array_queries[0].shape)] - + [ - list(tensorflow_get_item(query, i).shape) - for i in range(len(query)) - if i not in array_inds - ] - + [[] for _ in range(len(array_inds) - 1)] - ) - elif len(array_inds): - target_shape = ( - [list(tensorflow_get_item(query, i).shape) for i in range(0, array_inds[0])] - + [list(tensorflow_shape(array_queries[0], as_array=True))] - + [[] for _ in range(len(array_inds) - 1)] - + [ - list(tensorflow_shape(tensorflow_get_item(query, i), as_array=True)) - for i in range(array_inds[-1] + 1, len(query)) - ] - ) - else: - target_shape = [list(q.shape) for q in query] - if ellipsis_inds is not None: - target_shape = ( - tensorflow_get_item(target_shape, slice(None, ellipsis_inds[0], None)) - + [ - tensorflow_get_item( - target_shape, slice(ellipsis_inds[0], ellipsis_inds[1], None) - ) - ] - + tensorflow_get_item(target_shape, slice(ellipsis_inds[1], None, None)) - ) - for i, ax in enumerate(new_axes): - if len(array_inds) and to_front: - ax = ax - (sum(1 for x in array_inds if x < ax) - 1) - ax = ax + i - target_shape = [ - *tensorflow_get_item(target_shape, slice(None, ax, None)), - 1, - *tensorflow_get_item(target_shape, slice(ax, None, None)), - ] - target_shape = tensorflow__deep_flatten_bknd(target_shape) - ag__result_list_6 = [] - for q in query: - res = tensorflow_expand_dims(q) if not len(q.shape) else q - ag__result_list_6.append(res) - query = ag__result_list_6 - if len(array_inds): - array_queries = [ - ( - tensorflow_reshape_bknd_(arr, (-1,)) - if len(arr.shape) > 1 - else tensorflow_expand_dims(arr) if not len(arr.shape) else arr - ) - for arr in array_queries - ] - array_queries = tensorflow_stack(array_queries, axis=1) - if len(array_inds) == len(query): - indices = tensorflow_reshape_bknd_(array_queries, (*target_shape, len(x_shape))) - elif len(array_inds) == 0: - indices = tensorflow_reshape_bknd_( - tensorflow_stack(tensorflow_meshgrid(*query, indexing="ij"), axis=-1), - (*target_shape, len(x_shape)), - ) - elif to_front: - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i not in array_inds], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - len(array_inds)), - ) - if len(array_inds) < len(query) - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*arr, *post) - for arr, post in itertools.product( - array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - else: - pre_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i < array_inds[0]], - indexing="ij", - ), - axis=-1, - ), - (-1, array_inds[0]), - ) - if array_inds[0] > 0 - else tensorflow_empty((1, 0)) - ) - post_array_queries = ( - tensorflow_reshape_bknd_( - tensorflow_stack( - tensorflow_meshgrid( - *[v for i, v in enumerate(query) if i > array_inds[-1]], - indexing="ij", - ), - axis=-1, - ), - (-1, len(query) - 1 - array_inds[-1]), - ) - if array_inds[-1] < len(query) - 1 - else tensorflow_empty((1, 0)) - ) - indices = tensorflow_reshape_bknd_( - tensorflow_asarray( - [ - (*pre, *arr, *post) - for pre, arr, post in itertools.product( - pre_array_queries, array_queries, post_array_queries - ) - ] - ), - (*target_shape, len(x_shape)), - ) - return ( - tensorflow_astype_bknd_(indices, tf.int64), - target_shape, - array_inds if len(array_inds) and to_front else None, - ) - - -def tensorflow_get_num_dims(x, /, *, as_array=False): - return ( - tensorflow.cast(tensorflow.shape(tensorflow.shape(x))[0], tensorflow.int64) - if as_array - else int(tensorflow.shape(tensorflow.shape(x))) - ) - - -def tensorflow_to_numpy( - x: Union[tensorflow.Tensor, tensorflow.Variable], /, *, copy: bool = True -): - if ( - tensorflow_is_array_bknd(x) - and tensorflow_get_num_dims(x) == 0 - and tensorflow_as_native_dtype(x.dtype) is tensorflow.bfloat16 - ): - x = tensorflow.expand_dims(x, 0) - if copy: - return np.squeeze(np.array(tensorflow.convert_to_tensor(x)), 0) - else: - return np.squeeze(np.asarray(tensorflow.convert_to_tensor(x)), 0) - if copy: - return np.array(tensorflow.convert_to_tensor(x)) - else: - return np.asarray(tensorflow.convert_to_tensor(x)) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_to_scalar(x: Union[tensorflow.Tensor, tensorflow.Variable], /): - ret = tensorflow_to_numpy(x).item() - if x.dtype == tensorflow.bfloat16: - return float(ret) - return ret - - -def tensorflow_to_scalar_bknd_(self: tensorflow.Tensor): - return tensorflow_to_scalar(self) - - -def tensorflow_is_float_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "float" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (float, np.floating)) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, (float, np.floating)) - or tensorflow_is_array_bknd(x) - and "float" in tensorflow_dtype(x), - ) - ) - return "float" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_is_uint_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "uint" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, np.unsignedinteger) - elif isinstance(dtype_in, (list, tuple, dict)): - return tensorflow_nested_argwhere_bknd( - dtype_in, - lambda x: isinstance(x, np.unsignedinteger) - or tensorflow_is_array_bknd(x) - and "uint" in tensorflow_dtype(x), - ) - return "uint" in tensorflow_as_ivy_dtype_bknd(dtype_in) - - -def tensorflow_default_uint_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - uint_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_uint_dtype_stack - if tensorflow_exists_bknd(uint_dtype): - if as_native is True: - return tensorflow_as_native_dtype(uint_dtype) - return str(tensorflow_as_ivy_dtype(uint_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = input.dtype - elif isinstance(input, (list, tuple, dict)): - - def is_native(x): - return tensorflow_is_native_array(x) - - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if is_native(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif isinstance(input, Number): - if input > 4294967295 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - elif default_uint_dtype_stack: - ret = default_uint_dtype_stack[-1] - else: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_uint_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "uint32" - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_is_int_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, tuple): - dtype_in = tensorflow_default_int_dtype_bknd() - elif isinstance(dtype_in, np.ndarray): - return "int" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (int, np.integer)) and not isinstance( - dtype_in, bool - ) - elif isinstance(dtype_in, (list, tuple, dict)): - - def nested_fun(x): - return ( - isinstance(x, (int, np.integer)) - or tensorflow_is_array_bknd(x) - and "int" in tensorflow_dtype(x) - ) and x is not bool - - return bool(tensorflow_nested_argwhere_bknd(dtype_in, nested_fun)) - return "int" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_infer_default_dtype_bknd( - dtype: Union[str, tf.DType, str], as_native: bool = False -): - if tensorflow_is_complex_dtype_bknd(dtype): - default_dtype = tensorflow_default_complex_dtype_bknd(as_native=as_native) - elif tensorflow_is_float_dtype_bknd(dtype): - default_dtype = tensorflow_default_float_dtype_bknd(as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(dtype): - default_dtype = tensorflow_default_uint_dtype_bknd(as_native=as_native) - elif tensorflow_is_int_dtype_bknd(dtype): - default_dtype = tensorflow_default_int_dtype_bknd(as_native=as_native) - elif as_native: - default_dtype = tensorflow_as_native_dtype("bool") - else: - default_dtype = tensorflow_as_ivy_dtype("bool") - return default_dtype - - -def tensorflow_dtype_bits(dtype_in: Union[tensorflow.DType, str, np.dtype], /): - dtype_str = tensorflow_as_ivy_dtype(dtype_in) - if "bool" in dtype_str: - return 1 - return int( - dtype_str.replace("tf.", "") - .replace("uint", "") - .replace("int", "") - .replace("bfloat", "") - .replace("float", "") - .replace("complex", "") - ) - - -def tensorflow__infer_dtype(dtype: tensorflow.DType): - default_dtype = tensorflow_infer_default_dtype_bknd(dtype) - if tensorflow_dtype_bits(dtype) < tensorflow_dtype_bits(default_dtype): - return default_dtype - return dtype - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_prod( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - dtype: Optional[tensorflow.DType] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - dtype = tensorflow_as_native_dtype(dtype) - if dtype is None: - dtype = tensorflow__infer_dtype(x.dtype) - axis = tuple(axis) if isinstance(axis, list) else axis - return tensorflow.experimental.numpy.prod( - x, axis=axis, dtype=dtype, keepdims=keepdims - ) - - -def tensorflow__numel_bknd(shape): - shape = tuple(shape) - return tensorflow_to_scalar_bknd_(tensorflow_prod(shape)) if shape != () else 1 - - -def tensorflow_check_one_way_broadcastable(x1, x2): - if len(x1) > len(x2): - return False - for a, b in zip(x1[::-1], x2[::-1]): - if a in (1, b): - pass - else: - return False - return True - - -def tensorflow_check_shapes_broadcastable(var, data): - if not tensorflow_check_one_way_broadcastable(var, data): - raise Exception(f"Could not broadcast shape {data} to shape {var}.") - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_broadcast_to( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Union[tf.TensorShape, Sequence[int]], - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_shapes_broadcastable(x.shape, shape) - if tensorflow.rank(x) > len(shape): - return tensorflow.broadcast_to(tensorflow.reshape(x, -1), shape) - return tensorflow.broadcast_to(x, shape) - - -def tensorflow__broadcast_to_bknd(input, target_shape): - if tensorflow__numel_bknd(tuple(input.shape)) == tensorflow__numel_bknd( - tuple(target_shape) - ): - return tensorflow_reshape(input, target_shape) - else: - input = input if len(input.shape) else tensorflow_expand_dims(input, axis=0) - return tensorflow_broadcast_to(input, target_shape) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_any( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - axis: Optional[Union[int, Sequence[int]]] = None, - keepdims: bool = False, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - if axis is None: - num_dims = len(x.shape) - axis = tuple(range(num_dims)) - elif isinstance(axis, list): - axis = tuple(axis) - try: - return tensorflow.reduce_any( - tensorflow.cast(x, tensorflow.bool), axis=axis, keepdims=keepdims - ) - except tensorflow.errors.InvalidArgumentError as e: - raise Exception(e) from e - - -def tensorflow__broadcast_inputs(x1, x2): - x1_, x2_ = x1, x2 - iterables = list, tuple, tuple - if not isinstance(x1_, iterables): - x1_, x2_ = x2, x1 - if not isinstance(x1_, iterables): - return [x1], [x2] - if not isinstance(x2_, iterables): - x1 = [x1] * len(x2) - return x1, x2 - - -def tensorflow_check_equal(x1, x2, inverse=False, message="", as_array=True): - def eq_fn(x1, x2): - return x1 == x2 if inverse else x1 != x2 - - def comp_fn(x1, x2): - return tensorflow_any(eq_fn(x1, x2)) - - if not as_array: - - def iter_comp_fn(x1_, x2_): - return any(eq_fn(x1, x2) for x1, x2 in zip(x1_, x2_)) - - def comp_fn(x1, x2): - return iter_comp_fn(*tensorflow__broadcast_inputs(x1, x2)) - - eq = comp_fn(x1, x2) - if inverse and eq: - raise Exception(f"{x1} must not be equal to {x2}" if message == "" else message) - elif not inverse and eq: - raise Exception(f"{x1} must be equal to {x2}" if message == "" else message) - - -def tensorflow_multiply( - x1: Union[float, tensorflow.Tensor, tensorflow.Variable], - x2: Union[float, tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - oirg_x1 = x1 - oirg_x2 = x2 - try: - dtype = ( - x1.dtype - if hasattr(x1, "dtype") - else x2.dtype if hasattr(x2, "dtype") else tensorflow_default_dtype_bknd() - ) - if not tensorflow_is_array_bknd(x1): - x1 = tensorflow_asarray(x1, dtype=dtype) - if not tensorflow_is_array_bknd(x2): - x2 = tensorflow_asarray(x2, dtype=dtype) - except: - x1 = oirg_x1 - x2 = oirg_x2 - return tensorflow.math.multiply(x1, x2) - - -def tensorflow_check_gather_nd_input_valid(params, indices, batch_dims): - if batch_dims >= len(params.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`params`) = {len(params.shape)}." - ) - if batch_dims >= len(indices.shape): - raise Exception( - f"batch_dims = {batch_dims} must be less than rank(`indices`) = {len(indices.shape)}." - ) - if tensorflow_get_item( - params.shape, slice(0, batch_dims, None) - ) != tensorflow_get_item(indices.shape, slice(0, batch_dims, None)): - raise Exception( - f"batch dimensions must match in `params` and `indices`; saw {tensorflow_get_item(params.shape, slice(0, batch_dims, None))} vs. {tensorflow_get_item(indices.shape, slice(0, batch_dims, None))}" - ) - if indices.shape[-1] > len( - tensorflow_get_item(params.shape, slice(batch_dims, None, None)) - ): - raise Exception( - f"index innermost dimension length must be <= rank(`params[batch_dims:]`); saw: {indices.shape[-1]} vs. {len(tensorflow_get_item(params.shape, slice(batch_dims, None, None)))} ." - ) - - -def tensorflow_gather_nd_helper(params, indices): - indices_shape = tensorflow.shape(indices) - params_shape = tensorflow.shape(params) - num_index_dims = indices_shape[-1] - result_dim_sizes_list = [ - tensorflow.math.reduce_prod(params_shape[i + 1 :]) - for i in range(len(params_shape) - 1) - ] + [1] - result_dim_sizes = tensorflow.convert_to_tensor( - result_dim_sizes_list, dtype=indices.dtype - ) - implicit_indices_factor = result_dim_sizes[num_index_dims - 1] - flat_params = tensorflow.reshape(params, (-1,)) - new_shape = [1] * (len(indices_shape) - 1) + [num_index_dims] - indices_scales = tensorflow.reshape(result_dim_sizes[0:num_index_dims], new_shape) - indices_for_flat_tiled = tensorflow.reshape( - tensorflow.reduce_sum(indices * indices_scales, -1, keepdims=True), (-1, 1) - ) - indices_for_flat_tiled = tensorflow.repeat( - indices_for_flat_tiled, implicit_indices_factor, axis=1 - ) - implicit_indices = tensorflow.repeat( - tensorflow.expand_dims(tensorflow.range(implicit_indices_factor), 0), - indices_for_flat_tiled.shape[0], - axis=0, - ) - indices_for_flat = indices_for_flat_tiled + implicit_indices - flat_indices_for_flat = tensorflow.reshape(indices_for_flat, (-1,)) - flat_gather = tensorflow.gather(flat_params, flat_indices_for_flat) - res = tensorflow.reshape( - flat_gather, - tensorflow.concat([indices_shape[:-1], params_shape[num_index_dims:]], 0), - ) - return res - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_gather_nd( - params: Union[tensorflow.Tensor, tensorflow.Variable], - indices: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - batch_dims: int = 0, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - tensorflow_check_gather_nd_input_valid(params, indices, batch_dims) - try: - return tensorflow.gather_nd(params, indices, batch_dims=batch_dims) - except Exception: - batch_dims %= len(params.shape) - result = [] - if batch_dims == 0: - result = tensorflow_gather_nd_helper(params, indices) - else: - for b in range(batch_dims): - if b == 0: - zip_list = list(zip(params, indices)) - else: - zip_list = [ - (p, i) - for z in [zip(p1, i1) for p1, i1 in zip_list] - for p, i in z - ] - for z in zip_list: - p, i = z[0], z[1] - r = tensorflow_gather_nd_helper(p, i) - result.append(r) - result = tensorflow.stack(result) - result = tensorflow.reshape( - result, - tensorflow.concat([params.shape[0:batch_dims], result.shape[1:]], 0), - ) - return result - - -def tensorflow__is_variable_bknd(x, exclusive=False, to_ignore=None): - x = x - return tensorflow_nested_map_bknd( - lambda x: tensorflow_is_variable(x, exclusive=exclusive), - x, - include_derived=True, - shallow=False, - to_ignore=to_ignore, - ) - - -def tensorflow_inplace_update( - x: Union[tensorflow.Tensor, tensorflow.Tensor], - val: Union[tensorflow.Tensor, tensorflow.Tensor], - /, - *, - ensure_in_backend: bool = False, - keep_input_dtype: bool = False, -): - if tensorflow_is_array_bknd(x) and tensorflow_is_array_bknd(val): - if keep_input_dtype: - val = tensorflow_astype(val, x.dtype) - (x_native, val_native), _ = (x, val), "_" - if tensorflow__is_variable_bknd(x_native): - x_native.assign(val_native) - if tensorflow_is_ivy_array_bknd(x): - x = x_native - else: - x = tensorflow.convert_to_tensor(x_native) - else: - x = x_native - return x - else: - return val - - -def tensorflow_scatter_nd( - indices: Union[tensorflow.Tensor, tensorflow.Variable], - updates: Union[tensorflow.Tensor, tensorflow.Variable], - /, - shape: Optional[Union[tf.TensorShape, Sequence[int]]] = None, - *, - reduction: str = "sum", - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - updates_dtype = updates.dtype - if tensorflow_exists_bknd(out): - dtype = tensorflow_promote_types_bknd(out.dtype, updates_dtype) - updates = tensorflow.cast( - updates, - ( - tensorflow_as_native_dtype(dtype) - if tensorflow_exists_bknd(out) - else updates_dtype - ), - ) - expected_shape = ( - list(tensorflow.shape(indices)[:-1]) - + list(out.shape[tensorflow.shape(indices)[-1] :]) - if tensorflow_exists_bknd(out) - else list(tensorflow.shape(indices)[:-1]) - + list(shape[tensorflow.shape(indices)[-1] :]) - ) - updates = tensorflow__broadcast_to_bknd(updates, expected_shape) - if len(updates.shape) == 0: - indices = tensorflow.expand_dims(indices, 0) - updates = tensorflow.expand_dims(updates, 0) - target = out - target_given = tensorflow_exists_bknd(target) - if tensorflow_exists_bknd(shape) and target_given: - tensorflow_check_equal(tuple(target.shape), tuple(shape), as_array=False) - if not target_given: - shape = list(shape) if tensorflow_exists_bknd(shape) else list(out.shape) - target = tensorflow.zeros(shape, dtype=updates.dtype) - if reduction == "sum": - res = tensorflow.tensor_scatter_nd_add(target, indices, updates) - elif reduction == "min": - res = tensorflow.tensor_scatter_nd_min(target, indices, updates) - elif reduction == "max": - res = tensorflow.tensor_scatter_nd_max(target, indices, updates) - elif reduction == "mul": - updates = tensorflow_multiply(tensorflow_gather_nd(target, indices), updates) - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - elif reduction == "replace": - res = tensorflow.tensor_scatter_nd_update(target, indices, updates) - else: - raise Exception( - f'reduction is {reduction}, but it must be one of "sum", "min", "max", "mul" or "replace"' - ) - if tensorflow_exists_bknd(out): - return tensorflow_inplace_update(out, res) - return res - - -def tensorflow_handle_set_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, val, **kwargs): - try: - inp.__setitem__(query, val) - res = inp - except IndexError: - raise - except Exception: - res = fn(inp, query, val, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_set_item -def tensorflow_set_item_bknd( - x: Union[tensorflow.Tensor, tf.Tensor], - query: Union[tensorflow.Tensor, tf.Tensor, Tuple], - val: Union[tensorflow.Tensor, tf.Tensor], - /, - *, - copy: Optional[bool] = False, -): - if isinstance(query, (list, tuple)) and any( - [(q is Ellipsis or isinstance(q, slice) and q.stop is None) for q in query] - ): - x_stop_gradient = tensorflow_stop_gradient(x, preserve_type=False) - np_array = x_stop_gradient.numpy() - val_stop_gradient = tensorflow_stop_gradient(val, preserve_type=False) - np_array = tensorflow_set_item_bknd( - np_array, query, np.asarray(val_stop_gradient) - ) - return tensorflow_asarray(np_array) - if copy: - x = tensorflow_copy_array(x) - if not tensorflow_is_array_bknd(val): - val = tensorflow_asarray(val) - if 0 in x.shape or 0 in val.shape: - return x - if tensorflow_is_array_bknd(query) and tensorflow_is_bool_dtype_bknd(query): - if not len(query.shape): - query = tensorflow_tile(query, (x.shape[0],)) - indices = tensorflow_nonzero(query, as_tuple=False) - else: - indices, target_shape, _ = tensorflow__parse_query_bknd( - query, tensorflow_shape(x, as_array=True), scatter=True - ) - if indices is None: - return x - val = tensorflow_astype_bknd_(val, x.dtype) - ret = tensorflow_scatter_nd(indices, val, reduction="replace", out=x) - return ret - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_real( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.real(x) - - -def tensorflow_real_bknd_(self): - return tensorflow_real(self) - - -@tensorflow_handle_array_like_without_promotion -def tensorflow_imag( - val: Union[tensorflow.Tensor, tensorflow.Variable], - /, - *, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.math.imag(val, name=None) - - -def tensorflow_imag_bknd_(self): - return tensorflow_imag(self) - - -def tensorflow__check_complex128_bknd(input): - if tensorflow_is_array_bknd(input): - return tensorflow_dtype(input) == "complex128" - elif isinstance(input, np.ndarray): - return str(input.dtype) == "complex128" - if hasattr(input, "real") and hasattr(input, "imag"): - return tensorflow__check_float64_bknd( - tensorflow_real_bknd_(input) - ) and tensorflow__check_float64_bknd(tensorflow_imag_bknd_(input)) - return False - - -def tensorflow_default_complex_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - complex_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_complex_dtype_stack - if tensorflow_exists_bknd(complex_dtype): - if as_native is True: - return tensorflow_as_native_dtype(complex_dtype) - return str(tensorflow_as_ivy_dtype(complex_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: tensorflow__check_complex128_bknd(x), - stop_after_n_found=1, - ): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_complex128_bknd(input): - ret = tf.complex128 - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - elif not default_complex_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_complex_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "complex64" - else: - ret = default_complex_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_default_dtype_bknd( - *, - dtype: Optional[Union[str, str]] = None, - item: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - as_native: bool = False, -): - if tensorflow_exists_bknd(dtype): - if as_native is True: - return tensorflow_as_native_dtype(dtype) - return tensorflow_as_ivy_dtype(dtype) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(item): - if hasattr(item, "override_dtype_check"): - return item.override_dtype_check() - elif isinstance(item, (list, tuple, dict)) and len(item) == 0: - pass - elif tensorflow_is_complex_dtype_bknd(item): - return tensorflow_default_complex_dtype_bknd( - input=item, as_native=as_native - ) - elif tensorflow_is_float_dtype_bknd(item): - return tensorflow_default_float_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_uint_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif tensorflow_is_int_dtype_bknd(item): - return tensorflow_default_int_dtype_bknd(input=item, as_native=as_native) - elif as_native: - return tensorflow_as_native_dtype("bool") - else: - return "bool" - global default_dtype_stack - if not default_dtype_stack: - global default_float_dtype_stack - if default_float_dtype_stack: - ret = default_float_dtype_stack[-1] - else: - ret = "float32" - else: - ret = default_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return tensorflow_as_ivy_dtype(ret) - - -def tensorflow_default_float_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - float_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_float_dtype_stack - if tensorflow_exists_bknd(float_dtype): - if as_native is True: - return tensorflow_as_native_dtype(float_dtype) - return str(tensorflow_as_ivy_dtype(float_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, lambda x: tensorflow__check_float64_bknd(x), stop_after_n_found=1 - ): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif isinstance(input, Number): - if tensorflow__check_float64_bknd(input): - ret = tf.float64 - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - elif not default_float_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_float_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "float32" - else: - ret = default_float_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_ivy_dtype( - dtype_in: Union[tensorflow.DType, str, int, float, complex, bool, np.dtype], / -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd() - if dtype_in is float: - return tensorflow_default_float_dtype_bknd() - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd() - if dtype_in is bool: - return str("bool") - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if isinstance(dtype_in, str): - if dtype_in in native_dtype_dict: - dtype_str = dtype_in - else: - raise Exception( - f"Cannot convert to ivy dtype. {dtype_in} is not supported by TensorFlow backend." - ) - else: - dtype_str = ivy_dtype_dict[dtype_in] - if "uint" in dtype_str: - return str(dtype_str) - elif "int" in dtype_str: - return str(dtype_str) - elif "float" in dtype_str: - return str(dtype_str) - elif "complex" in dtype_str: - return str(dtype_str) - elif "bool" in dtype_str: - return str("bool") - else: - raise Exception(f"Cannot recognize {dtype_str} as a valid Dtype.") - - -def tensorflow_default_int_dtype_bknd( - *, - input: Optional[Union[tensorflow.Tensor, tf.Tensor]] = None, - int_dtype: Optional[Union[str, tf.DType]] = None, - as_native: bool = False, -): - global default_int_dtype_stack - if tensorflow_exists_bknd(int_dtype): - if as_native is True: - return tensorflow_as_native_dtype(int_dtype) - return str(tensorflow_as_ivy_dtype(int_dtype)) - as_native = tensorflow_default_bknd(as_native, False) - if tensorflow_exists_bknd(input): - if tensorflow_is_array_bknd(input): - ret = tensorflow_dtype(input) - elif isinstance(input, tuple): - ret = tensorflow_default_int_dtype_bknd() - elif isinstance(input, np.ndarray): - ret = str(input.dtype) - elif isinstance(input, (list, tuple, dict)): - if tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "uint64" - if tensorflow_is_array_bknd(x) - else x > 9223372036854775807 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.uint64 - elif tensorflow_nested_argwhere_bknd( - input, - lambda x: ( - tensorflow_dtype(x) == "int64" - if tensorflow_is_array_bknd(x) - else x > 2147483647 and x != math.inf - ), - stop_after_n_found=1, - ): - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif isinstance(input, Number): - if input > 9223372036854775807 and input != math.inf and backend != "torch": - ret = tf.uint64 - elif input > 2147483647 and input != math.inf: - ret = tf.int64 - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - elif not default_int_dtype_stack: - def_dtype = tensorflow_default_dtype_bknd() - if tensorflow_is_int_dtype_bknd(def_dtype): - ret = def_dtype - else: - ret = "int32" - else: - ret = default_int_dtype_stack[-1] - if as_native: - return tensorflow_as_native_dtype(ret) - return str(tensorflow_as_ivy_dtype(ret)) - - -def tensorflow_as_native_dtype( - dtype_in: Union[tensorflow.DType, str, bool, int, float, np.dtype], -): - if dtype_in is int: - return tensorflow_default_int_dtype_bknd(as_native=True) - if dtype_in is float: - return tensorflow_default_float_dtype_bknd(as_native=True) - if dtype_in is complex: - return tensorflow_default_complex_dtype_bknd(as_native=True) - if dtype_in is bool: - return tensorflow.bool - if isinstance(dtype_in, np.dtype): - dtype_in = dtype_in.name - if not isinstance(dtype_in, str): - return dtype_in - if dtype_in in native_dtype_dict: - return native_dtype_dict[str(dtype_in)] - else: - raise Exception( - f"Cannot convert to TensorFlow dtype. {dtype_in} is not supported by TensorFlow." - ) - - -def tensorflow_dtype( - x: Union[tensorflow.Tensor, tensorflow.Variable, np.ndarray], - *, - as_native: bool = False, -): - if as_native: - return tensorflow_as_native_dtype(x.dtype) - return tensorflow_as_ivy_dtype(x.dtype) - - -def tensorflow_is_bool_dtype_bknd( - dtype_in: Union[str, str, tensorflow.Tensor, tf.Tensor, Number], / -): - if tensorflow_is_array_bknd(dtype_in): - dtype_in = tensorflow_dtype(dtype_in) - elif isinstance(dtype_in, np.ndarray): - return "bool" in dtype_in.dtype.name - elif isinstance(dtype_in, Number): - return isinstance(dtype_in, (bool, np.bool_)) and not isinstance(dtype_in, bool) - elif isinstance(dtype_in, (list, tuple, dict)): - return bool( - tensorflow_nested_argwhere_bknd( - dtype_in, lambda x: isinstance(x, (bool, np.bool_)) and x is not int - ) - ) - return "bool" in tensorflow_as_ivy_dtype(dtype_in) - - -def tensorflow_handle_get_item(fn): - @functools.wraps(fn) - def wrapper(inp, query, **kwargs): - try: - res = inp.__getitem__(query) - except IndexError: - raise - except Exception: - res = fn(inp, query, **kwargs) - return res - - return wrapper - - -@tensorflow_handle_get_item -def tensorflow_get_item( - x: Union[tensorflow.Tensor, tensorflow.Variable], - /, - query: Union[tensorflow.Tensor, tensorflow.Variable, Tuple], - *, - copy: Optional[bool] = None, -): - if ( - tensorflow_is_array_bknd(query) - and tensorflow_is_bool_dtype_bknd(query) - and not len(query.shape) - ): - return tensorflow.expand_dims(x, 0) - return x[query] - - -def tensorflow_index_nest_bknd( - nest: Union[List, Tuple, Dict, tensorflow.Tensor, tf.Tensor, dict], - index: Union[List[int], Tuple[int], Iterable[int]], - /, -): - ret = nest - for i in index: - ret = tensorflow_get_item(ret, i) - return ret - - -def tensorflow__get_first_array(*args, **kwargs): - def array_fn(x): - return ( - tensorflow_is_array_bknd(x) - if not hasattr(x, "_ivy_array") - else tensorflow_is_array_bknd(x.ivy_array) - ) - - array_fn = array_fn if "array_fn" not in kwargs else kwargs["array_fn"] - arr = None - if args: - arr_idxs = tensorflow_nested_argwhere_bknd(args, array_fn, stop_after_n_found=1) - if arr_idxs: - arr = tensorflow_index_nest_bknd(args, arr_idxs[0]) - else: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - elif kwargs: - arr_idxs = tensorflow_nested_argwhere_bknd( - kwargs, array_fn, stop_after_n_found=1 - ) - if arr_idxs: - arr = tensorflow_index_nest_bknd(kwargs, arr_idxs[0]) - return arr diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow__stateful.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow__stateful.py deleted file mode 100644 index dbad1e919ab1..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow__stateful.py +++ /dev/null @@ -1,1799 +0,0 @@ -# global -from __future__ import annotations -import re -import os -import tensorflow as tf -import functools -from tensorflow.python.util import nest -from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, Union -import inspect -from collections import OrderedDict -from packaging.version import parse -import keras - - -def get_assignment_dict(): - # Traverse the call stack - lhs = None - for frame_info in inspect.stack(): - # Check if the code context is an assignment statement - if frame_info.code_context and "=" in frame_info.code_context[0]: - # Split the assignment and retrieve the LHS - lhs = frame_info.code_context[0].split("=")[0].strip() - if "self" not in lhs: - continue - break - - if not lhs: - return None, "" - - # Replace indexing with attribute access - lhs = re.sub(r"\[(\d+)\]", r".\1", lhs) - - # Split the LHS based on "." and get individual components - components = lhs.split(".") - - # Initialize the dictionary - assignment_dict = {} - - # Retrieve the live objects associated with each component - for i in range(len(components)): - # Construct the key - key = ".".join(components[: i + 1]) - - # Retrieve the value - if i == 0: - value = frame_info.frame.f_locals.get(components[i]) - else: - value = getattr(assignment_dict[".".join(components[:i])], components[i]) - - # Add the key-value pair to the dictionary - assignment_dict[key] = value - - return assignment_dict, lhs - - -def store_frame_info(fn): - @functools.wraps(fn) - def frame_info_wrapper(self, *args, **kwargs): - if self._previous_frame_info is None: - # store the info about the calling frame. - stack = inspect.stack() - self._previous_frame_info = stack[1] - res = fn(self, *args, **kwargs) - # reset the frame-info - self._previous_frame_info = None - return res - - return frame_info_wrapper - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] - - -class NodeDef(NamedTuple): - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -def _register_pytree_node( - typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc -) -> None: - SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return {key: value for key, value in zip(context, values)} - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) - -if parse(keras.__version__).major > 2: - _register_pytree_node( - keras.src.utils.tracking.TrackedDict, _dict_flatten, _dict_unflatten - ) - - -def _get_node_type(pytree: Any) -> Any: - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES.keys() - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -class TreeSpec: - def __init__(self, type, context, children_specs): - self.type: Any = type - self.context: Context = context - self.children_specs: List["TreeSpec"] = children_specs - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def get_keychains(self, prefix="", sep="/"): - keychains = [] - for key, child_spec in zip(self.context, self.children_specs): - new_prefix = prefix + key + sep if prefix else key + sep - if child_spec.children_specs: # Non-leaf node - keychains.extend(child_spec.get_keychains(new_prefix, sep)) - else: # Leaf node - keychains.append(new_prefix[: -len(sep)]) - return keychains - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += len(repr_prefix) - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = _dict_flatten - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List["TreeSpec"] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise TypeError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = _dict_unflatten - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def serialize_obj(obj): - if inspect.isclass(obj) or isinstance(obj, type): - return {"cls_module": obj.__module__, "cls_name": obj.__name__} - return obj - - -def recursive_serialize(d): - if isinstance(d, dict): - return {k: recursive_serialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_serialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return serialize_obj(d) - - -def deserialize_obj(serialized): - if ( - isinstance(serialized, dict) - and "cls_module" in serialized - and "cls_name" in serialized - ): - module = __import__(serialized["cls_module"], fromlist=[serialized["cls_name"]]) - cls = getattr(module, serialized["cls_name"]) - return cls - return serialized - - -def recursive_deserialize(d): - if isinstance(d, dict) and "cls_module" not in d: - return {k: recursive_deserialize(v) for k, v in d.items()} - elif isinstance(d, list): - return [recursive_deserialize(v) for v in d] - elif isinstance(d, tuple): - return tuple(recursive_serialize(v) for v in d) - else: - return deserialize_obj(d) - - -class ModelHelpers: - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_first_array(*args, **kwargs): - arr = None - flattened_args = tf.nest.flatten((args, kwargs)) - arr_candidates = tf.nest.map_structure( - lambda x: x if isinstance(x, (tf.Tensor, tf.Variable)) else False, - flattened_args, - ) - for arr_candidate in arr_candidates: - if arr_candidate is not False: - arr = arr_candidate - break - return arr - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _get_input_shapes(*args): - input_shapes = [] - for x in args: - if isinstance(x, (tf.Tensor, tf.Variable)): - input_shapes.append(x.shape) - else: - try: - x = tf.convert_to_tensor(x) - input_shapes.append(x.shape) - except Exception: - input_shapes.append(None) - return input_shapes - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _extract_v(v, keychain_mappings: dict, orig_key_chain, /): - if ModelHelpers._dict_has_key_chain(v, orig_key_chain): - ret_cont = ModelHelpers._dict_at_key_chain(v, orig_key_chain) - else: - ret_cont = dict() - for old_kc, new_kc in keychain_mappings.items(): - if orig_key_chain in old_kc: - # Check if `v` contains `new_kc` before replacing in `ret_cont` - if ModelHelpers._dict_has_key_chain(v, new_kc): - ret_cont = ModelHelpers._dict_set_at_key_chain( - ret_cont, - "/".join(old_kc.split("/")[1:]), - ModelHelpers._dict_at_key_chain(v, new_kc), - ) - else: - continue - return ret_cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _remove_duplicate_variables(vs, created, /): - created_ids = tf.nest.map_structure(lambda x: id(x), created) - vs_ids = tf.nest.map_structure(lambda x: id(x), vs) - ids = {} - duplicate_keychains = [] - keychain_mappings = {} - - def unique_callback(x, kc): - ids[x] = kc - return x - - def found_dup_callback(x, kc): - if ids[x] == kc: - return x - duplicate_keychains.append(kc) - keychain_mappings[kc] = ids[x] - return x - - created_ids = nest.map_structure_with_paths( - lambda kc, x: unique_callback(x, kc), created_ids - ) - vs_ids = nest.map_structure_with_paths( - lambda kc, x: ( - unique_callback(x, kc) if x not in ids else found_dup_callback(x, kc) - ), - vs_ids, - ) - for dup_kc in duplicate_keychains: - vs = ModelHelpers._dict_prune_key_chain(vs, dup_kc) - return vs, keychain_mappings - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_set_at_key_chain(in_dict, key_chain, val, inplace=False): - keys = re.split("[/.]", key_chain) - if inplace: - cont = in_dict - else: - cont = in_dict - sub_cont = cont - for key in keys[:-1]: - if key not in sub_cont: - sub_cont[key] = dict() - sub_cont = sub_cont[key] - sub_cont[keys[-1]] = val - return cont - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_at_key_chain(dict, key_chain, ignore_key_errors=False): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError as e: - if ignore_key_errors: - return - raise Exception(repr(e)) - return ret - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_has_key_chain(dict, key_chain): - keys = re.split("[/.]", key_chain) - ret = dict - for key in keys: - try: - ret = ret[key] - except KeyError: - return False - return True - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _dict_prune_key_chain(in_dict, key_chain): - keys_in_chain = re.split("[/.]", key_chain) - out_dict = {} - for key, value in in_dict.items(): - if isinstance(value, dict): - if key == keys_in_chain[0]: - if len(keys_in_chain) == 1: - new_val = [] - else: - new_val = ModelHelpers._dict_prune_key_chain( - value, - "/".join(keys_in_chain[1:]), - ) - if len(new_val) > 0: - out_dict[key] = new_val - else: - if len(value) > 0: - out_dict[key] = value - else: - if len(keys_in_chain) != 1 or key != keys_in_chain[0]: - out_dict[key] = value - return out_dict - - @staticmethod - @tf.autograph.experimental.do_not_convert - def _addindent(s_, numSpaces): - s = s_.split("\n") - # don't do anything for single-line stuff - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(numSpaces * " ") + line for line in s] - s = "\n".join(s) - s = first + "\n" + s - return s - - -class Layer(tf.keras.layers.Layer, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Layer, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, Layer) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, Model) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Layer, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Layer, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - # Get the signature of the call method - call_signature = inspect.signature(self.call) - - # Convert all positional arguments to keyword arguments based on the signature - new_kwargs = {} - for idx, (param_name, param) in enumerate(call_signature.parameters.items()): - if idx < len(args): - new_kwargs[param_name] = args[idx] - - # Merge the existing kwargs - new_kwargs.update(kwargs) - return super(Layer, self).__call__(**new_kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - @property - def layers(self): - return self._layers - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (Layer, tf.keras.layers.Layer)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - return - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, Layer): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance(getattr(self, name), (Layer, tf.keras.layers.Layer)): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), Layer): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str - - -class Model(tf.keras.Model, ModelHelpers): - _build_mode = None - _with_partial_v = None - _store_vars = True - _built = False - _v = None - _buffers = None - _module_dict = None - _args = None - _kwargs = None - _module_graph = None - _target = None - _lazy_traced = False - _training = None - _dynamic_backend = None - _device = None - _dtype = None - _previous_frame_info = None - - def __init__( - self, - /, - *args, - v=None, - buffers=None, - build_mode="on_init", - store_vars=True, - with_partial_v=False, - dynamic_backend=None, - training=True, - dtype=None, - device=None, - module_dict=None, - **kwargs, - ): - super(Model, self).__init__( - trainable=training, - dtype=dtype, - ) - self._build_mode = build_mode - self._with_partial_v = with_partial_v - self._store_vars = store_vars - self._built = False - self._v_from_constructor = v if isinstance(v, dict) or v is None else dict(v) - self._v = v if v is not None else dict() - self._buffers = dict(buffers or {}) - self._module_dict = module_dict if module_dict is not None else dict() - self._args = args - self._kwargs = kwargs - self._module_graph = None - self._target = None - self._lazy_traced = False - self._training = training - self._dynamic_backend = dynamic_backend - self._device = device or "cpu" - self._dtype = dtype or tf.float32 - if build_mode != "on_init": - return - self.build(*args, dynamic_backend=dynamic_backend, **kwargs) - - @tf.autograph.experimental.do_not_convert - def _find_variables( - self, - /, - *, - obj=None, - without_initialisation=False, - _visited=None, - trainable=True, - ): - _visited = _visited or {} - vs = dict() - if id(obj) in _visited: - return vs - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - fn = "_build_and_return_v" if trainable else "_build_and_return_buffers" - if not obj._built and without_initialisation: - return lambda: getattr(obj, fn)( - *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs - ) - - return getattr(obj, fn)( - *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs - ) - elif isinstance(obj, tf.keras.layers.Layer) and obj is not self: - return obj.v if trainable else obj.buffers - - elif isinstance(obj, (list, tuple)): - for i, v in enumerate(obj): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[f"v{str(i)}"] = ret - return vs - elif isinstance(obj, dict): - for k, v in obj.items(): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - elif not hasattr(obj, "__dict__"): - return vs - for k, v in obj.__dict__.items(): - if ( - v is not None - and k[0:2] != "__" - and not k.startswith( - ( - "_module_dict", - "_self_", - "_args", - "_kwargs", - ) - ) - ): - ret = self._find_variables( - obj=v, - without_initialisation=without_initialisation, - _visited=_visited, - trainable=trainable, - ) - if ret: - vs[k[1:] if k[0] == "_" else k] = ret - return vs - - @tf.autograph.experimental.do_not_convert - def _find_buffers(self): - if hasattr(self, "_module_dict"): - for key, sub_module in self._module_dict.items(): - if len(sub_module._buffers) > 0: - self._buffers[key] = sub_module._buffers - - @tf.autograph.experimental.do_not_convert - def _build_and_return_v(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.v - - @tf.autograph.experimental.do_not_convert - def _build_and_return_buffers(self, *args, **kwargs): - if not self._built: - self.build(*args, **kwargs) - return self.buffers - - @tf.autograph.experimental.do_not_convert - def _wrap_call_methods( - self, keychain_mappings, /, *, key="", obj=None, _visited=None - ): - _visited = _visited or {} - if id(obj) in _visited or not isinstance(key, str): - return - _visited[id(obj)] = True - if isinstance(obj, (Layer, Model)) and obj is not self: - orig_key_chain = key[1:] if key[0] == "_" else key - - obj.__call__ = self._fn_with_var_arg( - obj.__call__, self._extract_v, keychain_mappings, orig_key_chain - ) - return - elif isinstance(obj, (list, tuple)): - for i, val in enumerate(obj): - self._wrap_call_methods( - keychain_mappings, - key=f"{key}/v{str(i)}", - obj=val, - _visited=_visited, - ) - return - elif isinstance(obj, dict): - for k, val in obj.items(): - k = f"{key}/{k}" if key != "" and isinstance(k, str) else k - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - for k, val in obj.module_dict.items(): - if k.startswith(("__", "_self_")): - continue - k = f"{key}/{k}" if key != "" else k - if val is not None: - self._wrap_call_methods( - keychain_mappings, key=k, obj=val, _visited=_visited - ) - return - - @tf.autograph.experimental.do_not_convert - def _compute_module_dict(self): - self._module_dict = dict() - for key, value in self.__dict__.items(): - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - if ( - "stateful" in value.__module__ - or hasattr(value, "_frontend_module") - or not hasattr(value, "_module_dict") - ): - self._module_dict[key] = value - else: - self._module_dict[key] = value._module_dict - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg_wrapper( - self, *a, fn, v_fn, keychain_mappings, orig_key_chain, **kw - ): - if "v" in kw: - del kw["v"] - v = v_fn(self.v, keychain_mappings, orig_key_chain) - return fn(*a, **kw, v=v) - - @tf.autograph.experimental.do_not_convert - def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): - _fn_with_var_arg_wrapper = functools.partial( - self._fn_with_var_arg_wrapper, - fn=fn, - v_fn=v_fn, - keychain_mappings=keychain_mappings, - orig_key_chain=orig_key_chain, - ) - _fn_with_var_arg_wrapper.wrapped = True - return _fn_with_var_arg_wrapper - - @tf.autograph.experimental.do_not_convert - def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - - # If `v` was provided, replace with the module's v - replace_v = False - if v is not None: - v_orig = self.v - self._v = v - replace_v = True - - # If `buffers` were provided, replace with the module's buffers - replace_buffers = False - if buffers is not None: - buffers_orig = self.buffers - self._buffers = buffers - replace_buffers = True - - if replace_v or replace_buffers: - # Call the forward pass - ret = super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - # Replace v, buffers if needed - self._v = v_orig if replace_v else self._v - self._buffers = buffers_orig if replace_buffers else self._buffers - return ret - elif hasattr(self.__call__, "wrapped"): - return self.__call__(*args, **kwargs) - - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 - - @tf.autograph.experimental.do_not_convert - def build( - self, - *args, - from_call=False, - device=None, - dtype=None, - dynamic_backend=None, - **kwargs, - ): - self._built = True - return - - def _lock_state(self): - pass - - @tf.autograph.experimental.do_not_convert - def register_buffer(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._buffers.update({name: value}) - return value - - @tf.autograph.experimental.do_not_convert - def register_parameter(self, name: str, value: Union[tf.Tensor, tf.Variable]): - self._v.update({name: value}) - - @tf.autograph.experimental.do_not_convert - def train(self, mode: bool = True): - self._training = mode - for module in self.children(): - if isinstance(module, tf.keras.layers.Layer) and not hasattr( - module, "train" - ): - module.trainable = mode - continue - module.train(mode) - self.trainable = mode - return self - - @tf.autograph.experimental.do_not_convert - def eval(self): - return self.train(mode=False) - - @tf.autograph.experimental.do_not_convert - def call(self, inputs, training=None, mask=None): - raise NotImplementedError( - "When subclassing the `Module` class, you should implement a `call` method." - ) - - def get_build_config(self): - config = super().get_build_config() - config = recursive_serialize(config) - return config - - def build_from_config(self, config): - config = recursive_deserialize(config) - return super().build_from_config(config) - - def get_config(self): - base_config = super().get_config() - config = {} - - # Get the names and values of positional arguments in __init__ - init_signature = inspect.signature(self.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Include the positional arguments in the config - var_positional_arg_encountered = False - var_positional_arg_name = None - offset = 0 - for i, arg in enumerate(self._args): - arg_name = arg_names[min(i, len(arg_names) - 1)] - if var_positional_arg_encountered: - config.update( - { - f"{var_positional_arg_name}_{i - offset}": arg, - } - ) - elif ( - init_signature.parameters[arg_name].kind - == inspect.Parameter.VAR_POSITIONAL - ): - var_positional_arg_encountered = True - var_positional_arg_name = arg_name - offset = i - config.update( - { - f"{var_positional_arg_name}_{0}": arg, - } - ) - else: - config.update( - { - arg_name: arg, - } - ) - - # Include the keywords arguments in the config - kwargs = self._kwargs.copy() - kwargs.pop("devices", None) - config.update(**kwargs) - new_config = {**base_config, **config} - new_config = recursive_serialize(new_config) - return new_config - - @classmethod - def from_config(cls, config): - config = recursive_deserialize(config) - # Get the signature of the __init__ method - init_signature = inspect.signature(cls.__init__) - arg_names = list(init_signature.parameters.keys()) - - # Separate positional and keyword arguments based on the __init__ signature - args = [] - pos_or_kw = OrderedDict() - kwargs = {} - var_positional_args = [] - for arg_name in arg_names: - if ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.KEYWORD_ONLY - ): - # Handle keyword arguments - kwargs[arg_name] = config.pop(arg_name) - elif ( - arg_name in config - and init_signature.parameters[arg_name].kind - == inspect.Parameter.POSITIONAL_OR_KEYWORD - ): - # Handle positional or keyword arguments - pos_or_kw[arg_name] = config.pop(arg_name) - elif any(re.match(rf"{arg_name}_\d+", key) for key in config.keys()): - # Handle variable positional arguments - var_positional_args.extend( - [ - config.pop(key) - for key in sorted(config.keys()) - if re.match(rf"{arg_name}_\d+", key) - ] - ) - - # Unpack positional arguments and the rest as keyword arguments - config.pop("name", None) - config.pop("trainable", None) - config.pop("dtype", None) - kwargs.update(config) - - # Determine the final args and kwargs - if var_positional_args: - args = list(pos_or_kw.values()) + var_positional_args - else: - kwargs.update(pos_or_kw) - - return cls(*args, **kwargs) - - # Methods to be Optionally Overridden # - # -----------------------------------# - - @tf.autograph.experimental.do_not_convert - def _create_variables(self, *, device=None, dtype=None): - return {} - - @tf.autograph.experimental.do_not_convert - def _build(self, *args, **kwargs) -> bool: - return True - - @tf.autograph.experimental.do_not_convert - def _forward(self, *args, **kwargs): - raise NotImplementedError( - "When subclassing the `Module` class, you should " - "implement a `_forward` method." - ) - - @tf.autograph.experimental.do_not_convert - def _extra_repr(self) -> str: - return "" - - # Properties # - # -----------# - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - @property - def build_mode(self): - return self._build_mode - - @property - def training(self): - return self._training - - @property - def v(self): - return self._v - - @property - def buffers(self): - return self._buffers - - @property - def state_dict(self): - return {**self.v, **self.buffers} - - @property - def module_dict(self): - return self._module_dict - - # Dunder Methods # - # ---------------# - @store_frame_info - @tf.autograph.experimental.do_not_convert - def __call__( - self, - *args, - v=None, - buffers=None, - **kwargs, - ): - # TODO: Temp workaround to avoid `call`` from being transformed by AutoGraph - if not hasattr(self.__class__.call, "autograph_info__"): - setattr(self.__class__.call, "autograph_info__", True) - ret = self._call(*args, v=v, buffers=buffers, **kwargs) - return ret - - @tf.autograph.experimental.do_not_convert - def __getattr__(self, name): - if name == "v": - if not super().__getattribute__("_v") and not getattr( # noqa: E501 - self, "_built", False - ): - return self._build_and_return_v( - *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs - ) - - _dict = super().__getattribute__("__dict__") - if name in _dict: - return _dict[name] - - elif "_v" in _dict and name in _dict["_v"]: - return _dict["_v"][name] - - return super().__getattribute__(name) - - @tf.autograph.experimental.do_not_convert - def __setattr__(self, name, value): - if name in ["v", "buffers"]: - name = "_" + name - if isinstance(value, (Layer, tf.keras.layers.Layer, Model, tf.keras.Model)): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # compute the module dict - self._compute_module_dict() - - obj_to_search = ( - None - if not isinstance(value, (tf.keras.layers.Layer, Layer, Model)) - else ( - self._modules - if hasattr(self, "_modules") and self._modules - else self - ) - ) - found_vars = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - ) - flattened_v, v_spec = tree_flatten(found_vars) - flattend_kc = v_spec.get_keychains() - for kc, v in zip(flattend_kc, flattened_v): - new_kc = kc.replace("/", ".") - if new_kc not in self.v: - self.register_parameter(new_kc, v) - - # once all variables built, find and assign buffers - found_buffers = self._find_variables( - obj=obj_to_search, - without_initialisation=( - True - if self._v_from_constructor and not self._with_partial_v - else False - ), - trainable=False, - ) - flattened_buf, buf_spec = tree_flatten(found_buffers) - flattend_kc = buf_spec.get_keychains() - for kc, buf in zip(flattend_kc, flattened_buf): - new_kc = kc.replace("/", ".") - self.register_buffer(new_kc, buf) - - super().__setattr__(name, value) - return - elif isinstance(value, tf.Variable) and not name.startswith("_"): - _dict = getattr(self, "__dict__", None) - if _dict: - _dict[name] = value - - # Manual solution for cases where a `tf.int32` tensor - # is placed on the GPU. TensorFlow doesn't have dedicated - # kernels for placing `tf.int32` variables on the GPU and so - # we manually cast them to `tf.int64` here otherwise due to - # `tf.config.soft_device_placement(True)` by default, - # TensorFlow puts the `tf.int32` variables on CPU which causes - # unintended consequences downstream during tracing or - # `tf.function` compilation e.g. - # Ref: https://github.com/tensorflow/tensorflow/issues/9506 - # Ref: https://stackoverflow.com/questions/44813939/could-not-satisfy-explicit-device-specification-devicegpu0-because-no-devic - dtype = ( - tf.int64 - if value.dtype == tf.int32 and "gpu:" in value.device.lower() - else value.dtype - ) - cast_dtype = dtype != value.dtype - val = ( - value - if not cast_dtype - else tf.Variable(initial_value=tf.cast(value.value(), dtype), name=name) - ) - self.register_parameter(name, val) - super().__setattr__(name, val) - else: - try: - obj_to_search = getattr(self, name) - except AttributeError: - obj_to_search = None - if isinstance(obj_to_search, (Model, Layer)): - # retrieve all hierarchical submodules - assign_dict, kc = get_assignment_dict() - - # Iterate over all submods in assign_dict - # updating their `v` and `buffers` with the - # new value - for key, submod in assign_dict.items(): - # Get the subkey to match - subkey = kc[len(key) :].lstrip(".") - - if hasattr(submod, "v"): - for v_key, v_value in submod.v.items(): - if v_key.startswith(subkey): - submod.register_parameter(v_key, value) - - # Repeat the same process for submod.buffers - if hasattr(submod, "buffers"): - for b_key, b_value in submod.buffers.items(): - if b_key.startswith(subkey): - submod.register_buffer(b_key, value) - - # finally update the module dict - self._module_dict[name] = value - - return super().__setattr__(name, value) - - @tf.autograph.experimental.do_not_convert - def __delattr__(self, name): - if hasattr(self, name): - if isinstance( - getattr(self, name), - (Layer, tf.keras.layers.Layer, Model, tf.keras.Model), - ): - super().__delattr__(name) - return - super().__delattr__(name) - - @tf.autograph.experimental.do_not_convert - def __repr__(self): - extra_lines = [] - extra_repr = self._extra_repr() - if extra_repr: - extra_lines = extra_repr.split("\n") - child_lines = [] - for key in self.v.keys(): - if isinstance(getattr(self, key, None), (Layer, Model)): - mod_str = repr(getattr(self, key)) - mod_str = self._addindent(mod_str, 2) - child_lines.append(f"({key}): {mod_str}") - lines = extra_lines + child_lines - - main_str = f"{self.__class__.__name__}(" - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += "\n " + "\n ".join(lines) + "\n" - - main_str += ")" - return main_str diff --git a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow_zeros.py b/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow_zeros.py deleted file mode 100644 index 96e2ca258d05..000000000000 --- a/ivy/compiler/_cache/Translated_Outputs/tensorflow_zeros_output/run_0/tensorflow_zeros.py +++ /dev/null @@ -1,21 +0,0 @@ -import tensorflow -import tensorflow as tf - -from typing import Optional -from typing import Sequence -from typing import Union - -from .tensorflow__helpers import tensorflow_handle_array_like_without_promotion -from .tensorflow__helpers import tensorflow_infer_dtype - - -@tensorflow_infer_dtype -@tensorflow_handle_array_like_without_promotion -def tensorflow_zeros( - shape: Union[tf.TensorShape, Sequence[int]], - *, - dtype: tensorflow.DType, - device: Optional[str] = None, - out: Optional[Union[tensorflow.Tensor, tensorflow.Variable]] = None, -): - return tensorflow.zeros(shape, dtype=tensorflow.float32) diff --git a/ivy/compiler/_cache/ivy_to_tensorflow_translation_cache.pkl b/ivy/compiler/_cache/ivy_to_tensorflow_translation_cache.pkl index e7bed4f39758..14ed17dff688 100644 --- a/ivy/compiler/_cache/ivy_to_tensorflow_translation_cache.pkl +++ b/ivy/compiler/_cache/ivy_to_tensorflow_translation_cache.pkl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9eb73e6415ee3c6c8206005981b415965cf7d69adbea4ee65558a4442b25441b -size 133939495 +oid sha256:053b84d444e833cbd8f4f49b63ee48b25dd8a9a0260510e3ff9294c4701f4ca2 +size 433809516 diff --git a/ivy/compiler/_cache/torch_frontend_to_ivy_translation_cache.pkl b/ivy/compiler/_cache/torch_frontend_to_ivy_translation_cache.pkl index baf5079aeebf..4d1b7c29d736 100644 --- a/ivy/compiler/_cache/torch_frontend_to_ivy_translation_cache.pkl +++ b/ivy/compiler/_cache/torch_frontend_to_ivy_translation_cache.pkl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6c1f11c85cafe087a27c146fc1d4f719deccd487b00dee96b2615a00e1144a8a -size 2560350 +oid sha256:b7ea0b0b9efd1928ae6efa0d689204e4ee3b4eb0e871d556240fdb33a9a6b7d4 +size 2052633 diff --git a/ivy/compiler/_cache/torch_to_torch_frontend_translation_cache.pkl b/ivy/compiler/_cache/torch_to_torch_frontend_translation_cache.pkl index 221ac9c0972b..ac2914222d3f 100644 --- a/ivy/compiler/_cache/torch_to_torch_frontend_translation_cache.pkl +++ b/ivy/compiler/_cache/torch_to_torch_frontend_translation_cache.pkl @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:313306050716d7370654efdd76a0ba21080d00529c6859edc9d0672e38f95dac -size 1138395 +oid sha256:627733944bd8ad36dac7b5008e5a471abd90dcd7b703a7d864a7ad886473daf6 +size 807976