diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml index 3882b77b78f..85d752964ab 100644 --- a/torchax/pyproject.toml +++ b/torchax/pyproject.toml @@ -51,3 +51,21 @@ packages = ["torchax"] [tool.pytest.ini_options] addopts="-n auto" + +[tool.ruff] +# Equivalent to column_limit +line-length = 80 + +# Enable preview mode to use rules like E306 +preview = true + +indent-width = 2 + +[tool.ruff.lint] +select = [ + "E", "F", "W", # Your existing rule selections + "Q", +] +# Enforces a blank line before nested functions and classes +extend-select = ["E306"] + diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 617970a56c4..ff81c782c23 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -13,20 +13,21 @@ VERSION = __version__ __all__ = [ - 'default_env', - 'extract_jax', - 'enable_globally', + "default_env", + "extract_jax", + "enable_globally", ] from jax._src import xla_bridge -os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') +os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1") # torchax:oss-begin -if getattr(jax.config, 'jax_pjrt_client_create_options', None): +if getattr(jax.config, "jax_pjrt_client_create_options", None): jax.config.update( - 'jax_pjrt_client_create_options', - f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}') + "jax_pjrt_client_create_options", + f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}", + ) # torchax:oss-end env = None @@ -49,7 +50,7 @@ def extract_jax(mod: torch.nn.Module, env=None): states = env.t2j_copy(states) - #@jax.jit + # @jax.jit def jax_func(states, inputs): (states, inputs) = env.j2t_iso((states, inputs)) with env: @@ -79,29 +80,30 @@ def disable_temporarily(): enable_globally() -torch.utils.rename_privateuse1_backend('jax') +torch.utils.rename_privateuse1_backend("jax") unsupported_dtype = [torch.quint8] torch.utils.generate_methods_for_privateuse1_backend( - for_tensor=True, - for_module=True, - for_storage=True, - unsupported_dtype=unsupported_dtype) + for_tensor=True, + for_module=True, + for_storage=True, + unsupported_dtype=unsupported_dtype, +) import jax import torchax.device_module -torch._register_device_module('jax', torchax.device_module) +torch._register_device_module("jax", torchax.device_module) def enable_accuracy_mode(): - jax.config.update('jax_enable_x64', True) - jax.config.update('jax_default_matmul_precision', 'highest') + jax.config.update("jax_enable_x64", True) + jax.config.update("jax_default_matmul_precision", "highest") default_env().config.internal_respect_torch_return_dtypes = True def enable_performance_mode(): - jax.config.update('jax_enable_x64', False) - jax.config.update('jax_default_matmul_precision', 'default') + jax.config.update("jax_enable_x64", False) + jax.config.update("jax_default_matmul_precision", "default") default_env().config.internal_respect_torch_return_dtypes = False @@ -109,15 +111,17 @@ def enable_performance_mode(): class CompileOptions: # only valid if compiling nn.Module methods_to_compile: List[str] = dataclasses.field( - default_factory=lambda: ['forward']) + default_factory=lambda: ["forward"] + ) jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - mode: str = 'jax' # or dynamo or export + mode: str = "jax" # or dynamo or export def compile(fn, options: Optional[CompileOptions] = None): options = options or CompileOptions() - if options.mode == 'jax': + if options.mode == "jax": from torchax import interop + if isinstance(fn, torch.nn.Module): module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs) for n in options.methods_to_compile: diff --git a/torchax/torchax/amp.py b/torchax/torchax/amp.py index ef06e884a8a..6e38585acbd 100644 --- a/torchax/torchax/amp.py +++ b/torchax/torchax/amp.py @@ -36,23 +36,26 @@ class CastPolicy(enum.Enum): def execute_policy(policy, args, kwargs, target_lower_fp): - def is_float(a): return isinstance(a, torch.Tensor) and a.is_floating_point() + match policy: case CastPolicy.LOWER_PRECISION_FP: - return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp), - (args, kwargs)) + return pytree.tree_map_only( + is_float, lambda a: a.to(target_lower_fp), (args, kwargs) + ) case CastPolicy.FP32: - return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32), - (args, kwargs)) + return pytree.tree_map_only( + is_float, lambda a: a.to(torch.float32), (args, kwargs) + ) case CastPolicy.PROMOTE: dtypes = set(a.dtype for a in args) widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1] - return pytree.tree_map_only(is_float, lambda a: a.to(widest), - (args, kwargs)) + return pytree.tree_map_only( + is_float, lambda a: a.to(widest), (args, kwargs) + ) case _: - raise AssertionError(f'Policy {policy} not implemented yet.') + raise AssertionError(f"Policy {policy} not implemented yet.") @contextlib.contextmanager @@ -60,6 +63,7 @@ def autocast(device, dtype=torch.bfloat16, env=None): del device if env is None: import torchax + env = torchax.default_env() env.autocast_dtype, old = dtype, env.autocast_dtype yield @@ -68,266 +72,135 @@ def autocast(device, dtype=torch.bfloat16, env=None): # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327 autocast_policy = { - torch.ops.aten.conv1d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv1d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv2d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv3d.padding: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.bmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linalg_vecdot.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.baddbmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._addmm_activation.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.addbmm.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.linear.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._convolution.deprecated: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.matmul.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_tbc.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.mkldnn_rnn_layer.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose1d.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose2d.input: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.conv_transpose3d.input: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.prelu.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten.scaled_dot_product_attention.default: - CastPolicy.LOWER_PRECISION_FP, - torch.ops.aten._native_multi_head_attention.default: - CastPolicy.LOWER_PRECISION_FP, - - # fp32 cast policy - torch.ops.aten.avg_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler.default: - CastPolicy.FP32, - torch.ops.aten.polar.default: - CastPolicy.FP32, - torch.ops.aten.prod.default: - CastPolicy.FP32, - torch.ops.aten.prod.dim_int: - CastPolicy.FP32, - torch.ops.aten.prod.dim_Dimname: - CastPolicy.FP32, - torch.ops.aten.quantile.default: - CastPolicy.FP32, - torch.ops.aten.quantile.scalar: - CastPolicy.FP32, - torch.ops.aten.nanquantile.default: - CastPolicy.FP32, - torch.ops.aten.nanquantile.scalar: - CastPolicy.FP32, - torch.ops.aten.stft.default: - CastPolicy.FP32, - torch.ops.aten.stft.center: - CastPolicy.FP32, - torch.ops.aten.cdist.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler_2d.default: - CastPolicy.FP32, - torch.ops.aten._grid_sampler_2d_cpu_fallback.default: - CastPolicy.FP32, - torch.ops.aten.grid_sampler_3d.default: - CastPolicy.FP32, - torch.ops.aten.trace.default: - CastPolicy.FP32, - torch.ops.aten.view_as_complex.default: - CastPolicy.FP32, - torch.ops.aten.cholesky.default: - CastPolicy.FP32, - torch.ops.aten.cholesky_inverse.default: - CastPolicy.FP32, - torch.ops.aten.cholesky_solve.default: - CastPolicy.FP32, - torch.ops.aten.inverse.default: - CastPolicy.FP32, - torch.ops.aten.lu_solve.default: - CastPolicy.FP32, - torch.ops.aten.orgqr.default: - CastPolicy.FP32, - torch.ops.aten.ormqr.default: - CastPolicy.FP32, - torch.ops.aten.pinverse.default: - CastPolicy.FP32, - torch.ops.aten.max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.max_unpool2d.default: - CastPolicy.FP32, - torch.ops.aten.max_unpool3d.default: - CastPolicy.FP32, - torch.ops.aten.adaptive_avg_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.reflection_pad1d.default: - CastPolicy.FP32, - torch.ops.aten.reflection_pad2d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad1d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad2d.default: - CastPolicy.FP32, - torch.ops.aten.replication_pad3d.default: - CastPolicy.FP32, - torch.ops.aten.mse_loss.default: - CastPolicy.FP32, - torch.ops.aten.cosine_embedding_loss.default: - CastPolicy.FP32, - torch.ops.aten.nll_loss.default: - CastPolicy.FP32, - torch.ops.aten.nll_loss2d.default: - CastPolicy.FP32, - torch.ops.aten.hinge_embedding_loss.default: - CastPolicy.FP32, - torch.ops.aten.poisson_nll_loss.default: - CastPolicy.FP32, - torch.ops.aten.smooth_l1_loss.default: - CastPolicy.FP32, - torch.ops.aten.cross_entropy_loss.default: - CastPolicy.FP32, - torch.ops.aten.l1_loss.default: - CastPolicy.FP32, - torch.ops.aten.huber_loss.default: - CastPolicy.FP32, - torch.ops.aten.margin_ranking_loss.default: - CastPolicy.FP32, - torch.ops.aten.soft_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.triplet_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.multi_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.ctc_loss.IntList: - CastPolicy.FP32, - torch.ops.aten.ctc_loss.Tensor: - CastPolicy.FP32, - torch.ops.aten.kl_div.default: - CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss.default: - CastPolicy.FP32, - torch.ops.aten.binary_cross_entropy_with_logits.default: - CastPolicy.FP32, - torch.ops.aten.fft_fft.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifft.default: - CastPolicy.FP32, - torch.ops.aten.fft_fft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_fftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_ifftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfft2.default: - CastPolicy.FP32, - torch.ops.aten.fft_rfftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_irfftn.default: - CastPolicy.FP32, - torch.ops.aten.fft_hfft.default: - CastPolicy.FP32, - torch.ops.aten.fft_ihfft.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cond.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cond.p_str: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.default: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.tol_tensor: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: - CastPolicy.FP32, - torch.ops.aten.linalg_matrix_rank.atol_rtol_float: - CastPolicy.FP32, - torch.ops.aten.linalg_solve.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cholesky.default: - CastPolicy.FP32, - torch.ops.aten.linalg_svdvals.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigvals.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigvalsh.default: - CastPolicy.FP32, - torch.ops.aten.linalg_inv.default: - CastPolicy.FP32, - torch.ops.aten.linalg_householder_product.default: - CastPolicy.FP32, - torch.ops.aten.linalg_tensorinv.default: - CastPolicy.FP32, - torch.ops.aten.linalg_tensorsolve.default: - CastPolicy.FP32, - torch.ops.aten.fake_quantize_per_tensor_affine.default: - CastPolicy.FP32, - torch.ops.aten.geqrf.default: - CastPolicy.FP32, - torch.ops.aten._lu_with_info.default: - CastPolicy.FP32, - torch.ops.aten.qr.default: - CastPolicy.FP32, - torch.ops.aten.svd.default: - CastPolicy.FP32, - torch.ops.aten.triangular_solve.default: - CastPolicy.FP32, - torch.ops.aten.fractional_max_pool2d.default: - CastPolicy.FP32, - torch.ops.aten.fractional_max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.adaptive_max_pool3d.default: - CastPolicy.FP32, - torch.ops.aten.multilabel_margin_loss_forward.default: - CastPolicy.FP32, - torch.ops.aten.linalg_qr.default: - CastPolicy.FP32, - torch.ops.aten.linalg_cholesky_ex.default: - CastPolicy.FP32, - torch.ops.aten.linalg_svd.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eig.default: - CastPolicy.FP32, - torch.ops.aten.linalg_eigh.default: - CastPolicy.FP32, - torch.ops.aten.linalg_lstsq.default: - CastPolicy.FP32, - torch.ops.aten.linalg_inv_ex.default: - CastPolicy.FP32, - - # promote - torch.ops.aten.stack.default: - CastPolicy.PROMOTE, - torch.ops.aten.cat.default: - CastPolicy.PROMOTE, - torch.ops.aten.index_copy.default: - CastPolicy.PROMOTE, - torch.ops.aten.index_copy.dimname: - CastPolicy.PROMOTE, + torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP, + torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP, + # fp32 cast policy + torch.ops.aten.avg_pool3d.default: CastPolicy.FP32, + torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler.default: CastPolicy.FP32, + torch.ops.aten.polar.default: CastPolicy.FP32, + torch.ops.aten.prod.default: CastPolicy.FP32, + torch.ops.aten.prod.dim_int: CastPolicy.FP32, + torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32, + torch.ops.aten.quantile.default: CastPolicy.FP32, + torch.ops.aten.quantile.scalar: CastPolicy.FP32, + torch.ops.aten.nanquantile.default: CastPolicy.FP32, + torch.ops.aten.nanquantile.scalar: CastPolicy.FP32, + torch.ops.aten.stft.default: CastPolicy.FP32, + torch.ops.aten.stft.center: CastPolicy.FP32, + torch.ops.aten.cdist.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32, + torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32, + torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32, + torch.ops.aten.trace.default: CastPolicy.FP32, + torch.ops.aten.view_as_complex.default: CastPolicy.FP32, + torch.ops.aten.cholesky.default: CastPolicy.FP32, + torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32, + torch.ops.aten.cholesky_solve.default: CastPolicy.FP32, + torch.ops.aten.inverse.default: CastPolicy.FP32, + torch.ops.aten.lu_solve.default: CastPolicy.FP32, + torch.ops.aten.orgqr.default: CastPolicy.FP32, + torch.ops.aten.ormqr.default: CastPolicy.FP32, + torch.ops.aten.pinverse.default: CastPolicy.FP32, + torch.ops.aten.max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.max_unpool2d.default: CastPolicy.FP32, + torch.ops.aten.max_unpool3d.default: CastPolicy.FP32, + torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32, + torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32, + torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad1d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad2d.default: CastPolicy.FP32, + torch.ops.aten.replication_pad3d.default: CastPolicy.FP32, + torch.ops.aten.mse_loss.default: CastPolicy.FP32, + torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32, + torch.ops.aten.nll_loss.default: CastPolicy.FP32, + torch.ops.aten.nll_loss2d.default: CastPolicy.FP32, + torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32, + torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32, + torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32, + torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32, + torch.ops.aten.l1_loss.default: CastPolicy.FP32, + torch.ops.aten.huber_loss.default: CastPolicy.FP32, + torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32, + torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32, + torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32, + torch.ops.aten.kl_div.default: CastPolicy.FP32, + torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32, + torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32, + torch.ops.aten.fft_fft.default: CastPolicy.FP32, + torch.ops.aten.fft_ifft.default: CastPolicy.FP32, + torch.ops.aten.fft_fft2.default: CastPolicy.FP32, + torch.ops.aten.fft_ifft2.default: CastPolicy.FP32, + torch.ops.aten.fft_fftn.default: CastPolicy.FP32, + torch.ops.aten.fft_ifftn.default: CastPolicy.FP32, + torch.ops.aten.fft_rfft.default: CastPolicy.FP32, + torch.ops.aten.fft_irfft.default: CastPolicy.FP32, + torch.ops.aten.fft_rfft2.default: CastPolicy.FP32, + torch.ops.aten.fft_irfft2.default: CastPolicy.FP32, + torch.ops.aten.fft_rfftn.default: CastPolicy.FP32, + torch.ops.aten.fft_irfftn.default: CastPolicy.FP32, + torch.ops.aten.fft_hfft.default: CastPolicy.FP32, + torch.ops.aten.fft_ihfft.default: CastPolicy.FP32, + torch.ops.aten.linalg_cond.default: CastPolicy.FP32, + torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32, + torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32, + torch.ops.aten.linalg_solve.default: CastPolicy.FP32, + torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32, + torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32, + torch.ops.aten.linalg_inv.default: CastPolicy.FP32, + torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32, + torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32, + torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32, + torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32, + torch.ops.aten.geqrf.default: CastPolicy.FP32, + torch.ops.aten._lu_with_info.default: CastPolicy.FP32, + torch.ops.aten.qr.default: CastPolicy.FP32, + torch.ops.aten.svd.default: CastPolicy.FP32, + torch.ops.aten.triangular_solve.default: CastPolicy.FP32, + torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32, + torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32, + torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32, + torch.ops.aten.linalg_qr.default: CastPolicy.FP32, + torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32, + torch.ops.aten.linalg_svd.default: CastPolicy.FP32, + torch.ops.aten.linalg_eig.default: CastPolicy.FP32, + torch.ops.aten.linalg_eigh.default: CastPolicy.FP32, + torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32, + torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32, + # promote + torch.ops.aten.stack.default: CastPolicy.PROMOTE, + torch.ops.aten.cat.default: CastPolicy.PROMOTE, + torch.ops.aten.index_copy.default: CastPolicy.PROMOTE, + torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE, } diff --git a/torchax/torchax/decompositions.py b/torchax/torchax/decompositions.py index 81cbcd02e3a..f6d0a2891b7 100644 --- a/torchax/torchax/decompositions.py +++ b/torchax/torchax/decompositions.py @@ -36,15 +36,14 @@ def _try_register(op, impl): @out_wrapper() def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: - def idx(left, middle, right): dim_idx = torch.arange(-left, middle + right, device=a.device) return middle - 1 - (middle - 1 - dim_idx.abs()).abs() return _reflection_or_replication_pad( - a, - padding, - idx, + a, + padding, + idx, ) @@ -55,31 +54,31 @@ def idx(left, middle, right): @out_wrapper() def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor: - def idx(left, middle, right): dim_idx = torch.arange(-left, middle + right, device=a.device) return torch.clamp(dim_idx, 0, middle - 1) return _reflection_or_replication_pad( - a, - padding, - idx, + a, + padding, + idx, ) decomp.global_decomposition_table["post_autograd"][ - aten.replication_pad2d.default] = _replication_pad + aten.replication_pad2d.default +] = _replication_pad def _reflection_or_replication_pad( - a: Tensor, - padding: Tuple[int, ...], - idx_fn: Callable[[int, int, int], Tensor], + a: Tensor, + padding: Tuple[int, ...], + idx_fn: Callable[[int, int, int], Tensor], ) -> Tensor: dim = len(padding) // 2 torch._check( - a.dim() in (dim + 1, dim + 2), - lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", + a.dim() in (dim + 1, dim + 2), + lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input", ) inp_shape = a.shape[-dim:] nc_dim = a.dim() - dim @@ -144,11 +143,11 @@ def _sum_tensors(ts) -> Tensor: @register_decomposition(aten.grid_sampler_3d) def _grid_sampler_3d( - a: torch.Tensor, - grid: torch.Tensor, - interpolation_mode: int = 0, - padding_mode: int = 0, - align_corners: bool = False, + a: torch.Tensor, + grid: torch.Tensor, + interpolation_mode: int = 0, + padding_mode: int = 0, + align_corners: bool = False, ) -> Tensor: """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075 @@ -156,11 +155,13 @@ def _grid_sampler_3d( """ _expand_grid = False torch._check( - interpolation_mode in (0, 1), - lambda: f"Invalid interpolation mode {interpolation_mode}", + interpolation_mode in (0, 1), + lambda: f"Invalid interpolation mode {interpolation_mode}", ) torch._check( - padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}") + padding_mode in (0, 1, 2), + lambda: f"Invalid padding mode {padding_mode}", + ) # a is 5D: [B, C, D, H, W] @@ -175,8 +176,9 @@ def unnormalize(coords: Tensor, size: int) -> Tensor: # Reflects coordinates until they fall between low and high (inclusive). # The bounds are passed as twice their value so that half-integer values # can be represented as ints. - def reflect_coordinates(coords: Tensor, twice_low: int, - twice_high: int) -> Tensor: + def reflect_coordinates( + coords: Tensor, twice_low: int, twice_high: int + ) -> Tensor: if twice_low == twice_high: return torch.zeros_like(coords) coords_min = twice_low / 2 @@ -184,8 +186,9 @@ def reflect_coordinates(coords: Tensor, twice_low: int, coords2 = (coords - coords_min).abs() extra = torch.fmod(coords2, coords_span) flips = (coords2 / coords_span).floor().to(dtype=torch.int8) - return torch.where(flips & 1 == 0, extra + coords_min, - coords_span + coords_min - extra) + return torch.where( + flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra + ) def compute_coordinates(coords: Tensor, size: int) -> Tensor: if padding_mode == 0: # Zero @@ -224,15 +227,18 @@ def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor): # broadcasting with N_idx, C_idx for the purposes of advanced indexing c = C if _expand_grid else 1 return tuple( - torch.where(cond, t, 0).view(N, c, oD, oH, oW) for t in ( - xs.to(dtype=torch.int64), - ys.to(dtype=torch.int64), - zs.to(dtype=torch.int64), - ws, - )) - - def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, - w) -> Tensor: + torch.where(cond, t, 0).view(N, c, oD, oH, oW) + for t in ( + xs.to(dtype=torch.int64), + ys.to(dtype=torch.int64), + zs.to(dtype=torch.int64), + ws, + ) + ) + + def get_summand( + ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w + ) -> Tensor: # Perform clipping, index into input tensor and multiply by weight idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w) return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_ @@ -265,16 +271,18 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf) return _sum_tensors( - get_summand(ix, iy, id_, w) for (ix, iy, id_, w) in ( - (ix_nwf, iy_nwf, id_nwf, w_nwf), - (ix_nef, iy_nef, id_nef, w_nef), - (ix_swf, iy_swf, id_swf, w_swf), - (ix_sef, iy_sef, id_sef, w_sef), - (ix_nwb, iy_nwb, id_nwb, w_nwb), - (ix_neb, iy_neb, id_neb, w_neb), - (ix_swb, iy_swb, id_swb, w_swb), - (ix_seb, iy_seb, id_seb, w_seb), - )) + get_summand(ix, iy, id_, w) + for (ix, iy, id_, w) in ( + (ix_nwf, iy_nwf, id_nwf, w_nwf), + (ix_nef, iy_nef, id_nef, w_nef), + (ix_swf, iy_swf, id_swf, w_swf), + (ix_sef, iy_sef, id_sef, w_sef), + (ix_nwb, iy_nwb, id_nwb, w_nwb), + (ix_neb, iy_neb, id_neb, w_neb), + (ix_swb, iy_swb, id_swb, w_swb), + (ix_seb, iy_seb, id_seb, w_seb), + ) + ) else: # interpolation_mode == 1: # Nearest ix = compute_source_index(x, iW) iy = compute_source_index(y, iH) @@ -288,482 +296,482 @@ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, DECOMPOSITIONS = decomp.get_decompositions([ - torch.ops.aten.upsample_bicubic2d, - torch.ops.aten.upsample_nearest1d, - torch.ops.aten.upsample_nearest2d, - torch.ops.aten.upsample_nearest3d, - torch.ops.aten._upsample_nearest_exact1d, - torch.ops.aten._upsample_nearest_exact2d, - torch.ops.aten._upsample_nearest_exact3d, - torch.ops.aten._native_batch_norm_legit.no_stats, - torch.ops.aten._native_batch_norm_legit_functional.default, - torch.ops.aten._adaptive_avg_pool2d, - torch.ops.aten._adaptive_avg_pool3d, - torch.ops.aten.grid_sampler_2d, - torch.ops.aten.grid_sampler_3d, - torch.ops.aten.native_dropout, - torch.ops.aten.reflection_pad1d, - torch.ops.aten.reflection_pad2d, - torch.ops.aten.reflection_pad3d, - torch.ops.aten.replication_pad1d, - torch.ops.aten.replication_pad2d, - torch.ops.aten.replication_pad3d, - torch.ops.aten.bernoulli, - torch.ops.aten.rand_like, - torch.ops.aten._batch_norm_with_update, - torch.ops.aten.channel_shuffle, - torch.ops.aten.nll_loss2d_forward, - torch.ops.aten.nll_loss2d_backward, - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, - torch.ops.aten.log_normal, - torch.ops.aten.addcdiv.default, - torch.ops.aten.addcdiv.out, - torch.ops.aten.addcdiv_.default, - torch.ops.aten.addcmul.default, - torch.ops.aten.addcmul.out, - torch.ops.aten.addcmul_.default, - torch.ops.aten.addr.default, - torch.ops.aten.addr.out, - torch.ops.aten.affine_grid_generator.default, - torch.ops.aten.affine_grid_generator.out, - torch.ops.aten.alias_copy.default, - torch.ops.aten.alias_copy.out, - torch.ops.aten.all.default, - torch.ops.aten.all.dim, - torch.ops.aten.all.dims, - torch.ops.aten.all.out, - torch.ops.aten.all.dims_out, - torch.ops.aten.all.all_out, - torch.ops.aten.all.dimname, - torch.ops.aten.all.dimname_out, - torch.ops.aten.aminmax.default, - torch.ops.aten.aminmax.out, - torch.ops.aten.arange.default, - torch.ops.aten.arange.start, - torch.ops.aten.baddbmm.default, - torch.ops.aten.baddbmm.out, - torch.ops.aten.binary_cross_entropy.default, - torch.ops.aten.binary_cross_entropy.out, - torch.ops.aten.binary_cross_entropy_backward.default, - torch.ops.aten.binary_cross_entropy_backward.grad_input, - torch.ops.aten.binary_cross_entropy_with_logits.default, - torch.ops.aten.binary_cross_entropy_with_logits.out, - torch.ops.aten.block_diag.default, - torch.ops.aten.block_diag.out, - torch.ops.aten.celu.default, - torch.ops.aten.celu.out, - torch.ops.aten.celu_.default, - torch.ops.aten.channel_shuffle.default, - torch.ops.aten.channel_shuffle.out, - torch.ops.aten.clamp_max.default, - torch.ops.aten.clamp_max.Tensor, - torch.ops.aten.clamp_max.out, - torch.ops.aten.clamp_max.Tensor_out, - torch.ops.aten.clamp_min.default, - torch.ops.aten.clamp_min.Tensor, - torch.ops.aten.clamp_min.out, - torch.ops.aten.clamp_min.Tensor_out, - torch.ops.aten.col2im.default, - torch.ops.aten.col2im.out, - torch.ops.aten.count_nonzero.dim_IntList, - torch.ops.aten.count_nonzero.dim_IntList_out, - torch.ops.aten.count_nonzero.default, - torch.ops.aten.count_nonzero.out, - torch.ops.aten.linalg_cross.default, - torch.ops.aten.linalg_cross.out, - torch.ops.aten.cudnn_batch_norm.default, - torch.ops.aten.cudnn_batch_norm.out, - torch.ops.aten.cudnn_batch_norm_backward.default, - torch.ops.aten.cudnn_batch_norm_backward.out, - torch.ops.aten.miopen_batch_norm_backward.default, - torch.ops.aten.miopen_batch_norm_backward.out, - torch.ops.aten.deg2rad.default, - torch.ops.aten.deg2rad.out, - torch.ops.aten.deg2rad_.default, - torch.ops.aten.detach.default, - torch.ops.aten.diag_embed.default, - torch.ops.aten.diag_embed.out, - torch.ops.aten.diagonal_backward.default, - torch.ops.aten.diagonal_backward.out, - torch.ops.aten.dot.default, - torch.ops.aten.dot.out, - torch.ops.aten.vdot.default, - torch.ops.aten.vdot.out, - torch.ops.aten.elu.default, - torch.ops.aten.elu.out, - torch.ops.aten.elu_.default, - torch.ops.aten.elu_backward.default, - torch.ops.aten.elu_backward.grad_input, - torch.ops.aten.embedding_dense_backward.default, - torch.ops.aten.embedding_dense_backward.out, - torch.ops.aten.empty_like.default, - torch.ops.aten.empty_like.out, - torch.ops.aten._euclidean_dist.default, - torch.ops.aten.expand_copy.default, - torch.ops.aten.expand_copy.out, - torch.ops.aten.eye.default, - torch.ops.aten.eye.m, - torch.ops.aten.eye.out, - torch.ops.aten.eye.m_out, - torch.ops.aten.fill.Scalar, - torch.ops.aten.fill.Tensor, - torch.ops.aten.fill_.Scalar, - torch.ops.aten.fill_.Tensor, - torch.ops.aten.floor_divide.default, - torch.ops.aten.floor_divide.Scalar, - torch.ops.aten.floor_divide.out, - torch.ops.aten.floor_divide.Scalar_out, - torch.ops.aten.frac.default, - torch.ops.aten.frac.out, - torch.ops.aten.frac_.default, - torch.ops.aten.gelu_.default, - torch.ops.aten.gelu_backward.default, - torch.ops.aten.gelu_backward.grad_input, - torch.ops.aten.glu.default, - torch.ops.aten.glu.out, - torch.ops.aten.glu_backward.default, - torch.ops.aten.glu_backward.grad_input, - torch.ops.aten.hardshrink.default, - torch.ops.aten.hardshrink.out, - torch.ops.aten.hardsigmoid.default, - torch.ops.aten.hardsigmoid.out, - torch.ops.aten.hardsigmoid_.default, - torch.ops.aten.hardsigmoid_backward.default, - torch.ops.aten.hardsigmoid_backward.grad_input, - torch.ops.aten.hardswish.default, - torch.ops.aten.hardswish.out, - torch.ops.aten.hardswish_.default, - torch.ops.aten.hardswish_backward.default, - torch.ops.aten.hardswish_backward.out, - torch.ops.aten.hardtanh_.default, - torch.ops.aten.hardtanh_backward.default, - torch.ops.aten.hardtanh_backward.grad_input, - torch.ops.aten.heaviside.default, - torch.ops.aten.heaviside.out, - torch.ops.aten.heaviside_.default, - torch.ops.aten.huber_loss.default, - torch.ops.aten.huber_loss.out, - torch.ops.aten.huber_loss_backward.default, - torch.ops.aten.huber_loss_backward.out, - torch.ops.aten.im2col.default, - torch.ops.aten.im2col.out, - torch.ops.aten.index_add.default, - torch.ops.aten.index_add.out, - torch.ops.aten.index_add.dimname, - torch.ops.aten.index_add_.default, - torch.ops.aten.index_copy.default, - torch.ops.aten.index_copy.dimname, - torch.ops.aten.index_copy.out, - torch.ops.aten.index_copy_.default, - torch.ops.aten.index_copy_.dimname, - torch.ops.aten.index_fill.int_Tensor, - torch.ops.aten.index_fill.int_Scalar, - torch.ops.aten.index_fill.Dimname_Scalar, - torch.ops.aten.index_fill.Dimname_Tensor, - torch.ops.aten.index_fill.int_Scalar_out, - torch.ops.aten.index_fill.int_Tensor_out, - torch.ops.aten.index_fill_.int_Tensor, - torch.ops.aten.index_fill_.int_Scalar, - torch.ops.aten.index_fill_.Dimname_Scalar, - torch.ops.aten.index_fill_.Dimname_Tensor, - torch.ops.aten.isin.Tensor_Tensor, - torch.ops.aten.isin.Tensor_Tensor_out, - torch.ops.aten.isin.Tensor_Scalar, - torch.ops.aten.isin.Tensor_Scalar_out, - torch.ops.aten.isin.Scalar_Tensor, - torch.ops.aten.isin.Scalar_Tensor_out, - torch.ops.aten.isneginf.default, - torch.ops.aten.isneginf.out, - torch.ops.aten.isposinf.default, - torch.ops.aten.isposinf.out, - torch.ops.aten.leaky_relu_.default, - torch.ops.aten.leaky_relu_backward.default, - torch.ops.aten.leaky_relu_backward.grad_input, - torch.ops.aten.lerp.Scalar, - torch.ops.aten.lerp.Tensor, - torch.ops.aten.lerp.Scalar_out, - torch.ops.aten.lerp.Tensor_out, - torch.ops.aten.lerp_.Scalar, - torch.ops.aten.lerp_.Tensor, - torch.ops.aten.linspace.Tensor_Tensor, - torch.ops.aten.linspace.Tensor_Scalar, - torch.ops.aten.linspace.Scalar_Tensor, - torch.ops.aten.linspace.default, - torch.ops.aten.linspace.out, - torch.ops.aten.linspace.Tensor_Tensor_out, - torch.ops.aten.linspace.Tensor_Scalar_out, - torch.ops.aten.linspace.Scalar_Tensor_out, - torch.ops.aten.logaddexp.default, - torch.ops.aten.logaddexp.out, - torch.ops.aten.logaddexp2.default, - torch.ops.aten.logaddexp2.out, - torch.ops.aten.logit.default, - torch.ops.aten.logit.out, - torch.ops.aten.logit_.default, - torch.ops.aten.logit_backward.default, - torch.ops.aten.log_sigmoid_backward.default, - torch.ops.aten.log_sigmoid_backward.grad_input, - torch.ops.aten.log_sigmoid_forward.default, - torch.ops.aten.log_sigmoid_forward.output, - torch.ops.aten._log_softmax_backward_data.default, - torch.ops.aten._log_softmax_backward_data.out, - torch.ops.aten.logspace.Tensor_Tensor, - torch.ops.aten.logspace.Tensor_Scalar, - torch.ops.aten.logspace.Scalar_Tensor, - torch.ops.aten.logspace.default, - torch.ops.aten.logspace.out, - torch.ops.aten.logspace.Tensor_Tensor_out, - torch.ops.aten.logspace.Tensor_Scalar_out, - torch.ops.aten.logspace.Scalar_Tensor_out, - torch.ops.aten.logsumexp.default, - torch.ops.aten.masked_fill.Scalar, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar_out, - torch.ops.aten.masked_fill.Tensor_out, - torch.ops.aten.masked_fill_.Scalar, - torch.ops.aten.masked_fill_.Tensor, - torch.ops.aten.mish.default, - torch.ops.aten.mish.out, - torch.ops.aten.mish_.default, - torch.ops.aten.mse_loss.default, - torch.ops.aten.mse_loss.out, - torch.ops.aten.mse_loss_backward.default, - torch.ops.aten.mse_loss_backward.grad_input, - torch.ops.aten.multi_margin_loss.default, - torch.ops.aten.multi_margin_loss.out, - torch.ops.aten.multilabel_margin_loss_forward.default, - torch.ops.aten.multilabel_margin_loss_forward.output, - torch.ops.aten.mv.default, - torch.ops.aten.mv.out, - torch.ops.aten.mvlgamma.default, - torch.ops.aten.mvlgamma.out, - torch.ops.aten.mvlgamma_.default, - torch.ops.aten.nansum.default, - torch.ops.aten.nansum.out, - torch.ops.aten.nan_to_num.default, - torch.ops.aten.nan_to_num.out, - torch.ops.aten.nan_to_num_.default, - torch.ops.aten.native_batch_norm_backward.default, - torch.ops.aten.native_batch_norm_backward.out, - torch.ops.aten.native_dropout_backward.default, - torch.ops.aten.native_dropout_backward.out, - torch.ops.aten.native_group_norm_backward.default, - torch.ops.aten.native_group_norm_backward.out, - torch.ops.aten.native_layer_norm_backward.default, - torch.ops.aten.native_layer_norm_backward.out, - torch.ops.aten.new_empty.default, - torch.ops.aten.new_empty.out, - torch.ops.aten.new_full.default, - torch.ops.aten.new_full.out, - torch.ops.aten.new_ones.default, - torch.ops.aten.new_ones.out, - torch.ops.aten.new_zeros.default, - torch.ops.aten.new_zeros.out, - torch.ops.aten.nll_loss2d_forward.default, - torch.ops.aten.nll_loss2d_forward.output, - torch.ops.aten.nll_loss2d_backward.default, - torch.ops.aten.nll_loss2d_backward.grad_input, - torch.ops.aten.nll_loss_backward.default, - torch.ops.aten.nll_loss_backward.grad_input, - torch.ops.aten.nll_loss_forward.default, - torch.ops.aten.nll_loss_forward.output, - torch.ops.aten.norm.Scalar, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.norm.names_ScalarOpt_dim, - torch.ops.aten.norm.ScalarOpt_dim_dtype, - torch.ops.aten.norm.dtype_out, - torch.ops.aten.norm.out, - torch.ops.aten.norm.ScalarOpt_dtype, - torch.ops.aten.norm.ScalarOpt_dtype_out, - torch.ops.aten.norm.Scalar_out, - torch.ops.aten.norm.names_ScalarOpt_dim_dtype, - torch.ops.aten.norm.names_dtype_out, - torch.ops.aten.norm.names_out, - torch.ops.aten.ones.default, - torch.ops.aten.ones_like.default, - torch.ops.aten.ones_like.out, - torch.ops.aten.pixel_shuffle.default, - torch.ops.aten.pixel_shuffle.out, - torch.ops.aten.pixel_unshuffle.default, - torch.ops.aten.pixel_unshuffle.out, - torch.ops.aten._prelu_kernel.default, - torch.ops.aten._prelu_kernel_backward.default, - torch.ops.aten._reshape_alias.default, - torch.ops.aten.rad2deg.default, - torch.ops.aten.rad2deg.out, - torch.ops.aten.rad2deg_.default, - torch.ops.aten.reflection_pad1d.default, - torch.ops.aten.reflection_pad1d.out, - torch.ops.aten.reflection_pad1d_backward.default, - torch.ops.aten.reflection_pad1d_backward.grad_input, - torch.ops.aten.reflection_pad2d.default, - torch.ops.aten.reflection_pad2d.out, - torch.ops.aten.reflection_pad2d_backward.default, - torch.ops.aten.reflection_pad2d_backward.grad_input, - torch.ops.aten.reflection_pad3d.default, - torch.ops.aten.reflection_pad3d.out, - torch.ops.aten.reflection_pad3d_backward.default, - torch.ops.aten.reflection_pad3d_backward.grad_input, - torch.ops.aten.replication_pad1d.default, - torch.ops.aten.replication_pad1d.out, - torch.ops.aten.replication_pad2d.default, - torch.ops.aten.replication_pad2d.out, - torch.ops.aten.replication_pad3d.default, - torch.ops.aten.replication_pad3d.out, - torch.ops.aten.renorm.default, - torch.ops.aten.renorm.out, - torch.ops.aten.renorm_.default, - torch.ops.aten.resize_as.default, - torch.ops.aten.resize_as.out, - torch.ops.aten.roll.default, - torch.ops.aten.roll.out, - torch.ops.aten.rot90.default, - torch.ops.aten.rot90.out, - torch.ops.aten.rrelu_with_noise.default, - torch.ops.aten.rrelu_with_noise.out, - torch.ops.aten.rrelu_with_noise_.default, - torch.ops.aten.rsub.Tensor, - torch.ops.aten.rsub.Scalar, - torch.ops.aten.rsub.Tensor_out, - torch.ops.aten.rsub.Scalar_out, - torch.ops.aten._safe_softmax.default, - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, - torch.ops.aten.select_backward.default, - torch.ops.aten.select_backward.out, - torch.ops.aten.select_scatter.default, - torch.ops.aten.select_scatter.out, - torch.ops.aten.sgn.default, - torch.ops.aten.sgn.out, - torch.ops.aten.sgn_.default, - torch.ops.aten.sigmoid_backward.default, - torch.ops.aten.sigmoid_backward.grad_input, - torch.ops.aten.silu.default, - torch.ops.aten.silu.out, - torch.ops.aten.silu_.default, - torch.ops.aten.silu_backward.default, - torch.ops.aten.silu_backward.grad_input, - torch.ops.aten.sinc.default, - torch.ops.aten.sinc.out, - torch.ops.aten.sinc_.default, - torch.ops.aten.slice_backward.default, - torch.ops.aten.slice_backward.out, - torch.ops.aten.smooth_l1_loss.default, - torch.ops.aten.smooth_l1_loss.out, - torch.ops.aten.smooth_l1_loss_backward.default, - torch.ops.aten.smooth_l1_loss_backward.grad_input, - torch.ops.aten.soft_margin_loss.default, - torch.ops.aten.soft_margin_loss.out, - torch.ops.aten.soft_margin_loss_backward.default, - torch.ops.aten.soft_margin_loss_backward.grad_input, - torch.ops.aten._softmax_backward_data.default, - torch.ops.aten._softmax_backward_data.out, - torch.ops.aten.softplus.default, - torch.ops.aten.softplus.out, - torch.ops.aten.softplus_backward.default, - torch.ops.aten.softplus_backward.grad_input, - torch.ops.aten.softshrink.default, - torch.ops.aten.softshrink.out, - torch.ops.aten.special_entr.default, - torch.ops.aten.special_entr.out, - torch.ops.aten.special_log_ndtr.default, - torch.ops.aten.special_log_ndtr.out, - torch.ops.aten.special_xlog1py.default, - torch.ops.aten.special_xlog1py.other_scalar, - torch.ops.aten.special_xlog1py.self_scalar, - torch.ops.aten.special_xlog1py.out, - torch.ops.aten.special_xlog1py.self_scalar_out, - torch.ops.aten.special_xlog1py.other_scalar_out, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes_copy.default, - torch.ops.aten.split_with_sizes_copy.out, - torch.ops.aten.squeeze.default, - torch.ops.aten.squeeze.dim, - torch.ops.aten.std.default, - torch.ops.aten.std.dim, - torch.ops.aten.std.correction, - torch.ops.aten.std.names_dim, - torch.ops.aten.std.names_out, - torch.ops.aten.std.out, - torch.ops.aten.std.correction_out, - torch.ops.aten.std.correction_names, - torch.ops.aten.std.correction_names_out, - torch.ops.aten.std_mean.default, - torch.ops.aten.std_mean.dim, - torch.ops.aten.std_mean.correction, - torch.ops.aten.std_mean.names_dim, - torch.ops.aten.std_mean.correction_names, - torch.ops.aten.std_mean.correction_out, - torch.ops.aten.stack.default, - torch.ops.aten.stack.out, - torch.ops.aten.sum.default, - torch.ops.aten.sum.out, - torch.ops.aten.t.default, - torch.ops.aten.t_copy.out, - torch.ops.aten.t_copy.default, - torch.ops.aten.take.default, - torch.ops.aten.take.out, - torch.ops.aten.tanh_backward.default, - torch.ops.aten.tanh_backward.grad_input, - torch.ops.aten.threshold.default, - torch.ops.aten.threshold.out, - torch.ops.aten.threshold_.default, - torch.ops.aten.threshold_backward.default, - torch.ops.aten.threshold_backward.grad_input, - torch.ops.aten.trace.default, - torch.ops.aten.trace.out, - torch.ops.aten.transpose.int, - torch.ops.aten.tril.default, - torch.ops.aten.tril.out, - torch.ops.aten.tril_.default, - torch.ops.aten.triu.default, - torch.ops.aten.triu.out, - torch.ops.aten.triu_.default, - torch.ops.aten.unbind.int, - torch.ops.aten.unbind.Dimname, - torch.ops.aten.unfold_backward.default, - torch.ops.aten.unfold_backward.out, - torch.ops.aten.unfold_copy.default, - torch.ops.aten.unfold_copy.out, - torch.ops.aten._unsafe_index.Tensor, - torch.ops.aten._unsafe_index_put.default, - torch.ops.aten._unsafe_masked_index.default, - torch.ops.aten._unsafe_masked_index_put_accumulate.default, - torch.ops.aten.unsafe_split.Tensor, - torch.ops.aten.unsafe_split_with_sizes.default, - torch.ops.aten.unsqueeze_copy.out, - torch.ops.aten.unsqueeze_copy.default, - torch.ops.aten._unsafe_view.default, - torch.ops.aten._unsafe_view.out, - torch.ops.aten.upsample_linear1d.default, - torch.ops.aten.upsample_linear1d.out, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.upsample_bilinear2d.default, - torch.ops.aten.upsample_bilinear2d.out, - torch.ops.aten.upsample_trilinear3d.vec, - torch.ops.aten.upsample_trilinear3d.default, - torch.ops.aten.upsample_trilinear3d.out, - torch.ops.aten.xlogy.Tensor, - torch.ops.aten.xlogy.Scalar_Other, - torch.ops.aten.xlogy.Scalar_Self, - torch.ops.aten.xlogy.OutTensor, - torch.ops.aten.xlogy.OutScalar_Self, - torch.ops.aten.xlogy.OutScalar_Other, - torch.ops.aten.xlogy_.Tensor, - torch.ops.aten.xlogy_.Scalar_Other, - torch.ops.aten.zero.default, - torch.ops.aten.zero.out, - torch.ops.aten.zero_.default, - torch.ops.aten.zeros.default, - torch.ops.aten.zeros_like.default, - torch.ops.aten.zeros_like.out, - torch.ops.aten._chunk_cat.default, - torch.ops.aten._chunk_cat.out, - torch.ops.aten._weight_norm_interface.default, - torch.ops.aten._weight_norm_interface.out, + torch.ops.aten.upsample_bicubic2d, + torch.ops.aten.upsample_nearest1d, + torch.ops.aten.upsample_nearest2d, + torch.ops.aten.upsample_nearest3d, + torch.ops.aten._upsample_nearest_exact1d, + torch.ops.aten._upsample_nearest_exact2d, + torch.ops.aten._upsample_nearest_exact3d, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten._native_batch_norm_legit_functional.default, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._adaptive_avg_pool3d, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.grid_sampler_3d, + torch.ops.aten.native_dropout, + torch.ops.aten.reflection_pad1d, + torch.ops.aten.reflection_pad2d, + torch.ops.aten.reflection_pad3d, + torch.ops.aten.replication_pad1d, + torch.ops.aten.replication_pad2d, + torch.ops.aten.replication_pad3d, + torch.ops.aten.bernoulli, + torch.ops.aten.rand_like, + torch.ops.aten._batch_norm_with_update, + torch.ops.aten.channel_shuffle, + torch.ops.aten.nll_loss2d_forward, + torch.ops.aten.nll_loss2d_backward, + torch.ops.aten.bernoulli_.Tensor, + torch.ops.aten.bernoulli_.float, + torch.ops.aten.log_normal, + torch.ops.aten.addcdiv.default, + torch.ops.aten.addcdiv.out, + torch.ops.aten.addcdiv_.default, + torch.ops.aten.addcmul.default, + torch.ops.aten.addcmul.out, + torch.ops.aten.addcmul_.default, + torch.ops.aten.addr.default, + torch.ops.aten.addr.out, + torch.ops.aten.affine_grid_generator.default, + torch.ops.aten.affine_grid_generator.out, + torch.ops.aten.alias_copy.default, + torch.ops.aten.alias_copy.out, + torch.ops.aten.all.default, + torch.ops.aten.all.dim, + torch.ops.aten.all.dims, + torch.ops.aten.all.out, + torch.ops.aten.all.dims_out, + torch.ops.aten.all.all_out, + torch.ops.aten.all.dimname, + torch.ops.aten.all.dimname_out, + torch.ops.aten.aminmax.default, + torch.ops.aten.aminmax.out, + torch.ops.aten.arange.default, + torch.ops.aten.arange.start, + torch.ops.aten.baddbmm.default, + torch.ops.aten.baddbmm.out, + torch.ops.aten.binary_cross_entropy.default, + torch.ops.aten.binary_cross_entropy.out, + torch.ops.aten.binary_cross_entropy_backward.default, + torch.ops.aten.binary_cross_entropy_backward.grad_input, + torch.ops.aten.binary_cross_entropy_with_logits.default, + torch.ops.aten.binary_cross_entropy_with_logits.out, + torch.ops.aten.block_diag.default, + torch.ops.aten.block_diag.out, + torch.ops.aten.celu.default, + torch.ops.aten.celu.out, + torch.ops.aten.celu_.default, + torch.ops.aten.channel_shuffle.default, + torch.ops.aten.channel_shuffle.out, + torch.ops.aten.clamp_max.default, + torch.ops.aten.clamp_max.Tensor, + torch.ops.aten.clamp_max.out, + torch.ops.aten.clamp_max.Tensor_out, + torch.ops.aten.clamp_min.default, + torch.ops.aten.clamp_min.Tensor, + torch.ops.aten.clamp_min.out, + torch.ops.aten.clamp_min.Tensor_out, + torch.ops.aten.col2im.default, + torch.ops.aten.col2im.out, + torch.ops.aten.count_nonzero.dim_IntList, + torch.ops.aten.count_nonzero.dim_IntList_out, + torch.ops.aten.count_nonzero.default, + torch.ops.aten.count_nonzero.out, + torch.ops.aten.linalg_cross.default, + torch.ops.aten.linalg_cross.out, + torch.ops.aten.cudnn_batch_norm.default, + torch.ops.aten.cudnn_batch_norm.out, + torch.ops.aten.cudnn_batch_norm_backward.default, + torch.ops.aten.cudnn_batch_norm_backward.out, + torch.ops.aten.miopen_batch_norm_backward.default, + torch.ops.aten.miopen_batch_norm_backward.out, + torch.ops.aten.deg2rad.default, + torch.ops.aten.deg2rad.out, + torch.ops.aten.deg2rad_.default, + torch.ops.aten.detach.default, + torch.ops.aten.diag_embed.default, + torch.ops.aten.diag_embed.out, + torch.ops.aten.diagonal_backward.default, + torch.ops.aten.diagonal_backward.out, + torch.ops.aten.dot.default, + torch.ops.aten.dot.out, + torch.ops.aten.vdot.default, + torch.ops.aten.vdot.out, + torch.ops.aten.elu.default, + torch.ops.aten.elu.out, + torch.ops.aten.elu_.default, + torch.ops.aten.elu_backward.default, + torch.ops.aten.elu_backward.grad_input, + torch.ops.aten.embedding_dense_backward.default, + torch.ops.aten.embedding_dense_backward.out, + torch.ops.aten.empty_like.default, + torch.ops.aten.empty_like.out, + torch.ops.aten._euclidean_dist.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand_copy.out, + torch.ops.aten.eye.default, + torch.ops.aten.eye.m, + torch.ops.aten.eye.out, + torch.ops.aten.eye.m_out, + torch.ops.aten.fill.Scalar, + torch.ops.aten.fill.Tensor, + torch.ops.aten.fill_.Scalar, + torch.ops.aten.fill_.Tensor, + torch.ops.aten.floor_divide.default, + torch.ops.aten.floor_divide.Scalar, + torch.ops.aten.floor_divide.out, + torch.ops.aten.floor_divide.Scalar_out, + torch.ops.aten.frac.default, + torch.ops.aten.frac.out, + torch.ops.aten.frac_.default, + torch.ops.aten.gelu_.default, + torch.ops.aten.gelu_backward.default, + torch.ops.aten.gelu_backward.grad_input, + torch.ops.aten.glu.default, + torch.ops.aten.glu.out, + torch.ops.aten.glu_backward.default, + torch.ops.aten.glu_backward.grad_input, + torch.ops.aten.hardshrink.default, + torch.ops.aten.hardshrink.out, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid.out, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.hardsigmoid_backward.default, + torch.ops.aten.hardsigmoid_backward.grad_input, + torch.ops.aten.hardswish.default, + torch.ops.aten.hardswish.out, + torch.ops.aten.hardswish_.default, + torch.ops.aten.hardswish_backward.default, + torch.ops.aten.hardswish_backward.out, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.hardtanh_backward.default, + torch.ops.aten.hardtanh_backward.grad_input, + torch.ops.aten.heaviside.default, + torch.ops.aten.heaviside.out, + torch.ops.aten.heaviside_.default, + torch.ops.aten.huber_loss.default, + torch.ops.aten.huber_loss.out, + torch.ops.aten.huber_loss_backward.default, + torch.ops.aten.huber_loss_backward.out, + torch.ops.aten.im2col.default, + torch.ops.aten.im2col.out, + torch.ops.aten.index_add.default, + torch.ops.aten.index_add.out, + torch.ops.aten.index_add.dimname, + torch.ops.aten.index_add_.default, + torch.ops.aten.index_copy.default, + torch.ops.aten.index_copy.dimname, + torch.ops.aten.index_copy.out, + torch.ops.aten.index_copy_.default, + torch.ops.aten.index_copy_.dimname, + torch.ops.aten.index_fill.int_Tensor, + torch.ops.aten.index_fill.int_Scalar, + torch.ops.aten.index_fill.Dimname_Scalar, + torch.ops.aten.index_fill.Dimname_Tensor, + torch.ops.aten.index_fill.int_Scalar_out, + torch.ops.aten.index_fill.int_Tensor_out, + torch.ops.aten.index_fill_.int_Tensor, + torch.ops.aten.index_fill_.int_Scalar, + torch.ops.aten.index_fill_.Dimname_Scalar, + torch.ops.aten.index_fill_.Dimname_Tensor, + torch.ops.aten.isin.Tensor_Tensor, + torch.ops.aten.isin.Tensor_Tensor_out, + torch.ops.aten.isin.Tensor_Scalar, + torch.ops.aten.isin.Tensor_Scalar_out, + torch.ops.aten.isin.Scalar_Tensor, + torch.ops.aten.isin.Scalar_Tensor_out, + torch.ops.aten.isneginf.default, + torch.ops.aten.isneginf.out, + torch.ops.aten.isposinf.default, + torch.ops.aten.isposinf.out, + torch.ops.aten.leaky_relu_.default, + torch.ops.aten.leaky_relu_backward.default, + torch.ops.aten.leaky_relu_backward.grad_input, + torch.ops.aten.lerp.Scalar, + torch.ops.aten.lerp.Tensor, + torch.ops.aten.lerp.Scalar_out, + torch.ops.aten.lerp.Tensor_out, + torch.ops.aten.lerp_.Scalar, + torch.ops.aten.lerp_.Tensor, + torch.ops.aten.linspace.Tensor_Tensor, + torch.ops.aten.linspace.Tensor_Scalar, + torch.ops.aten.linspace.Scalar_Tensor, + torch.ops.aten.linspace.default, + torch.ops.aten.linspace.out, + torch.ops.aten.linspace.Tensor_Tensor_out, + torch.ops.aten.linspace.Tensor_Scalar_out, + torch.ops.aten.linspace.Scalar_Tensor_out, + torch.ops.aten.logaddexp.default, + torch.ops.aten.logaddexp.out, + torch.ops.aten.logaddexp2.default, + torch.ops.aten.logaddexp2.out, + torch.ops.aten.logit.default, + torch.ops.aten.logit.out, + torch.ops.aten.logit_.default, + torch.ops.aten.logit_backward.default, + torch.ops.aten.log_sigmoid_backward.default, + torch.ops.aten.log_sigmoid_backward.grad_input, + torch.ops.aten.log_sigmoid_forward.default, + torch.ops.aten.log_sigmoid_forward.output, + torch.ops.aten._log_softmax_backward_data.default, + torch.ops.aten._log_softmax_backward_data.out, + torch.ops.aten.logspace.Tensor_Tensor, + torch.ops.aten.logspace.Tensor_Scalar, + torch.ops.aten.logspace.Scalar_Tensor, + torch.ops.aten.logspace.default, + torch.ops.aten.logspace.out, + torch.ops.aten.logspace.Tensor_Tensor_out, + torch.ops.aten.logspace.Tensor_Scalar_out, + torch.ops.aten.logspace.Scalar_Tensor_out, + torch.ops.aten.logsumexp.default, + torch.ops.aten.masked_fill.Scalar, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar_out, + torch.ops.aten.masked_fill.Tensor_out, + torch.ops.aten.masked_fill_.Scalar, + torch.ops.aten.masked_fill_.Tensor, + torch.ops.aten.mish.default, + torch.ops.aten.mish.out, + torch.ops.aten.mish_.default, + torch.ops.aten.mse_loss.default, + torch.ops.aten.mse_loss.out, + torch.ops.aten.mse_loss_backward.default, + torch.ops.aten.mse_loss_backward.grad_input, + torch.ops.aten.multi_margin_loss.default, + torch.ops.aten.multi_margin_loss.out, + torch.ops.aten.multilabel_margin_loss_forward.default, + torch.ops.aten.multilabel_margin_loss_forward.output, + torch.ops.aten.mv.default, + torch.ops.aten.mv.out, + torch.ops.aten.mvlgamma.default, + torch.ops.aten.mvlgamma.out, + torch.ops.aten.mvlgamma_.default, + torch.ops.aten.nansum.default, + torch.ops.aten.nansum.out, + torch.ops.aten.nan_to_num.default, + torch.ops.aten.nan_to_num.out, + torch.ops.aten.nan_to_num_.default, + torch.ops.aten.native_batch_norm_backward.default, + torch.ops.aten.native_batch_norm_backward.out, + torch.ops.aten.native_dropout_backward.default, + torch.ops.aten.native_dropout_backward.out, + torch.ops.aten.native_group_norm_backward.default, + torch.ops.aten.native_group_norm_backward.out, + torch.ops.aten.native_layer_norm_backward.default, + torch.ops.aten.native_layer_norm_backward.out, + torch.ops.aten.new_empty.default, + torch.ops.aten.new_empty.out, + torch.ops.aten.new_full.default, + torch.ops.aten.new_full.out, + torch.ops.aten.new_ones.default, + torch.ops.aten.new_ones.out, + torch.ops.aten.new_zeros.default, + torch.ops.aten.new_zeros.out, + torch.ops.aten.nll_loss2d_forward.default, + torch.ops.aten.nll_loss2d_forward.output, + torch.ops.aten.nll_loss2d_backward.default, + torch.ops.aten.nll_loss2d_backward.grad_input, + torch.ops.aten.nll_loss_backward.default, + torch.ops.aten.nll_loss_backward.grad_input, + torch.ops.aten.nll_loss_forward.default, + torch.ops.aten.nll_loss_forward.output, + torch.ops.aten.norm.Scalar, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.norm.names_ScalarOpt_dim, + torch.ops.aten.norm.ScalarOpt_dim_dtype, + torch.ops.aten.norm.dtype_out, + torch.ops.aten.norm.out, + torch.ops.aten.norm.ScalarOpt_dtype, + torch.ops.aten.norm.ScalarOpt_dtype_out, + torch.ops.aten.norm.Scalar_out, + torch.ops.aten.norm.names_ScalarOpt_dim_dtype, + torch.ops.aten.norm.names_dtype_out, + torch.ops.aten.norm.names_out, + torch.ops.aten.ones.default, + torch.ops.aten.ones_like.default, + torch.ops.aten.ones_like.out, + torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_shuffle.out, + torch.ops.aten.pixel_unshuffle.default, + torch.ops.aten.pixel_unshuffle.out, + torch.ops.aten._prelu_kernel.default, + torch.ops.aten._prelu_kernel_backward.default, + torch.ops.aten._reshape_alias.default, + torch.ops.aten.rad2deg.default, + torch.ops.aten.rad2deg.out, + torch.ops.aten.rad2deg_.default, + torch.ops.aten.reflection_pad1d.default, + torch.ops.aten.reflection_pad1d.out, + torch.ops.aten.reflection_pad1d_backward.default, + torch.ops.aten.reflection_pad1d_backward.grad_input, + torch.ops.aten.reflection_pad2d.default, + torch.ops.aten.reflection_pad2d.out, + torch.ops.aten.reflection_pad2d_backward.default, + torch.ops.aten.reflection_pad2d_backward.grad_input, + torch.ops.aten.reflection_pad3d.default, + torch.ops.aten.reflection_pad3d.out, + torch.ops.aten.reflection_pad3d_backward.default, + torch.ops.aten.reflection_pad3d_backward.grad_input, + torch.ops.aten.replication_pad1d.default, + torch.ops.aten.replication_pad1d.out, + torch.ops.aten.replication_pad2d.default, + torch.ops.aten.replication_pad2d.out, + torch.ops.aten.replication_pad3d.default, + torch.ops.aten.replication_pad3d.out, + torch.ops.aten.renorm.default, + torch.ops.aten.renorm.out, + torch.ops.aten.renorm_.default, + torch.ops.aten.resize_as.default, + torch.ops.aten.resize_as.out, + torch.ops.aten.roll.default, + torch.ops.aten.roll.out, + torch.ops.aten.rot90.default, + torch.ops.aten.rot90.out, + torch.ops.aten.rrelu_with_noise.default, + torch.ops.aten.rrelu_with_noise.out, + torch.ops.aten.rrelu_with_noise_.default, + torch.ops.aten.rsub.Tensor, + torch.ops.aten.rsub.Scalar, + torch.ops.aten.rsub.Tensor_out, + torch.ops.aten.rsub.Scalar_out, + torch.ops.aten._safe_softmax.default, + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, + torch.ops.aten.select_backward.default, + torch.ops.aten.select_backward.out, + torch.ops.aten.select_scatter.default, + torch.ops.aten.select_scatter.out, + torch.ops.aten.sgn.default, + torch.ops.aten.sgn.out, + torch.ops.aten.sgn_.default, + torch.ops.aten.sigmoid_backward.default, + torch.ops.aten.sigmoid_backward.grad_input, + torch.ops.aten.silu.default, + torch.ops.aten.silu.out, + torch.ops.aten.silu_.default, + torch.ops.aten.silu_backward.default, + torch.ops.aten.silu_backward.grad_input, + torch.ops.aten.sinc.default, + torch.ops.aten.sinc.out, + torch.ops.aten.sinc_.default, + torch.ops.aten.slice_backward.default, + torch.ops.aten.slice_backward.out, + torch.ops.aten.smooth_l1_loss.default, + torch.ops.aten.smooth_l1_loss.out, + torch.ops.aten.smooth_l1_loss_backward.default, + torch.ops.aten.smooth_l1_loss_backward.grad_input, + torch.ops.aten.soft_margin_loss.default, + torch.ops.aten.soft_margin_loss.out, + torch.ops.aten.soft_margin_loss_backward.default, + torch.ops.aten.soft_margin_loss_backward.grad_input, + torch.ops.aten._softmax_backward_data.default, + torch.ops.aten._softmax_backward_data.out, + torch.ops.aten.softplus.default, + torch.ops.aten.softplus.out, + torch.ops.aten.softplus_backward.default, + torch.ops.aten.softplus_backward.grad_input, + torch.ops.aten.softshrink.default, + torch.ops.aten.softshrink.out, + torch.ops.aten.special_entr.default, + torch.ops.aten.special_entr.out, + torch.ops.aten.special_log_ndtr.default, + torch.ops.aten.special_log_ndtr.out, + torch.ops.aten.special_xlog1py.default, + torch.ops.aten.special_xlog1py.other_scalar, + torch.ops.aten.special_xlog1py.self_scalar, + torch.ops.aten.special_xlog1py.out, + torch.ops.aten.special_xlog1py.self_scalar_out, + torch.ops.aten.special_xlog1py.other_scalar_out, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes_copy.default, + torch.ops.aten.split_with_sizes_copy.out, + torch.ops.aten.squeeze.default, + torch.ops.aten.squeeze.dim, + torch.ops.aten.std.default, + torch.ops.aten.std.dim, + torch.ops.aten.std.correction, + torch.ops.aten.std.names_dim, + torch.ops.aten.std.names_out, + torch.ops.aten.std.out, + torch.ops.aten.std.correction_out, + torch.ops.aten.std.correction_names, + torch.ops.aten.std.correction_names_out, + torch.ops.aten.std_mean.default, + torch.ops.aten.std_mean.dim, + torch.ops.aten.std_mean.correction, + torch.ops.aten.std_mean.names_dim, + torch.ops.aten.std_mean.correction_names, + torch.ops.aten.std_mean.correction_out, + torch.ops.aten.stack.default, + torch.ops.aten.stack.out, + torch.ops.aten.sum.default, + torch.ops.aten.sum.out, + torch.ops.aten.t.default, + torch.ops.aten.t_copy.out, + torch.ops.aten.t_copy.default, + torch.ops.aten.take.default, + torch.ops.aten.take.out, + torch.ops.aten.tanh_backward.default, + torch.ops.aten.tanh_backward.grad_input, + torch.ops.aten.threshold.default, + torch.ops.aten.threshold.out, + torch.ops.aten.threshold_.default, + torch.ops.aten.threshold_backward.default, + torch.ops.aten.threshold_backward.grad_input, + torch.ops.aten.trace.default, + torch.ops.aten.trace.out, + torch.ops.aten.transpose.int, + torch.ops.aten.tril.default, + torch.ops.aten.tril.out, + torch.ops.aten.tril_.default, + torch.ops.aten.triu.default, + torch.ops.aten.triu.out, + torch.ops.aten.triu_.default, + torch.ops.aten.unbind.int, + torch.ops.aten.unbind.Dimname, + torch.ops.aten.unfold_backward.default, + torch.ops.aten.unfold_backward.out, + torch.ops.aten.unfold_copy.default, + torch.ops.aten.unfold_copy.out, + torch.ops.aten._unsafe_index.Tensor, + torch.ops.aten._unsafe_index_put.default, + torch.ops.aten._unsafe_masked_index.default, + torch.ops.aten._unsafe_masked_index_put_accumulate.default, + torch.ops.aten.unsafe_split.Tensor, + torch.ops.aten.unsafe_split_with_sizes.default, + torch.ops.aten.unsqueeze_copy.out, + torch.ops.aten.unsqueeze_copy.default, + torch.ops.aten._unsafe_view.default, + torch.ops.aten._unsafe_view.out, + torch.ops.aten.upsample_linear1d.default, + torch.ops.aten.upsample_linear1d.out, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_bilinear2d.default, + torch.ops.aten.upsample_bilinear2d.out, + torch.ops.aten.upsample_trilinear3d.vec, + torch.ops.aten.upsample_trilinear3d.default, + torch.ops.aten.upsample_trilinear3d.out, + torch.ops.aten.xlogy.Tensor, + torch.ops.aten.xlogy.Scalar_Other, + torch.ops.aten.xlogy.Scalar_Self, + torch.ops.aten.xlogy.OutTensor, + torch.ops.aten.xlogy.OutScalar_Self, + torch.ops.aten.xlogy.OutScalar_Other, + torch.ops.aten.xlogy_.Tensor, + torch.ops.aten.xlogy_.Scalar_Other, + torch.ops.aten.zero.default, + torch.ops.aten.zero.out, + torch.ops.aten.zero_.default, + torch.ops.aten.zeros.default, + torch.ops.aten.zeros_like.default, + torch.ops.aten.zeros_like.out, + torch.ops.aten._chunk_cat.default, + torch.ops.aten._chunk_cat.out, + torch.ops.aten._weight_norm_interface.default, + torch.ops.aten._weight_norm_interface.out, ]) MUTABLE_DECOMPOSITION = [ - torch.ops.aten.bernoulli_.Tensor, - torch.ops.aten.bernoulli_.float, + torch.ops.aten.bernoulli_.Tensor, + torch.ops.aten.bernoulli_.float, ] diff --git a/torchax/torchax/distributed.py b/torchax/torchax/distributed.py index eb12f4eb2d5..b73bf202c63 100644 --- a/torchax/torchax/distributed.py +++ b/torchax/torchax/distributed.py @@ -51,51 +51,54 @@ def group_name(self): @staticmethod def _work( - tensors: Union[torch.Tensor, List[torch.Tensor], - List[List[torch.Tensor]]], + tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]], ) -> dist.Work: fut = torch.futures.Future() fut.set_result(tensors) return torch._C._distributed_c10d._create_work_from_future(fut) def _allgather_base( - self, - output: torch.Tensor, - input: torch.Tensor, - opts=..., + self, + output: torch.Tensor, + input: torch.Tensor, + opts=..., ) -> dist.Work: assert isinstance(input, torchax.tensor.Tensor) assert isinstance(output, torchax.tensor.Tensor) torch.distributed._functional_collectives.all_gather_tensor_inplace( - output, input, group=self) + output, input, group=self + ) return self._work(output) def allreduce( - self, - tensors: List[torch.Tensor], - opts: dist.AllreduceOptions = ..., + self, + tensors: List[torch.Tensor], + opts: dist.AllreduceOptions = ..., ) -> dist.Work: assert len(tensors) == 1 assert isinstance(tensors[0], torchax.tensor.Tensor) torch.distributed._functional_collectives.all_reduce_inplace( - tensors[0], - torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ - opts.reduceOp.op], - self, + tensors[0], + torch.distributed._functional_collectives.REDUCE_OP_TO_STR[ + opts.reduceOp.op + ], + self, ) return self._work(tensors) def broadcast( - self, - tensors: List[torch.Tensor], - opts: dist.BroadcastOptions = ..., + self, + tensors: List[torch.Tensor], + opts: dist.BroadcastOptions = ..., ) -> dist.Work: assert len(tensors) == 1 assert isinstance(tensors[0], torchax.tensor.Tensor) tensors[0].copy_( - torch.distributed._functional_collectives.broadcast( - tensors[0], opts.rootRank, group=self)) + torch.distributed._functional_collectives.broadcast( + tensors[0], opts.rootRank, group=self + ) + ) return self._work(tensors) @@ -103,9 +106,9 @@ def broadcast( dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"]) -def jax_rendezvous_handler(url: str, - timeout: datetime.timedelta = ..., - **kwargs): +def jax_rendezvous_handler( + url: str, timeout: datetime.timedelta = ..., **kwargs +): """Initialize distributed store with JAX process IDs. Requires `$MASTER_ADDR` and `$MASTER_PORT`. @@ -117,10 +120,10 @@ def jax_rendezvous_handler(url: str, master_port = int(os.environ["MASTER_PORT"]) # TODO(wcromar): Use `torchrun`'s store if available store = dist.TCPStore( - master_ip, - master_port, - jax.process_count(), - is_master=jax.process_index() == 0, + master_ip, + master_port, + jax.process_count(), + is_master=jax.process_index() == 0, ) yield (store, jax.process_index(), jax.process_count()) @@ -142,9 +145,9 @@ def jax_wrapper(index, jax_args): torch_outputs = f(index, *args) return env.t2j_iso(torch_outputs) - jax_outputs = jax.pmap( - jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()), - env.t2j_iso(args)) + jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")( + np.arange(jax.device_count()), env.t2j_iso(args) + ) return env.j2t_iso(jax_outputs) @@ -171,10 +174,10 @@ class DistributedDataParallel(torch.nn.Module): """ def __init__( - self, - module: torch.nn.Module, - env: Optional[torchax.tensor.Environment] = None, - **kwargs, + self, + module: torch.nn.Module, + env: Optional[torchax.tensor.Environment] = None, + **kwargs, ): if kwargs: logging.warning(f"Unsupported kwargs {kwargs}") @@ -182,15 +185,17 @@ def __init__( super().__init__() self._env = env or torchax.default_env() self._mesh = Mesh( - mesh_utils.create_device_mesh((jax.device_count(),)), - axis_names=("batch",), + mesh_utils.create_device_mesh((jax.device_count(),)), + axis_names=("batch",), ) replicated_state = torch_pytree.tree_map_only( - torch.Tensor, - lambda t: self._env.j2t_iso( - jax.device_put( - self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))), - module.state_dict(), + torch.Tensor, + lambda t: self._env.j2t_iso( + jax.device_put( + self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()) + ) + ), + module.state_dict(), ) # TODO: broadcast module.load_state_dict(replicated_state, assign=True) @@ -205,23 +210,27 @@ def shard_input(self, inp): sharding = NamedSharding(self._mesh, P("batch")) return self._env.j2t_iso( - jax.make_array_from_single_device_arrays( - global_batch_shape, - NamedSharding(self._mesh, P("batch")), - arrays=[ - jax.device_put(self._env.to_xla(batch)._elem, device) for batch, - device in zip(per_replica_batches, sharding.addressable_devices) - ], - )) + jax.make_array_from_single_device_arrays( + global_batch_shape, + NamedSharding(self._mesh, P("batch")), + arrays=[ + jax.device_put(self._env.to_xla(batch)._elem, device) + for batch, device in zip( + per_replica_batches, sharding.addressable_devices + ) + ], + ) + ) def replicate_input(self, inp): return self._env.j2t_iso( - jax.device_put(inp._elem, NamedSharding(self._mesh, P()))) + jax.device_put(inp._elem, NamedSharding(self._mesh, P())) + ) def jit_step(self, func): - @functools.partial( - interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0}) + interop.jax_jit, kwargs_for_jax_jit={"donate_argnums": 0} + ) def _jit_fn(states, args): self.load_state_dict(states) outputs = func(*args) diff --git a/torchax/torchax/export.py b/torchax/torchax/export.py index 987fb92ba6e..636de0db820 100644 --- a/torchax/torchax/export.py +++ b/torchax/torchax/export.py @@ -1,5 +1,6 @@ # pylint: disable """Utilities for exporting a torch program to jax/stablehlo.""" + import copy from typing import Any, Dict, Tuple import torch @@ -24,12 +25,13 @@ def __init__(self, graph_module): import torchax.ops.jtorch def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: - if not isinstance(target, - (torch._ops.OpOverloadPacket, torch._ops.OpOverload)): + if not isinstance( + target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ): return super().call_function(target, args, kwargs) if DEBUG: - print('Running ', target.name(), '--------') + print("Running ", target.name(), "--------") op = ops_registry.all_aten_ops.get(target) if op is None: @@ -40,15 +42,15 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: op = ops_registry.all_aten_ops.get(target.overloadpacket) if op is None: print(target.name(), target.tags) - raise RuntimeError('No lowering found for', target.name()) + raise RuntimeError("No lowering found for", target.name()) return op.func(*args, **kwargs) def run_node(self, n) -> Any: res = super().run_node(n) if DEBUG: - if n.op == 'call_function': - if hasattr(res, 'shape'): - print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape) + if n.op == "call_function": + if hasattr(res, "shape"): + print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape) return res @@ -60,9 +62,12 @@ def run_node(self, n) -> Any: def _extract_states_from_exported_program(exported_model): # NOTE call convention: (parameters, buffers, user_inputs) - param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers + param_and_buffer_keys = ( + exported_model.graph_signature.parameters + + exported_model.graph_signature.buffers + ) state_dict = copy.copy(exported_model.state_dict) - if (constants := getattr(exported_model, 'constants', None)) is not None: + if (constants := getattr(exported_model, "constants", None)) is not None: state_dict.update(constants) param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys) @@ -80,19 +85,19 @@ def exported_program_to_jax(exported_program, export_raw: bool = False): func(state, input) would be how you call it. """ - if torch.__version__ >= '2.2': + if torch.__version__ >= "2.2": # torch version 2.1 didn't expose this yet exported_program = exported_program.run_decompositions() exported_program = exported_program.run_decompositions( - decompositions.DECOMPOSITIONS) + decompositions.DECOMPOSITIONS + ) if DEBUG: print(exported_program.graph_module.code) names, states = _extract_states_from_exported_program(exported_program) def _extract_args(args, kwargs): - flat_args, received_spec = pytree.tree_flatten( - (args, kwargs)) # type: ignore[possibly-undefined] + flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined] return flat_args num_mutations = len(exported_program.graph_signature.buffers_to_mutate) @@ -100,9 +105,9 @@ def _extract_args(args, kwargs): def func(states, inputs): args = _extract_args(inputs, {}) res = JaxInterpreter(exported_program.graph_module).run( - *states, - *args, - enable_io_processing=False, + *states, + *args, + enable_io_processing=False, ) res = res[num_mutations:] return res @@ -120,21 +125,21 @@ def extract_avals(exported): """ def _to_aval(arg_meta, symbolic_shapes): - """Convet from torch type to jax abstract value for export tracing - """ + """Convet from torch type to jax abstract value for export tracing""" def _get_dim(d): if isinstance(d, torch.SymInt): return symbolic_shapes[str(d)] return d - val = arg_meta['val'] - is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance( - val, bool) + val = arg_meta["val"] + is_scalar = ( + isinstance(val, float) or isinstance(val, int) or isinstance(val, bool) + ) if is_scalar: - return jax.ShapeDtypeStruct([], type(arg_meta['val'])) + return jax.ShapeDtypeStruct([], type(arg_meta["val"])) - tensor_meta = arg_meta['tensor_meta'] + tensor_meta = arg_meta["tensor_meta"] shape = [_get_dim(d) for d in tensor_meta.shape] return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype)) @@ -142,8 +147,9 @@ def _get_inputs(exported): """Return placeholders with input metadata""" placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"] input_placeholders = [ - p for p, s in zip(placeholders, exported.graph_signature.input_specs) - if s.kind == torch.export.graph_signature.InputKind.USER_INPUT + p + for p, s in zip(placeholders, exported.graph_signature.input_specs) + if s.kind == torch.export.graph_signature.InputKind.USER_INPUT ] return input_placeholders @@ -165,17 +171,27 @@ def _build_symbolic_constraints(symbol_name, torch_constraint): torch.export.Dim("a", min=5, max=10) ==> ("a >= 5", "a <= 10",) """ - if not isinstance(torch_constraint, torch.utils._sympy.value_ranges. - ValueRanges) or torch_constraint.is_bool: + if ( + not isinstance( + torch_constraint, + torch.utils._sympy.value_ranges.ValueRanges, + ) + or torch_constraint.is_bool + ): raise TypeError( - f"No symbolic constraint handler for: {torch_constraint}") + f"No symbolic constraint handler for: {torch_constraint}" + ) constraints = [] symbol = sympy.Symbol(symbol_name) if torch_constraint.lower != 2: constraints.append(symbol >= torch_constraint.lower) from sympy.core.singleton import S - if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity: + + if ( + not torch_constraint.upper.is_infinite + and torch_constraint.upper is not S.IntInfinity + ): constraints.append(symbol <= torch_constraint.upper) return tuple(sympy.pretty(c, use_unicode=False) for c in constraints) @@ -195,7 +211,8 @@ def _build_symbolic_shape(sym, constraint, free_symbols): constraints = _build_symbolic_constraints(symbol_name, constraint) if sym.is_symbol: symbolic_shape = jax.export.symbolic_shape( - symbol_name, constraints=constraints) + symbol_name, constraints=constraints + ) else: assert len(sym.free_symbols) > 0 scope = free_symbols[str(list(sym.free_symbols)[0])].scope @@ -209,10 +226,10 @@ def _build_symbolic_shape(sym, constraint, free_symbols): # have its own scope. symbolic_shapes = {} symbol_variables = [ - (s, v) for s, v in range_constraints.items() if s.is_symbol + (s, v) for s, v in range_constraints.items() if s.is_symbol ] symbol_exprs = [ - (s, v) for s, v in range_constraints.items() if not s.is_symbol + (s, v) for s, v in range_constraints.items() if not s.is_symbol ] for sym, constraint in symbol_variables + symbol_exprs: symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes) @@ -223,10 +240,15 @@ def _build_symbolic_shape(sym, constraint, free_symbols): args = _get_inputs(exported) if DEBUG: - print('Inputs to aval:', args, '--------') - print('Symbolic shapes:', symbolic_shapes) + print("Inputs to aval:", args, "--------") + print("Symbolic shapes:", symbolic_shapes) for arg in args: - print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes)) + print( + "Meta2Aval", + arg.meta, + "--> ", + _to_aval(arg.meta, symbolic_shapes), + ) return [_to_aval(arg.meta, symbolic_shapes) for arg in args] diff --git a/torchax/torchax/flax.py b/torchax/torchax/flax.py index 28542d79c90..5503dd9b71f 100644 --- a/torchax/torchax/flax.py +++ b/torchax/torchax/flax.py @@ -6,13 +6,13 @@ class FlaxNNModule(torch.nn.Module): - def __init__(self, env, flax_module, sample_args, sample_kwargs=None): super().__init__() prng = env.prng_key sample_kwargs = sample_kwargs or {} - parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args, - **sample_kwargs) + parameter_dict = tx.interop.call_jax( + flax_module.init, prng, *sample_args, **sample_kwargs + ) self._params = self._encode_nested_dict(parameter_dict) @@ -35,5 +35,6 @@ def _decode_nested_dict(self, child_module): def forward(self, *args, **kwargs): nested_dict_params = self._decode_nested_dict(self._params) - return tx.interop.call_jax(self._flax_module.apply, nested_dict_params, - *args, **kwargs) + return tx.interop.call_jax( + self._flax_module.apply, nested_dict_params, *args, **kwargs + ) diff --git a/torchax/torchax/interop.py b/torchax/torchax/interop.py index 419e3232773..b0b3533fe22 100644 --- a/torchax/torchax/interop.py +++ b/torchax/torchax/interop.py @@ -32,14 +32,13 @@ def extract_one(module, prefix): elif isinstance(v, torch.Tensor): buffers[qual_name] = v for name, child in module.named_children(): - extract_one(child, prefix + name + '.') + extract_one(child, prefix + name + ".") - extract_one(m, '') + extract_one(m, "") return params, buffers def set_all_buffers(m, params, buffers): - def set_one(module, prefix): for k in dir(module): qual_name = prefix + k @@ -49,17 +48,15 @@ def set_one(module, prefix): print(k, potential_v) setattr(module, k, torch.nn.Parameter(potential_v)) for name, child in module.named_children(): - set_one(child, prefix + name + '.') + set_one(child, prefix + name + ".") - set_one(m, '') + set_one(m, "") class JittableModule(torch.nn.Module): - - def __init__(self, - m: torch.nn.Module, - extra_jit_args={}, - dedup_parameters=True): + def __init__( + self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True + ): super().__init__() self.params, self.buffers = extract_all_buffers(m) self._model = m @@ -103,20 +100,20 @@ def functional_call(self, method_name, params, buffers, *args, **kwargs): return res def forward(self, *args, **kwargs): - if 'forward' not in self._jitted: + if "forward" not in self._jitted: jitted = jax_jit( - functools.partial(self.functional_call, 'forward'), - kwargs_for_jax_jit=self._extra_jit_args, + functools.partial(self.functional_call, "forward"), + kwargs_for_jax_jit=self._extra_jit_args, ) def jitted_forward(*args, **kwargs): return jitted(self.params, self.buffers, *args, **kwargs) - self._jitted['forward'] = jitted_forward - return self._jitted['forward'](*args, **kwargs) + self._jitted["forward"] = jitted_forward + return self._jitted["forward"](*args, **kwargs) def __getattr__(self, key): - if key == '_model': + if key == "_model": return super().__getattr__(key) if key in self._jitted: return self._jitted[key] @@ -124,8 +121,9 @@ def __getattr__(self, key): def make_jitted(self, key): jitted = jax_jit( - functools.partial(self.functional_call, key), - kwargs_for_jax_jit=self._extra_jit_args) + functools.partial(self.functional_call, key), + kwargs_for_jax_jit=self._extra_jit_args, + ) def call(*args, **kwargs): return jitted(self.params, self.buffers, *args, **kwargs) @@ -134,7 +132,6 @@ def call(*args, **kwargs): class CompileMixin: - def functional_call(self, method, params, buffers, *args, **kwargs): kwargs = kwargs or {} params_copy = copy.copy(params) @@ -147,19 +144,20 @@ def jit(self, method): jitted = jax_jit(functools.partial(self.functional_call, method_name)) def call(*args, **kwargs): - return jitted(self.named_paramters(), self.named_buffers(), *args, - **kwargs) + return jitted( + self.named_paramters(), self.named_buffers(), *args, **kwargs + ) return call def compile_nn_module(m: torch.nn.Module, methods=None): if methods is None: - methods = ['forward'] + methods = ["forward"] new_parent = type( - m.__class__.__name__ + '_with_CompileMixin', - (CompileMixin, m.__class__), + m.__class__.__name__ + "_with_CompileMixin", + (CompileMixin, m.__class__), ) m.__class__ = NewParent @@ -200,15 +198,17 @@ def _jax_view(t: TorchValue) -> JaxValue: jax_view = functools.partial(pytree.tree_map, _jax_view) -def call_jax(jax_func: JaxCallable, *args: TorchValue, - **kwargs: TorchValue) -> TorchValue: +def call_jax( + jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue +) -> TorchValue: args, kwargs = jax_view((args, kwargs)) res: JaxValue = jax_func(*args, **kwargs) return torch_view(res) -def call_torch(torch_func: TorchCallable, *args: JaxValue, - **kwargs: JaxValue) -> JaxValue: +def call_torch( + torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue +) -> JaxValue: args, kwargs = torch_view((args, kwargs)) with torchax.default_env(): res: TorchValue = torch_func(*args, **kwargs) @@ -218,10 +218,10 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue, def j2t_autograd(fn, call_jax=call_jax): """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`. - It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate - activations). The wrapped function is then run via `call_jax` and integrated into - the PyTorch autograd framework by saving the residuals into the context object. - """ + It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate + activations). The wrapped function is then run via `call_jax` and integrated into + the PyTorch autograd framework by saving the residuals into the context object. + """ @wraps(fn) def inner(*args, **kwargs): @@ -229,12 +229,11 @@ def inner(*args, **kwargs): from jax.util import safe_zip class JaxFun(torch.autograd.Function): - @staticmethod def forward(ctx, tree_def, *flat_args_kwargs): - - tensors, other = util.partition(flat_args_kwargs, - lambda x: isinstance(x, torch.Tensor)) + tensors, other = util.partition( + flat_args_kwargs, lambda x: isinstance(x, torch.Tensor) + ) # We want the arguments that don't require grads to be closured? y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors) @@ -252,8 +251,9 @@ def backward(ctx, *grad_out): assert len(grad_out) > 0 grad_out = grad_out if len(grad_out) > 1 else grad_out[0] - input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec, - ctx.saved_tensors, grad_out) + input_grads_structured = call_jax( + _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out + ) # Construct the gradient tuple to be returned. # It needs to match the inputs to forward: (tree_def, *flat_inputs) @@ -261,8 +261,9 @@ def backward(ctx, *grad_out): # The subsequent gradients correspond to flat_inputs. # We need to put a None for inputs that did not require gradients. final_grads = [None] - for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:], - input_grads_structured): + for needs_grad, grad in safe_zip( + ctx.needs_input_grad[1:], input_grads_structured + ): final_grads.append(grad if needs_grad else None) return tuple(final_grads) @@ -303,6 +304,7 @@ def _jax_backward(vjp_spec, saved_tensors, grad_out): Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function. """ from jax.tree_util import tree_unflatten + fun_vjp = tree_unflatten(vjp_spec, saved_tensors) return fun_vjp(grad_out) @@ -317,27 +319,31 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None): return torch_view(jitted) -def jax_jit(torch_function, - kwargs_for_jax_jit=None, - fix_for_buffer_donation=False): +def jax_jit( + torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False +): return wrap_jax_jit( - torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit) + torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit + ) def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): return wrap_jax_jit( - torch_function, - jax_jit_func=shard_map, - kwargs_for_jax=kwargs_for_jax_shard_map) + torch_function, + jax_jit_func=shard_map, + kwargs_for_jax=kwargs_for_jax_shard_map, + ) def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): return wrap_jax_jit( - torch_function, - jax_jit_func=jax.value_and_grad, - kwargs_for_jax=kwargs_for_value_and_grad) + torch_function, + jax_jit_func=jax.value_and_grad, + kwargs_for_jax=kwargs_for_value_and_grad, + ) def gradient_checkpoint(torch_function, kwargs=None): return wrap_jax_jit( - torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs) + torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs + ) diff --git a/torchax/torchax/mesh_util.py b/torchax/torchax/mesh_util.py index 3f65b8440b5..7b74ea09142 100644 --- a/torchax/torchax/mesh_util.py +++ b/torchax/torchax/mesh_util.py @@ -80,12 +80,13 @@ def __call__(self, name, shapedtype): `_shard_first_multiple_of`. """ del name - sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape, - self.axis_size) + sharding = _shard_first_multiple_of( + self.axis_name, shapedtype.shape, self.axis_size + ) if not self.replicate_unshardable and all(s is None for s in sharding): raise AssertionError( - f"Unable to find a dim to shard because " - f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" + f"Unable to find a dim to shard because " + f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}" ) return sharding @@ -145,15 +146,14 @@ def __init__(self, jax_mesh, sharder=None): self.jax_mesh = jax_mesh if sharder is None: assert len(self.jax_mesh.axis_names) == 1 - sharder = SingleAxisSharder(self.jax_mesh.axis_names[0], - len(self.mesh.device_ids)) + sharder = SingleAxisSharder( + self.jax_mesh.axis_names[0], len(self.mesh.device_ids) + ) self._sharder = sharder - def initialize_model_sharded(self, - model_class, - init_args, - init_kwargs=None, - override_sharder=None): + def initialize_model_sharded( + self, model_class, init_args, init_kwargs=None, override_sharder=None + ): """Initializes a PyTorch model with its parameters sharded across the mesh. This method orchestrates the initialization of a `torch.nn.Module` such @@ -194,8 +194,8 @@ def initialize_model_sharded(self, states = model.state_dict() output_shards = { - name: NamedSharding(self.jax_mesh, sharder(name, tensor)) - for name, tensor in states.items() + name: NamedSharding(self.jax_mesh, sharder(name, tensor)) + for name, tensor in states.items() } def model_initializer(): @@ -204,7 +204,9 @@ def model_initializer(): return dict(model.state_dict()) jitted = interop.jax_jit( - model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}) + model_initializer, + kwargs_for_jax_jit={"out_shardings": output_shards}, + ) weights_dict = jitted() model.load_state_dict(weights_dict, assign=True) diff --git a/torchax/torchax/ops/__init__.py b/torchax/torchax/ops/__init__.py index 71c1b137132..a6852161657 100644 --- a/torchax/torchax/ops/__init__.py +++ b/torchax/torchax/ops/__init__.py @@ -4,7 +4,7 @@ def all_aten_jax_ops(): import torchax.ops.ops_registry # type: ignore return { - key: val.func - for key, val in torchax.ops.ops_registry.all_aten_ops.items() - if val.is_jax_function + key: val.func + for key, val in torchax.ops.ops_registry.all_aten_ops.items() + if val.is_jax_function } diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 8d2242fdb59..2d9a93a794f 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -16,40 +16,44 @@ from torchax import interop from torchax.ops import jax_reimplement from torchax.view import View + # Keys are OpOverload, value is a callable that takes # Tensor all_ops = {} def op(*aten, **kwargs): - def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) continue if isinstance(a, torch._ops.OpOverloadPacket): - opname = a.default.name() if 'default' in a.overloads( - ) else a._qualified_op_name + opname = ( + a.default.name() + if "default" in a.overloads() + else a._qualified_op_name + ) elif isinstance(a, torch._ops.OpOverload): opname = a.name() else: - raise RuntimeError(f'oops {a}') + raise RuntimeError(f"oops {a}") torchfunc = functools.partial(interop.call_jax, func) # HACK: to_copy is where we make the initial conversion from CPU tensor to JAX tensor - torch.library.impl(opname, 'privateuseone')( - torchfunc if a != torch.ops.aten._to_copy else func) + torch.library.impl(opname, "privateuseone")( + torchfunc if a != torch.ops.aten._to_copy else func + ) return func return inner @op( - torch.ops.aten.view_copy, - torch.ops.aten.view, - torch.ops.aten._unsafe_view, - torch.ops.aten.reshape, + torch.ops.aten.view_copy, + torch.ops.aten.view, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, ) def _aten_unsafe_view(x, shape): return jnp.reshape(x, shape) @@ -69,13 +73,11 @@ def _aten_add(x, y, *, alpha=1): return res -@op(torch.ops.aten.copy_, - is_jax_function=False, - is_view_op=True, - needs_env=True) +@op( + torch.ops.aten.copy_, is_jax_function=False, is_view_op=True, needs_env=True +) def _aten_copy(x, y, memory_format=None, env=None): - - if y.device.type == 'cpu': + if y.device.type == "cpu": y = env.to_xla(y) if isinstance(x, View): @@ -164,9 +166,9 @@ def _aten_complex(real, imag): Returns: A complex array with the specified real and imaginary parts. """ - return jnp.array( - real, dtype=jnp.float32) + 1j * jnp.array( - imag, dtype=jnp.float32) + return jnp.array(real, dtype=jnp.float32) + 1j * jnp.array( + imag, dtype=jnp.float32 + ) # aten.exponential_ @@ -215,9 +217,10 @@ def _aten_cholesky(input, upper=False): def _aten_linalg_cholesky_ex(input, upper=False, check_errors=False): if check_errors: raise NotImplementedError( - "check_errors=True is not supported in this JAX implementation. " - "Check for positive definiteness using jnp.linalg.eigvalsh before " - "calling this function.") + "check_errors=True is not supported in this JAX implementation. " + "Check for positive definiteness using jnp.linalg.eigvalsh before " + "calling this function." + ) L = jax.scipy.linalg.cholesky(input, lower=not upper) if len(L.shape) > 2: @@ -281,7 +284,8 @@ def _aten_searchsorted(sorted_sequence, values): new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) res = jnp.searchsorted(sorted_sequence, values) if sorted_sequence.dtype == np.dtype( - np.int32) or sorted_sequence.dtype == np.dtype(np.int32): + np.int32 + ) or sorted_sequence.dtype == np.dtype(np.int32): # res = res.astype(new_dtype) res = res.astype(np.dtype(np.int64)) return res # jnp.searchsorted(sorted_sequence, values) @@ -396,7 +400,7 @@ def _aten_real(x): @op(torch.Tensor.resize_) -def _aten_resize_(x, size, interpolation='linear'): +def _aten_resize_(x, size, interpolation="linear"): new_size = tuple(size) return jax.numpy.resize(x, new_size) @@ -446,8 +450,9 @@ def _aten_softmax(x, dim, halftofloat=False): def _is_int(x): if isinstance(x, int): return True - if isinstance(x, jax.Array) and (x.dtype.name.startswith('int') or - x.dtype.name.startswith('uint')): + if isinstance(x, jax.Array) and ( + x.dtype.name.startswith("int") or x.dtype.name.startswith("uint") + ): return True return False @@ -459,19 +464,20 @@ def highest_precision_int_dtype(tensor1, tensor2): return tensor1.dtype dtype_hierarchy = { - 'uint8': 8, - 'int8': 8, - 'uint16': 16, - 'int16': 16, - 'uint32': 32, - 'int32': 32, - 'uint64': 64, - 'int64': 64, + "uint8": 8, + "int8": 8, + "uint16": 16, + "int16": 16, + "uint32": 32, + "int32": 32, + "uint64": 64, + "int64": 64, } return max( - tensor1.dtype, - tensor2.dtype, - key=lambda dtype: dtype_hierarchy[str(dtype)]) + tensor1.dtype, + tensor2.dtype, + key=lambda dtype: dtype_hierarchy[str(dtype)], + ) @op(torch.ops.aten.pow) @@ -481,7 +487,7 @@ def _aten_pow(x, y): y = float(y) if _is_int(x) and _is_int(y_orig): # Do the math in float then cast - res = jnp.power(jnp.astype(x, jnp.dtype('float')), y) + res = jnp.power(jnp.astype(x, jnp.dtype("float")), y) return res.astype(highest_precision_int_dtype(x, y_orig)) res = jnp.power(x, y) if isinstance(x, float): @@ -503,21 +509,21 @@ def _aten_view_as_complex(input): def _aten_div(x, y, rounding_mode=""): res_dtype = None if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('float32') + res_dtype = jnp.dtype("float32") - if (isinstance(x, float) or isinstance(y, float)): + if isinstance(x, float) or isinstance(y, float): res_dtype = new_dtype = mappings.t2j_dtype(torch.get_default_dtype()) if rounding_mode == "floor": res = jnp.floor_divide(x, y) if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('int64') + res_dtype = jnp.dtype("int64") else: res = x / y if rounding_mode == "trunc": res = jnp.trunc(res) if _is_int(x) and _is_int(y): - res_dtype = jnp.dtype('int64') + res_dtype = jnp.dtype("int64") if res_dtype: res = res.astype(res_dtype) return res @@ -543,11 +549,9 @@ def _aten_bmm(x, y): @op(torch.ops.aten.embedding) # embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -def _aten_embedding(a, - w, - padding_idx=-1, - scale_grad_by_freq=False, - sparse=False): +def _aten_embedding( + a, w, padding_idx=-1, scale_grad_by_freq=False, sparse=False +): return jnp.take(a, w, axis=0) @@ -557,9 +561,9 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type): unique_indices = jnp.unique(indices) norm = jnp.linalg.norm( - _aten_embedding(weight, unique_indices), - ord=norm_type, - axis=1, + _aten_embedding(weight, unique_indices), + ord=norm_type, + axis=1, ) indice_idx = jnp.where(norm > max_norm) @@ -568,41 +572,44 @@ def _aten_embedding_renorm_(weight, indices, max_norm, norm_type): indices_to_update = unique_indices[indice_idx] - weight = weight.at[indices_to_update].set(weight[indices_to_update] * - scale[:, None]) + weight = weight.at[indices_to_update].set( + weight[indices_to_update] * scale[:, None] + ) return weight -#- func: _embedding_bag_forward_only( +# - func: _embedding_bag_forward_only( # Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, # int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) @op(torch.ops.aten._embedding_bag) @op(torch.ops.aten._embedding_bag_forward_only) -def _aten__embedding_bag(weight, - indices, - offsets=None, - scale_grad_by_freq=False, - mode=0, - sparse=False, - per_sample_weights=None, - include_last_offset=False, - padding_idx=-1): +def _aten__embedding_bag( + weight, + indices, + offsets=None, + scale_grad_by_freq=False, + mode=0, + sparse=False, + per_sample_weights=None, + include_last_offset=False, + padding_idx=-1, +): """Jax implementation of the PyTorch _embedding_bag function. - Args: - weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). - indices: A LongTensor containing the indices to extract. - offsets: A LongTensor containing the starting offset of each bag. - scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch. - mode: 0 = "sum", 1 = "mean" or 2 = "max" - sparse: Whether the gradients with respect to weight should be a sparse tensor. - per_sample_weights: If given, each embedding vector is weighted by per_sample_weights - include_last_offset: Whether to include the last offset as a valid bag. - padding_idx: If specified, the entries at padding_idx do not contribute to the gradient. + Args: + weight: The learnable weights of the module of shape (num_embeddings, embedding_dim). + indices: A LongTensor containing the indices to extract. + offsets: A LongTensor containing the starting offset of each bag. + scale_grad_by_freq: Whether to scale gradients by the inverse of frequency of the words in the mini-batch. + mode: 0 = "sum", 1 = "mean" or 2 = "max" + sparse: Whether the gradients with respect to weight should be a sparse tensor. + per_sample_weights: If given, each embedding vector is weighted by per_sample_weights + include_last_offset: Whether to include the last offset as a valid bag. + padding_idx: If specified, the entries at padding_idx do not contribute to the gradient. - Returns: - A tuple of (output, offset2bag, bag_size, max_indices). - """ + Returns: + A tuple of (output, offset2bag, bag_size, max_indices). + """ embedded = _aten_embedding(weight, indices, padding_idx) if offsets is None: @@ -626,8 +633,11 @@ def _aten__embedding_bag(weight, for bag in range(offsets_np.shape[0]): start = int(offsets_np[bag]) - end = int(indices.shape[0] if bag + - 1 == offsets_np.shape[0] else offsets_np[bag + 1]) + end = int( + indices.shape[0] + if bag + 1 == offsets_np.shape[0] + else offsets_np[bag + 1] + ) bag_size[bag] = end - start offset2bag = offset2bag.at[start:end].set(bag) @@ -639,7 +649,8 @@ def _aten__embedding_bag(weight, elif mode == 2: output_bag = jnp.max(embedded[start:end], axis=0) max_indices = max_indices.at[start:end].set( - jnp.argmax(embedded[start:end], axis=0)) + jnp.argmax(embedded[start:end], axis=0) + ) # The original code returned offset2bag, bag_size, and max_indices as numpy arrays. # Converting them to JAX arrays for consistency. @@ -658,7 +669,6 @@ def _aten_rsqrt(x): @op(torch.ops.aten.expand) @op(torch.ops.aten.expand_copy) def _aten_expand(x, dims): - def fix_dims(d, xs): if d == -1: return xs @@ -667,7 +677,7 @@ def fix_dims(d, xs): shape = list(x.shape) if len(shape) < len(dims): shape = [ - 1, + 1, ] * (len(dims) - len(shape)) + shape # make sure that dims and shape is the same by # left pad with 1s. Otherwise the zip below will @@ -781,8 +791,8 @@ def make_range(rank, dim, start, end): return tuple(res) return [ - x[make_range(rank, dim, start, end)] - for start, end in zip([0] + list(splits[:-1]), splits) + x[make_range(rank, dim, start, end)] + for start, end in zip([0] + list(splits[:-1]), splits) ] @@ -818,8 +828,9 @@ def _aten_ne(x, y): # >> [[0, 1, 2, 3]] shape (1, 4) def _indices_along_axis(x, axis): return jnp.expand_dims( - jnp.arange(x.shape[axis]), - axis=[d for d in range(len(x.shape)) if d != axis]) + jnp.arange(x.shape[axis]), + axis=[d for d in range(len(x.shape)) if d != axis], + ) def _broadcast_indices(indices, shape): @@ -837,19 +848,17 @@ def _aten_cummax(x, dim): indices = _broadcast_indices(indice_along_axis, x.shape) def cummax_reduce_func(carry, elem): - v1, v2 = carry['val'], elem['val'] - i1, i2 = carry['idx'], elem['idx'] + v1, v2 = carry["val"], elem["val"] + i1, i2 = carry["idx"], elem["idx"] v = jnp.maximum(v1, v2) i = jnp.where(v1 > v2, i1, i2) - return {'val': v, 'idx': i} + return {"val": v, "idx": i} res = jax.lax.associative_scan( - cummax_reduce_func, { - 'val': x, - 'idx': indices - }, axis=axis) - return res['val'], res['idx'] + cummax_reduce_func, {"val": x, "idx": indices}, axis=axis + ) + return res["val"], res["idx"] @op(torch.ops.aten.cummin) @@ -863,19 +872,17 @@ def _aten_cummin(x, dim): indices = _broadcast_indices(indice_along_axis, x.shape) def cummin_reduce_func(carry, elem): - v1, v2 = carry['val'], elem['val'] - i1, i2 = carry['idx'], elem['idx'] + v1, v2 = carry["val"], elem["val"] + i1, i2 = carry["idx"], elem["idx"] v = jnp.minimum(v1, v2) i = jnp.where(v1 < v2, i1, i2) - return {'val': v, 'idx': i} + return {"val": v, "idx": i} res = jax.lax.associative_scan( - cummin_reduce_func, { - 'val': x, - 'idx': indices - }, axis=axis) - return res['val'], res['idx'] + cummin_reduce_func, {"val": x, "idx": indices}, axis=axis + ) + return res["val"], res["idx"] @op(torch.ops.aten.cumsum) @@ -902,11 +909,9 @@ def _aten_cumprod(input, dim, dtype=None, out=None): @op(torch.ops.aten.native_layer_norm) -def _aten_native_layer_norm(input, - normalized_shape, - weight=None, - bias=None, - eps=1e-5): +def _aten_native_layer_norm( + input, normalized_shape, weight=None, bias=None, eps=1e-5 +): """Implements layer normalization in Jax as defined by `aten::native_layer_norm`. Args: @@ -971,8 +976,9 @@ def _aten_addbmm(input, batch1, batch2, *, beta=1, alpha=1): alpha = jnp.array(alpha).astype(batch1.dtype) beta = jnp.array(beta).astype(batch1.dtype) mm = jnp.einsum("bxy, byz -> xz", batch1, batch2) - return jax.lax.cond(beta == 0, lambda: alpha * mm, - lambda: beta * input + alpha * mm) + return jax.lax.cond( + beta == 0, lambda: alpha * mm, lambda: beta * input + alpha * mm + ) @op(torch.ops.aten.gelu) @@ -996,56 +1002,54 @@ def _aten_squeeze_dim(self, dim=None): # NOTE: torch leaves the dims that is not 1 unchanged, # but jax raises error. dim = [ - i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1 + i if i >= 0 else (i + self.ndim) for i in dim if self.shape[i] == 1 ] return jnp.squeeze(self, dim) @op(torch.ops.aten.bucketize) -def _aten_bucketize(input, - boundaries, - *, - out_int32=False, - right=False, - out=None): +def _aten_bucketize( + input, boundaries, *, out_int32=False, right=False, out=None +): return_type = jnp.int32 if out_int32 else jnp.int64 return jnp.digitize(input, boundaries, right=not right).astype(return_type) @op(torch.ops.aten.conv2d) def _aten_conv2d( + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + return _aten_convolution( input, weight, bias, stride, padding, dilation, - groups, -): - return _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed=False, - output_padding=1, - groups=groups) + transposed=False, + output_padding=1, + groups=groups, + ) @op(torch.ops.aten.convolution) def _aten_convolution( - input, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, ): num_shape_dim = weight.ndim - 1 batch_dims = input.shape[:-num_shape_dim] @@ -1081,7 +1085,8 @@ def create_default_conv_dimension_numbers(num_spatial_dims): rhs_spec.append(i + 2) out_spec.append(i + 2) return jax.lax.ConvDimensionNumbers( - *map(tuple, (lhs_spec, rhs_spec, out_spec))) + *map(tuple, (lhs_spec, rhs_spec, out_spec)) + ) if transposed: rhs = jnp.flip(weight, range(2, 1 + num_shape_dim)) @@ -1092,27 +1097,27 @@ def create_default_conv_dimension_numbers(num_spatial_dims): rhs_shape.extend(rhs.shape[2:]) rhs = jnp.reshape(rhs, rhs_shape) res = jax.lax.conv_general_dilated( - input, - rhs, - (1,) * len(stride), - make_padding(padding, len(stride)), - lhs_dilation=stride, - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, + input, + rhs, + (1,) * len(stride), + make_padding(padding, len(stride)), + lhs_dilation=stride, + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, ) else: res = jax.lax.conv_general_dilated( - input, - weight, - stride, - make_padding(padding, len(stride)), - lhs_dilation=(1,) * len(stride), - rhs_dilation=dilation, - dimension_numbers=create_default_conv_dimension_numbers(len(stride)), - feature_group_count=groups, - batch_group_count=1, + input, + weight, + stride, + make_padding(padding, len(stride)), + lhs_dilation=(1,) * len(stride), + rhs_dilation=dilation, + dimension_numbers=create_default_conv_dimension_numbers(len(stride)), + feature_group_count=groups, + batch_group_count=1, ) if bias is not None: @@ -1129,8 +1134,9 @@ def create_default_conv_dimension_numbers(num_spatial_dims): # _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) @op(torch.ops.aten._native_batch_norm_legit.default) -def _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps): +def _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, training, momentum, eps +): """JAX implementation of batch normalization with optional parameters. Refers to https://github.com/pytorch/pytorch/blob/cd3a71f754a2248bcfe500de7c9860bd7d2002bf/torch/_decomp/decompositions.py#L1713. @@ -1165,7 +1171,7 @@ def _aten__native_batch_norm_legit(input, weight, bias, running_mean, else: rstd = jax.lax.rsqrt(running_var.reshape(reshape_dims) + eps) saved_mean = jnp.array( - [], dtype=input.dtype + [], dtype=input.dtype ) # No need to calculate batch statistics in inference mode saved_rstd = jnp.array([], dtype=input.dtype) @@ -1187,11 +1193,12 @@ def _aten__native_batch_norm_legit(input, weight, bias, running_mean, @op(torch.ops.aten._native_batch_norm_legit_no_training) -def _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps): - return _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, False, momentum, eps) +def _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps +): + return _aten__native_batch_norm_legit( + input, weight, bias, running_mean, running_var, False, momentum, eps + ) @op(torch.ops.aten.relu) @@ -1205,7 +1212,7 @@ def _aten_cat(tensors, dims=0): # torch.cat will ignore the empty tensor, while jnp.concatenate # will error if the dims > 0. filtered_tensors = [ - t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) + t for t in tensors if not (t.ndim == 1 and t.shape[0] == 0) ] if filtered_tensors: return jnp.concatenate(filtered_tensors, dims) @@ -1213,12 +1220,12 @@ def _aten_cat(tensors, dims=0): def _ceil_mode_padding( - padding: list[int], - input_shape: list[int], - kernel_size: list[int], - stride: list[int], - dilation: list[int], - ceil_mode: bool, + padding: list[int], + input_shape: list[int], + kernel_size: list[int], + stride: list[int], + dilation: list[int], + ceil_mode: bool, ): """Creates low and high padding specification for the given padding (which is symmetric) and ceil mode. @@ -1230,13 +1237,21 @@ def _ceil_mode_padding( right_padding = left_padding input_size = input_shape[2 + i] - output_size_rem = (input_size + 2 * left_padding - - (kernel_size[i] - 1) * dilation[i] - 1) % stride[i] + output_size_rem = ( + input_size + 2 * left_padding - (kernel_size[i] - 1) * dilation[i] - 1 + ) % stride[i] if ceil_mode and output_size_rem != 0: extra_padding = stride[i] - output_size_rem - new_output_size = (input_size + left_padding + right_padding + - extra_padding - (kernel_size[i] - 1) * dilation[i] - - 1 + stride[i] - 1) // stride[i] + 1 + new_output_size = ( + input_size + + left_padding + + right_padding + + extra_padding + - (kernel_size[i] - 1) * dilation[i] + - 1 + + stride[i] + - 1 + ) // stride[i] + 1 # Ensure that the last pooling starts inside the image. size_to_compare = input_size + left_padding @@ -1249,12 +1264,9 @@ def _ceil_mode_padding( @op(torch.ops.aten.max_pool2d_with_indices) @op(torch.ops.aten.max_pool3d_with_indices) -def _aten_max_pool2d_with_indices(inputs, - kernel_size, - strides=None, - padding=0, - dilation=1, - ceil_mode=False): +def _aten_max_pool2d_with_indices( + inputs, kernel_size, strides=None, padding=0, dilation=1, ceil_mode=False +): num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) # Default stride is kernel_size @@ -1269,13 +1281,16 @@ def _aten_max_pool2d_with_indices(inputs, input_shape = inputs.shape if num_batch_dims == 0: input_shape = [1, *input_shape] - padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides, - dilation, ceil_mode) + padding = _ceil_mode_padding( + padding, input_shape, kernel_size, strides, dilation, ceil_mode + ) - assert len(kernel_size) == len( - strides), f"len({kernel_size=}) must equal len({strides=})" - assert len(kernel_size) == len( - dilation), f"len({kernel_size=}) must equal len({dilation=})" + assert len(kernel_size) == len(strides), ( + f"len({kernel_size=}) must equal len({strides=})" + ) + assert len(kernel_size) == len(dilation), ( + f"len({kernel_size=}) must equal len({dilation=})" + ) strides = (1,) * (1 + num_batch_dims) + strides dims = (1,) * (1 + num_batch_dims) + kernel_size dilation = (1,) * (1 + num_batch_dims) + dilation @@ -1294,14 +1309,16 @@ def _aten_max_pool2d_with_indices(inputs, if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(kernel_size), ( - f"padding {padding} must specify pads for same number of dims as " - f"kernel_size {kernel_size}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" + f"padding {padding} must specify pads for same number of dims as " + f"kernel_size {kernel_size}" + ) + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2" + ) padding = ((0, 0), (0, 0)) + padding - indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size):])) - indices = indices.reshape(inputs.shape[-len(kernel_size):]) + indices = jnp.arange(np.prod(inputs.shape[-len(kernel_size) :])) + indices = indices.reshape(inputs.shape[-len(kernel_size) :]) indices = jnp.broadcast_to(indices, inputs.shape) def reduce_fn(a, b): @@ -1319,21 +1336,22 @@ def reduce_fn(a, b): # the indices tensor is usually unused in inference, separating the two # can help DCE computations for argmax. y = jax.lax.reduce_window( - inputs, - init_val, - jax.lax.max, - dims, - strides, - padding, - window_dilation=dilation) + inputs, + init_val, + jax.lax.max, + dims, + strides, + padding, + window_dilation=dilation, + ) indices, _ = jax.lax.reduce_window( - (indices, inputs), - (0, init_val), - reduce_fn, - dims, - strides, - padding, - window_dilation=dilation, + (indices, inputs), + (0, init_val), + reduce_fn, + dims, + strides, + padding, + window_dilation=dilation, ) if is_single_input: indices = jnp.squeeze(indices, axis=0) @@ -1354,17 +1372,19 @@ def _xla_aot_mark_sharding(t, mesh: str, partition_spec: str): from jax.sharding import PartitionSpec as P, NamedSharding import ast import torch_xla.distributed.spmd as xs + pmesh = xs.Mesh.from_str(mesh) assert pmesh is not None partition_spec_eval = ast.literal_eval(partition_spec) jmesh = pmesh.get_jax_mesh() return jax.lax.with_sharding_constraint( - t, NamedSharding(jmesh, P(*partition_spec_eval))) + t, NamedSharding(jmesh, P(*partition_spec_eval)) + ) @op(torch.ops.xla.einsum_linear_forward) def _xla_einsum_linear_forward(input, weight, bias): - with jax.named_scope('einsum_linear_forward'): - product = jax.numpy.einsum('...n,mn->...m', input, weight) + with jax.named_scope("einsum_linear_forward"): + product = jax.numpy.einsum("...n,mn->...m", input, weight) if bias is not None: return product + bias return product @@ -1378,10 +1398,9 @@ def _xla_einsum_linear_forward(input, weight, bias): @op(torch.ops.aten.min) def _aten_min(x, dim=None, keepdim=False): if dim is not None: - return _with_reduction_scalar(jnp.min, x, dim, - keepdim), _with_reduction_scalar( - jnp.argmin, x, dim, - keepdim).astype(jnp.int64) + return _with_reduction_scalar( + jnp.min, x, dim, keepdim + ), _with_reduction_scalar(jnp.argmin, x, dim, keepdim).astype(jnp.int64) else: return _with_reduction_scalar(jnp.min, x, dim, keepdim) @@ -1390,15 +1409,17 @@ def _aten_min(x, dim=None, keepdim=False): def _aten_mode(input, dim=-1, keepdim=False, *, out=None): if input.ndim == 0: # single number return input, jnp.array(0) - dim = (input.ndim + - dim) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim + dim = ( + input.ndim + dim + ) % input.ndim # jnp.scipy.stats.mode does not accept -1 as dim # keepdims must be True for accurate broadcasting mode, _ = jax.scipy.stats.mode(input, axis=dim, keepdims=True) mode_broadcast = jnp.broadcast_to(mode, input.shape) if not keepdim: mode = mode.squeeze(axis=dim) indices = jnp.argmax( - jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim) + jnp.equal(mode_broadcast, input), axis=dim, keepdims=keepdim + ) return mode, indices @@ -1432,7 +1453,8 @@ def _aten_var(x, dim=None, *, correction=1, keepdim=False, out=None): @op(torch.ops.prims.broadcast_in_dim) def _prims_broadcast_in_dim(t, shape, broadcast_dimensions): return jax.lax.broadcast_in_dim( - t, shape, broadcast_dimensions=broadcast_dimensions) + t, shape, broadcast_dimensions=broadcast_dimensions + ) # aten.native_group_norm -- should use decomp table @@ -1475,15 +1497,17 @@ def group_norm_body(x): # Function to apply within each group normalized = (x - mean) * rstd return normalized, mean, rstd - normalized, group_mean, group_rstd = jax.lax.map(group_norm_body, - reshaped_input) + normalized, group_mean, group_rstd = jax.lax.map( + group_norm_body, reshaped_input + ) # Reshape back to original input shape output = jnp.reshape(normalized, input_shape) # **Affine transformation** - affine_shape = [-1 if i == 1 else 1 for i in range(input.ndim) - ] # Shape for broadcasting + affine_shape = [ + -1 if i == 1 else 1 for i in range(input.ndim) + ] # Shape for broadcasting if weight is not None and bias is not None: output = bias.reshape(affine_shape) + output * weight.reshape(affine_shape) elif weight is not None: @@ -1515,11 +1539,13 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): The tensor containing the calculated vector norms. """ - if ord not in {2, float("inf"), float("-inf"), "fro" - } and not isinstance(ord, (int, float)): + if ord not in {2, float("inf"), float("-inf"), "fro"} and not isinstance( + ord, (int, float) + ): raise ValueError( - f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" - " 'fro'.") + f"Unsupported ord value: {ord}. Supported values are 2, inf, -inf, and" + " 'fro'." + ) # Special cases (for efficiency and clarity) if ord == 0: @@ -1527,13 +1553,14 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): # float sets it to float64. set it back to input type result = jnp.astype(jnp.array(float(self != 0)), self.dtype) else: - result = _with_reduction_scalar(jnp.sum, jnp.where(self != 0, 1, 0), dim, - keepdim) + result = _with_reduction_scalar( + jnp.sum, jnp.where(self != 0, 1, 0), dim, keepdim + ) elif ord == 2: # Euclidean norm result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, - jnp.abs(self)**2, dim, keepdim)) + _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) + ) elif ord == float("inf"): result = _with_reduction_scalar(jnp.max, jnp.abs(self), dim, keepdim) @@ -1543,13 +1570,13 @@ def _aten_linalg_vector_norm(self, ord=2, dim=None, keepdim=False, dtype=None): elif ord == "fro": # Frobenius norm result = jnp.sqrt( - _with_reduction_scalar(jnp.sum, - jnp.abs(self)**2, dim, keepdim)) + _with_reduction_scalar(jnp.sum, jnp.abs(self) ** 2, dim, keepdim) + ) else: # General case (e.g., ord = 1, ord = 3) - result = _with_reduction_scalar(jnp.sum, - jnp.abs(self)**ord, dim, - keepdim)**(1.0 / ord) + result = _with_reduction_scalar( + jnp.sum, jnp.abs(self) ** ord, dim, keepdim + ) ** (1.0 / ord) # (Optional) dtype conversion if dtype is not None: @@ -1585,12 +1612,9 @@ def _aten_sinh(self): # aten.native_layer_norm_backward @op(torch.ops.aten.native_layer_norm_backward) -def _aten_native_layer_norm_backward(grad_out, - input, - normalized_shape, - weight, - bias, - eps=1e-5): +def _aten_native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps=1e-5 +): """Implements the backward pass of layer normalization in Jax as defined by `aten::native_layer_norm_backward`. Args: @@ -1604,8 +1628,9 @@ def _aten_native_layer_norm_backward(grad_out, Returns: A tuple of (grad_input, grad_weight, grad_bias). """ - return jax.lax.native_layer_norm_backward(grad_out, input, normalized_shape, - weight, bias, eps) + return jax.lax.native_layer_norm_backward( + grad_out, input, normalized_shape, weight, bias, eps + ) # aten.reflection_pad3d_backward @@ -1721,8 +1746,11 @@ def _scatter_index(dim, index): target_shape = [1] * len(index_shape) target_shape[i] = index_shape[i] input_indexes.append( - jnp.broadcast_to( - jnp.arange(index_shape[i]).reshape(target_shape), index_shape)) + jnp.broadcast_to( + jnp.arange(index_shape[i]).reshape(target_shape), + index_shape, + ) + ) return tuple(input_indexes), tuple(source_indexes) @@ -1738,7 +1766,6 @@ def _aten_scatter_add(input, dim, index, src): # aten.masked_scatter @op(torch.ops.aten.masked_scatter) def _aten_masked_scatter(self, mask, source): - broadcast_shape = jnp.broadcast_shapes(self.shape, mask.shape) if self.shape != broadcast_shape: @@ -1751,7 +1778,7 @@ def _aten_masked_scatter(self, mask, source): source_flat = source.flatten() true_indices = jnp.where(mask_flat)[0] - self_flat = self_flat.at[true_indices].set(source_flat[:len(true_indices)]) + self_flat = self_flat.at[true_indices].set(source_flat[: len(true_indices)]) final_arr = self_flat.reshape(self.shape) return final_arr @@ -1813,13 +1840,9 @@ def _aten_atan(self): @op(torch.ops.aten.scatter_reduce) @op(torch.ops.aten.scatter) -def _aten_scatter_reduce(input, - dim, - index, - src, - reduce=None, - *, - include_self=True): +def _aten_scatter_reduce( + input, dim, index, src, reduce=None, *, include_self=True +): if not isinstance(src, jnp.ndarray): src = jnp.array(src, dtype=input.dtype) input_indexes, source_indexes = _scatter_index(dim, index) @@ -1906,8 +1929,9 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): """ num_batch_dims = inputs.ndim - (len(window_shape) + 1) strides = strides or (1,) * len(window_shape) - assert len(window_shape) == len( - strides), f"len({window_shape}) must equal len({strides})" + assert len(window_shape) == len(strides), ( + f"len({window_shape}) must equal len({strides})" + ) strides = (1,) * (1 + num_batch_dims) + strides dims = (1,) * (1 + num_batch_dims) + window_shape @@ -1924,10 +1948,12 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): if not isinstance(padding, str): padding = tuple(map(tuple, padding)) assert len(padding) == len(window_shape), ( - f"padding {padding} must specify pads for same number of dims as " - f"window_shape {window_shape}") - assert all([len(x) == 2 for x in padding - ]), f"each entry in padding {padding} must be length 2" + f"padding {padding} must specify pads for same number of dims as " + f"window_shape {window_shape}" + ) + assert all([len(x) == 2 for x in padding]), ( + f"each entry in padding {padding} must be length 2" + ) padding = ((0, 0), (0, 0)) + padding y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding) if is_single_input: @@ -1937,20 +1963,21 @@ def pool(inputs, init, reduce_fn, window_shape, strides, padding): @op(torch.ops.aten._adaptive_avg_pool2d) @op(torch.ops.aten._adaptive_avg_pool3d) -def adaptive_avg_pool2or3d(input: jnp.ndarray, - output_size: Tuple[int, int]) -> jnp.ndarray: +def adaptive_avg_pool2or3d( + input: jnp.ndarray, output_size: Tuple[int, int] +) -> jnp.ndarray: """ - Applies a 2/3D adaptive average pooling over an input signal composed of several input planes. + Applies a 2/3D adaptive average pooling over an input signal composed of several input planes. - See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. + See :class:`~torch.nn.AdaptiveAvgPool2d` for details and output shape. - Args: - input: input tensor - output_size: the target output size (single integer or double-integer tuple) + Args: + input: input tensor + output_size: the target output size (single integer or double-integer tuple) - Context: - https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401 - """ + Context: + https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2401 + """ shape = input.shape ndim = len(shape) out_dim = len(output_size) @@ -1958,22 +1985,25 @@ def adaptive_avg_pool2or3d(input: jnp.ndarray, # Preconditions - assert ndim in ( - out_dim + 1, out_dim + 2 - ), f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim+1}D or {num_spatial_dim+2}D tensor, but got {ndim}" + assert ndim in (out_dim + 1, out_dim + 2), ( + f"adaptive_avg_pool{num_spatial_dim}d(): Expected {num_spatial_dim + 1}D or {num_spatial_dim + 2}D tensor, but got {ndim}" + ) for d in input.shape[-2:]: - assert d != 0, "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " \ - f"non-batch dimensions, but input has shape {tuple(shape)}." + assert d != 0, ( + "adaptive_avg_pool{num_spactial_dim}d(): Expected input to have non-zero size for " + f"non-batch dimensions, but input has shape {tuple(shape)}." + ) # Optimisation (we should also do this in the kernel implementation) if all(s % o == 0 for o, s in zip(output_size, shape[-out_dim:])): stride = tuple(i // o for i, o in zip(shape[-out_dim:], output_size)) - kernel = tuple(i - (o - 1) * s - for i, o, s in zip(shape[-out_dim:], output_size, stride)) + kernel = tuple( + i - (o - 1) * s for i, o, s in zip(shape[-out_dim:], output_size, stride) + ) return _aten_avg_pool( - input, - kernel, - strides=stride, + input, + kernel, + strides=stride, ) def start_index(a, b, c): @@ -2025,9 +2055,12 @@ def _unsqueeze_to_dim(x, dim): reduce_axis = (-3, -1) else: assert out_dim == 3 - vals = input[..., - _unsqueeze_to_dim(idx[0], 6), - _unsqueeze_to_dim(idx[1], 4), idx[2]] + vals = input[ + ..., + _unsqueeze_to_dim(idx[0], 6), + _unsqueeze_to_dim(idx[1], 4), + idx[2], + ] reduce_axis = (-5, -3, -1) # Shortcut for the simpler case @@ -2053,7 +2086,12 @@ def maybe_mask(vals, length, range_max, adaptive, dim): for i in range(len(length)): vals, length[i] = maybe_mask( - vals, length[i], range_max[i], adaptive=adaptive[i], dim=(i - out_dim)) + vals, + length[i], + range_max[i], + adaptive=adaptive[i], + dim=(i - out_dim), + ) # We unroll the sum as we assume that the kernels are going to be small ret = jnp.sum(vals, axis=reduce_axis) @@ -2066,13 +2104,13 @@ def maybe_mask(vals, length, range_max, adaptive, dim): @op(torch.ops.aten.avg_pool2d) @op(torch.ops.aten.avg_pool3d) def _aten_avg_pool( - inputs, - kernel_size, - strides=None, - padding=0, - ceil_mode=False, - count_include_pad=True, - divisor_override=None, + inputs, + kernel_size, + strides=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, ): num_batch_dims = len(inputs.shape) - len(kernel_size) - 1 kernel_size = tuple(kernel_size) @@ -2085,8 +2123,14 @@ def _aten_avg_pool( input_shape = inputs.shape if num_batch_dims == 0: input_shape = [1, *input_shape] - padding = _ceil_mode_padding(padding, input_shape, kernel_size, strides, - [1] * len(kernel_size), ceil_mode) + padding = _ceil_mode_padding( + padding, + input_shape, + kernel_size, + strides, + [1] * len(kernel_size), + ceil_mode, + ) y = pool(inputs, 0.0, jax.lax.add, kernel_size, strides, padding) if divisor_override is not None: @@ -2119,12 +2163,12 @@ def _aten_avg_pool( if len(div_shape) - 2 == len(kernel_size): div_shape = (1,) + div_shape[1:] y = y / pool( - jnp.ones(div_shape, y.dtype), - jnp.array(0.0, y.dtype), - jax.lax.add, - kernel_size, - strides, - padding, + jnp.ones(div_shape, y.dtype), + jnp.array(0.0, y.dtype), + jax.lax.add, + kernel_size, + strides, + padding, ) return y.astype(inputs.dtype) @@ -2153,7 +2197,7 @@ def _helper(curr_dim_idx, sofar): @op(torch.ops.aten.reciprocal) def _aten_reciprocal(a): if _is_int(a): - return (1 / a).astype(jnp.dtype('float32')) + return (1 / a).astype(jnp.dtype("float32")) return 1 / a @@ -2203,10 +2247,9 @@ def _aten_round(input, decimals=0): @op(torch.ops.aten.max) def _aten_max(self, dim=None, keepdim=False): if dim is not None: - return _with_reduction_scalar(jnp.max, self, dim, - keepdim), _with_reduction_scalar( - jnp.argmax, self, dim, - keepdim).astype(jnp.int64) + return _with_reduction_scalar( + jnp.max, self, dim, keepdim + ), _with_reduction_scalar(jnp.argmax, self, dim, keepdim).astype(jnp.int64) else: return _with_reduction_scalar(jnp.max, self, dim, keepdim) @@ -2256,21 +2299,21 @@ def _aten_any(self, dim=None, keepdim=False): @op(torch.ops.aten.arange.default) @op_base.convert_dtype(use_default_dtype=False) def _aten_arange( - start, - end=None, - step=None, - *, - dtype=None, - layout=None, - requires_grad=False, - device=None, - pin_memory=False, + start, + end=None, + step=None, + *, + dtype=None, + layout=None, + requires_grad=False, + device=None, + pin_memory=False, ): return jnp.arange( - op_base.maybe_convert_constant_dtype(start, dtype), - op_base.maybe_convert_constant_dtype(end, dtype), - op_base.maybe_convert_constant_dtype(step, dtype), - dtype=dtype, + op_base.maybe_convert_constant_dtype(start, dtype), + op_base.maybe_convert_constant_dtype(end, dtype), + op_base.maybe_convert_constant_dtype(step, dtype), + dtype=dtype, ) @@ -2339,7 +2382,6 @@ def _aten_bitwise_xor(self, other): # aten.broadcast_tensors @op(torch.ops.aten.broadcast_tensors) def _aten_broadcast_tensors(*tensors): - def _get_broadcast_shape(shapes): """ Determines the output shape by broadcasting all input shapes. @@ -2381,12 +2423,16 @@ def _broadcast_dimensions(input_shape, output_shape): """ res = tuple( - i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape))) + i for i, (in_dim, out_dim) in enumerate(zip(input_shape, output_shape)) + ) return res # clean some function's previous wrap - if len(tensors) == 1 and len(tensors[0]) >= 1 and isinstance( - tensors[0][0], jax.Array): + if ( + len(tensors) == 1 + and len(tensors[0]) >= 1 + and isinstance(tensors[0][0], jax.Array) + ): tensors = tensors[0] # Get the shapes of all input tensors @@ -2395,9 +2441,10 @@ def _broadcast_dimensions(input_shape, output_shape): output_shape = _get_broadcast_shape(shapes) # Broadcast each tensor to the output shape broadcasted_tensors = [ - jax.lax.broadcast_in_dim(t, output_shape, - _broadcast_dimensions(t.shape, output_shape)) - for t in tensors + jax.lax.broadcast_in_dim( + t, output_shape, _broadcast_dimensions(t.shape, output_shape) + ) + for t in tensors ] return broadcasted_tensors @@ -2461,8 +2508,9 @@ def _aten_cdist_forward(x1, x2, p, compute_mode=""): @op(torch.ops.aten._pdist_forward) def _aten__pdist_forward(x, p=2): pairwise_dists = _aten_cdist_forward(x, x, p) - condensed_dists = pairwise_dists[jnp.triu_indices( - pairwise_dists.shape[0], k=1)] + condensed_dists = pairwise_dists[ + jnp.triu_indices(pairwise_dists.shape[0], k=1) + ] return condensed_dists @@ -2502,8 +2550,11 @@ def _aten_diagonal(input, offset=0, dim1=0, dim2=1): def diag_indices_with_offset(input_shape, offset, dim1=0, dim2=1): input_len = len(input_shape) if dim1 == dim2 or not (0 <= dim1 < input_len and 0 <= dim2 < input_len): - raise ValueError("dim1 and dim2 must be different and in range [0, " + - str(input_len - 1) + "]") + raise ValueError( + "dim1 and dim2 must be different and in range [0, " + + str(input_len - 1) + + "]" + ) size1, size2 = input_shape[dim1], input_shape[dim2] if offset >= 0: @@ -2605,12 +2656,9 @@ def _aten_exp2(input): # aten.fill @op(torch.ops.aten.fill) @op(torch.ops.aten.full_like) -def _aten_fill(x, - value, - dtype=None, - pin_memory=None, - memory_format=None, - device=None): +def _aten_fill( + x, value, dtype=None, pin_memory=None, memory_format=None, device=None +): if dtype is None: dtype = x.dtype else: @@ -2685,8 +2733,11 @@ def _aten_glu(x, dim=-1): # aten.hardtanh @op(torch.ops.aten.hardtanh) def _aten_hardtanh(input, min_val=-1, max_val=1, inplace=False): - if input.dtype == np.int64 and isinstance(max_val, float) and isinstance( - min_val, float): + if ( + input.dtype == np.int64 + and isinstance(max_val, float) + and isinstance(min_val, float) + ): min_val = int(min_val) max_val = int(max_val) return jnp.clip(input, min_val, max_val) @@ -2705,7 +2756,8 @@ def _aten_histc(input, bins=100, min=0, max=0): max = jnp.max(input) range_value = (min, max) hist, bin_edges = jnp.histogram( - input, bins=bins, range=range_value, weights=None, density=None) + input, bins=bins, range=range_value, weights=None, density=None + ) return hist @@ -2743,12 +2795,12 @@ def _aten_linalg_eig(A): @op(torch.ops.aten._linalg_eigh) -def _aten_linalg_eigh(A, UPLO='L'): +def _aten_linalg_eigh(A, UPLO="L"): return jnp.linalg.eigh(A, UPLO) @op(torch.ops.aten.linalg_lstsq) -def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'): +def _aten_linalg_lstsq(A, B, rcond=None, driver="gelsy"): input_dtype = A.dtype m = A.shape[-2] @@ -2757,40 +2809,39 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'): is_batched = A.ndim > 2 if is_batched: - batch_shape = jnp.broadcast_shapes(A.shape[:-2], B.shape[:-2]) batch_size = int(np.prod(batch_shape)) A_reshaped = A.reshape((batch_size,) + A.shape[-2:]) B_reshaped = B.reshape((batch_size,) + B.shape[-2:]) X, residuals, rank, singular_values = jax.vmap( - jnp.linalg.lstsq, in_axes=(0, - 0))(A_reshaped, B_reshaped, rcond=rcond) + jnp.linalg.lstsq, in_axes=(0, 0) + )(A_reshaped, B_reshaped, rcond=rcond) X = X.reshape(batch_shape + X.shape[-2:]) - if driver in ['gelsd', 'gelsy', 'gelss']: + if driver in ["gelsd", "gelsy", "gelss"]: rank = rank.reshape(batch_shape) else: rank = jnp.array([], dtype=jnp.int64) full_rank = jnp.all(rank == n) - if driver == 'gelsy' or m <= n or (not full_rank): + if driver == "gelsy" or m <= n or (not full_rank): residuals = jnp.array([], dtype=input_dtype) else: residuals = residuals.reshape(batch_shape + residuals.shape[-1:]) - if driver in ['gelsd', 'gelss']: - singular_values = singular_values.reshape(batch_shape + - singular_values.shape[-1:]) + if driver in ["gelsd", "gelss"]: + singular_values = singular_values.reshape( + batch_shape + singular_values.shape[-1:] + ) else: singular_values = jnp.array([], dtype=input_dtype) else: - X, residuals, rank, singular_values = jnp.linalg.lstsq(A, B, rcond=rcond) - if driver not in ['gelsd', 'gelsy', 'gelss']: + if driver not in ["gelsd", "gelsy", "gelss"]: rank = jnp.array([], dtype=jnp.int64) rank_value = None @@ -2799,11 +2850,11 @@ def _aten_linalg_lstsq(A, B, rcond=None, driver='gelsy'): rank = jnp.array(rank_value, dtype=jnp.int64) # When driver is ‘gels’, assume that A is full-rank. - full_rank = driver == 'gels' or rank_value == n - if driver == 'gelsy' or m <= n or (not full_rank): + full_rank = driver == "gels" or rank_value == n + if driver == "gelsy" or m <= n or (not full_rank): residuals = jnp.array([], dtype=input_dtype) - if driver not in ['gelsd', 'gelss']: + if driver not in ["gelsd", "gelss"]: singular_values = jnp.array([], dtype=input_dtype) return X, residuals, rank, singular_values @@ -2815,14 +2866,15 @@ def _aten_linalg_ldl_factor_ex(A, hermitian=False, check_errors=False): # https://github.com/jax-ml/jax/issues/12779 # TODO: Not tested for complex inputs. Does not support hermitian=True pivots = jnp.broadcast_to( - jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1]) + jnp.arange(1, A.shape[-1] + 1, dtype=jnp.int32), A.shape[:-1] + ) info = jnp.zeros(A.shape[:-2], jnp.int32) C = jnp.linalg.cholesky(A) if C.size == 0: return C, pivots, info # Fill diagonals of stacked matrices - @functools.partial(jnp.vectorize, signature='(k,k),(k,k)->(k,k)') + @functools.partial(jnp.vectorize, signature="(k,k),(k,k)->(k,k)") def fill_diagonal_batch(x, y): return jnp.fill_diagonal(x, jnp.diag(y), inplace=False) @@ -3022,9 +3074,11 @@ def _aten_nonzero_static(input, size, fill_value=-1): if size < indices.shape[0]: indices = indices[:size] elif size > indices.shape[0]: - padding = jnp.full((size - indices.shape[0], indices.shape[1]), - fill_value, - dtype=indices.dtype) + padding = jnp.full( + (size - indices.shape[0], indices.shape[1]), + fill_value, + dtype=indices.dtype, + ) indices = jnp.concatenate((indices, padding)) return indices @@ -3035,9 +3089,9 @@ def _aten_nonzero_static(input, size, fill_value=-1): def _aten_nonzero(x, as_tuple=False): if jnp.ndim(x) == 0 and (as_tuple or x.item() == 0): return torch.empty(0, 0, dtype=torch.int64) - if jnp.ndim( - x - ) == 0: # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) + if ( + jnp.ndim(x) == 0 + ): # when x is scalar, return torch.tensor([], size=(1, 0), dtype=torch.int64) res = torch.empty(1, 0, dtype=torch.int64) return jnp.array(res.numpy()) index_tuple = jnp.nonzero(x) @@ -3078,28 +3132,30 @@ def _aten_put(self, index, source, accumulate=False): # aten.randperm # randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.randperm, needs_env=True) -def _aten_randperm(n, - *, - generator=None, - dtype=None, - layout=None, - device=None, - pin_memory=None, - env=None): +def _aten_randperm( + n, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, + env=None, +): """ - Generates a random permutation of integers from 0 to n-1. + Generates a random permutation of integers from 0 to n-1. - Args: - n: The upper bound (exclusive) of the permutation range. - generator: A PRNGKey used as the random key. If None, a new key is created. - dtype: The desired data type of the output array. Default is jnp.int64. - layout: The desired layout of the output array (e.g., 'row-major', 'column-major'). - device: The desired device on which to place the output array (e.g., jax.devices()[0]). - pin_memory: Whether to pin the output array's memory to the host. + Args: + n: The upper bound (exclusive) of the permutation range. + generator: A PRNGKey used as the random key. If None, a new key is created. + dtype: The desired data type of the output array. Default is jnp.int64. + layout: The desired layout of the output array (e.g., 'row-major', 'column-major'). + device: The desired device on which to place the output array (e.g., jax.devices()[0]). + pin_memory: Whether to pin the output array's memory to the host. - Returns: - A DeviceArray containing a random permutation of integers from 0 to n-1. - """ + Returns: + A DeviceArray containing a random permutation of integers from 0 to n-1. + """ if dtype: dtype = mappings.t2j_dtype(dtype) else: @@ -3150,10 +3206,10 @@ def _aten_slice_scatter(input, src, dim=0, start=None, end=None, step=1): @op(torch.ops.aten.sort) def _aten_sort(a, dim=-1, descending=False, stable=False): if a.shape == (): - return (a, jnp.astype(0, 'int64')) + return (a, jnp.astype(0, "int64")) return ( - jnp.sort(a, axis=dim, stable=stable, descending=descending), - jnp.argsort(a, axis=dim, stable=stable, descending=descending), + jnp.sort(a, axis=dim, stable=stable, descending=descending), + jnp.argsort(a, axis=dim, stable=stable, descending=descending), ) @@ -3195,8 +3251,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if dim != -1 and dim != len(input.shape) - 1: transpose_shape = list(range(len(input.shape))) transpose_shape[dim], transpose_shape[-1] = ( - transpose_shape[-1], - transpose_shape[dim], + transpose_shape[-1], + transpose_shape[dim], ) input = jnp.transpose(input, transpose_shape) @@ -3205,7 +3261,8 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): if sorted: values = jnp.sort(values, descending=True) indices = jnp.take_along_axis( - indices, jnp.argsort(values, axis=-1, descending=True), axis=-1) + indices, jnp.argsort(values, axis=-1, descending=True), axis=-1 + ) if not largest: values = -values # Negate values back if we found smallest @@ -3218,31 +3275,35 @@ def _aten_topk(input, k, dim=None, largest=True, sorted=True, *, out=None): # aten.tril_indices -#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) +# tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.tril_indices) -def _aten_tril_indices(row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None): +def _aten_tril_indices( + row, + col, + offset=0, + *, + dtype=jnp.int64.dtype, + layout=None, + device=None, + pin_memory=None, +): a, b = jnp.tril_indices(row, offset, col) return jnp.stack((a, b)) # aten.tril_indices -#tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) +# tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) @op(torch.ops.aten.triu_indices) -def _aten_triu_indices(row, - col, - offset=0, - *, - dtype=jnp.int64.dtype, - layout=None, - device=None, - pin_memory=None): +def _aten_triu_indices( + row, + col, + offset=0, + *, + dtype=jnp.int64.dtype, + layout=None, + device=None, + pin_memory=None, +): a, b = jnp.triu_indices(row, offset, col) return jnp.stack((a, b)) @@ -3250,8 +3311,7 @@ def _aten_triu_indices(row, @op(torch.ops.aten.unbind_copy) def _aten_unbind(a, dim=0): return [ - jax.lax.index_in_dim(a, i, dim, keepdims=False) - for i in range(a.shape[dim]) + jax.lax.index_in_dim(a, i, dim, keepdims=False) for i in range(a.shape[dim]) ] @@ -3260,21 +3320,22 @@ def _aten_unbind(a, dim=0): # NOTE: Like the CUDA and CPU implementations, this implementation always sorts # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten.unique_dim) -def _aten_unique_dim(input_tensor, - dim, - sort=True, - return_inverse=False, - return_counts=False): +def _aten_unique_dim( + input_tensor, dim, sort=True, return_inverse=False, return_counts=False +): result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=return_counts, - axis=dim, - equal_nan=False) + input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=return_counts, + axis=dim, + equal_nan=False, + ) result_list = ( - list(result_tensor_or_tuple) if isinstance(result_tensor_or_tuple, tuple) - else [result_tensor_or_tuple]) + list(result_tensor_or_tuple) + if isinstance(result_tensor_or_tuple, tuple) + else [result_tensor_or_tuple] + ) if not return_inverse: result_list.insert(1, None) @@ -3298,12 +3359,13 @@ def _aten_unique_dim(input_tensor, @op(torch.ops.aten._unique) def _aten_unique(input_tensor, sort=True, return_inverse=False): result_tensor_or_tuple = jnp.unique( - input_tensor, - return_index=False, - return_inverse=return_inverse, - return_counts=False, - axis=None, - equal_nan=False) + input_tensor, + return_index=False, + return_inverse=return_inverse, + return_counts=False, + axis=None, + equal_nan=False, + ) if return_inverse: return result_tensor_or_tuple else: @@ -3315,24 +3377,23 @@ def _aten_unique(input_tensor, sort=True, return_inverse=False): # NOTE: Like the CUDA and CPU implementations, this implementation always sorts # the tensor regardless of the `sorted` argument passed to `torch.unique`. @op(torch.ops.aten._unique2) -def _aten_unique2(input_tensor, - sort=True, - return_inverse=False, - return_counts=False): +def _aten_unique2( + input_tensor, sort=True, return_inverse=False, return_counts=False +): return _aten_unique_dim( - input_tensor=input_tensor, - dim=None, - sort=sort, - return_inverse=return_inverse, - return_counts=return_counts) + input_tensor=input_tensor, + dim=None, + sort=sort, + return_inverse=return_inverse, + return_counts=return_counts, + ) # aten.unique_consecutive @op(torch.ops.aten.unique_consecutive) -def _aten_unique_consecutive(input_tensor, - return_inverse=False, - return_counts=None, - dim=None): +def _aten_unique_consecutive( + input_tensor, return_inverse=False, return_counts=None, dim=None +): # Explanation of computations (shown in 1D for simplicity): # # Input [a b b c c c d d d d e e e e e] @@ -3355,17 +3416,19 @@ def _aten_unique_consecutive(input_tensor, dim += ndim nd_slice_0 = tuple( - slice(None, -1) if d == dim else slice(None) for d in range(ndim)) + slice(None, -1) if d == dim else slice(None) for d in range(ndim) + ) nd_slice_1 = tuple( - slice(1, None) if d == dim else slice(None) for d in range(ndim)) + slice(1, None) if d == dim else slice(None) for d in range(ndim) + ) axes_to_reduce = tuple(d for d in range(ndim) if d != dim) - does_not_equal_prior = ( - jnp.any( - input_tensor[nd_slice_0] != input_tensor[nd_slice_1], - axis=axes_to_reduce, - keepdims=False)) + does_not_equal_prior = jnp.any( + input_tensor[nd_slice_0] != input_tensor[nd_slice_1], + axis=axes_to_reduce, + keepdims=False, + ) if input_tensor.shape[dim] != 0: # Prepend `True` to represent the first element of the input. @@ -3373,17 +3436,21 @@ def _aten_unique_consecutive(input_tensor, include_indices = jnp.argwhere(does_not_equal_prior)[:, 0] - output_tensor = input_tensor[tuple( - include_indices if d == dim else slice(None) for d in range(ndim))] + output_tensor = input_tensor[ + tuple(include_indices if d == dim else slice(None) for d in range(ndim)) + ] if return_inverse or return_counts: counts = ( - jnp.append(include_indices[1:], input_tensor.shape[dim]) - - include_indices[:]) + jnp.append(include_indices[1:], input_tensor.shape[dim]) + - include_indices[:] + ) inverse = ( - jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape) - if return_inverse else None) + jnp.reshape(jnp.repeat(jnp.arange(len(counts)), counts), inverse_shape) + if return_inverse + else None + ) return output_tensor, inverse, counts @@ -3408,32 +3475,33 @@ def _aten_where(condition, x=None, y=None): # aten.to.dtype # Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None @op(torch.ops.aten.to.dtype) -def _aten_to_dtype(a, - dtype, - non_blocking=False, - copy=False, - memory_format=None): +def _aten_to_dtype( + a, dtype, non_blocking=False, copy=False, memory_format=None +): if dtype: jaxdtype = mappings.t2j_dtype(dtype) return a.astype(jaxdtype) @op(torch.ops.aten.to.dtype_layout) -def _aten_to_dtype_layout(a, - *, - dtype=None, - layout=None, - device=None, - pin_memory=None, - non_blocking=False, - copy=False, - memory_format=None): +def _aten_to_dtype_layout( + a, + *, + dtype=None, + layout=None, + device=None, + pin_memory=None, + non_blocking=False, + copy=False, + memory_format=None, +): return _aten_to_dtype( - a, - dtype, - non_blocking=non_blocking, - copy=copy, - memory_format=memory_format) + a, + dtype, + non_blocking=non_blocking, + copy=copy, + memory_format=memory_format, + ) # aten.to.device @@ -3456,11 +3524,9 @@ def _aten_var_mean_correction(tensor, dim=None, correction=1, keepdim=False): @op(torch.ops.aten.scalar_tensor) @op_base.convert_dtype() -def _aten_scalar_tensor(s, - dtype=None, - layout=None, - device=None, - pin_memory=None): +def _aten_scalar_tensor( + s, dtype=None, layout=None, device=None, pin_memory=None +): return jnp.array(s, dtype=dtype) @@ -3470,9 +3536,16 @@ def _aten_to_device(x, device, dtype): @op(torch.ops.aten.max_pool2d_with_indices_backward) -def max_pool2d_with_indices_backward_custom(grad_output, self, kernel_size, - stride, padding, dilation, - ceil_mode, indices): +def max_pool2d_with_indices_backward_custom( + grad_output, + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + indices, +): """ Approximates the gradient calculation of PyTorch's max_pool2d_with_indices_backward. @@ -3528,15 +3601,15 @@ def _aten_tensor_split(ary, indices_or_sections, axis=0): @op(torch.ops.aten.randn, needs_env=True) @op_base.convert_dtype() def _randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): shape = size if len(shape) == 1 and isinstance(shape[0], (list, tuple)): @@ -3550,11 +3623,11 @@ def _randn( @op(torch.ops.aten.bernoulli.p, needs_env=True) def _aten_bernoulli( - self, - p=0.5, - *, - generator=None, - env=None, + self, + p=0.5, + *, + generator=None, + env=None, ): key = env.get_and_rotate_prng_key(generator) res = jax.random.uniform(key, self.shape) < p @@ -3571,14 +3644,14 @@ def geometric(self, p, *, generator=None, env=None): @op(torch.ops.aten.randn_like, needs_env=True) @op_base.convert_dtype() def _aten_randn_like( - x, - *, - dtype=None, - layout=None, - device=None, - pin_memory=False, - memory_format=torch.preserve_format, - env=None, + x, + *, + dtype=None, + layout=None, + device=None, + pin_memory=False, + memory_format=torch.preserve_format, + env=None, ): key = env.get_and_rotate_prng_key() return jax.random.normal(key, dtype=dtype or x.dtype, shape=x.shape) @@ -3587,15 +3660,15 @@ def _aten_randn_like( @op(torch.ops.aten.rand, needs_env=True) @op_base.convert_dtype() def _rand( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, - env=None, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, + env=None, ): shape = size if len(shape) == 1 and isinstance(shape[0], (list, tuple)): @@ -3618,30 +3691,40 @@ def _aten_allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): @op(torch.ops.aten.native_batch_norm) -def _aten_native_batch_norm(input, - weight, - bias, - running_mean, - running_var, - training=False, - momentum=0.1, - eps=1e-5): - +def _aten_native_batch_norm( + input, + weight, + bias, + running_mean, + running_var, + training=False, + momentum=0.1, + eps=1e-5, +): if running_mean is None: running_mean = jnp.zeros( - input.shape[1], dtype=input.dtype) # Initialize running mean if None + input.shape[1], dtype=input.dtype + ) # Initialize running mean if None if running_var is None: running_var = jnp.ones( - input.shape[1], - dtype=input.dtype) # Initialize running variance if None + input.shape[1], dtype=input.dtype + ) # Initialize running variance if None if training: - return _aten__native_batch_norm_legit(input, weight, bias, running_mean, - running_var, training, momentum, eps) + return _aten__native_batch_norm_legit( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + ) else: - return _aten__native_batch_norm_legit_no_training(input, weight, bias, - running_mean, running_var, - momentum, eps) + return _aten__native_batch_norm_legit_no_training( + input, weight, bias, running_mean, running_var, momentum, eps + ) @op(torch.ops.aten.normal, needs_env=True) @@ -3660,23 +3743,25 @@ def _aten_lift_fresh(self): @op(torch.ops.aten.uniform, needs_env=True) def _aten_uniform(self, from_=0, to=1, *, generator=None, env=None): - assert from_ <= to, f'Uniform from(passed in {from_}) must be less than to(passed in {to})' + assert from_ <= to, ( + f"Uniform from(passed in {from_}) must be less than to(passed in {to})" + ) shape = self.shape res = _rand(*shape, generator=generator, env=env) return res * (to - from_) + from_ -#func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +# func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @op(torch.ops.aten.randint, needs_env=True) @op_base.convert_dtype(use_default_dtype=False) def _aten_randint( - *args, - generator=None, - dtype=None, - env=None, - **kwargs, + *args, + generator=None, + dtype=None, + env=None, + **kwargs, ): if len(args) == 3: # low, high, size @@ -3686,7 +3771,8 @@ def _aten_randint( low = 0 else: raise AssertionError( - f'Expected at 2 or 3 args for Aten::randint, got {len(args)}') + f"Expected at 2 or 3 args for Aten::randint, got {len(args)}" + ) key = env.get_and_rotate_prng_key(generator) res = jax.random.randint(key, size, low, high) @@ -3695,17 +3781,19 @@ def _aten_randint( return res -@op(torch.ops.aten.randint_like, - torch.ops.aten.randint.generator, - needs_env=True) +@op( + torch.ops.aten.randint_like, + torch.ops.aten.randint.generator, + needs_env=True, +) @op_base.convert_dtype(use_default_dtype=False) def _aten_randint_like( - input, - *args, - generator=None, - dtype=None, - env=None, - **kwargs, + input, + *args, + generator=None, + dtype=None, + env=None, + **kwargs, ): if len(args) == 2: low, high = args @@ -3714,7 +3802,8 @@ def _aten_randint_like( low = 0 else: raise AssertionError( - f'Expected at 1 or 2 args for Aten::randint_like, got {len(args)}') + f"Expected at 1 or 2 args for Aten::randint_like, got {len(args)}" + ) shape = input.shape dtype = dtype or input.dtype @@ -3737,7 +3826,8 @@ def _aten_copysign(input, other, *, out=None): # regardless of their exact integer dtype, whereas jax.copysign returns # float64 when one or both of them is int64. if jnp.issubdtype(input.dtype, jnp.integer) and jnp.issubdtype( - other.dtype, jnp.integer): + other.dtype, jnp.integer + ): result = result.astype(jnp.float32) return result @@ -3773,7 +3863,6 @@ def _aten_special_laguerre_polynomial_l(self, n): @jnp.vectorize def vectorized(x, n_i): - def negative_n(x): return jnp.zeros_like(x) @@ -3787,18 +3876,21 @@ def zero_abs(x): return jnp.ones_like(x) def default(x): - def f(k, carry): p, q = carry - return (q, ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1)) + return ( + q, + ((k * 2 + (jnp.ones_like(x) - x)) * q - k * p) / (k + 1), + ) _, q = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, jnp.ones_like(x) - x)) return q return jnp.piecewise( - x, [n_i == 1, n_i == 0, - jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], - [one_n, zero_n, zero_abs, negative_n, default]) + x, + [n_i == 1, n_i == 0, jnp.abs(n_i) == jnp.zeros_like(x), n_i < 0], + [one_n, zero_n, zero_abs, negative_n, default], + ) return vectorized(self, n.astype(jnp.int64)) @@ -3810,39 +3902,39 @@ def _aten_special_modified_bessel_i0(self): def small(x): A = jnp.array( - [ - -4.41534164647933937950e-18, - 3.33079451882223809783e-17, - -2.43127984654795469359e-16, - 1.71539128555513303061e-15, - -1.16853328779934516808e-14, - 7.67618549860493561688e-14, - -4.85644678311192946090e-13, - 2.95505266312963983461e-12, - -1.72682629144155570723e-11, - 9.67580903537323691224e-11, - -5.18979560163526290666e-10, - 2.65982372468238665035e-09, - -1.30002500998624804212e-08, - 6.04699502254191894932e-08, - -2.67079385394061173391e-07, - 1.11738753912010371815e-06, - -4.41673835845875056359e-06, - 1.64484480707288970893e-05, - -5.75419501008210370398e-05, - 1.88502885095841655729e-04, - -5.76375574538582365885e-04, - 1.63947561694133579842e-03, - -4.32430999505057594430e-03, - 1.05464603945949983183e-02, - -2.37374148058994688156e-02, - 4.93052842396707084878e-02, - -9.49010970480476444210e-02, - 1.71620901522208775349e-01, - -3.04682672343198398683e-01, - 6.76795274409476084995e-01, - ], - dtype=self.dtype, + [ + -4.41534164647933937950e-18, + 3.33079451882223809783e-17, + -2.43127984654795469359e-16, + 1.71539128555513303061e-15, + -1.16853328779934516808e-14, + 7.67618549860493561688e-14, + -4.85644678311192946090e-13, + 2.95505266312963983461e-12, + -1.72682629144155570723e-11, + 9.67580903537323691224e-11, + -5.18979560163526290666e-10, + 2.65982372468238665035e-09, + -1.30002500998624804212e-08, + 6.04699502254191894932e-08, + -2.67079385394061173391e-07, + 1.11738753912010371815e-06, + -4.41673835845875056359e-06, + 1.64484480707288970893e-05, + -5.75419501008210370398e-05, + 1.88502885095841655729e-04, + -5.76375574538582365885e-04, + 1.63947561694133579842e-03, + -4.32430999505057594430e-03, + 1.05464603945949983183e-02, + -2.37374148058994688156e-02, + 4.93052842396707084878e-02, + -9.49010970480476444210e-02, + 1.71620901522208775349e-01, + -3.04682672343198398683e-01, + 6.76795274409476084995e-01, + ], + dtype=self.dtype, ) def f(carry, val): @@ -3851,40 +3943,41 @@ def f(carry, val): return (p, q, ((x / 2.0) - 2.0) * q - p + val), None (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) return jnp.exp(x) * (0.5 * (a - p)) def default(x): B = jnp.array( - [ - -7.23318048787475395456e-18, - -4.83050448594418207126e-18, - 4.46562142029675999901e-17, - 3.46122286769746109310e-17, - -2.82762398051658348494e-16, - -3.42548561967721913462e-16, - 1.77256013305652638360e-15, - 3.81168066935262242075e-15, - -9.55484669882830764870e-15, - -4.15056934728722208663e-14, - 1.54008621752140982691e-14, - 3.85277838274214270114e-13, - 7.18012445138366623367e-13, - -1.79417853150680611778e-12, - -1.32158118404477131188e-11, - -3.14991652796324136454e-11, - 1.18891471078464383424e-11, - 4.94060238822496958910e-10, - 3.39623202570838634515e-09, - 2.26666899049817806459e-08, - 2.04891858946906374183e-07, - 2.89137052083475648297e-06, - 6.88975834691682398426e-05, - 3.36911647825569408990e-03, - 8.04490411014108831608e-01, - ], - dtype=self.dtype, + [ + -7.23318048787475395456e-18, + -4.83050448594418207126e-18, + 4.46562142029675999901e-17, + 3.46122286769746109310e-17, + -2.82762398051658348494e-16, + -3.42548561967721913462e-16, + 1.77256013305652638360e-15, + 3.81168066935262242075e-15, + -9.55484669882830764870e-15, + -4.15056934728722208663e-14, + 1.54008621752140982691e-14, + 3.85277838274214270114e-13, + 7.18012445138366623367e-13, + -1.79417853150680611778e-12, + -1.32158118404477131188e-11, + -3.14991652796324136454e-11, + 1.18891471078464383424e-11, + 4.94060238822496958910e-10, + 3.39623202570838634515e-09, + 2.26666899049817806459e-08, + 2.04891858946906374183e-07, + 2.89137052083475648297e-06, + 6.88975834691682398426e-05, + 3.36911647825569408990e-03, + 8.04490411014108831608e-01, + ], + dtype=self.dtype, ) def f(carry, val): @@ -3893,7 +3986,8 @@ def f(carry, val): return (p, q, (32.0 / x - 2.0) * q - p + val), None (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) return jnp.exp(x) * (0.5 * (b - p)) / jnp.sqrt(x) @@ -3908,38 +4002,38 @@ def _aten_special_modified_bessel_i1(self): def small(x): A = jnp.array( - [ - 2.77791411276104639959e-18, - -2.11142121435816608115e-17, - 1.55363195773620046921e-16, - -1.10559694773538630805e-15, - 7.60068429473540693410e-15, - -5.04218550472791168711e-14, - 3.22379336594557470981e-13, - -1.98397439776494371520e-12, - 1.17361862988909016308e-11, - -6.66348972350202774223e-11, - 3.62559028155211703701e-10, - -1.88724975172282928790e-09, - 9.38153738649577178388e-09, - -4.44505912879632808065e-08, - 2.00329475355213526229e-07, - -8.56872026469545474066e-07, - 3.47025130813767847674e-06, - -1.32731636560394358279e-05, - 4.78156510755005422638e-05, - -1.61760815825896745588e-04, - 5.12285956168575772895e-04, - -1.51357245063125314899e-03, - 4.15642294431288815669e-03, - -1.05640848946261981558e-02, - 2.47264490306265168283e-02, - -5.29459812080949914269e-02, - 1.02643658689847095384e-01, - -1.76416518357834055153e-01, - 2.52587186443633654823e-01, - ], - dtype=self.dtype, + [ + 2.77791411276104639959e-18, + -2.11142121435816608115e-17, + 1.55363195773620046921e-16, + -1.10559694773538630805e-15, + 7.60068429473540693410e-15, + -5.04218550472791168711e-14, + 3.22379336594557470981e-13, + -1.98397439776494371520e-12, + 1.17361862988909016308e-11, + -6.66348972350202774223e-11, + 3.62559028155211703701e-10, + -1.88724975172282928790e-09, + 9.38153738649577178388e-09, + -4.44505912879632808065e-08, + 2.00329475355213526229e-07, + -8.56872026469545474066e-07, + 3.47025130813767847674e-06, + -1.32731636560394358279e-05, + 4.78156510755005422638e-05, + -1.61760815825896745588e-04, + 5.12285956168575772895e-04, + -1.51357245063125314899e-03, + 4.15642294431288815669e-03, + -1.05640848946261981558e-02, + 2.47264490306265168283e-02, + -5.29459812080949914269e-02, + 1.02643658689847095384e-01, + -1.76416518357834055153e-01, + 2.52587186443633654823e-01, + ], + dtype=self.dtype, ) def f(carry, val): @@ -3948,42 +4042,45 @@ def f(carry, val): return (p, q, ((jnp.abs(x) / 2.0) - 2.0) * q - p + val), None (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) return jax.lax.cond( - x < 0, lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), - lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))) + x < 0, + lambda: -(0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x))), + lambda: 0.5 * (a - p) * jnp.abs(x) * jnp.exp(jnp.abs(x)), + ) def default(x): B = jnp.array( - [ - 7.51729631084210481353e-18, - 4.41434832307170791151e-18, - -4.65030536848935832153e-17, - -3.20952592199342395980e-17, - 2.96262899764595013876e-16, - 3.30820231092092828324e-16, - -1.88035477551078244854e-15, - -3.81440307243700780478e-15, - 1.04202769841288027642e-14, - 4.27244001671195135429e-14, - -2.10154184277266431302e-14, - -4.08355111109219731823e-13, - -7.19855177624590851209e-13, - 2.03562854414708950722e-12, - 1.41258074366137813316e-11, - 3.25260358301548823856e-11, - -1.89749581235054123450e-11, - -5.58974346219658380687e-10, - -3.83538038596423702205e-09, - -2.63146884688951950684e-08, - -2.51223623787020892529e-07, - -3.88256480887769039346e-06, - -1.10588938762623716291e-04, - -9.76109749136146840777e-03, - 7.78576235018280120474e-01, - ], - dtype=self.dtype, + [ + 7.51729631084210481353e-18, + 4.41434832307170791151e-18, + -4.65030536848935832153e-17, + -3.20952592199342395980e-17, + 2.96262899764595013876e-16, + 3.30820231092092828324e-16, + -1.88035477551078244854e-15, + -3.81440307243700780478e-15, + 1.04202769841288027642e-14, + 4.27244001671195135429e-14, + -2.10154184277266431302e-14, + -4.08355111109219731823e-13, + -7.19855177624590851209e-13, + 2.03562854414708950722e-12, + 1.41258074366137813316e-11, + 3.25260358301548823856e-11, + -1.89749581235054123450e-11, + -5.58974346219658380687e-10, + -3.83538038596423702205e-09, + -2.63146884688951950684e-08, + -2.51223623787020892529e-07, + -3.88256480887769039346e-06, + -1.10588938762623716291e-04, + -9.76109749136146840777e-03, + 7.78576235018280120474e-01, + ], + dtype=self.dtype, ) def f(carry, val): @@ -3992,12 +4089,14 @@ def f(carry, val): return (p, q, (32.0 / jnp.abs(x) - 2.0) * q - p + val), None (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) return jax.lax.cond( - x < 0, lambda: -(jnp.exp(jnp.abs(x)) * - (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), - lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))) + x < 0, + lambda: -(jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x))), + lambda: jnp.exp(jnp.abs(x)) * (0.5 * (b - p)) / jnp.sqrt(jnp.abs(x)), + ) return jnp.piecewise(self, [self <= 8], [small, default]) @@ -4015,19 +4114,19 @@ def negative(x): def small(x): A = jnp.array( - [ - 1.37446543561352307156e-16, - 4.25981614279661018399e-14, - 1.03496952576338420167e-11, - 1.90451637722020886025e-09, - 2.53479107902614945675e-07, - 2.28621210311945178607e-05, - 1.26461541144692592338e-03, - 3.59799365153615016266e-02, - 3.44289899924628486886e-01, - -5.35327393233902768720e-01, - ], - dtype=self.dtype, + [ + 1.37446543561352307156e-16, + 4.25981614279661018399e-14, + 1.03496952576338420167e-11, + 1.90451637722020886025e-09, + 2.53479107902614945675e-07, + 2.28621210311945178607e-05, + 1.26461541144692592338e-03, + 3.59799365153615016266e-02, + 3.44289899924628486886e-01, + -5.35327393233902768720e-01, + ], + dtype=self.dtype, ) def f(carry, val): @@ -4036,41 +4135,43 @@ def f(carry, val): return (p, q, (x * x - 2.0) * q - p + val), None (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) - return 0.5 * (a - p) - jnp.log( - 0.5 * x) * _aten_special_modified_bessel_i0(x) + return 0.5 * (a - p) - jnp.log(0.5 * x) * _aten_special_modified_bessel_i0( + x + ) def default(x): B = jnp.array( - [ - 5.30043377268626276149e-18, - -1.64758043015242134646e-17, - 5.21039150503902756861e-17, - -1.67823109680541210385e-16, - 5.51205597852431940784e-16, - -1.84859337734377901440e-15, - 6.34007647740507060557e-15, - -2.22751332699166985548e-14, - 8.03289077536357521100e-14, - -2.98009692317273043925e-13, - 1.14034058820847496303e-12, - -4.51459788337394416547e-12, - 1.85594911495471785253e-11, - -7.95748924447710747776e-11, - 3.57739728140030116597e-10, - -1.69753450938905987466e-09, - 8.57403401741422608519e-09, - -4.66048989768794782956e-08, - 2.76681363944501510342e-07, - -1.83175552271911948767e-06, - 1.39498137188764993662e-05, - -1.28495495816278026384e-04, - 1.56988388573005337491e-03, - -3.14481013119645005427e-02, - 2.44030308206595545468e+00, - ], - dtype=self.dtype, + [ + 5.30043377268626276149e-18, + -1.64758043015242134646e-17, + 5.21039150503902756861e-17, + -1.67823109680541210385e-16, + 5.51205597852431940784e-16, + -1.84859337734377901440e-15, + 6.34007647740507060557e-15, + -2.22751332699166985548e-14, + 8.03289077536357521100e-14, + -2.98009692317273043925e-13, + 1.14034058820847496303e-12, + -4.51459788337394416547e-12, + 1.85594911495471785253e-11, + -7.95748924447710747776e-11, + 3.57739728140030116597e-10, + -1.69753450938905987466e-09, + 8.57403401741422608519e-09, + -4.66048989768794782956e-08, + 2.76681363944501510342e-07, + -1.83175552271911948767e-06, + 1.39498137188764993662e-05, + -1.28495495816278026384e-04, + 1.56988388573005337491e-03, + -3.14481013119645005427e-02, + 2.44030308206595545468e00, + ], + dtype=self.dtype, ) def f(carry, val): @@ -4079,12 +4180,14 @@ def f(carry, val): return (p, q, (8.0 / x - 2.0) * q - p + val), None (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - return jnp.piecewise(self, [self <= 2, self < 0, self == 0], - [small, negative, zero, default]) + return jnp.piecewise( + self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] + ) @op(torch.ops.aten.special_modified_bessel_k1) @@ -4100,20 +4203,20 @@ def negative(x): def small(x): A = jnp.array( - [ - -7.02386347938628759343e-18, - -2.42744985051936593393e-15, - -6.66690169419932900609e-13, - -1.41148839263352776110e-10, - -2.21338763073472585583e-08, - -2.43340614156596823496e-06, - -1.73028895751305206302e-04, - -6.97572385963986435018e-03, - -1.22611180822657148235e-01, - -3.53155960776544875667e-01, - 1.52530022733894777053e+00, - ], - dtype=self.dtype, + [ + -7.02386347938628759343e-18, + -2.42744985051936593393e-15, + -6.66690169419932900609e-13, + -1.41148839263352776110e-10, + -2.21338763073472585583e-08, + -2.43340614156596823496e-06, + -1.73028895751305206302e-04, + -6.97572385963986435018e-03, + -1.22611180822657148235e-01, + -3.53155960776544875667e-01, + 1.52530022733894777053e00, + ], + dtype=self.dtype, ) def f(carry, val): @@ -4123,41 +4226,43 @@ def f(carry, val): return (p, q, a), None (p, _, a), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=A + ) - return jnp.log( - 0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x + return ( + jnp.log(0.5 * x) * _aten_special_modified_bessel_i1(x) + 0.5 * (a - p) / x + ) def default(x): B = jnp.array( - [ - -5.75674448366501715755e-18, - 1.79405087314755922667e-17, - -5.68946255844285935196e-17, - 1.83809354436663880070e-16, - -6.05704724837331885336e-16, - 2.03870316562433424052e-15, - -7.01983709041831346144e-15, - 2.47715442448130437068e-14, - -8.97670518232499435011e-14, - +3.34841966607842919884e-13, - -1.28917396095102890680e-12, - 5.13963967348173025100e-12, - -2.12996783842756842877e-11, - 9.21831518760500529508e-11, - -4.19035475934189648750e-10, - 2.01504975519703286596e-09, - -1.03457624656780970260e-08, - 5.74108412545004946722e-08, - -3.50196060308781257119e-07, - 2.40648494783721712015e-06, - -1.93619797416608296024e-05, - 1.95215518471351631108e-04, - -2.85781685962277938680e-03, - 1.03923736576817238437e-01, - 2.72062619048444266945e+00, - ], - dtype=self.dtype, + [ + -5.75674448366501715755e-18, + 1.79405087314755922667e-17, + -5.68946255844285935196e-17, + 1.83809354436663880070e-16, + -6.05704724837331885336e-16, + 2.03870316562433424052e-15, + -7.01983709041831346144e-15, + 2.47715442448130437068e-14, + -8.97670518232499435011e-14, + +3.34841966607842919884e-13, + -1.28917396095102890680e-12, + 5.13963967348173025100e-12, + -2.12996783842756842877e-11, + 9.21831518760500529508e-11, + -4.19035475934189648750e-10, + 2.01504975519703286596e-09, + -1.03457624656780970260e-08, + 5.74108412545004946722e-08, + -3.50196060308781257119e-07, + 2.40648494783721712015e-06, + -1.93619797416608296024e-05, + 1.95215518471351631108e-04, + -2.85781685962277938680e-03, + 1.03923736576817238437e-01, + 2.72062619048444266945e00, + ], + dtype=self.dtype, ) def f(carry, val): @@ -4167,12 +4272,14 @@ def f(carry, val): return (p, q, b), None (p, _, b), _ = jax.lax.scan( - f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B) + f, init=(jnp.zeros_like(x), jnp.zeros_like(x), 0), xs=B + ) return jnp.exp(-x) * (0.5 * (b - p)) / jnp.sqrt(x) - return jnp.piecewise(self, [self <= 2, self < 0, self == 0], - [small, negative, zero, default]) + return jnp.piecewise( + self, [self <= 2, self < 0, self == 0], [small, negative, zero, default] + ) @op(torch.ops.aten.polygamma) @@ -4198,104 +4305,120 @@ def very_small(x): def small(x): RP = jnp.array( - [ - -4.79443220978201773821e09, - 1.95617491946556577543e12, - -2.49248344360967716204e14, - 9.70862251047306323952e15, - ], - dtype=self.dtype, + [ + -4.79443220978201773821e09, + 1.95617491946556577543e12, + -2.49248344360967716204e14, + 9.70862251047306323952e15, + ], + dtype=self.dtype, ) RQ = jnp.array( - [ - 4.99563147152651017219e02, - 1.73785401676374683123e05, - 4.84409658339962045305e07, - 1.11855537045356834862e10, - 2.11277520115489217587e12, - 3.10518229857422583814e14, - 3.18121955943204943306e16, - 1.71086294081043136091e18, - ], - dtype=self.dtype, + [ + 4.99563147152651017219e02, + 1.73785401676374683123e05, + 4.84409658339962045305e07, + 1.11855537045356834862e10, + 2.11277520115489217587e12, + 3.10518229857422583814e14, + 3.18121955943204943306e16, + 1.71086294081043136091e18, + ], + dtype=self.dtype, ) rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - return ((x * x - 5.78318596294678452118e00) * - (x * x - 3.04712623436620863991e01) * rp / rq) + return ( + (x * x - 5.78318596294678452118e00) + * (x * x - 3.04712623436620863991e01) + * rp + / rq + ) def default(x): PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, + [ + 7.96936729297347051624e-04, + 8.28352392107440799803e-02, + 1.23953371646414299388e00, + 5.44725003058768775090e00, + 8.74716500199817011941e00, + 5.30324038235394892183e00, + 9.99999999999999997821e-01, + ], + dtype=self.dtype, ) PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, + [ + 9.24408810558863637013e-04, + 8.56288474354474431428e-02, + 1.25352743901058953537e00, + 5.47097740330417105182e00, + 8.76190883237069594232e00, + 5.30605288235394617618e00, + 1.00000000000000000218e00, + ], + dtype=self.dtype, ) QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, + [ + -1.13663838898469149931e-02, + -1.28252718670509318512e00, + -1.95539544257735972385e01, + -9.32060152123768231369e01, + -1.77681167980488050595e02, + -1.47077505154951170175e02, + -5.14105326766599330220e01, + -6.05014350600728481186e00, + ], + dtype=self.dtype, ) QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, + [ + 6.43178256118178023184e01, + 8.56430025976980587198e02, + 3.88240183605401609683e03, + 7.24046774195652478189e03, + 5.93072701187316984827e03, + 2.06209331660327847417e03, + 2.42005740240291393179e02, + ], + dtype=self.dtype, ) pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i) + PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i + ) pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i) + PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i + ) qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i) + QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i + ) qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i) + QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i + ) - return ((pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) - - 5.0 / x * - (qp / qq) * jnp.sin(x - 0.785398163397448309615660845819875721)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) + return ( + ( + pp / pq * jnp.cos(x - 0.785398163397448309615660845819875721) + - 5.0 + / x + * (qp / qq) + * jnp.sin(x - 0.785398163397448309615660845819875721) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) self = jnp.abs(self) # Last True condition in `piecewise` takes priority, but last function is # default. See https://github.com/numpy/numpy/issues/16475 - return jnp.piecewise(self, [self <= 5.0, self < 0.00001], - [small, very_small, default]) + return jnp.piecewise( + self, [self <= 5.0, self < 0.00001], [small, very_small, default] + ) @op(torch.ops.aten.special_bessel_j1) @@ -4305,106 +4428,122 @@ def _aten_special_bessel_j1(self): def small(x): RP = jnp.array( - [ - -8.99971225705559398224e08, - 4.52228297998194034323e11, - -7.27494245221818276015e13, - 3.68295732863852883286e15, - ], - dtype=self.dtype, + [ + -8.99971225705559398224e08, + 4.52228297998194034323e11, + -7.27494245221818276015e13, + 3.68295732863852883286e15, + ], + dtype=self.dtype, ) RQ = jnp.array( - [ - 6.20836478118054335476e02, - 2.56987256757748830383e05, - 8.35146791431949253037e07, - 2.21511595479792499675e10, - 4.74914122079991414898e12, - 7.84369607876235854894e14, - 8.95222336184627338078e16, - 5.32278620332680085395e18, - ], - dtype=self.dtype, + [ + 6.20836478118054335476e02, + 2.56987256757748830383e05, + 8.35146791431949253037e07, + 2.21511595479792499675e10, + 4.74914122079991414898e12, + 7.84369607876235854894e14, + 8.95222336184627338078e16, + 5.32278620332680085395e18, + ], + dtype=self.dtype, ) rp = op_base.foreach_loop(RP, lambda carry, rp_i: carry * (x * x) + rp_i) rq = op_base.foreach_loop(RQ, lambda carry, rq_i: carry * (x * x) + rq_i) - return (rp / rq * x * (x * x - 1.46819706421238932572e01) * - (x * x - 4.92184563216946036703e01)) + return ( + rp + / rq + * x + * (x * x - 1.46819706421238932572e01) + * (x * x - 4.92184563216946036703e01) + ) def default(x): PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, + [ + 7.62125616208173112003e-04, + 7.31397056940917570436e-02, + 1.12719608129684925192e00, + 5.11207951146807644818e00, + 8.42404590141772420927e00, + 5.21451598682361504063e00, + 1.00000000000000000254e00, + ], + dtype=self.dtype, ) PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, + [ + 5.71323128072548699714e-04, + 6.88455908754495404082e-02, + 1.10514232634061696926e00, + 5.07386386128601488557e00, + 8.39985554327604159757e00, + 5.20982848682361821619e00, + 9.99999999999999997461e-01, + ], + dtype=self.dtype, ) QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, + [ + 5.10862594750176621635e-02, + 4.98213872951233449420e00, + 7.58238284132545283818e01, + 3.66779609360150777800e02, + 7.10856304998926107277e02, + 5.97489612400613639965e02, + 2.11688757100572135698e02, + 2.52070205858023719784e01, + ], + dtype=self.dtype, ) QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, + [ + 7.42373277035675149943e01, + 1.05644886038262816351e03, + 4.98641058337653607651e03, + 9.56231892404756170795e03, + 7.99704160447350683650e03, + 2.82619278517639096600e03, + 3.36093607810698293419e02, + ], + dtype=self.dtype, ) pp = op_base.foreach_loop( - PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i) + PP, lambda carry, pp_i: carry * (25.0 / (x * x)) + pp_i + ) pq = op_base.foreach_loop( - PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i) + PQ, lambda carry, pq_i: carry * (25.0 / (x * x)) + pq_i + ) qp = op_base.foreach_loop( - QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i) + QP, lambda carry, qp_i: carry * (25.0 / (x * x)) + qp_i + ) qq = op_base.foreach_loop( - QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i) + QQ, lambda carry, qq_i: carry * (25.0 / (x * x)) + qq_i + ) - return ((pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) - - 5.0 / x * - (qp / qq) * jnp.sin(x - 2.356194490192344928846982537459627163)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) + return ( + ( + pp / pq * jnp.cos(x - 2.356194490192344928846982537459627163) + - 5.0 + / x + * (qp / qq) + * jnp.sin(x - 2.356194490192344928846982537459627163) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) # If x < 0, bessel_j1(x) = -bessel_j1(-x) sign = jnp.sign(self) self = jnp.abs(self) return sign * jnp.piecewise( - self, - [self <= 5.0], - [small, default], + self, + [self <= 5.0], + [small, default], ) @@ -4421,86 +4560,89 @@ def negative(x): def small(x): YP = jnp.array( - [ - 1.55924367855235737965e04, - -1.46639295903971606143e07, - 5.43526477051876500413e09, - -9.82136065717911466409e11, - 8.75906394395366999549e13, - -3.46628303384729719441e15, - 4.42733268572569800351e16, - -1.84950800436986690637e16, - ], - dtype=self.dtype, + [ + 1.55924367855235737965e04, + -1.46639295903971606143e07, + 5.43526477051876500413e09, + -9.82136065717911466409e11, + 8.75906394395366999549e13, + -3.46628303384729719441e15, + 4.42733268572569800351e16, + -1.84950800436986690637e16, + ], + dtype=self.dtype, ) YQ = jnp.array( - [ - 1.04128353664259848412e03, - 6.26107330137134956842e05, - 2.68919633393814121987e08, - 8.64002487103935000337e10, - 2.02979612750105546709e13, - 3.17157752842975028269e15, - 2.50596256172653059228e17, - ], - dtype=self.dtype, + [ + 1.04128353664259848412e03, + 6.26107330137134956842e05, + 2.68919633393814121987e08, + 8.64002487103935000337e10, + 2.02979612750105546709e13, + 3.17157752842975028269e15, + 2.50596256172653059228e17, + ], + dtype=self.dtype, ) yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - return yp / yq + (0.636619772367581343075535053490057448 * jnp.log(x) * - _aten_special_bessel_j0(x)) + return yp / yq + ( + 0.636619772367581343075535053490057448 + * jnp.log(x) + * _aten_special_bessel_j0(x) + ) def default(x): PP = jnp.array( - [ - 7.96936729297347051624e-04, - 8.28352392107440799803e-02, - 1.23953371646414299388e00, - 5.44725003058768775090e00, - 8.74716500199817011941e00, - 5.30324038235394892183e00, - 9.99999999999999997821e-01, - ], - dtype=self.dtype, + [ + 7.96936729297347051624e-04, + 8.28352392107440799803e-02, + 1.23953371646414299388e00, + 5.44725003058768775090e00, + 8.74716500199817011941e00, + 5.30324038235394892183e00, + 9.99999999999999997821e-01, + ], + dtype=self.dtype, ) PQ = jnp.array( - [ - 9.24408810558863637013e-04, - 8.56288474354474431428e-02, - 1.25352743901058953537e00, - 5.47097740330417105182e00, - 8.76190883237069594232e00, - 5.30605288235394617618e00, - 1.00000000000000000218e00, - ], - dtype=self.dtype, + [ + 9.24408810558863637013e-04, + 8.56288474354474431428e-02, + 1.25352743901058953537e00, + 5.47097740330417105182e00, + 8.76190883237069594232e00, + 5.30605288235394617618e00, + 1.00000000000000000218e00, + ], + dtype=self.dtype, ) QP = jnp.array( - [ - -1.13663838898469149931e-02, - -1.28252718670509318512e00, - -1.95539544257735972385e01, - -9.32060152123768231369e01, - -1.77681167980488050595e02, - -1.47077505154951170175e02, - -5.14105326766599330220e01, - -6.05014350600728481186e00, - ], - dtype=self.dtype, + [ + -1.13663838898469149931e-02, + -1.28252718670509318512e00, + -1.95539544257735972385e01, + -9.32060152123768231369e01, + -1.77681167980488050595e02, + -1.47077505154951170175e02, + -5.14105326766599330220e01, + -6.05014350600728481186e00, + ], + dtype=self.dtype, ) QQ = jnp.array( - [ - 6.43178256118178023184e01, - 8.56430025976980587198e02, - 3.88240183605401609683e03, - 7.24046774195652478189e03, - 5.93072701187316984827e03, - 2.06209331660327847417e03, - 2.42005740240291393179e02, - ], - dtype=self.dtype, + [ + 6.43178256118178023184e01, + 8.56430025976980587198e02, + 3.88240183605401609683e03, + 7.24046774195652478189e03, + 5.93072701187316984827e03, + 2.06209331660327847417e03, + 2.42005740240291393179e02, + ], + dtype=self.dtype, ) factor = 25.0 / (x * x) @@ -4509,15 +4651,22 @@ def default(x): qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - return ((pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) + - 5.0 / x * - (qp / qq) * jnp.cos(x - 0.785398163397448309615660845819875721)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) + return ( + ( + pp / pq * jnp.sin(x - 0.785398163397448309615660845819875721) + + 5.0 + / x + * (qp / qq) + * jnp.cos(x - 0.785398163397448309615660845819875721) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) return jnp.piecewise( - self, - [self <= 5.0, self < 0., self == 0.], - [small, negative, zero, default], + self, + [self <= 5.0, self < 0.0, self == 0.0], + [small, negative, zero, default], ) @@ -4534,86 +4683,87 @@ def negative(x): def small(x): YP = jnp.array( - [ - 1.26320474790178026440e09, - -6.47355876379160291031e11, - 1.14509511541823727583e14, - -8.12770255501325109621e15, - 2.02439475713594898196e17, - -7.78877196265950026825e17, - ], - dtype=self.dtype, + [ + 1.26320474790178026440e09, + -6.47355876379160291031e11, + 1.14509511541823727583e14, + -8.12770255501325109621e15, + 2.02439475713594898196e17, + -7.78877196265950026825e17, + ], + dtype=self.dtype, ) YQ = jnp.array( - [ - 5.94301592346128195359e02, - 2.35564092943068577943e05, - 7.34811944459721705660e07, - 1.87601316108706159478e10, - 3.88231277496238566008e12, - 6.20557727146953693363e14, - 6.87141087355300489866e16, - 3.97270608116560655612e18, - ], - dtype=self.dtype, + [ + 5.94301592346128195359e02, + 2.35564092943068577943e05, + 7.34811944459721705660e07, + 1.87601316108706159478e10, + 3.88231277496238566008e12, + 6.20557727146953693363e14, + 6.87141087355300489866e16, + 3.97270608116560655612e18, + ], + dtype=self.dtype, ) yp = op_base.foreach_loop(YP, lambda carry, yp_i: carry * (x * x) + yp_i) yq = op_base.foreach_loop(YQ, lambda carry, yq_i: carry * (x * x) + yq_i) - return (x * (yp / yq) + - (0.636619772367581343075535053490057448 * - (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x))) + return x * (yp / yq) + ( + 0.636619772367581343075535053490057448 + * (_aten_special_bessel_j1(x) * jnp.log(x) - 1.0 / x) + ) def default(x): PP = jnp.array( - [ - 7.62125616208173112003e-04, - 7.31397056940917570436e-02, - 1.12719608129684925192e00, - 5.11207951146807644818e00, - 8.42404590141772420927e00, - 5.21451598682361504063e00, - 1.00000000000000000254e00, - ], - dtype=self.dtype, + [ + 7.62125616208173112003e-04, + 7.31397056940917570436e-02, + 1.12719608129684925192e00, + 5.11207951146807644818e00, + 8.42404590141772420927e00, + 5.21451598682361504063e00, + 1.00000000000000000254e00, + ], + dtype=self.dtype, ) PQ = jnp.array( - [ - 5.71323128072548699714e-04, - 6.88455908754495404082e-02, - 1.10514232634061696926e00, - 5.07386386128601488557e00, - 8.39985554327604159757e00, - 5.20982848682361821619e00, - 9.99999999999999997461e-01, - ], - dtype=self.dtype, + [ + 5.71323128072548699714e-04, + 6.88455908754495404082e-02, + 1.10514232634061696926e00, + 5.07386386128601488557e00, + 8.39985554327604159757e00, + 5.20982848682361821619e00, + 9.99999999999999997461e-01, + ], + dtype=self.dtype, ) QP = jnp.array( - [ - 5.10862594750176621635e-02, - 4.98213872951233449420e00, - 7.58238284132545283818e01, - 3.66779609360150777800e02, - 7.10856304998926107277e02, - 5.97489612400613639965e02, - 2.11688757100572135698e02, - 2.52070205858023719784e01, - ], - dtype=self.dtype, + [ + 5.10862594750176621635e-02, + 4.98213872951233449420e00, + 7.58238284132545283818e01, + 3.66779609360150777800e02, + 7.10856304998926107277e02, + 5.97489612400613639965e02, + 2.11688757100572135698e02, + 2.52070205858023719784e01, + ], + dtype=self.dtype, ) QQ = jnp.array( - [ - 7.42373277035675149943e01, - 1.05644886038262816351e03, - 4.98641058337653607651e03, - 9.56231892404756170795e03, - 7.99704160447350683650e03, - 2.82619278517639096600e03, - 3.36093607810698293419e02, - ], - dtype=self.dtype, + [ + 7.42373277035675149943e01, + 1.05644886038262816351e03, + 4.98641058337653607651e03, + 9.56231892404756170795e03, + 7.99704160447350683650e03, + 2.82619278517639096600e03, + 3.36093607810698293419e02, + ], + dtype=self.dtype, ) factor = 25.0 / (x * x) @@ -4622,15 +4772,22 @@ def default(x): qp = op_base.foreach_loop(QP, lambda carry, qp_i: carry * factor + qp_i) qq = op_base.foreach_loop(QQ, lambda carry, qq_i: carry * factor + qq_i) - return ((pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) + - 5.0 / x * - (qp / qq) * jnp.cos(x - 2.356194490192344928846982537459627163)) * - 0.797884560802865355879892119868763737 / jnp.sqrt(x)) + return ( + ( + pp / pq * jnp.sin(x - 2.356194490192344928846982537459627163) + + 5.0 + / x + * (qp / qq) + * jnp.cos(x - 2.356194490192344928846982537459627163) + ) + * 0.797884560802865355879892119868763737 + / jnp.sqrt(x) + ) return jnp.piecewise( - self, - [self <= 5.0, self < 0., self == 0.], - [small, negative, zero, default], + self, + [self <= 5.0, self < 0.0, self == 0.0], + [small, negative, zero, default], ) @@ -4641,13 +4798,13 @@ def _aten_special_chebyshev_polynomial_t(self, n): @jnp.vectorize def vectorized(x, n_i): - def negative_n(x): return jnp.zeros_like(x) def one_x(x): - return jnp.where((x > 0) | (n_i % 2 == 0), jnp.ones_like(x), - -jnp.ones_like(x)) + return jnp.where( + (x > 0) | (n_i % 2 == 0), jnp.ones_like(x), -jnp.ones_like(x) + ) def large_n_small_x(x): return jnp.cos(n_i * jnp.acos(x)) @@ -4659,18 +4816,24 @@ def one_n(x): return x def default(x): - def f(_, carry): p, q = carry return (q, 2 * x * q - p) - _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1., x)) + _, r = jax.lax.fori_loop(0, n_i - 1, f, init_val=(1.0, x)) return r - return jnp.piecewise(x, [ - n_i == 1, n_i == 0, (n_i == 6) & (jnp.abs(x) < 1), - jnp.abs(x) == 1., n_i < 0 - ], [one_n, zero_n, large_n_small_x, one_x, negative_n, default]) + return jnp.piecewise( + x, + [ + n_i == 1, + n_i == 0, + (n_i == 6) & (jnp.abs(x) < 1), + jnp.abs(x) == 1.0, + n_i < 0, + ], + [one_n, zero_n, large_n_small_x, one_x, negative_n, default], + ) # Explcicitly vectorize since we must vectorizes over both self and n return vectorized(self, n.astype(jnp.int64)) @@ -4683,7 +4846,6 @@ def _aten_special_chebyshev_polynomial_u(self, n): @jnp.vectorize def vectorized(x, n_i): - def negative_n(x): return jnp.zeros_like(x) @@ -4693,9 +4855,9 @@ def one_x(x): def large_n_small_x(x): sin_acos_x = jnp.sin(jnp.acos(x)) return jnp.where( - sin_acos_x != 0, - jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, - (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, + sin_acos_x != 0, + jnp.sin((n_i + 1) * jnp.acos(x)) / sin_acos_x, + (n_i + 1) * jnp.cos((n_i + 1) * jnp.acos(x)) / x, ) def zero_n(x): @@ -4705,7 +4867,6 @@ def one_n(x): return 2 * x def default(x): - def f(_, carry): p, q = carry return (q, 2 * x * q - p) @@ -4714,15 +4875,15 @@ def f(_, carry): return r return jnp.piecewise( - x, - [ - n_i == 1, - n_i == 0, - (n_i > 8) & (jnp.abs(x) < 1), - jnp.abs(x) == 1.0, - n_i < 0, - ], - [one_n, zero_n, large_n_small_x, one_x, negative_n, default], + x, + [ + n_i == 1, + n_i == 0, + (n_i > 8) & (jnp.abs(x) < 1), + jnp.abs(x) == 1.0, + n_i < 0, + ], + [one_n, zero_n, large_n_small_x, one_x, negative_n, default], ) return vectorized(self, n.astype(jnp.int64)) @@ -4747,7 +4908,6 @@ def _aten_special_hermite_polynomial_h(self, n): @jnp.vectorize def vectorized(x, n_i): - def negative_n(x): return jnp.zeros_like(x) @@ -4758,7 +4918,6 @@ def one_n(x): return 2 * x def default(x): - def f(k, carry): p, q = carry return (q, 2 * x * q - 2 * k * p) @@ -4766,8 +4925,11 @@ def f(k, carry): _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, 2 * x)) return r - return jnp.piecewise(x, [n_i == 1, n_i == 0, n_i < 0], - [one_n, zero_n, negative_n, default]) + return jnp.piecewise( + x, + [n_i == 1, n_i == 0, n_i < 0], + [one_n, zero_n, negative_n, default], + ) return vectorized(self, n.astype(jnp.int64)) @@ -4779,7 +4941,6 @@ def _aten_special_hermite_polynomial_he(self, n): @jnp.vectorize def vectorized(x, n_i): - def negative_n(x): return jnp.zeros_like(x) @@ -4790,7 +4951,6 @@ def one_n(x): return x def default(x): - def f(k, carry): p, q = carry return (q, x * q - k * p) @@ -4798,33 +4958,37 @@ def f(k, carry): _, r = jax.lax.fori_loop(1, n_i, f, init_val=(1.0, x)) return r - return jnp.piecewise(x, [n_i == 1.0, n_i == 0.0, n_i < 0], - [one_n, zero_n, negative_n, default]) + return jnp.piecewise( + x, + [n_i == 1.0, n_i == 0.0, n_i < 0], + [one_n, zero_n, negative_n, default], + ) return vectorized(self, n.astype(jnp.int64)) @op(torch.ops.aten.multinomial, needs_env=True) -def _aten_multinomial(input, - num_samples, - replacement=False, - *, - generator=None, - out=None, - env=None): - assert num_samples <= input.shape[ - -1] or replacement, "cannot take a larger sample than population when replacement=False" +def _aten_multinomial( + input, num_samples, replacement=False, *, generator=None, out=None, env=None +): + assert num_samples <= input.shape[-1] or replacement, ( + "cannot take a larger sample than population when replacement=False" + ) key = env.get_and_rotate_prng_key(generator) if input.ndim == 1: return jax.random.choice( - key, input.shape[-1], (num_samples,), replace=replacement, p=input) + key, input.shape[-1], (num_samples,), replace=replacement, p=input + ) else: return jnp.array([ - jax.random.choice( - key, - input.shape[-1], (num_samples,), - replace=replacement, - p=input[i, :]) for i in range(input.shape[0]) + jax.random.choice( + key, + input.shape[-1], + (num_samples,), + replace=replacement, + p=input[i, :], + ) + for i in range(input.shape[0]) ]) @@ -4852,13 +5016,13 @@ def _aten_flatten(x, start_dim=0, end_dim=-1): if end_dim < 0: end_dim += len(shape) # Handle negative indexing - new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1:]) + new_shape = (*shape[:start_dim], -1, *shape[end_dim + 1 :]) return jnp.reshape(x, new_shape) @op(torch.ops.aten.new_empty) def _new_empty(self, size, **kwargs): - dtype = kwargs.get('dtype') + dtype = kwargs.get("dtype") if dtype is not None: dtype = mappings.t2j_dtype(dtype) else: @@ -4881,8 +5045,12 @@ def _aten_unsafe_index_put(self, indices, values, accumulate=False): return _aten_index_put(self, indices, values, accumulate) -@op(torch.ops.aten.conj_physical, torch.ops.aten.conj, - torch.ops.aten._conj_physical, torch.ops.aten._conj) +@op( + torch.ops.aten.conj_physical, + torch.ops.aten.conj, + torch.ops.aten._conj_physical, + torch.ops.aten._conj, +) def _aten_conj_physical(self): return jnp.conjugate(self) @@ -4952,18 +5120,16 @@ def _aten__linalg_solve_ex(a, b): # torch.linalg.solve_triangular @op(torch.ops.aten.linalg_solve_triangular) -def _aten_linalg_solve_triangular(a, - b, - *, - upper=True, - left=True, - unitriangular=False): +def _aten_linalg_solve_triangular( + a, b, *, upper=True, left=True, unitriangular=False +): if left is False: a = jnp.matrix_transpose(a) b = jnp.matrix_transpose(b) upper = not upper res = jax.scipy.linalg.solve_triangular( - a, b, lower=not upper, unit_diagonal=unitriangular) + a, b, lower=not upper, unit_diagonal=unitriangular + ) if left is False: res = jnp.matrix_transpose(res) return res @@ -4984,30 +5150,34 @@ def _aten__linalg_check_errors(*args, **kwargs): @op(torch.ops.aten.median) def _aten_median(self, dim=None, keepdim=False): output = _with_reduction_scalar( - functools.partial(jnp.quantile, q=0.5, method='lower'), - self, - dim=dim, - keepdim=keepdim).astype(self.dtype) + functools.partial(jnp.quantile, q=0.5, method="lower"), + self, + dim=dim, + keepdim=keepdim, + ).astype(self.dtype) if dim is None: return output else: - index = _with_reduction_scalar(_get_median_index, self, dim, - keepdim).astype(jnp.int64) + index = _with_reduction_scalar( + _get_median_index, self, dim, keepdim + ).astype(jnp.int64) return output, index @op(torch.ops.aten.nanmedian) def _aten_nanmedian(input, dim=None, keepdim=False, *, out=None): output = _with_reduction_scalar( - functools.partial(jnp.nanquantile, q=0.5, method='lower'), - input, - dim=dim, - keepdim=keepdim).astype(input.dtype) + functools.partial(jnp.nanquantile, q=0.5, method="lower"), + input, + dim=dim, + keepdim=keepdim, + ).astype(input.dtype) if dim is None: return output else: - index = _with_reduction_scalar(_get_median_index, input, dim, - keepdim).astype(jnp.int64) + index = _with_reduction_scalar( + _get_median_index, input, dim, keepdim + ).astype(jnp.int64) return output, index @@ -5028,18 +5198,20 @@ def _get_median_index(x, axis=None, keepdims=False): @op(torch.ops.aten.triangular_solve) -def _aten_triangular_solve(b, - a, - upper=True, - transpose=False, - unittriangular=False): - return (jax.lax.linalg.triangular_solve( +def _aten_triangular_solve( + b, a, upper=True, transpose=False, unittriangular=False +): + return ( + jax.lax.linalg.triangular_solve( a, b, left_side=True, lower=not upper, transpose_a=transpose, - unit_diagonal=unittriangular), a) + unit_diagonal=unittriangular, + ), + a, + ) # func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor @@ -5047,16 +5219,16 @@ def _aten_triangular_solve(b, def _aten__fft_c2c(self, dim, normalization, forward): if forward: norm = [ - 'backward', - 'ortho', - 'forward', + "backward", + "ortho", + "forward", ][normalization] return jnp.fft.fftn(self, axes=dim, norm=norm) else: norm = [ - 'forward', - 'ortho', - 'backward', + "forward", + "ortho", + "backward", ][normalization] return jnp.fft.ifftn(self, axes=dim, norm=norm) @@ -5064,9 +5236,9 @@ def _aten__fft_c2c(self, dim, normalization, forward): @op(torch.ops.aten._fft_r2c) def _aten__fft_r2c(self, dim, normalization, onesided): norm = [ - 'backward', - 'ortho', - 'forward', + "backward", + "ortho", + "forward", ][normalization] if onesided: return jnp.fft.rfftn(self, axes=dim, norm=norm) @@ -5077,9 +5249,9 @@ def _aten__fft_r2c(self, dim, normalization, onesided): @op(torch.ops.aten._fft_c2r) def _aten__fft_c2r(self, dim, normalization, last_dim_size): norm = [ - 'forward', - 'ortho', - 'backward', + "forward", + "ortho", + "backward", ][normalization] if len(dim) == 1: s = [last_dim_size] @@ -5089,17 +5261,15 @@ def _aten__fft_c2r(self, dim, normalization, last_dim_size): @op(torch.ops.aten._trilinear) -def _aten_trilinear(i1, - i2, - i3, - expand1, - expand2, - expand3, - sumdim, - unroll_dim=1): +def _aten_trilinear( + i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim=1 +): return _aten_sum( - jnp.expand_dims(i1, expand1) * jnp.expand_dims(i2, expand2) * - jnp.expand_dims(i3, expand3), sumdim) + jnp.expand_dims(i1, expand1) + * jnp.expand_dims(i2, expand2) + * jnp.expand_dims(i3, expand3), + sumdim, + ) @op(torch.ops.aten.max_unpool2d) @@ -5107,7 +5277,8 @@ def _aten_trilinear(i1, def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): if output_size is None: raise ValueError( - "output_size value is not set correctly. It cannot be None or empty.") + "output_size value is not set correctly. It cannot be None or empty." + ) output_size = [input.shape[0], input.shape[1]] + output_size output = jnp.zeros(output_size, dtype=input.dtype) @@ -5122,14 +5293,16 @@ def _aten_max_unpoolxd(input, indices, output_size, stride=None, padding=0): return output -def _aten_upsample(input, - output_size, - align_corners, - antialias, - method, - scale_factors=None, - scales_h=None, - scales_w=None): +def _aten_upsample( + input, + output_size, + align_corners, + antialias, + method, + scale_factors=None, + scales_h=None, + scales_w=None, +): # input: is of type jaxlib.xla_extension.ArrayImpl image = input @@ -5177,56 +5350,62 @@ def _aten_upsample(input, # https://github.com/jax-ml/jax/issues/11206 if align_corners: scale = jnp.array([ - (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims + (shape[i] - 1.0) / (image.shape[i] - 1.0) for i in spatial_dims ]) translation = jnp.array([0 for i in spatial_dims]) return jax_reimplement.scale_and_translate( - image, - shape, - method=method, - scale=scale, - spatial_dims=spatial_dims, - translation=translation, - antialias=antialias, + image, + shape, + method=method, + scale=scale, + spatial_dims=spatial_dims, + translation=translation, + antialias=antialias, ) @op(torch.ops.aten._upsample_bilinear2d_aa) -def _aten_upsample_billinear_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): +def _aten_upsample_billinear_aa( + input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None, +): return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bilinear", # method - scale_factors, - scales_h, - scales_w) + input, + output_size, + align_corners, + True, # antialias + "bilinear", # method + scale_factors, + scales_h, + scales_w, + ) @op(torch.ops.aten._upsample_bicubic2d_aa) -def _aten_upsample_bicubic2d_aa(input, - output_size, - align_corners, - scale_factors=None, - scales_h=None, - scales_w=None): +def _aten_upsample_bicubic2d_aa( + input, + output_size, + align_corners, + scale_factors=None, + scales_h=None, + scales_w=None, +): return _aten_upsample( - input, - output_size, - align_corners, - True, # antialias - "bicubic", # method - scale_factors, - scales_h, - scales_w) + input, + output_size, + align_corners, + True, # antialias + "bicubic", # method + scale_factors, + scales_h, + scales_w, + ) @op(torch.ops.aten.polar) @@ -5235,10 +5414,9 @@ def _aten_polar(abs, angle, *, out=None): @op(torch.ops.aten.cdist) -def _aten_cdist(x1, - x2, - p=2.0, - compute_mode='use_mm_for_euclid_dist_if_necessary'): +def _aten_cdist( + x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary" +): x1 = x1.astype(jnp.float32) x2 = x2.astype(jnp.float32) @@ -5247,17 +5425,18 @@ def _aten_cdist(x1, return _hamming_distance(x1, x2).astype(jnp.float32) elif p == 2.0: # Use optimized Euclidean distance calculation - if compute_mode == 'use_mm_for_euclid_dist_if_necessary' and ( - x1.shape[-2] > 25 or x2.shape[-2] > 25): + if compute_mode == "use_mm_for_euclid_dist_if_necessary" and ( + x1.shape[-2] > 25 or x2.shape[-2] > 25 + ): return _euclidean_mm(x1, x2) - elif compute_mode == 'use_mm_for_euclid_dist': + elif compute_mode == "use_mm_for_euclid_dist": return _euclidean_mm(x1, x2) else: return _euclidean_direct(x1, x2) else: # General p-norm distance calculation diff = jnp.abs(jnp.expand_dims(x1, -2) - jnp.expand_dims(x2, -3)) - return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32)**(1 / p) + return jnp.sum(jnp.power(diff, p), axis=-1).astype(jnp.float32) ** (1 / p) def _hamming_distance(x1, x2): @@ -5338,7 +5517,7 @@ def _aten_lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): # Extract lower triangle L = jnp.tril(LU_data, k=-1) - #emulate pytorch behavior: Add ones to the diagonal of L + # emulate pytorch behavior: Add ones to the diagonal of L eye = jnp.eye(n, m, dtype=LU_data.dtype) L = L + eye @@ -5413,7 +5592,11 @@ def update_indices(i, _indices): unpackedP = v_lu_unpack_2d(reshapedP, reshapedPivot) # reshape result back to P's shape - newRetshape = (*P.shape[:-2], unpackedP.shape[-2], unpackedP.shape[-1]) + newRetshape = ( + *P.shape[:-2], + unpackedP.shape[-2], + unpackedP.shape[-1], + ) P = unpackedP.reshape(newRetshape) else: # emulate pytroch behavior: return empty tensors @@ -5440,10 +5623,14 @@ def kthvalue(input, k, dim=None, keepdim=False, *, out=None): while dimension < 0: dimension = dimension + input.ndim values = jax.lax.index_in_dim( - jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim) + jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim + ) indices = jax.lax.index_in_dim( - jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1, - dimension, keepdim) + jnp.argpartition(input, k - 1, dimension).astype("int64"), + k - 1, + dimension, + keepdim, + ) return values, indices @@ -5454,14 +5641,14 @@ def _aten_take(self, index): # func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor @op(torch.ops.aten.pad) -def _aten_pad(self, pad, mode='constant', value=None): +def _aten_pad(self, pad, mode="constant", value=None): if not isinstance(pad, (tuple, list)) or len(pad) % 2 != 0: raise ValueError("Padding must be a sequence of even length.") num_dims = self.ndim if len(pad) > 2 * num_dims: raise ValueError( - f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." + f"Padding sequence length ({len(pad)}) exceeds 2 * number of dimensions ({2 * num_dims})." ) # JAX's pad function expects padding for each dimension as a tuple of (low, high) @@ -5485,99 +5672,75 @@ def _aten_pad(self, pad, mode='constant', value=None): if value is None: value = 0.0 return jnp.pad( - self, pad_width=jax_pad_width, mode="constant", constant_values=value) + self, + pad_width=jax_pad_width, + mode="constant", + constant_values=value, + ) elif mode == "reflect": return jnp.pad(self, pad_width=jax_pad_width, mode="reflect") elif mode == "edge": return jnp.pad(self, pad_width=jax_pad_width, mode="edge") else: raise ValueError( - f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." + f"Unsupported padding mode: {mode}. Expected 'constant', 'reflect', or 'edge'." ) mutation_ops_to_functional = { - torch.ops.aten.add_: - op_base.InplaceOp(torch.ops.aten.add), - torch.ops.aten.sub_: - op_base.InplaceOp(torch.ops.aten.sub), - torch.ops.aten.mul_: - op_base.InplaceOp(torch.ops.aten.mul), - torch.ops.aten.div_: - op_base.InplaceOp(torch.ops.aten.div), - torch.ops.aten.pow_: - op_base.InplaceOp(torch.ops.aten.pow), - torch.ops.aten.lt_: - op_base.InplaceOp(torch.ops.aten.lt), - torch.ops.aten.le_: - op_base.InplaceOp(torch.ops.aten.le), - torch.ops.aten.gt_: - op_base.InplaceOp(torch.ops.aten.gt), - torch.ops.aten.ge_: - op_base.InplaceOp(torch.ops.aten.ge), - torch.ops.aten.eq_: - op_base.InplaceOp(torch.ops.aten.eq), - torch.ops.aten.ne_: - op_base.InplaceOp(torch.ops.aten.ne), - torch.ops.aten.bernoulli_: - op_base.InplaceOp(torch.ops.aten.bernoulli.p), - torch.ops.aten.bernoulli_.float: - op_base.InplaceOp(_aten_bernoulli, is_jax_func=True), - torch.ops.aten.geometric_: - op_base.InplaceOp(torch.ops.aten.geometric), - torch.ops.aten.normal_: - op_base.InplaceOp(torch.ops.aten.normal), - torch.ops.aten.random_: - op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.uniform_: - op_base.InplaceOp(torch.ops.aten.uniform), - torch.ops.aten.relu_: - op_base.InplaceOp(torch.ops.aten.relu), - # squeeze_ is expected to change tensor's shape. So replace with new value - torch.ops.aten.squeeze_: - op_base.InplaceOp(torch.ops.aten.squeeze, True), - torch.ops.aten.sqrt_: - op_base.InplaceOp(torch.ops.aten.sqrt), - torch.ops.aten.clamp_: - op_base.InplaceOp(torch.ops.aten.clamp), - torch.ops.aten.clamp_min_: - op_base.InplaceOp(torch.ops.aten.clamp_min), - torch.ops.aten.sigmoid_: - op_base.InplaceOp(torch.ops.aten.sigmoid), - torch.ops.aten.tanh_: - op_base.InplaceOp(torch.ops.aten.tanh), - torch.ops.aten.ceil_: - op_base.InplaceOp(torch.ops.aten.ceil), - torch.ops.aten.logical_not_: - op_base.InplaceOp(torch.ops.aten.logical_not), - torch.ops.aten.unsqueeze_: - op_base.InplaceOp(torch.ops.aten.unsqueeze), - torch.ops.aten.transpose_: - op_base.InplaceOp(torch.ops.aten.transpose), - torch.ops.aten.log_normal_: - op_base.InplaceOp(torch.ops.aten.log_normal), - torch.ops.aten.scatter_add_: - op_base.InplaceOp(torch.ops.aten.scatter_add), - torch.ops.aten.scatter_reduce_.two: - op_base.InplaceOp(torch.ops.aten.scatter_reduce), - torch.ops.aten.scatter_: - op_base.InplaceOp(torch.ops.aten.scatter), - torch.ops.aten.bitwise_or_: - op_base.InplaceOp(torch.ops.aten.bitwise_or), + torch.ops.aten.add_: op_base.InplaceOp(torch.ops.aten.add), + torch.ops.aten.sub_: op_base.InplaceOp(torch.ops.aten.sub), + torch.ops.aten.mul_: op_base.InplaceOp(torch.ops.aten.mul), + torch.ops.aten.div_: op_base.InplaceOp(torch.ops.aten.div), + torch.ops.aten.pow_: op_base.InplaceOp(torch.ops.aten.pow), + torch.ops.aten.lt_: op_base.InplaceOp(torch.ops.aten.lt), + torch.ops.aten.le_: op_base.InplaceOp(torch.ops.aten.le), + torch.ops.aten.gt_: op_base.InplaceOp(torch.ops.aten.gt), + torch.ops.aten.ge_: op_base.InplaceOp(torch.ops.aten.ge), + torch.ops.aten.eq_: op_base.InplaceOp(torch.ops.aten.eq), + torch.ops.aten.ne_: op_base.InplaceOp(torch.ops.aten.ne), + torch.ops.aten.bernoulli_: op_base.InplaceOp(torch.ops.aten.bernoulli.p), + torch.ops.aten.bernoulli_.float: op_base.InplaceOp( + _aten_bernoulli, is_jax_func=True + ), + torch.ops.aten.geometric_: op_base.InplaceOp(torch.ops.aten.geometric), + torch.ops.aten.normal_: op_base.InplaceOp(torch.ops.aten.normal), + torch.ops.aten.random_: op_base.InplaceOp(torch.ops.aten.uniform), + torch.ops.aten.uniform_: op_base.InplaceOp(torch.ops.aten.uniform), + torch.ops.aten.relu_: op_base.InplaceOp(torch.ops.aten.relu), + # squeeze_ is expected to change tensor's shape. So replace with new value + torch.ops.aten.squeeze_: op_base.InplaceOp(torch.ops.aten.squeeze, True), + torch.ops.aten.sqrt_: op_base.InplaceOp(torch.ops.aten.sqrt), + torch.ops.aten.clamp_: op_base.InplaceOp(torch.ops.aten.clamp), + torch.ops.aten.clamp_min_: op_base.InplaceOp(torch.ops.aten.clamp_min), + torch.ops.aten.sigmoid_: op_base.InplaceOp(torch.ops.aten.sigmoid), + torch.ops.aten.tanh_: op_base.InplaceOp(torch.ops.aten.tanh), + torch.ops.aten.ceil_: op_base.InplaceOp(torch.ops.aten.ceil), + torch.ops.aten.logical_not_: op_base.InplaceOp(torch.ops.aten.logical_not), + torch.ops.aten.unsqueeze_: op_base.InplaceOp(torch.ops.aten.unsqueeze), + torch.ops.aten.transpose_: op_base.InplaceOp(torch.ops.aten.transpose), + torch.ops.aten.log_normal_: op_base.InplaceOp(torch.ops.aten.log_normal), + torch.ops.aten.scatter_add_: op_base.InplaceOp(torch.ops.aten.scatter_add), + torch.ops.aten.scatter_reduce_.two: op_base.InplaceOp( + torch.ops.aten.scatter_reduce + ), + torch.ops.aten.scatter_: op_base.InplaceOp(torch.ops.aten.scatter), + torch.ops.aten.bitwise_or_: op_base.InplaceOp(torch.ops.aten.bitwise_or), } # Note: tuple comparisons work intuitively, e.g. `_jax_version >= (0, 4, 32)`. _jax_version = tuple(int(v) for v in jax.version._version.split(".")) mutation_needs_env = { - torch.ops.aten.bernoulli_, - torch.ops.aten.bernoulli_.float, + torch.ops.aten.bernoulli_, + torch.ops.aten.bernoulli_.float, } for operator, mutation in mutation_ops_to_functional.items(): ops_registry.register_torch_dispatch_op( - operator, - mutation, - is_jax_function=False, - is_view_op=True, - needs_env=(operator in mutation_needs_env)) + operator, + mutation, + is_jax_function=False, + is_view_op=True, + needs_env=(operator in mutation_needs_env), + ) diff --git a/torchax/torchax/ops/jax_reimplement.py b/torchax/torchax/ops/jax_reimplement.py index d9acc3be51a..f98dc240437 100644 --- a/torchax/torchax/ops/jax_reimplement.py +++ b/torchax/torchax/ops/jax_reimplement.py @@ -15,27 +15,40 @@ # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L52 -def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize, - scale, translation, kernel: Callable, antialias: bool): +def compute_weight_mat( + input_size: core.DimSize, + output_size: core.DimSize, + scale, + translation, + kernel: Callable, + antialias: bool, +): dtype = jnp.result_type(scale, translation) - inv_scale = 1. / scale + inv_scale = 1.0 / scale # When downsampling the kernel should be scaled since we want to low pass # filter and interpolate, but when upsampling it should not be since we only # want to interpolate. - kernel_scale = jnp.maximum(inv_scale, 1.) if antialias else 1. - sample_f = ((jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale - - translation * inv_scale - 0.5) + kernel_scale = jnp.maximum(inv_scale, 1.0) if antialias else 1.0 + sample_f = ( + (jnp.arange(output_size, dtype=dtype) + 0.5) * inv_scale + - translation * inv_scale + - 0.5 + ) x = ( - jnp.abs(sample_f[jnp.newaxis, :] - - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis]) / - kernel_scale) + jnp.abs( + sample_f[jnp.newaxis, :] + - jnp.arange(input_size, dtype=dtype)[:, jnp.newaxis] + ) + / kernel_scale + ) weights = kernel(x) total_weight_sum = jnp.sum(weights, axis=0, keepdims=True) weights = jnp.where( - jnp.abs(total_weight_sum) > 1000. * float(np.finfo(np.float32).eps), - jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, - 1)), 0) + jnp.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps), + jnp.divide(weights, jnp.where(total_weight_sum != 0, total_weight_sum, 1)), + 0, + ) # Zero out weights where the sample location is completely outside the input # range. # Note sample_f has already had the 0.5 removed, hence the weird range below. @@ -44,17 +57,28 @@ def compute_weight_mat(input_size: core.DimSize, output_size: core.DimSize, return weights input_size_minus_0_5 = core.dimension_as_value(input_size) - 0.5 return jnp.where( - jnp.logical_and(sample_f >= -0.5, sample_f - <= input_size_minus_0_5)[jnp.newaxis, :], weights, 0) + jnp.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[ + jnp.newaxis, : + ], + weights, + 0, + ) # (barney-s) -------------- END returning weights without zeroing --------------------- # JAX Link: https://github.com/jax-ml/jax/blob/18f48bd52abe907ff9818da52f3d195d32910c1b/jax/_src/image/scale.py#L86 -def _scale_and_translate(x, output_shape: core.Shape, - spatial_dims: Sequence[int], scale, translation, - kernel, antialias: bool, precision): +def _scale_and_translate( + x, + output_shape: core.Shape, + spatial_dims: Sequence[int], + scale, + translation, + kernel, + antialias: bool, + precision, +): input_shape = x.shape assert len(input_shape) == len(output_shape) assert len(spatial_dims) == len(scale) @@ -68,8 +92,9 @@ def _scale_and_translate(x, output_shape: core.Shape, d = canonicalize_axis(d, x.ndim) m = input_shape[d] n = output_shape[d] - w = compute_weight_mat(m, n, scale[i], translation[i], kernel, - antialias).astype(x.dtype) + w = compute_weight_mat( + m, n, scale[i], translation[i], kernel, antialias + ).astype(x.dtype) contractions.append(w) contractions.append([d, len(output_shape) + i]) out_indices[d] = len(output_shape) + i @@ -83,15 +108,16 @@ def _scale_and_translate(x, output_shape: core.Shape, # scale and translation here are scalar elements of an np.array, what is the # correct type annotation? def scale_and_translate( - image, - shape: core.Shape, - spatial_dims: Sequence[int], - scale, - translation, - # (barney-s) use string - method: str, #(barney-s) | ResizeMethod, - antialias: bool = True, - precision=lax.Precision.HIGHEST): + image, + shape: core.Shape, + spatial_dims: Sequence[int], + scale, + translation, + # (barney-s) use string + method: str, # (barney-s) | ResizeMethod, + antialias: bool = True, + precision=lax.Precision.HIGHEST, +): """Apply a scale and translation to an image. Generates a new image of shape 'shape' by resampling from the input image @@ -149,23 +175,35 @@ def scale_and_translate( """ shape = core.canonicalize_shape(shape) if len(shape) != image.ndim: - msg = ('shape must have length equal to the number of dimensions of x; ' - f' {shape} vs {image.shape}') + msg = ( + "shape must have length equal to the number of dimensions of x; " + f" {shape} vs {image.shape}" + ) raise ValueError(msg) if isinstance(method, str): method = ResizeMethod.from_string(method) if method == ResizeMethod.NEAREST: # Nearest neighbor is currently special-cased for straight resize, so skip # for now. - raise ValueError('Nearest neighbor resampling is not currently supported ' - 'for scale_and_translate.') + raise ValueError( + "Nearest neighbor resampling is not currently supported " + "for scale_and_translate." + ) assert isinstance(method, ResizeMethod) kernel = _kernels[method] - image, = promote_dtypes_inexact(image) + (image,) = promote_dtypes_inexact(image) scale, translation = promote_dtypes_inexact(scale, translation) - return _scale_and_translate(image, shape, spatial_dims, scale, translation, - kernel, antialias, precision) + return _scale_and_translate( + image, + shape, + spatial_dims, + scale, + translation, + kernel, + antialias, + precision, + ) # END ----------------- END JAX code copied for testing ----------------------------- diff --git a/torchax/torchax/ops/jc10d.py b/torchax/torchax/ops/jc10d.py index 79544943f91..9cfbc8ba97e 100644 --- a/torchax/torchax/ops/jc10d.py +++ b/torchax/torchax/ops/jc10d.py @@ -6,7 +6,6 @@ def op(*aten, **kwargs): - def inner(func): for a in aten: ops_registry.register_torch_dispatch_op(a, func, **kwargs) @@ -22,7 +21,6 @@ def _c10d_all_gather(input, group_size: int, group_name: str): @op(torch.ops._c10d_functional.all_reduce) def _c10d_all_reduce(self, reduceOp: str, group_name: str): - if reduceOp == "sum": res = jax.lax.psum(self, axis_name="torch_dist") elif reduceOp == "avg": @@ -39,9 +37,9 @@ def _c10d_all_reduce(self, reduceOp: str, group_name: str): @op(torch.ops._c10d_functional.broadcast) def _c10d_broadcast(self, src: int, group_name: str): masked = jnp.where( - jax.lax.axis_index("torch_dist") == src, - self, - jnp.zeros_like(self), + jax.lax.axis_index("torch_dist") == src, + self, + jnp.zeros_like(self), ) return jax.lax.psum(masked, "torch_dist") diff --git a/torchax/torchax/ops/jimage.py b/torchax/torchax/ops/jimage.py index 947be0a5e3f..8ca0653ce46 100644 --- a/torchax/torchax/ops/jimage.py +++ b/torchax/torchax/ops/jimage.py @@ -7,19 +7,16 @@ def cubic_kernel(x, a=-0.75): absx = jnp.abs(x) x2 = absx * absx x3 = x2 * absx - cond1 = (absx <= 1) + cond1 = absx <= 1 cond2 = (absx > 1) & (absx < 2) f1 = (a + 2) * x3 - (a + 3) * x2 + 1 f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) -def compute_contribs(in_size, - out_size, - scale, - support=2.0, - align_corners=False, - dtype=None): +def compute_contribs( + in_size, out_size, scale, support=2.0, align_corners=False, dtype=None +): if align_corners: if out_size == 1: in_coords = jnp.zeros((1,), dtype=dtype) @@ -48,10 +45,10 @@ def gather_weights(img, idxs, axis): def interpolate_along_axis_bchw(img, idxs, weights, axis): """ - Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). - idxs: (out_size, 4) int32 indices - weights: (out_size, 4) float32 weights - """ + Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W). + idxs: (out_size, 4) int32 indices + weights: (out_size, 4) float32 weights + """ assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" out_size = idxs.shape[0] k = idxs.shape[1] # Typically 4 for cubic @@ -66,13 +63,15 @@ def gather_and_weight(i): def gather_one(offset): return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) - gathered = jnp.stack([gather_one(o) for o in range(k)], - axis=0) # (4, B, C, H, W) + gathered = jnp.stack( + [gather_one(o) for o in range(k)], axis=0 + ) # (4, B, C, H, W) weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) return weighted out = jax.vmap(gather_and_weight)( - jnp.arange(out_size)) # (out_size, B, C, H, W) + jnp.arange(out_size) + ) # (out_size, B, C, H, W) # Move the interpolated axis back into place if axis == 2: # interpolated over H @@ -94,20 +93,20 @@ def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False): scale_x = out_w / w idxs_y, weights_y = compute_contribs( - h, - out_h, - scale_y, - align_corners=align_corners, - dtype=img.dtype, + h, + out_h, + scale_y, + align_corners=align_corners, + dtype=img.dtype, ) tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) idxs_x, weights_x = compute_contribs( - w, - out_w, - scale_x, - align_corners=align_corners, - dtype=img.dtype, + w, + out_w, + scale_x, + align_corners=align_corners, + dtype=img.dtype, ) out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) return out diff --git a/torchax/torchax/ops/jlibrary.py b/torchax/torchax/ops/jlibrary.py index 17cdb161c3c..90156b9c843 100644 --- a/torchax/torchax/ops/jlibrary.py +++ b/torchax/torchax/ops/jlibrary.py @@ -61,9 +61,7 @@ def register_torch_composite(composite_name, impl, *ops, **jit_args): @jaten.op(*ops) def _composite_impl(*args): - class ImplWrapper(torch.nn.Module): - def __init__(self): super().__init__() diff --git a/torchax/torchax/ops/jtorch.py b/torchax/torchax/ops/jtorch.py index 935c214d78f..a2bb6b1a6bd 100644 --- a/torchax/torchax/ops/jtorch.py +++ b/torchax/torchax/ops/jtorch.py @@ -25,7 +25,8 @@ def register_function(torch_func, **kwargs): @register_function(torch.as_tensor, is_jax_function=False, needs_env=True) @op_base.convert_dtype( - use_default_dtype=False) # Attempt to infer type from elements + use_default_dtype=False +) # Attempt to infer type from elements def _as_tensor(data, dtype=None, device=None, env=None): if isinstance(data, torch.Tensor): return env._to_copy(data, dtype, device) @@ -38,13 +39,14 @@ def _as_tensor(data, dtype=None, device=None, env=None): @register_function(torch.tensor) @op_base.convert_dtype( - use_default_dtype=False) # Attempt to infer type from elements + use_default_dtype=False +) # Attempt to infer type from elements def _tensor(data, *, dtype=None, **kwargs): python_types_to_torch_types = { - bool: jnp.bool, - int: jnp.int64, - float: jnp.float32, - complex: jnp.complex64, + bool: jnp.bool, + int: jnp.int64, + float: jnp.float32, + complex: jnp.complex64, } if not dtype: leaves = jax.tree_util.tree_leaves(data) @@ -52,7 +54,8 @@ def _tensor(data, *, dtype=None, **kwargs): dtype = python_types_to_torch_types.get(type(leaves[0])) return jnp.array( - data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype())) + data, dtype=dtype or mappings.t2j_dtype(torch.get_default_dtype()) + ) @register_function(torch.allclose) @@ -91,7 +94,6 @@ def _diag(input, diagonal=0): @register_function(torch.einsum) @register_function(torch.ops.aten.einsum) def _einsum(equation, *operands): - def get_params(*a): inner_list = a[0] if not isinstance(inner_list, jax.Array): @@ -109,22 +111,23 @@ def get_params(*a): def _sdpa_reference( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, ) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = torch.ones( - L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril( + diagonal=0 + ) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: @@ -152,28 +155,29 @@ def _tpu_flash_attention(query, key, value, env): def wrap_flash_attention(query, key, value): block_sizes = flash_attention.BlockSizes( - block_b=min(2, query.shape[0]), - block_q=min(512, query.shape[2]), - block_k_major=min(512, key.shape[2]), - block_k=min(512, key.shape[2]), - block_q_major_dkv=min(512, query.shape[2]), - block_k_major_dkv=min(512, key.shape[2]), - block_k_dkv=min(512, key.shape[2]), - block_q_dkv=min(512, query.shape[2]), - block_k_major_dq=min(512, key.shape[2]), - block_k_dq=min(256, key.shape[2]), - block_q_dq=min(1024, query.shape[2]), + block_b=min(2, query.shape[0]), + block_q=min(512, query.shape[2]), + block_k_major=min(512, key.shape[2]), + block_k=min(512, key.shape[2]), + block_q_major_dkv=min(512, query.shape[2]), + block_k_major_dkv=min(512, key.shape[2]), + block_k_dkv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_k_major_dq=min(512, key.shape[2]), + block_k_dq=min(256, key.shape[2]), + block_q_dq=min(1024, query.shape[2]), ) return flash_attention.flash_attention( - query, key, value, causal=True, block_sizes=block_sizes) + query, key, value, causal=True, block_sizes=block_sizes + ) if env.config.shmap_flash_attention: wrap_flash_attention = shard_map( - wrap_flash_attention, - mesh=env._mesh, - in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), - out_specs=fsdp_partition, - check_rep=False, + wrap_flash_attention, + mesh=env._mesh, + in_specs=(fsdp_partition, fsdp_partition, fsdp_partition), + out_specs=fsdp_partition, + check_rep=False, ) # return flash_attn_mapped(query, key, value) return wrap_flash_attention(query, key, value) @@ -185,8 +189,8 @@ def pad(tensor, pad, mode="constant", value=None): # dict provides a Torch-to-NumPy translation. Any string not in this dict will # be passed through as-is. MODE_NAME_TRANSLATION = { - "circular": "wrap", - "replicate": "edge", + "circular": "wrap", + "replicate": "edge", } numpy_mode = MODE_NAME_TRANSLATION.get(mode, mode) @@ -197,7 +201,7 @@ def pad(tensor, pad, mode="constant", value=None): nd_slice = [slice(None)] * num_prefix_dims for i in range(len(pad) - 2, -1, -2): - pad_start, pad_end = pad[i:i + 2] + pad_start, pad_end = pad[i : i + 2] slice_start, slice_end = None, None if pad_start < 0: @@ -230,39 +234,40 @@ def pad(tensor, pad, mode="constant", value=None): @register_function( - torch.nn.functional.scaled_dot_product_attention, - is_jax_function=False, - needs_env=True, + torch.nn.functional.scaled_dot_product_attention, + is_jax_function=False, + needs_env=True, ) @register_function( - torch.ops.aten.scaled_dot_product_attention, - is_jax_function=False, - needs_env=True) + torch.ops.aten.scaled_dot_product_attention, + is_jax_function=False, + needs_env=True, +) def scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - scale=None, - enable_gqa=False, - env=None, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + env=None, ) -> torch.Tensor: - if env.config.use_tpu_flash_attention: jquery, jkey, jvalue = env.t2j_iso((query, key, value)) res = _tpu_flash_attention(jquery, jkey, jvalue, env) return env.j2t_iso(res) - return _sdpa_reference(query, key, value, attn_mask, dropout_p, is_causal, - scale, enable_gqa) + return _sdpa_reference( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa + ) @register_function( - torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True) + torch.Tensor.__getitem__, is_jax_function=False, is_view_op=True +) def getitem(self, indexes): - if isinstance(indexes, list) and isinstance(indexes[0], int): # list of int, i.e. x[[1, 2]] NOT x[1, 2] (the second would be tuple of int) indexes = (indexes,) @@ -271,10 +276,12 @@ def getitem(self, indexes): def is_narrow_slicing(): tensor_free = not pytree.tree_any( - lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), - indexes) - list_free = not isinstance(indexes, tuple) or all( - [False if isinstance(x, list) else True for x in indexes]) + lambda x: isinstance(x, torch.Tensor) or isinstance(x, jax.Array), + indexes, + ) + list_free = not isinstance(indexes, tuple) or all([ + False if isinstance(x, list) else True for x in indexes + ]) return tensor_free and list_free if is_narrow_slicing(): @@ -343,15 +350,15 @@ def empty(*size: Sequence[int], dtype=None, **kwargs): @register_function(torch.arange, is_jax_function=False) def arange( - start, - end=None, - step=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=None, + start, + end=None, + step=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=None, ): if end is None: end = start @@ -363,14 +370,14 @@ def arange( @register_function(torch.empty_strided, is_jax_function=False) def empty_strided( - size, - stride, - *, - dtype=None, - layout=None, - device=None, - requires_grad=False, - pin_memory=False, + size, + stride, + *, + dtype=None, + layout=None, + device=None, + requires_grad=False, + pin_memory=False, ): return empty(size, dtype=dtype) @@ -389,14 +396,14 @@ def rand(*size, **kwargs): @register_function(torch.randn, is_jax_function=False) def randn( - *size, - generator=None, - out=None, - dtype=None, - layout=torch.strided, - device=None, - requires_grad=False, - pin_memory=False, + *size, + generator=None, + out=None, + dtype=None, + layout=torch.strided, + device=None, + requires_grad=False, + pin_memory=False, ): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] @@ -501,8 +508,8 @@ def linalg_tensorsolve(A, b, dims=None): if dims is not None: A = jnp.moveaxis(A, dims, len(dims) * (A.ndim - 1,)) dims = None - if A.shape[:b.ndim] != b.shape: - b = jnp.reshape(b, A.shape[:b.ndim]) + if A.shape[: b.ndim] != b.shape: + b = jnp.reshape(b, A.shape[: b.ndim]) return jnp.linalg.tensorsolve(A, b, axes=dims) @@ -516,53 +523,54 @@ def functional_linear(self, weights, bias=None): @register_function(torch.nn.functional.interpolate) def functional_interpolate( - input, - size: Tuple[int, int], - scale_factor: Optional[float], - mode: str, - align_corners: bool, - recompute_scale_factor: bool, - antialias: bool, + input, + size: Tuple[int, int], + scale_factor: Optional[float], + mode: str, + align_corners: bool, + recompute_scale_factor: bool, + antialias: bool, ): supported_methods = ( - "nearest", - "linear", - "bilinear", - "trilinear", - "cubic", - "bicubic", - "tricubic", - "lanczos3", - "lanczos5", + "nearest", + "linear", + "bilinear", + "trilinear", + "cubic", + "bicubic", + "tricubic", + "lanczos3", + "lanczos5", ) is_jax_supported = mode in supported_methods if not is_jax_supported: raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" ) # None check antialias = antialias or False align_corners = align_corners or False - if mode in ('cubic', 'bicubic', - 'tricubic') and not antialias and size is not None: + if ( + mode in ("cubic", "bicubic", "tricubic") + and not antialias + and size is not None + ): return jimage.interpolate_bicubic_no_aa( - input, - size[0], - size[1], - align_corners, + input, + size[0], + size[1], + align_corners, ) else: # fallback raise torchax.tensor.OperatorNotFound( - f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" + f"JAX does not support interpolation mode: {mode}. Supported modes are: {supported_methods}" ) @register_function(torch.Tensor.repeat_interleave) -def torch_Tensor_repeat_interleave(self, - repeats, - dim=None, - *, - output_size=None): +def torch_Tensor_repeat_interleave( + self, repeats, dim=None, *, output_size=None +): return jnp.repeat(self, repeats, axis=dim, total_repeat_length=output_size) diff --git a/torchax/torchax/ops/jtorchvision_nms.py b/torchax/torchax/ops/jtorchvision_nms.py index 57832b560b0..6a639ddd638 100644 --- a/torchax/torchax/ops/jtorchvision_nms.py +++ b/torchax/torchax/ops/jtorchvision_nms.py @@ -24,9 +24,11 @@ def _bbox_overlap(boxes, gt_boxes): iou: Intersection over union matrix of all input bounding boxes """ bb_y_min, bb_x_min, bb_y_max, bb_x_max = jnp.split( - ary=boxes, indices_or_sections=4, axis=2) + ary=boxes, indices_or_sections=4, axis=2 + ) gt_y_min, gt_x_min, gt_y_max, gt_x_max = jnp.split( - ary=gt_boxes, indices_or_sections=4, axis=2) + ary=gt_boxes, indices_or_sections=4, axis=2 + ) # Calculates the intersection area. i_xmin = jnp.maximum(bb_x_min, jnp.transpose(gt_x_min, [0, 2, 1])) @@ -51,10 +53,15 @@ def _self_suppression(in_args): iou, _, iou_sum = in_args batch_size = iou.shape[0] can_suppress_others = jnp.reshape( - jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1]).astype(iou.dtype) - iou_suppressed = jnp.reshape( - (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype( - iou.dtype), [batch_size, -1, 1]) * iou + jnp.max(iou, 1) <= 0.5, [batch_size, -1, 1] + ).astype(iou.dtype) + iou_suppressed = ( + jnp.reshape( + (jnp.max(can_suppress_others * iou, 1) <= 0.5).astype(iou.dtype), + [batch_size, -1, 1], + ) + * iou + ) iou_sum_new = jnp.sum(iou_suppressed, [1, 2]) return iou_suppressed, jnp.any(iou_sum - iou_sum_new > 0.5), iou_sum_new @@ -62,11 +69,18 @@ def _self_suppression(in_args): def _cross_suppression(in_args): boxes, box_slice, iou_threshold, inner_idx = in_args batch_size = boxes.shape[0] - new_slice = lax.dynamic_slice(boxes, [0, inner_idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4]) + new_slice = lax.dynamic_slice( + boxes, + [0, inner_idx * _NMS_TILE_SIZE, 0], + [batch_size, _NMS_TILE_SIZE, 4], + ) iou = _bbox_overlap(new_slice, box_slice) - ret_slice = jnp.expand_dims((jnp.all(iou < iou_threshold, [1])).astype( - box_slice.dtype), 2) * box_slice + ret_slice = ( + jnp.expand_dims( + (jnp.all(iou < iou_threshold, [1])).astype(box_slice.dtype), 2 + ) + * box_slice + ) return boxes, ret_slice, iou_threshold, inner_idx + 1 @@ -87,38 +101,47 @@ def _suppression_loop_body(in_args): batch_size = boxes.shape[0] # Iterates over tiles that can possibly suppress the current tile. - box_slice = lax.dynamic_slice(boxes, [0, idx * _NMS_TILE_SIZE, 0], - [batch_size, _NMS_TILE_SIZE, 4]) + box_slice = lax.dynamic_slice( + boxes, [0, idx * _NMS_TILE_SIZE, 0], [batch_size, _NMS_TILE_SIZE, 4] + ) def _loop_cond(in_args): _, _, _, inner_idx = in_args return inner_idx < idx - _, box_slice, _, _ = lax.while_loop(_loop_cond, _cross_suppression, - (boxes, box_slice, iou_threshold, 0)) + _, box_slice, _, _ = lax.while_loop( + _loop_cond, _cross_suppression, (boxes, box_slice, iou_threshold, 0) + ) # Iterates over the current tile to compute self-suppression. iou = _bbox_overlap(box_slice, box_slice) mask = jnp.expand_dims( - jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) - > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), 0) + jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [1, -1]) + > jnp.reshape(jnp.arange(_NMS_TILE_SIZE), [-1, 1]), + 0, + ) iou *= (jnp.logical_and(mask, iou >= iou_threshold)).astype(iou.dtype) def _loop_cond2(in_args): _, loop_condition, _ = in_args return loop_condition - suppressed_iou, _, _ = lax.while_loop(_loop_cond2, _self_suppression, - (iou, True, jnp.sum(iou, [1, 2]))) + suppressed_iou, _, _ = lax.while_loop( + _loop_cond2, _self_suppression, (iou, True, jnp.sum(iou, [1, 2])) + ) suppressed_box = jnp.sum(suppressed_iou, 1) > 0 box_slice *= jnp.expand_dims(1.0 - suppressed_box.astype(box_slice.dtype), 2) # Uses box_slice to update the input boxes. - mask = jnp.reshape((jnp.equal(jnp.arange(num_tiles), - idx)).astype(boxes.dtype), [1, -1, 1, 1]) - boxes = jnp.tile(jnp.expand_dims( - box_slice, 1), [1, num_tiles, 1, 1]) * mask + jnp.reshape( - boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask) + mask = jnp.reshape( + (jnp.equal(jnp.arange(num_tiles), idx)).astype(boxes.dtype), + [1, -1, 1, 1], + ) + boxes = jnp.tile( + jnp.expand_dims(box_slice, 1), [1, num_tiles, 1, 1] + ) * mask + jnp.reshape(boxes, [batch_size, num_tiles, _NMS_TILE_SIZE, 4]) * ( + 1 - mask + ) boxes = jnp.reshape(boxes, [batch_size, -1, 4]) # Updates output_size. @@ -179,8 +202,10 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): """ batch_size = boxes.shape[0] num_boxes = boxes.shape[1] - pad = int(jnp.ceil( - float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE - num_boxes + pad = ( + int(jnp.ceil(float(num_boxes) / _NMS_TILE_SIZE)) * _NMS_TILE_SIZE + - num_boxes + ) boxes = jnp.pad(boxes.astype(jnp.float32), [[0, 0], [0, pad], [0, 0]]) scores = jnp.pad(scores.astype(jnp.float32), [[0, 0], [0, pad]]) num_boxes += pad @@ -188,29 +213,40 @@ def non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): def _loop_cond(in_args): unused_boxes, unused_threshold, output_size, idx = in_args return jnp.logical_and( - jnp.min(output_size) < max_output_size, idx - < num_boxes // _NMS_TILE_SIZE) + jnp.min(output_size) < max_output_size, + idx < num_boxes // _NMS_TILE_SIZE, + ) selected_boxes, _, output_size, _ = lax.while_loop( - _loop_cond, _suppression_loop_body, - (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0)) + _loop_cond, + _suppression_loop_body, + (boxes, iou_threshold, jnp.zeros([batch_size], jnp.int32), 0), + ) idx = num_boxes - lax.top_k( - jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) * - jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), - max_output_size)[0].astype(jnp.int32) + jnp.any(selected_boxes > 0, [2]).astype(jnp.int32) + * jnp.expand_dims(jnp.arange(num_boxes, 0, -1), 0), + max_output_size, + )[0].astype(jnp.int32) idx = jnp.minimum(idx, num_boxes - 1) idx = jnp.reshape( - idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1]) + idx + jnp.reshape(jnp.arange(batch_size) * num_boxes, [-1, 1]), [-1] + ) return idx - boxes = jnp.reshape((jnp.reshape(boxes, [-1, 4]))[idx], - [batch_size, max_output_size, 4]) - boxes = boxes * (jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) - < jnp.reshape(output_size, [-1, 1, 1])).astype(boxes.dtype) + boxes = jnp.reshape( + (jnp.reshape(boxes, [-1, 4]))[idx], [batch_size, max_output_size, 4] + ) + boxes = boxes * ( + jnp.reshape(jnp.arange(max_output_size), [1, -1, 1]) + < jnp.reshape(output_size, [-1, 1, 1]) + ).astype(boxes.dtype) scores = jnp.reshape( - jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size]) - scores = scores * (jnp.reshape(jnp.arange(max_output_size), [1, -1]) - < jnp.reshape(output_size, [-1, 1])).astype(scores.dtype) + jnp.reshape(scores, [-1, 1])[idx], [batch_size, max_output_size] + ) + scores = scores * ( + jnp.reshape(jnp.arange(max_output_size), [1, -1]) + < jnp.reshape(output_size, [-1, 1]) + ).astype(scores.dtype) return scores, boxes @@ -221,14 +257,16 @@ def nms(boxes, scores, iou_threshold): max_output_size = boxes.shape[0] boxes = boxes.reshape((1, *boxes.shape)) scores = scores.reshape((1, *scores.shape)) - res = non_max_suppression_padded(scores, boxes, max_output_size, - iou_threshold) + res = non_max_suppression_padded( + scores, boxes, max_output_size, iou_threshold + ) return res try: import torch import torchvision + ops_registry.register_torch_dispatch_op(torch.ops.torchvision.nms, nms) except Exception: pass diff --git a/torchax/torchax/ops/mappings.py b/torchax/torchax/ops/mappings.py index 409a6d8350b..c5dc29cc854 100644 --- a/torchax/torchax/ops/mappings.py +++ b/torchax/torchax/ops/mappings.py @@ -29,8 +29,9 @@ def t2j(t, use_dlpack=True): # https://github.com/google/jax/issues/7657 # https://github.com/google/jax/issues/17784 if t.dtype == torch.bfloat16: - nparray = (t.cpu().detach().to(torch.float32).numpy() - ) # numpy don't support bfloat16 + nparray = ( + t.cpu().detach().to(torch.float32).numpy() + ) # numpy don't support bfloat16 else: nparray = t.cpu().detach().numpy() res = jnp.asarray(nparray) @@ -69,71 +70,52 @@ def j2t(x, use_dlpack=True): TORCH_DTYPE_TO_JAX = { - # NO_MAPPING : jnp.float0.dtype (signless scalar int), - torch.bool: - jnp.bool_.dtype, - # NO_MAPPING : jnp.int4.dtype, - torch.int8: - jnp.int8.dtype, - torch.int16: - jnp.int16.dtype, - torch.int32: - jnp.int32.dtype, - torch.int64: - jnp.int64.dtype, - torch.long: - jnp.int64.dtype, - # NO_MAPPING : jnp.uint4 - torch.uint8: - jnp.uint8.dtype, - torch.uint16: - jnp.uint16.dtype, - torch.uint32: - jnp.uint32.dtype, - torch.uint64: - jnp.uint64.dtype, - # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype, - torch.float8_e4m3fn: - jnp.float8_e4m3fn.dtype, - # NO_MAPPING : jnp.float8_e4m3fnuz.dtype, - torch.float8_e5m2: - jnp.float8_e5m2.dtype, - # NO_MAPPING : jnp.float8_e5m2fnuz.dtype, - torch.bfloat16: - jnp.bfloat16.dtype, - torch.half: - jnp.float16.dtype, - torch.float16: - jnp.float16.dtype, - torch.float32: - jnp.float32.dtype, - torch.float64: - jnp.float64.dtype, - torch.double: - jnp.double.dtype, - torch.complex64: - jnp.complex64.dtype, - torch.complex128: - jnp.complex128.dtype, - None: - None, + # NO_MAPPING : jnp.float0.dtype (signless scalar int), + torch.bool: jnp.bool_.dtype, + # NO_MAPPING : jnp.int4.dtype, + torch.int8: jnp.int8.dtype, + torch.int16: jnp.int16.dtype, + torch.int32: jnp.int32.dtype, + torch.int64: jnp.int64.dtype, + torch.long: jnp.int64.dtype, + # NO_MAPPING : jnp.uint4 + torch.uint8: jnp.uint8.dtype, + torch.uint16: jnp.uint16.dtype, + torch.uint32: jnp.uint32.dtype, + torch.uint64: jnp.uint64.dtype, + # NO_MAPPING : jnp.float8_e4m3b11fnuz.dtype, + torch.float8_e4m3fn: jnp.float8_e4m3fn.dtype, + # NO_MAPPING : jnp.float8_e4m3fnuz.dtype, + torch.float8_e5m2: jnp.float8_e5m2.dtype, + # NO_MAPPING : jnp.float8_e5m2fnuz.dtype, + torch.bfloat16: jnp.bfloat16.dtype, + torch.half: jnp.float16.dtype, + torch.float16: jnp.float16.dtype, + torch.float32: jnp.float32.dtype, + torch.float64: jnp.float64.dtype, + torch.double: jnp.double.dtype, + torch.complex64: jnp.complex64.dtype, + torch.complex128: jnp.complex128.dtype, + None: None, } JAX_DTYPE_TO_TORCH = {value: key for key, value in TORCH_DTYPE_TO_JAX.items()} # Add imprecise mappings for some JAX dtypes which don't have torch analogues -JAX_DTYPE_TO_TORCH[jnp.dtype('int4')] = torch.int8 -JAX_DTYPE_TO_TORCH[jnp.dtype('uint4')] = torch.uint8 +JAX_DTYPE_TO_TORCH[jnp.dtype("int4")] = torch.int8 +JAX_DTYPE_TO_TORCH[jnp.dtype("uint4")] = torch.uint8 def t2j_dtype(dtype): if dtype not in TORCH_DTYPE_TO_JAX: raise RuntimeError( - f'Attempting to convert unknown type: {dtype} to jax type,') + f"Attempting to convert unknown type: {dtype} to jax type," + ) return TORCH_DTYPE_TO_JAX[dtype] def j2t_dtype(dtype): if dtype not in JAX_DTYPE_TO_TORCH: raise RuntimeError( - f'Attempting to convert unknown type: {dtype} to torch type,') + f"Attempting to convert unknown type: {dtype} to torch type," + ) return JAX_DTYPE_TO_TORCH[dtype] diff --git a/torchax/torchax/ops/op_base.py b/torchax/torchax/ops/op_base.py index d69e85ae50a..78b63771d37 100644 --- a/torchax/torchax/ops/op_base.py +++ b/torchax/torchax/ops/op_base.py @@ -12,12 +12,13 @@ class InplaceOp: - - def __init__(self, - functional_op, - replace=False, - position_to_mutate=0, - is_jax_func=False): + def __init__( + self, + functional_op, + replace=False, + position_to_mutate=0, + is_jax_func=False, + ): self.functional = functional_op self.replace = replace self.position_to_mutate = position_to_mutate @@ -51,15 +52,14 @@ def __call__(self, *args, **kwargs): class OutVariant: - def __call__(self, *args, **kwargs): - to_mutate = kwargs['out'] - del kwargs['out'] + to_mutate = kwargs["out"] + del kwargs["out"] to_mutate._elem = self.functional(*args, **kwargs)._elem return to_mutate -P = ParamSpec('P') +P = ParamSpec("P") def convert_dtype(use_default_dtype: bool = True): @@ -73,11 +73,12 @@ def convert_dtype(use_default_dtype: bool = True): """ def decorator(func: types.TorchCallable): - @functools.wraps(func) - def wrapper(*args: P.args, - dtype: Optional[torch.dtype] = None, - **kwargs: P.kwargs): + def wrapper( + *args: P.args, + dtype: Optional[torch.dtype] = None, + **kwargs: P.kwargs, + ): if not dtype and use_default_dtype: dtype = torch.get_default_dtype() if isinstance(dtype, torch.dtype): @@ -92,8 +93,9 @@ def wrapper(*args: P.args, return decorator -def maybe_convert_constant_dtype(val: Optional[types.JaxValue], - dtype: Optional[jnp.dtype]): +def maybe_convert_constant_dtype( + val: Optional[types.JaxValue], dtype: Optional[jnp.dtype] +): """Optionally converts scalar constant's dtype using `numpy` Use in cases where you require a constant and can't handle a traced array. @@ -120,12 +122,15 @@ def wrapper(x: jax.Array, *args: P.args, **kwargs: P.kwargs): return wrapper -def foreach_loop(seq: jax.Array, - fn: Callable[[jax.Array, jax.Array], jax.Array], - init_val=0.0): +def foreach_loop( + seq: jax.Array, + fn: Callable[[jax.Array, jax.Array], jax.Array], + init_val=0.0, +): """Run `fn` for each element of 1D array `seq`. Similar to `functools.reduce`, but implemented with `jax.lax.fori_loop`.""" assert len(seq.shape) == 1 - return jax.lax.fori_loop(0, len(seq), lambda i, carry: fn(carry, seq[i]), - init_val) + return jax.lax.fori_loop( + 0, len(seq), lambda i, carry: fn(carry, seq[i]), init_val + ) diff --git a/torchax/torchax/ops/ops_registry.py b/torchax/torchax/ops/ops_registry.py index aa0d61cbb49..7ba29cdcbd3 100644 --- a/torchax/torchax/ops/ops_registry.py +++ b/torchax/torchax/ops/ops_registry.py @@ -19,37 +19,43 @@ class Operator: all_torch_functions: Dict[TorchCallable, Operator] = {} -def register_torch_dispatch_op(aten_op, - impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False): +def register_torch_dispatch_op( + aten_op, + impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, + is_view_op=False, +): op = Operator( - aten_op, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op) + aten_op, + impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env, + is_view_op=is_view_op, + ) if aten_op in all_aten_ops: - logging.warning(f'Duplicate op registration for {aten_op}') + logging.warning(f"Duplicate op registration for {aten_op}") all_aten_ops[aten_op] = op return impl_callable -def register_torch_function_op(torch_func, - impl_callable, - is_jax_function=True, - is_user_defined=False, - needs_env=False, - is_view_op=False): +def register_torch_function_op( + torch_func, + impl_callable, + is_jax_function=True, + is_user_defined=False, + needs_env=False, + is_view_op=False, +): op = Operator( - torch_func, - impl_callable, - is_jax_function=is_jax_function, - is_user_defined=is_user_defined, - needs_env=needs_env, - is_view_op=is_view_op) + torch_func, + impl_callable, + is_jax_function=is_jax_function, + is_user_defined=is_user_defined, + needs_env=needs_env, + is_view_op=is_view_op, + ) all_torch_functions[torch_func] = op return impl_callable diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index afca83ac200..1038d05f425 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -46,7 +46,6 @@ def log_nested(env, message): class Tensor(torch.Tensor): - @staticmethod def __new__(cls, elem, env): dtype = mappings.j2t_dtype(elem.dtype) @@ -57,11 +56,11 @@ def __new__(cls, elem, env): if dtype is None: dtype = torch.float32 return torch.Tensor._make_wrapper_subclass( - cls, - shape, - dtype=dtype, - device="meta", - requires_grad=False, + cls, + shape, + dtype=dtype, + device="meta", + requires_grad=False, ) def __init__(self, elem: jax.Array, env: "Environment"): @@ -89,7 +88,8 @@ def flatten(self, start_dim=0, end_dim=-1): if end_dim == -1: end_dim = self.ndim new_shape = ( - self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1:]) + self._elem.shape[:start_dim] + (-1,) + self._elem.shape[end_dim + 1 :] + ) new_elem = jnp.reshape(self._elem, new_shape) return Tensor(new_elem, self._env) # return torch.reshape(self, new_shape) @@ -110,9 +110,10 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if func == torch.ops._c10d_functional.wait_tensor.default: return args[0]._env.dispatch(func, types, args, kwargs) raise AssertionError( - 'torchax Tensors can only do math within the torchax environment.' - 'Please wrap your code with `with torchax.default_env()` or ' - 'call torchax.enable_globally() before.') + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) def detach(self): return Tensor(jax.lax.stop_gradient(self.jax()), self._env) @@ -172,7 +173,8 @@ def shard_(self, sharding): def debug_accuracy(func, args, kwargs, current_output): args_torch, kwargs_torch, out_torch = torch_pytree.tree_map_only( - torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output)) + torch.Tensor, lambda x: x.torch(), (args, kwargs, current_output) + ) with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): if "device" in kwargs_torch: @@ -187,7 +189,8 @@ def debug_accuracy(func, args, kwargs, current_output): ex = ex.to(real.dtype) try: if isinstance(ex, torch.Tensor) and not torch.allclose( - ex, real, atol=1e-3, equal_nan=True): + ex, real, atol=1e-3, equal_nan=True + ): import pdb pdb.set_trace() @@ -200,7 +203,6 @@ def debug_accuracy(func, args, kwargs, current_output): def _make_debug_msg(is_dispatch, log_args, func, args, kwargs): - def _display(a): if isinstance(a, torch.Tensor): return f"Tensor of {type(a)}: {a.dtype}{a.shape}" @@ -212,9 +214,11 @@ def _display(a): kwargs = kwargs or {} title = "DISPATCH" if is_dispatch else "FUNCTION" args_msg = "args: " + ",".join(_display(a) for a in args) if log_args else "" - kwargs_msg = ("kwargs: " + - ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) - if log_args else "") + kwargs_msg = ( + "kwargs: " + ",".join(f"{key}: {_display(a)}" for key, a in kwargs.items()) + if log_args + else "" + ) return f"{title}: {_name_of_func(func)} {args_msg} ~ {kwargs_msg}" @@ -224,24 +228,27 @@ class XLAFunctionMode(torch.overrides.TorchFunctionMode): def __init__(self, env): self.env = env - def __torch_function__(self, - func, - types, - args=(), - kwargs=None) -> torch.Tensor: + def __torch_function__( + self, func, types, args=(), kwargs=None + ) -> torch.Tensor: message = f"FUNCTION: {_name_of_func(func)}" if self.env.config.debug_print_each_op_operands: message = message + "f" - message = _make_debug_msg(False, - self.env.config.debug_print_each_op_operands, - func, args, kwargs) + message = _make_debug_msg( + False, + self.env.config.debug_print_each_op_operands, + func, + args, + kwargs, + ) with log_nested(self.env, message): try: return self.env.dispatch(func, types, args, kwargs) except OperatorNotFound: pass if _name_of_func(func) in ( - "rot90"): # skip rot90 with k%4==0 due to no change + "rot90" + ): # skip rot90 with k%4==0 due to no change if len(args) >= 2 and type(args[1]) == int: if (args[1]) % 4 == 0: return args[0] @@ -249,24 +256,27 @@ def __torch_function__(self, class XLADispatchMode(torch_dispatch.TorchDispatchMode): - def __init__(self, env): self.env = env def __torch_dispatch__(self, func, types, args=(), kwargs=None): - message = _make_debug_msg(True, - self.env.config.debug_print_each_op_operands, - func, args, kwargs) + message = _make_debug_msg( + True, + self.env.config.debug_print_each_op_operands, + func, + args, + kwargs, + ) with log_nested(self.env, message): if isinstance(func, torch._ops.OpOverloadPacket): with self: return func(*args, **kwargs) # Only functions under these namespaces will be intercepted if func.namespace not in ( - "aten", - "_c10d_functional", - "torchvision", - "xla", + "aten", + "_c10d_functional", + "torchvision", + "xla", ): return func(*args, **kwargs) return self.env.dispatch(func, types, args, kwargs) @@ -280,18 +290,18 @@ def _name_of_func(func): # Constructors that don't take other tensor as input TENSOR_CONSTRUCTORS = { - torch.ones, - torch.zeros, - torch.empty, - torch.empty_strided, - torch.tensor, - torch.arange, - torch.eye, - torch.randn, - torch.rand, - torch.randint, - torch.full, - torch.as_tensor, + torch.ones, + torch.zeros, + torch.empty, + torch.empty_strided, + torch.tensor, + torch.arange, + torch.eye, + torch.randn, + torch.rand, + torch.randint, + torch.full, + torch.as_tensor, } # TODO(wen): use existing types, either from torch or jax @@ -328,7 +338,8 @@ def __init__(self, configuration=None): self.enabled = False self._prng_key = mutable_array( - jax.random.key(torch.initial_seed() % (1 << 63))) + jax.random.key(torch.initial_seed() % (1 << 63)) + ) self.autocast_dtype = None self._target_device = "cpu" @@ -355,7 +366,8 @@ def get_as_jax_device(self, device: Any): device = str(device) if not self.config.use_torch_native_for_cpu_tensor and device.startswith( - "cpu"): + "cpu" + ): return jax.devices("cpu")[0] if self.config.treat_cuda_as_jax_device and device.startswith("cuda"): @@ -381,8 +393,10 @@ def get_as_jax_device(self, device: Any): def load_ops(self): from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms - for k, v in itertools.chain(ops_registry.all_aten_ops.items(), - ops_registry.all_torch_functions.items()): + for k, v in itertools.chain( + ops_registry.all_aten_ops.items(), + ops_registry.all_torch_functions.items(), + ): if v.is_jax_function: self._ops[k] = v else: @@ -393,16 +407,15 @@ def load_ops(self): for k, v in DECOMPOSITIONS.items(): if k not in self._decomps: self._decomps[k] = ops_registry.Operator( - k, - v, - is_jax_function=False, - is_user_defined=False, - needs_env=False, - is_view_op=k in MUTABLE_DECOMPOSITION, + k, + v, + is_jax_function=False, + is_user_defined=False, + needs_env=False, + is_view_op=k in MUTABLE_DECOMPOSITION, ) def _get_op_or_decomp(self, func): - def _get_from_dict(op_dict, op): op = op_dict.get(func) if op is None and isinstance(func, torch._ops.OpOverloadPacket): @@ -419,7 +432,8 @@ def _get_from_dict(op_dict, op): if op is None: raise OperatorNotFound( - f"Operator with name {_name_of_func(func)} has no lowering") + f"Operator with name {_name_of_func(func)} has no lowering" + ) return op @@ -470,8 +484,9 @@ def _to_copy(self, the_tensor, new_dtype, new_device): return Tensor(arr, self) - def get_and_rotate_prng_key(self, - generator: Optional[torch.Generator] = None): + def get_and_rotate_prng_key( + self, generator: Optional[torch.Generator] = None + ): if generator is not None: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63)) @@ -522,10 +537,10 @@ def dispatch(self, func, types, args, kwargs): if func in TENSOR_CONSTRUCTORS: return self._handle_tensor_constructor(func, args, kwargs) if func in ( - torch.Tensor.to, - torch.ops.aten.lift_fresh.default, - torch.ops.aten._to_copy, - torch.ops.aten._to_copy.default, + torch.Tensor.to, + torch.ops.aten.lift_fresh.default, + torch.ops.aten._to_copy, + torch.ops.aten._to_copy.default, ): return self._torch_Tensor_to(args, kwargs) @@ -533,8 +548,9 @@ def dispatch(self, func, types, args, kwargs): # We should skip and let torch handle it. tensor_args = [ - t for t in torch_pytree.tree_flatten(args)[0] - if isinstance(t, torch.Tensor) + t + for t in torch_pytree.tree_flatten(args)[0] + if isinstance(t, torch.Tensor) ] def is_not_torchax_tensor(x): @@ -550,9 +566,9 @@ def is_not_torchax_tensor(x): old_args, old_kwargs = args, kwargs with self._dispatch_mode: args, kwargs = torch_pytree.tree_map_only( - torch.distributed._functional_collectives.AsyncCollectiveTensor, - torch.distributed._functional_collectives.wait_tensor, - (args, kwargs), + torch.distributed._functional_collectives.AsyncCollectiveTensor, + torch.distributed._functional_collectives.wait_tensor, + (args, kwargs), ) try: @@ -563,8 +579,12 @@ def is_not_torchax_tensor(x): if self.autocast_dtype is not None: autocast_policy = amp.autocast_policy.get(func) if autocast_policy is not None: - args, kwargs = amp.execute_policy(autocast_policy, args, kwargs, - self.autocast_dtype) + args, kwargs = amp.execute_policy( + autocast_policy, + args, + kwargs, + self.autocast_dtype, + ) if op.is_jax_function: args, kwargs = self.t2j_iso((args, kwargs)) @@ -633,17 +653,19 @@ def to_xla(self, torchvalues): def t2j_iso(self, torchtensors): """Convert torchax Tensor to jax array. - + This function will not copy, will just unwrap the inner jax array out. Note: iso is short for "isomorphic" """ def to_jax(x): if isinstance( - x, torch.distributed._functional_collectives.AsyncCollectiveTensor): + x, + torch.distributed._functional_collectives.AsyncCollectiveTensor, + ): x = x.wait() assert isinstance(x, Tensor) or isinstance(x, View), ( - f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" + f"Expect a Tensor or a View but got {type(x)}; usually this means there is a mixed math between XLATensor and torch.Tensor" ) return x.jax() @@ -651,7 +673,6 @@ def to_jax(x): return res def v2t_iso(self, views): - def to_tensor(x): if isinstance(x, View): return x.torch() @@ -662,38 +683,41 @@ def to_tensor(x): def j2t_iso(self, jaxarray): """Convert jax array to torchax Tensor. - + This function will not copy, will just wrap the jax array with a torchax Tensor Note: iso is short for "isomorphic" """ - return torch_pytree.tree_map_only(jax.Array, lambda x: Tensor(x, self), - jaxarray) + return torch_pytree.tree_map_only( + jax.Array, lambda x: Tensor(x, self), jaxarray + ) def j2t_copy(self, args): """Convert torch.Tensor in cpu to a jax array - + This might involves copying the data (depending if dlpack is enabled) """ return torch_pytree.tree_map_only( - jax.Array, - lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion), - args) + jax.Array, + lambda x: mappings.j2t(x, self.config.use_dlpack_for_data_conversion), + args, + ) def t2j_copy(self, args): """Convert jax array to torch.Tensor in cpu. - + This might involves copying the data (depending if dlpack is enabled) """ return torch_pytree.tree_map_only( - torch.Tensor, - lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion), - args) + torch.Tensor, + lambda x: mappings.t2j(x, self.config.use_dlpack_for_data_conversion), + args, + ) def override_op_definition(self, op_to_override, op_impl): self._ops[op_to_override] = ops_registry.Operator( - op_to_override, - op_impl, - is_jax_function=False, - is_user_defined=True, - needs_env=False, + op_to_override, + op_impl, + is_jax_function=False, + is_user_defined=True, + needs_env=False, ) diff --git a/torchax/torchax/tf_integration.py b/torchax/torchax/tf_integration.py index c9842089bfc..2542df31858 100644 --- a/torchax/torchax/tf_integration.py +++ b/torchax/torchax/tf_integration.py @@ -13,34 +13,35 @@ def exported_program_to_tf_function(ep, enable_xla=True): wrapped = lambda *args: jax_program(weights, (args,)) avals = export.extract_avals(ep) input_signature = [ - tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") - for i, t in enumerate(avals) + tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=f"args_{i}") + for i, t in enumerate(avals) ] tf_f = tf.function( - jax2tf.convert( - wrapped, - with_gradient=False, - enable_xla=enable_xla, - ), - autograph=False, - input_signature=input_signature, + jax2tf.convert( + wrapped, + with_gradient=False, + enable_xla=enable_xla, + ), + autograph=False, + input_signature=input_signature, ) return tf_f -def exported_program_to_tf_module(ep: torch.export.ExportedProgram, - enable_xla=True) -> tf.Module: +def exported_program_to_tf_module( + ep: torch.export.ExportedProgram, enable_xla=True +) -> tf.Module: tfm = tf.Module() tfm.f = exported_program_to_tf_function(ep, enable_xla) return tfm def save_exported_program_as_tf_saved_model( - ep: torch.export.ExportedProgram, - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, + ep: torch.export.ExportedProgram, + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = "", + enable_xla=True, ): """This function will export and save a pytorch ExportedProgram to tf.saved_model format. @@ -61,26 +62,28 @@ def save_exported_program_as_tf_saved_model( """ tfm = exported_program_to_tf_module(ep, enable_xla=enable_xla) signatures = { - serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) + serving_key: tfm.f.get_concrete_function(*tfm.f.input_signature) } - save_options = tf.saved_model.SaveOptions(function_aliases={ + save_options = tf.saved_model.SaveOptions( + function_aliases={ function_alias: tfm.f, - }) + } + ) tf.saved_model.save( - tfm, - saved_model_dir, - signatures=signatures, - options=save_options, + tfm, + saved_model_dir, + signatures=signatures, + options=save_options, ) def save_torch_module_as_tf_saved_model( - torch_model: torch.nn.Module, - args: Tuple[Any], - saved_model_dir: os.PathLike, - serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - function_alias: str = "", - enable_xla=True, + torch_model: torch.nn.Module, + args: Tuple[Any], + saved_model_dir: os.PathLike, + serving_key: str = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY, + function_alias: str = "", + enable_xla=True, ): """This function will export and save a pytorch nn.Module to tf.saved_model format. @@ -100,20 +103,23 @@ def save_torch_module_as_tf_saved_model( function for inference converter or other tools. """ ep = torch.export.export(torch_model, args) - save_exported_program_as_tf_saved_model(ep, saved_model_dir, serving_key, - function_alias, enable_xla) + save_exported_program_as_tf_saved_model( + ep, saved_model_dir, serving_key, function_alias, enable_xla + ) def exported_program_to_tflite_flatbuffer(ep: torch.export.ExportedProgram): tfm = exported_program_to_tf_module(ep) tf_concrete_func = tfm.f.get_concrete_function(*tfm.f.input_signature) converter = tf.lite.TFLiteConverter.from_concrete_functions( - [tf_concrete_func], tfm) + [tf_concrete_func], tfm + ) tflite_model = converter.convert() return tflite_model -def torch_module_to_tflite_flatbuffer(torch_model: torch.nn.Module, - args: Tuple[Any]): +def torch_module_to_tflite_flatbuffer( + torch_model: torch.nn.Module, args: Tuple[Any] +): ep = torch.export.export(torch_model, args) return exported_program_to_tflite_flatbuffer(ep) diff --git a/torchax/torchax/train.py b/torchax/torchax/train.py index fb4e16fc48e..c1be6ab9901 100644 --- a/torchax/torchax/train.py +++ b/torchax/torchax/train.py @@ -31,21 +31,22 @@ def make_train_step(model_fn, loss_fn, optax_optimizer, remat_policy=None): env = torchax.default_env() def loss(weights, buffers, args, label): # inputs are XLATensor - with env, jax.named_scope('compute_loss'): + with env, jax.named_scope("compute_loss"): res = model_fn(weights, buffers, args) l = loss_fn(res, label) return l - loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy}) + loss = interop.gradient_checkpoint(loss, kwargs={"policy": remat_policy}) grad_fn = interop.jax_value_and_grad(loss) - def step(weights, buffers, opt_state, args, label): #inputs are array - with jax.named_scope('compute_gradient'): + def step(weights, buffers, opt_state, args, label): # inputs are array + with jax.named_scope("compute_gradient"): loss, gradient = grad_fn(weights, buffers, args, label) with jax.named_scope("optimizer_updates"): - updates, opt_state = interop.call_jax(optax_optimizer.update, gradient, - opt_state, weights) + updates, opt_state = interop.call_jax( + optax_optimizer.update, gradient, opt_state, weights + ) weights = interop.call_jax(optax.apply_updates, weights, updates) return loss, weights, opt_state @@ -58,7 +59,6 @@ class Container: class ScannedModule(torch.nn.Module): - def __init__(self, module_list, checkpoint_policy=None): super().__init__() @@ -71,7 +71,7 @@ def __init__(self, module_list, checkpoint_policy=None): weights = self._stack_layer_weights(module_list) self.layer_weights_keys = list(self.c.one_mod.state_dict().keys()) self.params = torch.nn.ParameterDict({ - self._param_name_new(k): v for k, v in weights.items() + self._param_name_new(k): v for k, v in weights.items() }) def _stack_layer_weights(self, module_list): @@ -86,15 +86,15 @@ def _stack_layer_weights(self, module_list): return res def _param_name_new(self, old): - return '___'.join(old.split('.')) + return "___".join(old.split(".")) def _param_name_old(self, new): - return '.'.join(new.split('___')) + return ".".join(new.split("___")) def forward(self, *args, **kwargs): assert not kwargs weights = { - k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys + k: self.params[self._param_name_new(k)] for k in self.layer_weights_keys } scan = interop.torch_view(jax.lax.scan) @@ -106,12 +106,12 @@ def eval_one_layer(args, weight): return (newh, *rest), None _eval_one_layer = interop.gradient_checkpoint( - eval_one_layer, - kwargs={'policy': self.checkpoint_policy}, + eval_one_layer, + kwargs={"policy": self.checkpoint_policy}, ) h, _ = scan( - _eval_one_layer, - args, - weights, + _eval_one_layer, + args, + weights, ) return h[0] diff --git a/torchax/torchax/types.py b/torchax/torchax/types.py index 72a2f678c96..d61e1444eb4 100644 --- a/torchax/torchax/types.py +++ b/torchax/torchax/types.py @@ -4,9 +4,9 @@ import jax.numpy as jnp import sys -P = ParamSpec('P') +P = ParamSpec("P") -TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any] +TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, "TorchCallable", Any] TorchCallable: TypeAlias = Callable[P, TorchValue] -JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, 'JaxCallable', Any] -JaxCallable: TypeAlias = Callable[P, JaxValue] \ No newline at end of file +JaxValue: TypeAlias = Union[jax.Array, jnp.dtype, "JaxCallable", Any] +JaxCallable: TypeAlias = Callable[P, JaxValue] diff --git a/torchax/torchax/util.py b/torchax/torchax/util.py index e34f77119d6..4b4c4297dcd 100644 --- a/torchax/torchax/util.py +++ b/torchax/torchax/util.py @@ -1,8 +1,9 @@ from typing import Any, Callable -def partition(original: list[Any], - func: Callable[[Any], bool]) -> tuple[list[Any], list[Any]]: +def partition( + original: list[Any], func: Callable[[Any], bool] +) -> tuple[list[Any], list[Any]]: """Partitions elements into two parallel lists based on a predicate function. Iterates through the 'original' list, applying 'func' to each element 'a'. diff --git a/torchax/torchax/view.py b/torchax/torchax/view.py index 040fa24ef9e..0012273aa21 100644 --- a/torchax/torchax/view.py +++ b/torchax/torchax/view.py @@ -23,76 +23,77 @@ class ViewInfoType(Enum): class ViewInfo(ABC): """ - Abstract base class for all view operations. - Defines the interface for applying and updating view transformations. - """ + Abstract base class for all view operations. + Defines the interface for applying and updating view transformations. + """ def __init__( - self, - view_info_type: ViewInfoType = ViewInfoType.INVALID, + self, + view_info_type: ViewInfoType = ViewInfoType.INVALID, ): """ - Initialize a ViewInfo object. + Initialize a ViewInfo object. - Args: - view_info_type: The type of view operation - """ + Args: + view_info_type: The type of view operation + """ self.view_info_type = view_info_type @abstractmethod - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: """ - Apply this view transformation to a JAX array and update its value. + Apply this view transformation to a JAX array and update its value. - Args: - new_value: The new values to set in the view - jax_array: The parent array to update + Args: + new_value: The new values to set in the view + jax_array: The parent array to update - Returns: - Updated array - """ + Returns: + Updated array + """ pass @abstractmethod def transform_tensor(self, jax_array: jax.Array) -> jax.Array: """ - Apply this view transformation to a JAX array. + Apply this view transformation to a JAX array. - Args: - jax_array: The array to transform + Args: + jax_array: The array to transform - Returns: - Transformed array - """ + Returns: + Transformed array + """ pass @abstractmethod def calculate_output_shape(self, source: jax.Array) -> List[int]: """ - Calculate the resulting shape after applying this view. + Calculate the resulting shape after applying this view. - Args: - source: Original jax array before transformation + Args: + source: Original jax array before transformation - Returns: - Resulting shape after transformation - """ + Returns: + Resulting shape after transformation + """ pass class NarrowInfo(ViewInfo): """ - Represents a slicing operation on a tensor. - Handles operations like tensor[1:3, :, 2:5:2]. - """ + Represents a slicing operation on a tensor. + Handles operations like tensor[1:3, :, 2:5:2]. + """ def __init__(self, slices: Union[slice, Tuple[slice]]) -> None: """ - Args: - slices: The slice(s) to apply to the tensor. - E.g. jax_array.at[slices] will return the transformed tensor. - """ + Args: + slices: The slice(s) to apply to the tensor. + E.g. jax_array.at[slices] will return the transformed tensor. + """ super().__init__(ViewInfoType.NARROW) self.slices = slices @@ -107,8 +108,9 @@ def transform_tensor(self, jax_array: jax.Array) -> jax.Array: except IndexError as e: raise IndexError("Invalid slice operation") from e - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: return jax_array.at[self.slices].set(new_value) def calculate_output_shape(self, source: jax.Array) -> List[int]: @@ -117,15 +119,13 @@ def calculate_output_shape(self, source: jax.Array) -> List[int]: class SelectInfo(ViewInfo): """ - Represents a selection operation on a tensor. - Typically used for indexing operations that select specific elements. - """ + Represents a selection operation on a tensor. + Typically used for indexing operations that select specific elements. + """ - def __init__(self, - dim: int = 0, - start: int = 0, - end: int = 0, - stride: int = 0) -> None: + def __init__( + self, dim: int = 0, start: int = 0, end: int = 0, stride: int = 0 + ) -> None: super().__init__(ViewInfoType.SELECT) self.dim: int = dim self.start: int = start @@ -135,25 +135,31 @@ def __init__(self, def __eq__(self, other: object) -> bool: if not isinstance(other, SelectInfo): return False - return (self.dim == other.dim and self.start == other.start and - self.end == other.end and self.stride == other.stride) + return ( + self.dim == other.dim + and self.start == other.start + and self.end == other.end + and self.stride == other.stride + ) def transform_tensor(self, jax_array: jax.Array) -> jax.Array: raise NotImplementedError("SelectInfo.apply not implemented") - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: raise NotImplementedError("SelectInfo.update not implemented") def calculate_output_shape(self, source: jax.Array) -> List[int]: raise NotImplementedError( - "SelectInfo.calculate_output_shape not implemented") + "SelectInfo.calculate_output_shape not implemented" + ) class AsStridedInfo(ViewInfo): """ - Information for as_strided operations. - """ + Information for as_strided operations. + """ def __init__(self, stride: List[int], offset: int = 0) -> None: super().__init__(ViewInfoType.AS_STRIDED) @@ -168,28 +174,30 @@ def __eq__(self, other: object) -> bool: def transform_tensor(self, jax_array: jax.Array) -> jax.Array: raise NotImplementedError("AsStridedInfo.apply not implemented") - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: raise NotImplementedError("AsStridedInfo.update not implemented") def calculate_output_shape(self, source: jax.Array) -> List[int]: raise NotImplementedError( - "AsStridedInfo.calculate_output_shape not implemented") + "AsStridedInfo.calculate_output_shape not implemented" + ) class DiagonalInfo(ViewInfo): """ - Information for diagonal operations. - Extracts diagonal elements from a tensor. - """ + Information for diagonal operations. + Extracts diagonal elements from a tensor. + """ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: """ - Args: - offset: Offset from the main diagonal - dim1: First dimension for diagonal extraction - dim2: Second dimension for diagonal extraction - """ + Args: + offset: Offset from the main diagonal + dim1: First dimension for diagonal extraction + dim2: Second dimension for diagonal extraction + """ super().__init__(ViewInfoType.DIAGONAL) self.offset: int = offset self.dim1: int = dim1 @@ -198,47 +206,60 @@ def __init__(self, offset: int = 0, dim1: int = 0, dim2: int = 1) -> None: def __eq__(self, other: object) -> bool: if not isinstance(other, DiagonalInfo): return False - return (self.offset == other.offset and self.dim1 == other.dim1 and - self.dim2 == other.dim2) + return ( + self.offset == other.offset + and self.dim1 == other.dim1 + and self.dim2 == other.dim2 + ) def transform_tensor(self, jax_array: jax.Array) -> jax.Array: raise NotImplementedError("DiagonalInfo.apply not implemented") - def update_tensor(self, new_value: jax.Array, - jax_array: jax.Array) -> jax.Array: + def update_tensor( + self, new_value: jax.Array, jax_array: jax.Array + ) -> jax.Array: raise NotImplementedError("DiagonalInfo.update not implemented") def calculate_output_shape(self, source: jax.Array) -> List[int]: raise NotImplementedError( - "DiagonalInfo.calculate_output_shape not implemented") + "DiagonalInfo.calculate_output_shape not implemented" + ) class View(torch.Tensor): """ - A View is a reference to another Tensor or another View, - with a transformation applied to it. - """ + A View is a reference to another Tensor or another View, + with a transformation applied to it. + """ @staticmethod - def __new__(cls, parent: Union["torchax.Tensor", "View"], view_info: ViewInfo, - env: Any) -> "View": + def __new__( + cls, + parent: Union["torchax.Tensor", "View"], + view_info: ViewInfo, + env: Any, + ) -> "View": + """ + Args: + parent: Parent tensor or view + view_info: Information about the view transformation + env: Environment for tensor operations """ - Args: - parent: Parent tensor or view - view_info: Information about the view transformation - env: Environment for tensor operations - """ shape = view_info.calculate_output_shape(parent.jax()) return torch.Tensor._make_wrapper_subclass( - cls, - shape, - device="meta", - dtype=parent.dtype, - requires_grad=False, + cls, + shape, + device="meta", + dtype=parent.dtype, + requires_grad=False, ) - def __init__(self, parent: Union["torchax.Tensor", "View"], - view_info: ViewInfo, env: Any) -> None: + def __init__( + self, + parent: Union["torchax.Tensor", "View"], + view_info: ViewInfo, + env: Any, + ) -> None: super().__init__() self.parent = parent self.view_info = view_info @@ -246,8 +267,8 @@ def __init__(self, parent: Union["torchax.Tensor", "View"], def get_transformation_chain(self) -> List[ViewInfo]: """ - Get all view transformations from the source tensor to this view. - """ + Get all view transformations from the source tensor to this view. + """ if isinstance(self.parent, View): transformations = self.parent.get_transformation_chain() transformations.append(self.view_info) @@ -259,8 +280,8 @@ def get_transformation_chain(self) -> List[ViewInfo]: def source_jax(self) -> jax.Array: """ - Returns the source tensor. - """ + Returns the source tensor. + """ if isinstance(self.parent, View): return self.parent.source_jax() else: @@ -268,8 +289,8 @@ def source_jax(self) -> jax.Array: def replace_source_jax(self, new_value: jax.Array) -> None: """ - Update the source tensor with new values. - """ + Update the source tensor with new values. + """ if isinstance(self.parent, View): self.parent.replace_source_jax(new_value) else: @@ -278,22 +299,22 @@ def replace_source_jax(self, new_value: jax.Array) -> None: def torch(self) -> "torchax.Tensor": """ - Returns a Torchax tensor representing this view after all transformations - """ + Returns a Torchax tensor representing this view after all transformations + """ from torchax.tensor import Tensor return Tensor(self.jax(), self._env) def update( - self, - new_values: Union[jax.Array, "View", "torchax.Tensor"], - view_infos: Optional[List[ViewInfo]] = None, + self, + new_values: Union[jax.Array, "View", "torchax.Tensor"], + view_infos: Optional[List[ViewInfo]] = None, ) -> None: """ - Update this view with new values, propagating changes back to source. - If view_infos is None, it will use the transformation chain - from the source tensor. - """ + Update this view with new values, propagating changes back to source. + If view_infos is None, it will use the transformation chain + from the source tensor. + """ if view_infos is None: view_infos = self.get_transformation_chain() @@ -311,13 +332,15 @@ def update( intermediate_values = [source_array] for view_info in view_infos[:-1]: intermediate_values.append( - view_info.transform_tensor(intermediate_values[-1])) + view_info.transform_tensor(intermediate_values[-1]) + ) # TODO: Investigate efficiency of this algorithm # Update the source array with the new value by # applying inverse transformations in reverse order for view_info, parent_array in zip( - reversed(view_infos), reversed(intermediate_values)): + reversed(view_infos), reversed(intermediate_values) + ): # Apply the inverse transformation to propagate changes back new_values = view_info.update_tensor(new_values, parent_array) @@ -326,21 +349,22 @@ def update( @classmethod def __torch_dispatch__( - cls, - func: Any, - types: Tuple[Any, ...], - args: Tuple[Any, ...] = (), - kwargs: Optional[dict] = None, + cls, + func: Any, + types: Tuple[Any, ...], + args: Tuple[Any, ...] = (), + kwargs: Optional[dict] = None, ) -> Any: raise AssertionError( - 'torchax Tensors can only do math within the torchax environment.' - 'Please wrap your code with `with torchax.default_env()` or ' - 'call torchax.enable_globally() before.') + "torchax Tensors can only do math within the torchax environment." + "Please wrap your code with `with torchax.default_env()` or " + "call torchax.enable_globally() before." + ) def create_sub_view(self, view_info: ViewInfo) -> "View": """ - Create a new view that is a child of this view. - """ + Create a new view that is a child of this view. + """ return View(self, view_info, self._env) def __str__(self) -> str: @@ -348,8 +372,8 @@ def __str__(self) -> str: def jax(self) -> jax.Array: """ - Returns a copy of the source tensor after transformations. - """ + Returns a copy of the source tensor after transformations. + """ result = self.source_jax() for view_info in self.get_transformation_chain(): result = view_info.transform_tensor(result)