-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tensor wrapper subclass] Add Proxy, prim
, and lookaside
#1583
base: main
Are you sure you want to change the base?
Conversation
to support programs that only call ctor of tensor wrapper subclasses Signed-off-by: Masaki Kozuki <[email protected]>
meta=prims.tensor_subclass_ctor, | ||
fn=_tensor_subclass_ctor, | ||
bind_postprocess=_bind_postprocess_of_tensor_subclass_ctor, | ||
python_printer=prims.printer_of_tensor_subclass_ctor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very clever. The bind_postprocess
together with python_printer
achieve the effect that we generate an execution Python program that looks like people would write it in a script
t0 = ScaleTensorSubclass(x, scale) # t0: "cpu f32[2, 2]"
even though bsym.args
and bsym.kwargs
are not just x, scale
.
) | ||
|
||
|
||
def printer_of_tensor_subclass_ctor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the execution and initial traces look the same even though they contain prims.tensor_subclass_ctor
and torchex.tensor_subclass_ctor
as bsym.sym
. Could you modify the printer to distinguish between the two?
PrimIDs.TENSOR_SUBCLASS_CTOR, | ||
"tensor_subclass_ctor", | ||
meta=tensor_subclass_ctor_meta, | ||
python_printer=printer_of_tensor_subclass_ctor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the initial trace with this printer able to work with TensorProxy inputs? Here's an example for what I mean:
In [1]: import torch
In [2]: import thunder
In [3]: a = torch.randn(3, 3)
In [4]: def f(x): return x.sin()
In [5]: jf = thunder.jit(f)
In [6]: jf(a);
In [7]: initial_trace = jf._lc_cs.last_traces[0]
In [8]: with thunder.core.trace.detached_trace():
...: res = initial_trace.python_callable()(*initial_trace.args)
...:
In [9]: res
Out[9]: (<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(3, 3))>,)
Here the initial trace is a valid Thunder program able to accept and return TensorProxies if there's an active trace (detached_trace
in this case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've never thought of it. Will try it out
@@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs): | |||
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) | |||
|
|||
|
|||
@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for this to be defined as jit_lookaside? It looks simple enough to be just in thunder/torch/__init__.py
with the @torchsymbol
decorator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, maybe no.
I originally tried to make it more graceful in a way but currently it's quite simple as you see.
output=self, | ||
) | ||
cls_module = inspect.getmodule(self._subclass_type) | ||
bsym.sym = replace(bsym.sym, _module=cls_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please remind me how Symbol._module
is used? For hashing, printing, and what else?
Can this be part of bind_postprocess
? _subclass_type
is part of the arguments so the module is available after binding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm embarrassingly losing my memory. One thing which might be related is
lightning-thunder/thunder/core/symbol.py
Lines 606 to 641 in 52ee541
def import_ctx(self): | |
# NOTE This initializes the context (Accessing these properties is a function call with the desired side effect) | |
self._out_printables, self._arg_printables, self._kwarg_printables # type: ignore | |
if self.sym is not None and self.sym.python_impl is not None: | |
# NOTE BoundSymbols of Symbols with a python_impl defined are run in Python, and are assumed | |
# to not need any imports to run properly, unless _import_prims is True | |
if self.sym._print_as_impl: | |
assert self.sym.module is not None # TODO: Is this a valid assumption? | |
module_name = self.sym.module.__name__ | |
import_ctx = {module_name: self.sym.module} | |
else: | |
import_ctx = {} | |
elif self._call_ctx is not None: | |
# NOTE If the call ctx was specified directly, then no import is needed to call the function | |
import_ctx = {} | |
else: | |
from thunder.extend import TemporaryExecutor | |
# BoundSymbols of Symbols without Python implementations (either because they | |
# have Python implementations or defined call ctxs) are assumed to need | |
# a module import to run properly | |
if isinstance(self.sym.executor, TemporaryExecutor): | |
import_ctx = {} | |
else: | |
assert self.sym.module is not None # TODO: Is this a valid assumption? | |
module_name = self.sym.module.__name__ | |
import_ctx = {module_name: self.sym.module} | |
# TODO Include the other modules on the path? | |
# Also includes the root module of this (potential) submodule | |
if "." in module_name: | |
root_name = module_name.split(".")[0] | |
import_ctx[root_name] = sys.modules[root_name] | |
self._import_ctx.update(import_ctx) | |
return self._import_ctx |
lightning-thunder/thunder/core/symbol.py
Lines 629 to 638 in 52ee541
else: | |
assert self.sym.module is not None # TODO: Is this a valid assumption? | |
module_name = self.sym.module.__name__ | |
import_ctx = {module_name: self.sym.module} | |
# TODO Include the other modules on the path? | |
# Also includes the root module of this (potential) submodule | |
if "." in module_name: | |
root_name = module_name.split(".")[0] | |
import_ctx[root_name] = sys.modules[root_name] |
Symbol._module
cannot be None unless we change the condition
requires_grad=requires_grad, | ||
tensors=tensors, | ||
non_tensors=non_tensors, | ||
history=[t.history for t in tensors], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What breaks without this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now it looks like a fallout of things I tried while working on understanding and implementing things for tensor wrapper subclasses. let me see...
subclass = SubclassTensorProxy( | ||
None, | ||
shape=usize, | ||
device=udevice, | ||
dtype=udtype, | ||
requires_grad=urequires_grad, | ||
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), | ||
subclass_type=ucls, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this lookaside construct a proxy which has a bsym construction inside its init method instead of calling prims.tensor_subclass_ctor
?
subclass = SubclassTensorProxy( | |
None, | |
shape=usize, | |
device=udevice, | |
dtype=udtype, | |
requires_grad=urequires_grad, | |
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), | |
subclass_type=ucls, | |
) | |
subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], []) |
Tests from this PR pass when I replaced SubclassTensorProxy->prims.tensor_subclass_ctor
and removed bsym construction in the init method:
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index e8841954..e06b0ebd 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -955,15 +955,7 @@ def _make_wrapper_subclass(
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)
- subclass = SubclassTensorProxy(
- None,
- shape=usize,
- device=udevice,
- dtype=udtype,
- requires_grad=urequires_grad,
- history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
- subclass_type=ucls,
- )
+ subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], [])
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))
diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py
index c6c88656..424e60d6 100644
--- a/thunder/core/proxies.py
+++ b/thunder/core/proxies.py
@@ -1931,25 +1931,26 @@ class SubclassTensorProxy(TensorProxy):
self._non_tensors = kwarg_non_tensors
self._subclass_type = subclass_type
else:
+ pass
# TODO(crcrpar): Think about materializing `self` so that we can
# call `__tensor_init__` to know each attribute names.
from dataclasses import replace
import inspect
from thunder.core import prims
- bsym = prims.tensor_subclass_ctor.bind(
- self._subclass_type,
- self.name,
- self.shape,
- self.device,
- self.dtype,
- self.requires_grad,
- self._tensors,
- self._non_tensors,
- output=self,
- )
- cls_module = inspect.getmodule(self._subclass_type)
- bsym.sym = replace(bsym.sym, _module=cls_module)
+ # bsym = prims.tensor_subclass_ctor.bind(
+ # self._subclass_type,
+ # self.name,
+ # self.shape,
+ # self.device,
+ # self.dtype,
+ # self.requires_grad,
+ # self._tensors,
+ # self._non_tensors,
+ # output=self,
+ # )
+ # cls_module = inspect.getmodule(self._subclass_type)
+ # bsym.sym = replace(bsym.sym, _module=cls_module)
# NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
# inside of it either directly or indirectly: indirect way is to call it through
@@ -1959,11 +1960,11 @@ class SubclassTensorProxy(TensorProxy):
# but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
# https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
# puts the temporary scope to the current trace.
- current_trace = get_tracectx()
- if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
- current_trace.add_bound_symbol(bsym)
- else:
- cur_tail_scope.append(bsym)
+ # current_trace = get_tracectx()
+ # if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
+ # current_trace.add_bound_symbol(bsym)
+ # else:
+ # cur_tail_scope.append(bsym)
if not self._tensors and not self._non_tensors:
for a in tensors:
import inspect | ||
from thunder.core import prims | ||
|
||
bsym = prims.tensor_subclass_ctor.bind( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests pass when this piece is removed and prims.tensor_subclass_ctor
is used jit_ext.py. Is there a test that would require doing bsym construction here?
What does this PR do?
Implement
MySubclass(...)
inside programs tothunder.jit