Skip to content

[wip]ruff inter on torchax #9396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions torchax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,21 @@ packages = ["torchax"]

[tool.pytest.ini_options]
addopts="-n auto"

[tool.ruff]
# Equivalent to column_limit
line-length = 80

# Enable preview mode to use rules like E306
preview = true

indent-width = 2

[tool.ruff.lint]
select = [
"E", "F", "W", # Your existing rule selections
"Q",
]
# Enforces a blank line before nested functions and classes
extend-select = ["E306"]

47 changes: 46 additions & 1 deletion torchax/test/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,21 @@
import torch
import unittest
import torchax
from torchax import interop
from torchax import interop, jax_device
import torchax
import jax


def is_tpu_available():
"""Checks if any TPU devices are available to JAX."""
try:
# jax.devices('tpu') will return a list of TPU devices if available.
# If no TPUs are found or JAX is not configured for TPU,
# it will raise a RuntimeError.
tpu_devices = jax.devices('tpu')
return len(tpu_devices) > 0
except RuntimeError:
return False


class InteropTest(unittest.TestCase):
Expand Down Expand Up @@ -116,6 +129,38 @@ def forward(self, x):
# assert
torch.testing.assert_allclose(actual, expected)

def test_to_jax_device(self):
a = torch.ones(3, 3)

with jax_device("cpu"):
# move torch.tensor to torchax.tensor CPU
b = a.to("jax")
self.assertEqual(b.jax_device.platform, "cpu")
self.assertEqual(b.device.type, "jax")

if is_tpu_available():
# move torch.tensor to torchax.tensor TPU
with jax_device("tpu"):
c = a.to("jax")
self.assertEqual(c.jax_device.platform, "tpu")
self.assertEqual(c.device.type, "jax")

# move torchax.tensor on CPU to TPU
with jax_device("tpu"):
self.assertEqual(b.jax_device.platform, "cpu")
self.assertEqual(c.device.type, "jax")
c = b.to("jax")
self.assertEqual(c.jax_device.platform, "tpu")
self.assertEqual(c.device.type, "jax")

# move torchax.tensor on TPU to CPU
with jax_device("cpu"):
self.assertEqual(c.jax_device.platform, "tpu")
self.assertEqual(c.device.type, "jax")
d = c.to("jax")
self.assertEqual(d.jax_device.platform, "cpu")
self.assertEqual(d.device.type, "jax")


if __name__ == '__main__':
unittest.main()
86 changes: 61 additions & 25 deletions torchax/torchax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,27 @@
from torch.utils import _pytree as pytree
from torchax import tensor
from torchax import distributed # noqa: F401
from contextlib import contextmanager

__version__ = "0.0.4"
VERSION = __version__

__all__ = [
'default_env',
'extract_jax',
'enable_globally',
"default_env",
"extract_jax",
"enable_globally",
]

from jax._src import xla_bridge

os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")

# torchax:oss-begin
if getattr(jax.config, 'jax_pjrt_client_create_options', None):
if getattr(jax.config, "jax_pjrt_client_create_options", None):
jax.config.update(
'jax_pjrt_client_create_options',
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}')
"jax_pjrt_client_create_options",
f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
)
# torchax:oss-end

env = None
Expand All @@ -48,7 +50,7 @@ def extract_jax(mod: torch.nn.Module, env=None):

states = env.t2j_copy(states)

#@jax.jit
# @jax.jit
def jax_func(states, inputs):
(states, inputs) = env.j2t_iso((states, inputs))
with env:
Expand Down Expand Up @@ -78,53 +80,87 @@ def disable_temporarily():
enable_globally()


torch.utils.rename_privateuse1_backend('jax')
torch.utils.rename_privateuse1_backend("jax")
unsupported_dtype = [torch.quint8]
torch.utils.generate_methods_for_privateuse1_backend(
for_tensor=True,
for_module=True,
for_storage=True,
unsupported_dtype=unsupported_dtype)
for_tensor=True,
for_module=True,
for_storage=True,
unsupported_dtype=unsupported_dtype,
)

import jax
import torchax.device_module

torch._register_device_module('jax', torchax.device_module)
torch._register_device_module("jax", torchax.device_module)


def enable_accuracy_mode():
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_default_matmul_precision', 'highest')
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_matmul_precision", "highest")
default_env().config.internal_respect_torch_return_dtypes = True


def enable_performance_mode():
jax.config.update('jax_enable_x64', False)
jax.config.update('jax_default_matmul_precision', 'default')
jax.config.update("jax_enable_x64", False)
jax.config.update("jax_default_matmul_precision", "default")
default_env().config.internal_respect_torch_return_dtypes = False


@dataclasses.dataclass
class CompileOptions:
# only valid if compiling nn.Module
methods_to_compile: List[str] = dataclasses.field(
default_factory=lambda: ['forward'])
default_factory=lambda: ["forward"]
)
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
mode: str = 'jax' # or dynamo or export
mode: str = "jax" # or dynamo or export


def compile(fn, options: Optional[CompileOptions] = None):
options = options or CompileOptions()
if options.mode == 'jax':
if options.mode == "jax":
from torchax import interop

if isinstance(fn, torch.nn.Module):
module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
for n in options.methods_to_compile:
module.make_jitted(n)
return module
else:
return interop.jax_jit(fn)
elif options.mode == 'dynamo':
raise RuntimeError('dynamo mode is not supported yet')
elif options.mode == 'export':
raise RuntimeError('export mode is not supported yet')
elif options.mode == "dynamo":
raise RuntimeError("dynamo mode is not supported yet")
elif options.mode == "export":
raise RuntimeError("export mode is not supported yet")


@contextmanager
def jax_device(target_device: str, env: tensor.Environment | None = None):
"""
to("jax") cannot differentiate the device/platform (cpu vs tpu).
Use this context manager to control jax array's storage device

Examples:

a = torch.ones(3, 3)

with jax_device("cpu"):
b = a.to("jax")

with jax_device("tpu"):
c = a.to("jax")

with jax_device("tpu"):
c = b.to("jax")

"""
if env is None:
env = default_env()

prev_target_device = env.target_device
try:
env.target_device = target_device
yield env
finally:
env.target_device = prev_target_device
Loading