Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add Proxy, prim, and lookaside #1583

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 23, 2024

What does this PR do?

Implement

  • a new Proxy to express tensor wrapper subclasses
  • a new prim and a new lookaside to support MySubclass(...) inside programs to thunder.jit

to support programs that only call ctor of tensor wrapper subclasses

Signed-off-by: Masaki Kozuki <[email protected]>
meta=prims.tensor_subclass_ctor,
fn=_tensor_subclass_ctor,
bind_postprocess=_bind_postprocess_of_tensor_subclass_ctor,
python_printer=prims.printer_of_tensor_subclass_ctor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clever. The bind_postprocess together with python_printer achieve the effect that we generate an execution Python program that looks like people would write it in a script

t0 = ScaleTensorSubclass(x, scale)  # t0: "cpu f32[2, 2]"

even though bsym.args and bsym.kwargs are not just x, scale.

)


def printer_of_tensor_subclass_ctor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the execution and initial traces look the same even though they contain prims.tensor_subclass_ctor and torchex.tensor_subclass_ctor as bsym.sym. Could you modify the printer to distinguish between the two?

PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
python_printer=printer_of_tensor_subclass_ctor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the initial trace with this printer able to work with TensorProxy inputs? Here's an example for what I mean:

In [1]: import torch

In [2]: import thunder

In [3]: a = torch.randn(3, 3)

In [4]: def f(x): return x.sin()

In [5]: jf = thunder.jit(f)

In [6]: jf(a);

In [7]: initial_trace = jf._lc_cs.last_traces[0]

In [8]: with thunder.core.trace.detached_trace():
   ...:     res = initial_trace.python_callable()(*initial_trace.args)
   ...: 

In [9]: res
Out[9]: (<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(3, 3))>,)

Here the initial trace is a valid Thunder program able to accept and return TensorProxies if there's an active trace (detached_trace in this case).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never thought of it. Will try it out

@@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs):
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs)


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for this to be defined as jit_lookaside? It looks simple enough to be just in thunder/torch/__init__.py with the @torchsymbol decorator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, maybe no.
I originally tried to make it more graceful in a way but currently it's quite simple as you see.

output=self,
)
cls_module = inspect.getmodule(self._subclass_type)
bsym.sym = replace(bsym.sym, _module=cls_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please remind me how Symbol._module is used? For hashing, printing, and what else?

Can this be part of bind_postprocess? _subclass_type is part of the arguments so the module is available after binding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm embarrassingly losing my memory. One thing which might be related is

def import_ctx(self):
# NOTE This initializes the context (Accessing these properties is a function call with the desired side effect)
self._out_printables, self._arg_printables, self._kwarg_printables # type: ignore
if self.sym is not None and self.sym.python_impl is not None:
# NOTE BoundSymbols of Symbols with a python_impl defined are run in Python, and are assumed
# to not need any imports to run properly, unless _import_prims is True
if self.sym._print_as_impl:
assert self.sym.module is not None # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}
else:
import_ctx = {}
elif self._call_ctx is not None:
# NOTE If the call ctx was specified directly, then no import is needed to call the function
import_ctx = {}
else:
from thunder.extend import TemporaryExecutor
# BoundSymbols of Symbols without Python implementations (either because they
# have Python implementations or defined call ctxs) are assumed to need
# a module import to run properly
if isinstance(self.sym.executor, TemporaryExecutor):
import_ctx = {}
else:
assert self.sym.module is not None # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}
# TODO Include the other modules on the path?
# Also includes the root module of this (potential) submodule
if "." in module_name:
root_name = module_name.split(".")[0]
import_ctx[root_name] = sys.modules[root_name]
self._import_ctx.update(import_ctx)
return self._import_ctx
, more specifically, this else path,
else:
assert self.sym.module is not None # TODO: Is this a valid assumption?
module_name = self.sym.module.__name__
import_ctx = {module_name: self.sym.module}
# TODO Include the other modules on the path?
# Also includes the root module of this (potential) submodule
if "." in module_name:
root_name = module_name.split(".")[0]
import_ctx[root_name] = sys.modules[root_name]
for which Symbol._module cannot be None unless we change the condition

requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What breaks without this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now it looks like a fallout of things I tried while working on understanding and implementing things for tensor wrapper subclasses. let me see...

Comment on lines +958 to +966
subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this lookaside construct a proxy which has a bsym construction inside its init method instead of calling prims.tensor_subclass_ctor?

Suggested change
subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], [])

Tests from this PR pass when I replaced SubclassTensorProxy->prims.tensor_subclass_ctor and removed bsym construction in the init method:

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index e8841954..e06b0ebd 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -955,15 +955,7 @@ def _make_wrapper_subclass(
     udevice = unwrap(device)
     urequires_grad = unwrap(requires_grad)
 
-    subclass = SubclassTensorProxy(
-        None,
-        shape=usize,
-        device=udevice,
-        dtype=udtype,
-        requires_grad=urequires_grad,
-        history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
-        subclass_type=ucls,
-    )
+    subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], [])
     return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))
 
 
diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py
index c6c88656..424e60d6 100644
--- a/thunder/core/proxies.py
+++ b/thunder/core/proxies.py
@@ -1931,25 +1931,26 @@ class SubclassTensorProxy(TensorProxy):
             self._non_tensors = kwarg_non_tensors
             self._subclass_type = subclass_type
         else:
+            pass
             # TODO(crcrpar): Think about materializing `self` so that we can
             # call `__tensor_init__` to know each attribute names.
             from dataclasses import replace
             import inspect
             from thunder.core import prims
 
-            bsym = prims.tensor_subclass_ctor.bind(
-                self._subclass_type,
-                self.name,
-                self.shape,
-                self.device,
-                self.dtype,
-                self.requires_grad,
-                self._tensors,
-                self._non_tensors,
-                output=self,
-            )
-            cls_module = inspect.getmodule(self._subclass_type)
-            bsym.sym = replace(bsym.sym, _module=cls_module)
+            # bsym = prims.tensor_subclass_ctor.bind(
+            #     self._subclass_type,
+            #     self.name,
+            #     self.shape,
+            #     self.device,
+            #     self.dtype,
+            #     self.requires_grad,
+            #     self._tensors,
+            #     self._non_tensors,
+            #     output=self,
+            # )
+            # cls_module = inspect.getmodule(self._subclass_type)
+            # bsym.sym = replace(bsym.sym, _module=cls_module)
 
             # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
             # inside of it either directly or indirectly: indirect way is to call it through
@@ -1959,11 +1960,11 @@ class SubclassTensorProxy(TensorProxy):
             # but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
             # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
             # puts the temporary scope to the current trace.
-            current_trace = get_tracectx()
-            if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
-                current_trace.add_bound_symbol(bsym)
-            else:
-                cur_tail_scope.append(bsym)
+            # current_trace = get_tracectx()
+            # if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
+            #     current_trace.add_bound_symbol(bsym)
+            # else:
+            #     cur_tail_scope.append(bsym)
 
             if not self._tensors and not self._non_tensors:
                 for a in tensors:

import inspect
from thunder.core import prims

bsym = prims.tensor_subclass_ctor.bind(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests pass when this piece is removed and prims.tensor_subclass_ctor is used jit_ext.py. Is there a test that would require doing bsym construction here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants