diff --git a/thunder/__init__.py b/thunder/__init__.py index c09c1cc9b5..104af7734f 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -369,7 +369,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: data_ptr_to_tensor_group_index = {} tensor_group_index_to_tensor_indices = defaultdict(list) for idx, t in enumerate(flat_args): - if pytorch.is_tensor(t) and t.layout == pytorch.strided: + if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided: data_ptr = t.untyped_storage().data_ptr() if data_ptr not in data_ptr_to_tensor_group_index: data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 7b1cf3eb87..cfbb8c8a9f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -62,6 +62,7 @@ NumberProxy, StringProxy, TensorProxy, + SubclassTensorProxy, FutureTensorProxy, make_proxy_name, Variable, @@ -757,6 +758,42 @@ def grad_transform(*args, **kwargs): return forward_result +@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) +def _make_wrapper_subclass( + cls: torch._C._TensorMeta, + size: Sequence[int], + strides: Sequence[int] | None = None, + storage_offset: int | None = None, + memory_format: torch.memory_format | None = None, + dtype: torch.dtype | None = None, + layout: torch.layout | None = torch.strided, + device: torch.device | None = None, + pin_memory: bool = False, + requires_grad: bool = False, + dispatch_sizes_strides_policy: str | None = None, + dispatch_device: bool = False, + dispatch_layout: bool = False, + _extra_dispatch_keys: torch.DispatchKeySet | None = None, + storage_size: int | None = None, +): + ucls = unwrap(cls) + usize = unwrap(size) + udtype = unwrap(dtype) + udevice = unwrap(device) + urequires_grad = unwrap(requires_grad) + + subclass = SubclassTensorProxy( + None, + shape=usize, + device=udevice, + dtype=udtype, + requires_grad=urequires_grad, + history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), + subclass_type=ucls, + ) + return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance])) + + @register_general_jit_lookaside(torch.autocast.__enter__) def autocast_enter(autocast_obj): unwrap_autocast_obj = unwrap(autocast_obj) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c17a28296c..c9468d653d 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import auto, Enum from numbers import Number from functools import reduce, wraps @@ -77,6 +79,7 @@ def register_method(method_name: str, method: Callable, /) -> None: TupleProxy, AnyProxy, IntegerProxy, + SubclassTensorProxy, ) import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable @@ -272,6 +275,8 @@ class PrimIDs(Enum): COPY_ = auto() # SINK = auto() + # Tensor Subclasses methods + TENSOR_SUBCLASS_CTOR = auto() class OpTags(Enum): @@ -4048,3 +4053,27 @@ def sink_meta(*args, **kwargs): # TODO do we want another tag to remove this after prologue is constructed? sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) + + +def tensor_subclass_ctor_meta( + cls, name, shape, device, dtype, requires_grad, tensors, non_tensors +) -> SubclassTensorProxy: + s = SubclassTensorProxy( + name, + subclass_type=cls, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + tensors=tensors, + non_tensors=non_tensors, + history=[t.history for t in tensors], + ) + return s + + +tensor_subclass_ctor = make_prim( + PrimIDs.TENSOR_SUBCLASS_CTOR, + "tensor_subclass_ctor", + meta=tensor_subclass_ctor_meta, +) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 2f2eb1c665..8cf179e263 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1880,6 +1880,125 @@ def real(self): return method(self) +class SubclassTensorProxy(TensorProxy): + _tensors: list[TensorProxy] + _non_tensors: list[Any] + _subclass_type: torch._C._TensorMeta + + def __init__(self, *args, **kwargs): + from thunder.core.pytree import tree_flatten + + kwarg_tensors = kwargs.pop("tensors", []) + kwarg_non_tensors = kwargs.pop("non_tensors", []) + subclass_type = kwargs.pop("subclass_type", None) + + # If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` + # where `self` should already have gotten its name. + flat_args, spec = tree_flatten((args, kwargs)) + tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args)) + non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args)) + has_name_before_init = hasattr(self, "_name") + + is_dunder_init_following_make_wrapper_subclass: bool = False + if tensors: + baseutils.check( + has_name_before_init + and not kwarg_tensors + and not kwarg_non_tensors + and self._subclass_type is not None, + lambda: ( + f"{flat_args=} indicates this instance is created by" + "`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set" + ), + ) + is_dunder_init_following_make_wrapper_subclass = True + + if not is_dunder_init_following_make_wrapper_subclass: + super().__init__(*args, **kwargs) + + self._tensors = kwarg_tensors + self._non_tensors = kwarg_non_tensors + self._subclass_type = subclass_type + else: + # TODO(crcrpar): Think about materializing `self` so that we can + # call `__tensor_init__` to know each attribute names. + from thunder.core import prims + + self._tensors = tensors + self._non_tensors = non_tensors + bsym = prims.tensor_subclass_ctor.bind( + self._subclass_type, + self.name, + self.shape, + self.device, + self.dtype, + self.requires_grad, + self._tensors, + self._non_tensors, + output=self, + ) + # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)` + # inside of it either directly or indirectly: indirect way is to call it through + # a custom `torch.autograd.Function` as in + # https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209. + # If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical, + # but not, otherwise. As [the lookasdie of `torch.autograd.Function`]( + # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655) + # puts the temporary scope to the current trace. + current_trace = get_tracectx() + if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]): + current_trace.add_bound_symbol(bsym) + else: + cur_tail_scope.append(bsym) + + def replace(self, **changes): + r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments. + Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. + ``like`` is also a valid keyword and will take metadata from the tensor proxy argument + in preference to the old values but overridable by keyword arguments. + Note that the copy will use the current (environment) tracectx.""" + + like = changes.get("like") + ( + shape, + device, + dtype, + true_dtype, + numel, + ndim, + requires_grad, + grad, + distparallel_type, + thunder_fsdp_padding_size, + ) = _infer_tensor_properties( + like, + changes.get("shape", self._shape if like is None else None), + changes.get("device", self._device if like is None else None), + changes.get("dtype", self._dtype if like is None else None), + changes.get("requires_grad", self._requires_grad if like is None else None), + changes.get("grad", self._grad if like is None else None), + changes.get("distparallel_type", self._distparallel_type if like is None else None), + changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None), + ) + name = changes.get("name", self.name) + history = changes.get("history", self.history) + tags = changes.get("tags", self.tags) + return SubclassTensorProxy( + name=name, + tags=tags, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + distparallel_type=distparallel_type, + thunder_fsdp_padding_size=thunder_fsdp_padding_size, + history=history, + tensors=self._tensors, + non_tensors=self._non_tensors, + subclass_type=self._subclass_type, + ) + + class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface): def __init__( self, diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index afff715728..dc15653f2d 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2174,3 +2174,15 @@ def _shape_impl(t): shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x) _register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable) + + +def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors): + return cls(*tensors, *non_tensors) + + +tensor_subclass_ctor = ex.register_operator( + "tensor_subclass_ctor", + meta=prims.tensor_subclass_ctor, + fn=_tensor_subclass_ctor, +) +_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py new file mode 100644 index 0000000000..3652965a83 --- /dev/null +++ b/thunder/tests/test_tensor_subclass.py @@ -0,0 +1,182 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch + +import thunder +from thunder.tests.framework import instantiate +from thunder.tests.make_tensor import make_tensor + +if TYPE_CHECKING: + from typing import Any + + +@torch._dynamo.allow_in_graph +class EncapsulateXandScale(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, scale: torch.Tensor): + return ScaleTensorSubclass(x, scale) + + @staticmethod + def backward(ctx, grad): + return grad, None + + +def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass: + return EncapsulateXandScale.apply(x, scale) + + +@torch._dynamo.allow_in_graph +class ToScaleTensorSubclass(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return ScaleTensorSubclass.from_tensor(x) + + @staticmethod + def backward(ctx, grad): + return grad + + +def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass: + return ToScaleTensorSubclass.apply(x) + + +class ScaleTensorSubclass(torch.Tensor): + _x: torch.Tensor + _scale: torch.Tensor + __slots__ = ["_x", "_scale"] + + def __new__(cls, x: torch.Tensor, scale: torch.Tensor): + assert scale.numel() == 1, f"Invalid `scale`: {scale}" + dtype = x.dtype + device = x.device + self = torch.Tensor._make_wrapper_subclass( + cls, + x.size(), + dtype=dtype, + device=device, + # strides=x.stride(), + # storage_offset=x.storage_offset(), + # layout=x.layout, + # requires_grad=x.requires_grad, + ) + self._x = x + self._scale = scale + + return self + + # ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22 + __torch_function__ = torch._C._disabled_torch_function_impl + + def __repr__(self): + return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})" + + def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]: + return ["_x", "_scale"], {} + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, torch.Tensor], + metadata: dict[str, Any], + outer_size, + outer_stride, + ) -> ScaleTensorSubclass: + return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"]) + + @staticmethod + def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass: + scale = x.abs().max() + return ScaleTensorSubclass(x, scale) + + @classmethod + def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None): + + def allowed_subclass(typ): + return ( + issubclass(cls, typ) + or issubclass(torch._subclasses.FakeTensor, typ) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ) + ) + + def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): + if isinstance(t, ScaleTensorSubclass): + if t.is_floating_point(): + return t._x * t._scale + else: + return t._x + return t + + if not all(allowed_subclass(t) for t in types): + return NotImplementedError(f"Unsupported types are included: {types}") + + scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass)) + unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs)) + out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs) + if not isinstance(out, torch.Tensor): + return out + else: + return ScaleTensorSubclass(out, scales[0]) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_of_subclass_ctor_wrapper(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = ScaleTensorSubclass(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def f(x: torch.Tensor, scale: torch.Tensor): + y = ScaleTensorSubclass(x, scale) + z = ScaleTensorSubclass(y._x, y._scale) + return z + + jitted = executor.make_callable(f) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_calling_converter(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = encapsulate_x_and_scale(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def g(x: torch.Tensor) -> ScaleTensorSubclass: + y = to_scale_tensor_subclass(x) + return y + + jitted = thunder.jit(g) + x = make_tensor(shape, device=device, dtype=dtype) + + expected = g(x) + actual = jitted(x) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))