Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add Proxy, prim, and lookaside #1583

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
StringProxy,
TensorProxy,
FutureTensorProxy,
SubclassTensorProxy,
make_proxy_name,
Variable,
variableify,
Expand Down Expand Up @@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs):
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this to be defined as jit_lookaside? It looks simple enough to be just in thunder/torch/__init__.py with the @torchsymbol decorator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, maybe no.
I originally tried to make it more graceful in a way but currently it's quite simple as you see.

def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)

subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
Comment on lines +958 to +966
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this lookaside construct a proxy which has a bsym construction inside its init method instead of calling prims.tensor_subclass_ctor?

Suggested change
subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], [])

Tests from this PR pass when I replaced SubclassTensorProxy->prims.tensor_subclass_ctor and removed bsym construction in the init method:

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index e8841954..e06b0ebd 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -955,15 +955,7 @@ def _make_wrapper_subclass(
     udevice = unwrap(device)
     urequires_grad = unwrap(requires_grad)
 
-    subclass = SubclassTensorProxy(
-        None,
-        shape=usize,
-        device=udevice,
-        dtype=udtype,
-        requires_grad=urequires_grad,
-        history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
-        subclass_type=ucls,
-    )
+    subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], [])
     return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))
 
 
diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py
index c6c88656..424e60d6 100644
--- a/thunder/core/proxies.py
+++ b/thunder/core/proxies.py
@@ -1931,25 +1931,26 @@ class SubclassTensorProxy(TensorProxy):
             self._non_tensors = kwarg_non_tensors
             self._subclass_type = subclass_type
         else:
+            pass
             # TODO(crcrpar): Think about materializing `self` so that we can
             # call `__tensor_init__` to know each attribute names.
             from dataclasses import replace
             import inspect
             from thunder.core import prims
 
-            bsym = prims.tensor_subclass_ctor.bind(
-                self._subclass_type,
-                self.name,
-                self.shape,
-                self.device,
-                self.dtype,
-                self.requires_grad,
-                self._tensors,
-                self._non_tensors,
-                output=self,
-            )
-            cls_module = inspect.getmodule(self._subclass_type)
-            bsym.sym = replace(bsym.sym, _module=cls_module)
+            # bsym = prims.tensor_subclass_ctor.bind(
+            #     self._subclass_type,
+            #     self.name,
+            #     self.shape,
+            #     self.device,
+            #     self.dtype,
+            #     self.requires_grad,
+            #     self._tensors,
+            #     self._non_tensors,
+            #     output=self,
+            # )
+            # cls_module = inspect.getmodule(self._subclass_type)
+            # bsym.sym = replace(bsym.sym, _module=cls_module)
 
             # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
             # inside of it either directly or indirectly: indirect way is to call it through
@@ -1959,11 +1960,11 @@ class SubclassTensorProxy(TensorProxy):
             # but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
             # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
             # puts the temporary scope to the current trace.
-            current_trace = get_tracectx()
-            if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
-                current_trace.add_bound_symbol(bsym)
-            else:
-                cur_tail_scope.append(bsym)
+            # current_trace = get_tracectx()
+            # if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
+            #     current_trace.add_bound_symbol(bsym)
+            # else:
+            #     cur_tail_scope.append(bsym)
 
             if not self._tensors and not self._non_tensors:
                 for a in tensors:

return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
Expand Down
151 changes: 149 additions & 2 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations
from enum import auto, Enum
from numbers import Number
from functools import reduce, wraps
import operator
import builtins
import math
from types import NoneType
from typing import Union, Type, Any, List, Dict, Tuple, Optional
from typing import Union, Type, Any, List, Dict, Tuple, Optional, TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Callable, Hashable, Sequence

import torch

from thunder.core.langctxs import LanguageContext, register_langctx, Languages, langctx

if TYPE_CHECKING:
from collections.abc import Iterable
from thunder.core.codeutils import ContextObject

#
# Creates and registers the torch language context
#
Expand Down Expand Up @@ -77,6 +82,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
SubclassTensorProxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
Expand Down Expand Up @@ -272,6 +278,8 @@ class PrimIDs(Enum):
COPY_ = auto()
#
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -3543,7 +3551,10 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy:
if isinstance(a, SubclassTensorProxy):
# SubclassTensorProxy(like=...) would not copy some attrs such as `_tensors` while replace does.
return a.replace()
return TensorProxy(like=a)


Expand Down Expand Up @@ -4048,3 +4059,139 @@ def sink_meta(*args, **kwargs):

# TODO do we want another tag to remove this after prologue is constructed?
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))


def tensor_subclass_ctor_meta(
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors
) -> SubclassTensorProxy:
s = SubclassTensorProxy(
name,
subclass_type=cls,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What breaks without this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now it looks like a fallout of things I tried while working on understanding and implementing things for tensor wrapper subclasses. let me see...

)
return s


def get_nested_types(collection):
collection = utils.sequencify(collection)
types_set = {type(t) for t in collection}

def check_types(coll):
for item in coll:
types_set.add(type(item))
# Check if the item is a nested collection
if baseutils.is_collection(item):
# If it's a dictionary, check its values
if isinstance(item, dict):
check_types(item.values())
# Recursively check nested collections
else:
check_types(item)

check_types(collection)
return tuple(types_set)


def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(
filter(
lambda t: (
t.__module__ != "builtins"
and t != Number
# note(crcrpar): maybe `thunder.core`?
and not t.__module__.startswith("thunder.")
and not t.__module__.startswith("torch.")
),
types,
)
)


def printer_of_tensor_subclass_ctor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the execution and initial traces look the same even though they contain prims.tensor_subclass_ctor and torchex.tensor_subclass_ctor as bsym.sym. Could you modify the printer to distinguish between the two?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would we want to distinguish them? With comments (a.k.a bsym.header)? or signature?

Copy link
Collaborator

@t-vi t-vi Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the TorchEx version be using the subclass constructor directly or somesuch?
Or you do it with prims.tensor_subclass_ctor similar to how we have ltorch.foo und torch.foo for torch symbols.
(In fact, we do often use prims.foo, so that should be easy, too.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The impl in torchex calls the dunder init of tensor subclass as in

def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors):
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(to_torch_device(a))
else:
new_non_tensors.append(a)
return cls(*tensors, *new_non_tensors)

bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

baseutils.check(not kwarg_printables, lambda: f"No kwargs are supported but {kwarg_printables = }")

# NOTE(crcrpar): It's not a context but at the moment Tensor subclass is treated as `ContextObject`.
wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
if isinstance(wrapped_cls, torch._C._TensorMeta):
cls = wrapped_cls
else:
cls: torch._C._TensorMeta = wrapped_cls.obj
tensors, non_tensors = arg_printables[-2:]
new_non_tensors = []
for a in non_tensors:
if isinstance(a, dtypes.dtype):
new_non_tensors.append(dtypes.to_torch_dtype(a))
elif isinstance(a, devices.Device):
new_non_tensors.append(devices.to_torch_device(a))
else:
new_non_tensors.append(a)

arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors])
kwarg_str = ""

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str: str
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"
else:
comment_str = ""

cls_with_module = f"{cls.__name__}"
s = f"{result_str}{cls_with_module}({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

filtered_types = (cls,)
if non_tensors:
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors])
filtered_types += filter_types(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)
return s


def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
non_tensors = bsym.args[-1]

filtered_types: tuple[Any, ...] = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


tensor_subclass_ctor = make_prim(
PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
python_printer=printer_of_tensor_subclass_ctor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the initial trace with this printer able to work with TensorProxy inputs? Here's an example for what I mean:

In [1]: import torch

In [2]: import thunder

In [3]: a = torch.randn(3, 3)

In [4]: def f(x): return x.sin()

In [5]: jf = thunder.jit(f)

In [6]: jf(a);

In [7]: initial_trace = jf._lc_cs.last_traces[0]

In [8]: with thunder.core.trace.detached_trace():
   ...:     res = initial_trace.python_callable()(*initial_trace.args)
   ...: 

In [9]: res
Out[9]: (<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(3, 3))>,)

Here the initial trace is a valid Thunder program able to accept and return TensorProxies if there's an active trace (detached_trace in this case).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never thought of it. Will try it out

_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
)
Loading
Loading