-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tensor wrapper subclass] Add Proxy, prim
, and lookaside
#1583
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -63,6 +63,7 @@ | |||||||||||||||||||||
StringProxy, | ||||||||||||||||||||||
TensorProxy, | ||||||||||||||||||||||
FutureTensorProxy, | ||||||||||||||||||||||
SubclassTensorProxy, | ||||||||||||||||||||||
make_proxy_name, | ||||||||||||||||||||||
Variable, | ||||||||||||||||||||||
variableify, | ||||||||||||||||||||||
|
@@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs): | |||||||||||||||||||||
return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) | ||||||||||||||||||||||
def _make_wrapper_subclass( | ||||||||||||||||||||||
cls: torch._C._TensorMeta, | ||||||||||||||||||||||
size: Sequence[int], | ||||||||||||||||||||||
strides: Sequence[int] | None = None, | ||||||||||||||||||||||
storage_offset: int | None = None, | ||||||||||||||||||||||
memory_format: torch.memory_format | None = None, | ||||||||||||||||||||||
dtype: torch.dtype | None = None, | ||||||||||||||||||||||
layout: torch.layout | None = torch.strided, | ||||||||||||||||||||||
device: torch.device | None = None, | ||||||||||||||||||||||
pin_memory: bool = False, | ||||||||||||||||||||||
requires_grad: bool = False, | ||||||||||||||||||||||
dispatch_sizes_strides_policy: str | None = None, | ||||||||||||||||||||||
dispatch_device: bool = False, | ||||||||||||||||||||||
dispatch_layout: bool = False, | ||||||||||||||||||||||
_extra_dispatch_keys: torch.DispatchKeySet | None = None, | ||||||||||||||||||||||
storage_size: int | None = None, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
ucls = unwrap(cls) | ||||||||||||||||||||||
usize = unwrap(size) | ||||||||||||||||||||||
udtype = unwrap(dtype) | ||||||||||||||||||||||
udevice = unwrap(device) | ||||||||||||||||||||||
urequires_grad = unwrap(requires_grad) | ||||||||||||||||||||||
|
||||||||||||||||||||||
subclass = SubclassTensorProxy( | ||||||||||||||||||||||
None, | ||||||||||||||||||||||
shape=usize, | ||||||||||||||||||||||
device=udevice, | ||||||||||||||||||||||
dtype=udtype, | ||||||||||||||||||||||
requires_grad=urequires_grad, | ||||||||||||||||||||||
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), | ||||||||||||||||||||||
subclass_type=ucls, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
Comment on lines
+958
to
+966
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this lookaside construct a proxy which has a bsym construction inside its init method instead of calling
Suggested change
Tests from this PR pass when I replaced diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index e8841954..e06b0ebd 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -955,15 +955,7 @@ def _make_wrapper_subclass(
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)
- subclass = SubclassTensorProxy(
- None,
- shape=usize,
- device=udevice,
- dtype=udtype,
- requires_grad=urequires_grad,
- history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
- subclass_type=ucls,
- )
+ subclass = prims.tensor_subclass_ctor(ucls, None, usize, udevice, udtype, urequires_grad, [], [])
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))
diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py
index c6c88656..424e60d6 100644
--- a/thunder/core/proxies.py
+++ b/thunder/core/proxies.py
@@ -1931,25 +1931,26 @@ class SubclassTensorProxy(TensorProxy):
self._non_tensors = kwarg_non_tensors
self._subclass_type = subclass_type
else:
+ pass
# TODO(crcrpar): Think about materializing `self` so that we can
# call `__tensor_init__` to know each attribute names.
from dataclasses import replace
import inspect
from thunder.core import prims
- bsym = prims.tensor_subclass_ctor.bind(
- self._subclass_type,
- self.name,
- self.shape,
- self.device,
- self.dtype,
- self.requires_grad,
- self._tensors,
- self._non_tensors,
- output=self,
- )
- cls_module = inspect.getmodule(self._subclass_type)
- bsym.sym = replace(bsym.sym, _module=cls_module)
+ # bsym = prims.tensor_subclass_ctor.bind(
+ # self._subclass_type,
+ # self.name,
+ # self.shape,
+ # self.device,
+ # self.dtype,
+ # self.requires_grad,
+ # self._tensors,
+ # self._non_tensors,
+ # output=self,
+ # )
+ # cls_module = inspect.getmodule(self._subclass_type)
+ # bsym.sym = replace(bsym.sym, _module=cls_module)
# NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
# inside of it either directly or indirectly: indirect way is to call it through
@@ -1959,11 +1960,11 @@ class SubclassTensorProxy(TensorProxy):
# but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
# https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
# puts the temporary scope to the current trace.
- current_trace = get_tracectx()
- if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
- current_trace.add_bound_symbol(bsym)
- else:
- cur_tail_scope.append(bsym)
+ # current_trace = get_tracectx()
+ # if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
+ # current_trace.add_bound_symbol(bsym)
+ # else:
+ # cur_tail_scope.append(bsym)
if not self._tensors and not self._non_tensors:
for a in tensors: |
||||||||||||||||||||||
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance])) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
# Adds proxy methods | ||||||||||||||||||||||
# NOTE These methods map to themselves, which prevents the interpreter from looking into them | ||||||||||||||||||||||
# This is OK because these methods are written in a tracing-safe manner, and trying to | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,18 +1,23 @@ | ||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||
from enum import auto, Enum | ||||||||||||||||||||||
from numbers import Number | ||||||||||||||||||||||
from functools import reduce, wraps | ||||||||||||||||||||||
import operator | ||||||||||||||||||||||
import builtins | ||||||||||||||||||||||
import math | ||||||||||||||||||||||
from types import NoneType | ||||||||||||||||||||||
from typing import Union, Type, Any, List, Dict, Tuple, Optional | ||||||||||||||||||||||
from typing import Union, Type, Any, List, Dict, Tuple, Optional, TYPE_CHECKING | ||||||||||||||||||||||
from collections.abc import Callable | ||||||||||||||||||||||
from collections.abc import Callable, Hashable, Sequence | ||||||||||||||||||||||
|
||||||||||||||||||||||
import torch | ||||||||||||||||||||||
|
||||||||||||||||||||||
from thunder.core.langctxs import LanguageContext, register_langctx, Languages, langctx | ||||||||||||||||||||||
|
||||||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||||||
from collections.abc import Iterable | ||||||||||||||||||||||
from thunder.core.codeutils import ContextObject | ||||||||||||||||||||||
|
||||||||||||||||||||||
# | ||||||||||||||||||||||
# Creates and registers the torch language context | ||||||||||||||||||||||
# | ||||||||||||||||||||||
|
@@ -77,6 +82,7 @@ def register_method(method_name: str, method: Callable, /) -> None: | |||||||||||||||||||||
TupleProxy, | ||||||||||||||||||||||
AnyProxy, | ||||||||||||||||||||||
IntegerProxy, | ||||||||||||||||||||||
SubclassTensorProxy, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
import thunder.core.codeutils as codeutils | ||||||||||||||||||||||
from thunder.core.codeutils import Printable | ||||||||||||||||||||||
|
@@ -272,6 +278,8 @@ class PrimIDs(Enum): | |||||||||||||||||||||
COPY_ = auto() | ||||||||||||||||||||||
# | ||||||||||||||||||||||
SINK = auto() | ||||||||||||||||||||||
# Tensor Subclasses methods | ||||||||||||||||||||||
TENSOR_SUBCLASS_CTOR = auto() | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class OpTags(Enum): | ||||||||||||||||||||||
|
@@ -3543,7 +3551,10 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro | |||||||||||||||||||||
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy: | ||||||||||||||||||||||
def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy: | ||||||||||||||||||||||
if isinstance(a, SubclassTensorProxy): | ||||||||||||||||||||||
# SubclassTensorProxy(like=...) would not copy some attrs such as `_tensors` while replace does. | ||||||||||||||||||||||
return a.replace() | ||||||||||||||||||||||
return TensorProxy(like=a) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
|
@@ -4048,3 +4059,139 @@ def sink_meta(*args, **kwargs): | |||||||||||||||||||||
|
||||||||||||||||||||||
# TODO do we want another tag to remove this after prologue is constructed? | ||||||||||||||||||||||
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def tensor_subclass_ctor_meta( | ||||||||||||||||||||||
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors | ||||||||||||||||||||||
) -> SubclassTensorProxy: | ||||||||||||||||||||||
s = SubclassTensorProxy( | ||||||||||||||||||||||
name, | ||||||||||||||||||||||
subclass_type=cls, | ||||||||||||||||||||||
shape=shape, | ||||||||||||||||||||||
device=device, | ||||||||||||||||||||||
dtype=dtype, | ||||||||||||||||||||||
requires_grad=requires_grad, | ||||||||||||||||||||||
tensors=tensors, | ||||||||||||||||||||||
non_tensors=non_tensors, | ||||||||||||||||||||||
history=[t.history for t in tensors], | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What breaks without this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now it looks like a fallout of things I tried while working on understanding and implementing things for tensor wrapper subclasses. let me see... |
||||||||||||||||||||||
) | ||||||||||||||||||||||
return s | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def get_nested_types(collection): | ||||||||||||||||||||||
collection = utils.sequencify(collection) | ||||||||||||||||||||||
types_set = {type(t) for t in collection} | ||||||||||||||||||||||
|
||||||||||||||||||||||
def check_types(coll): | ||||||||||||||||||||||
for item in coll: | ||||||||||||||||||||||
types_set.add(type(item)) | ||||||||||||||||||||||
# Check if the item is a nested collection | ||||||||||||||||||||||
if baseutils.is_collection(item): | ||||||||||||||||||||||
# If it's a dictionary, check its values | ||||||||||||||||||||||
if isinstance(item, dict): | ||||||||||||||||||||||
check_types(item.values()) | ||||||||||||||||||||||
# Recursively check nested collections | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
check_types(item) | ||||||||||||||||||||||
|
||||||||||||||||||||||
check_types(collection) | ||||||||||||||||||||||
return tuple(types_set) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]: | ||||||||||||||||||||||
return tuple( | ||||||||||||||||||||||
filter( | ||||||||||||||||||||||
lambda t: ( | ||||||||||||||||||||||
t.__module__ != "builtins" | ||||||||||||||||||||||
and t != Number | ||||||||||||||||||||||
# note(crcrpar): maybe `thunder.core`? | ||||||||||||||||||||||
and not t.__module__.startswith("thunder.") | ||||||||||||||||||||||
and not t.__module__.startswith("torch.") | ||||||||||||||||||||||
), | ||||||||||||||||||||||
types, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def printer_of_tensor_subclass_ctor( | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the execution and initial traces look the same even though they contain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how would we want to distinguish them? With comments (a.k.a bsym.header)? or signature? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could the TorchEx version be using the subclass constructor directly or somesuch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The impl in torchex calls the dunder init of tensor subclass as in lightning-thunder/thunder/executors/torchex.py Lines 2225 to 2234 in 17b5570
|
||||||||||||||||||||||
bsym: BoundSymbol, | ||||||||||||||||||||||
out_printables: Any, | ||||||||||||||||||||||
arg_printables: Sequence[Printable], | ||||||||||||||||||||||
kwarg_printables: dict[str, Printable], | ||||||||||||||||||||||
) -> str | Iterable[str]: | ||||||||||||||||||||||
from itertools import chain | ||||||||||||||||||||||
|
||||||||||||||||||||||
baseutils.check(not kwarg_printables, lambda: f"No kwargs are supported but {kwarg_printables = }") | ||||||||||||||||||||||
|
||||||||||||||||||||||
# NOTE(crcrpar): It's not a context but at the moment Tensor subclass is treated as `ContextObject`. | ||||||||||||||||||||||
wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0] | ||||||||||||||||||||||
if isinstance(wrapped_cls, torch._C._TensorMeta): | ||||||||||||||||||||||
cls = wrapped_cls | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
cls: torch._C._TensorMeta = wrapped_cls.obj | ||||||||||||||||||||||
tensors, non_tensors = arg_printables[-2:] | ||||||||||||||||||||||
new_non_tensors = [] | ||||||||||||||||||||||
for a in non_tensors: | ||||||||||||||||||||||
if isinstance(a, dtypes.dtype): | ||||||||||||||||||||||
new_non_tensors.append(dtypes.to_torch_dtype(a)) | ||||||||||||||||||||||
elif isinstance(a, devices.Device): | ||||||||||||||||||||||
new_non_tensors.append(devices.to_torch_device(a)) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
new_non_tensors.append(a) | ||||||||||||||||||||||
|
||||||||||||||||||||||
arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors]) | ||||||||||||||||||||||
kwarg_str = "" | ||||||||||||||||||||||
|
||||||||||||||||||||||
result_str: str | ||||||||||||||||||||||
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0): | ||||||||||||||||||||||
result_str = "" | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = " | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Creates a comment describing the output | ||||||||||||||||||||||
comment_str: str | ||||||||||||||||||||||
if isinstance(bsym.output, Proxy): | ||||||||||||||||||||||
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}" | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
comment_str = "" | ||||||||||||||||||||||
|
||||||||||||||||||||||
cls_with_module = f"{cls.__name__}" | ||||||||||||||||||||||
s = f"{result_str}{cls_with_module}({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}" | ||||||||||||||||||||||
|
||||||||||||||||||||||
if bsym.header: | ||||||||||||||||||||||
header_lines = ( | ||||||||||||||||||||||
bsym.header | ||||||||||||||||||||||
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str) | ||||||||||||||||||||||
else bsym.header.splitlines() | ||||||||||||||||||||||
) | ||||||||||||||||||||||
header_lines = (f"# {line}" for line in header_lines) | ||||||||||||||||||||||
return chain(header_lines, [s]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
filtered_types = (cls,) | ||||||||||||||||||||||
if non_tensors: | ||||||||||||||||||||||
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors]) | ||||||||||||||||||||||
filtered_types += filter_types(types) | ||||||||||||||||||||||
new_imports = {t.__name__: t for t in filtered_types} | ||||||||||||||||||||||
bsym._import_ctx.update(new_imports) | ||||||||||||||||||||||
return s | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: | ||||||||||||||||||||||
cls = bsym.args[0] | ||||||||||||||||||||||
non_tensors = bsym.args[-1] | ||||||||||||||||||||||
|
||||||||||||||||||||||
filtered_types: tuple[Any, ...] = (cls,) | ||||||||||||||||||||||
if non_tensors: | ||||||||||||||||||||||
types = get_nested_types(non_tensors) | ||||||||||||||||||||||
filtered_types += filter_types(types) | ||||||||||||||||||||||
new_imports = {t.__name__: t for t in filtered_types} | ||||||||||||||||||||||
bsym._import_ctx.update(new_imports) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
tensor_subclass_ctor = make_prim( | ||||||||||||||||||||||
PrimIDs.TENSOR_SUBCLASS_CTOR, | ||||||||||||||||||||||
"tensor_subclass_ctor", | ||||||||||||||||||||||
meta=tensor_subclass_ctor_meta, | ||||||||||||||||||||||
python_printer=printer_of_tensor_subclass_ctor, | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the initial trace with this printer able to work with TensorProxy inputs? Here's an example for what I mean: In [1]: import torch
In [2]: import thunder
In [3]: a = torch.randn(3, 3)
In [4]: def f(x): return x.sin()
In [5]: jf = thunder.jit(f)
In [6]: jf(a);
In [7]: initial_trace = jf._lc_cs.last_traces[0]
In [8]: with thunder.core.trace.detached_trace():
...: res = initial_trace.python_callable()(*initial_trace.args)
...:
In [9]: res
Out[9]: (<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(3, 3))>,) Here the initial trace is a valid Thunder program able to accept and return TensorProxies if there's an active trace ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've never thought of it. Will try it out |
||||||||||||||||||||||
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor, | ||||||||||||||||||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for this to be defined as jit_lookaside? It looks simple enough to be just in
thunder/torch/__init__.py
with the@torchsymbol
decorator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, maybe no.
I originally tried to make it more graceful in a way but currently it's quite simple as you see.