Skip to content

Commit a133b50

Browse files
davidberard98pytorchmergebot
authored andcommitted
[JIT] Partially support ForwardRef type annotations for NamedTuple attributes (pytorch#96933)
**Summary** NamedTuple attributes can be annotated to declare their type: ```python class MyNamedTuple(NamedTuple): x: int y: torch.Tensor z: MyOtherType ``` Normally in python you can also declare your types as strings, `x: 'int'`. But NamedTuples previously didn't support this, because their annotation evaluation process was slightly different. This PR updates the NamedTuple attribute type annotation evaluation method to support ForwardRef declarations (i.e. declaring as strings). **Details** Below I repeat the comment I left in _jit_internal.py: NamedTuple types are slightly different from normal types. Normally, annotations are evaluted like this (during jit.script): 1. Load strings of python code into c++ and parse. 2. Get annotations as strings 3. Use the PythonResolver's resolution callback (rcb) to convert the string into a python object 4. We call into annotations.py:ann_to_type to convert python obj from step 3 into a type that torchscript understands. NamedTuples are more complicated, because they have sub-types. Normally, once we have the NamedTuple type object from #3, we can just look at the annotation literal values and use ann_to_type directly on them. But sometimes, users will annotate with string literals, e.g. ``` x: 'int' ``` This also happens with PEP563 (from __forward__ import annotations) These annotations appear in the annotation dict as ForwardRef('int'). Then, we need to convert the string into a python object. This requires having local context for custom objects or imported types. rcb() is what gives us this. So, we plumb rcb through the stack so it can be used in this context for the if block below. FAQ: - Why do we need this special handling for NamedTuple but string annotations work fine for normal types? Normally, we parse the string directly and then call rcb() directly from C++. - Why not use ForwardRef._evaluate? For that, we need globals() and locals() for the local context where the NamedTuple was defined. rcb is what lets us look up into these. So, basically rcb does the hard work for us. - What is rcb? rcb is a ResolutionCallback - python callable that takes a string and returns a type. It's generated by `createResolutionCallback.*` in _jit_internal.py. **Why is this only partial support**: This only plumbs the rcb through some paths. In particular, the `toSugaredValue` path uses a fake rcb. **Alternatives**: We could also treat this the way we treat non-nn.Module classes: we evaluate them separately, ahead of time. That solution is probably better, but probably requires a more risky refactor for the way NamedTuples are handled. Fixes pytorch#95858 Pull Request resolved: pytorch#96933 Approved by: https://github.com/qihqi
1 parent d850c33 commit a133b50

8 files changed

+160
-18
lines changed

test/jit/test_list_dict.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# Make the helper files in test/ importable
1818
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
1919
sys.path.append(pytorch_test_dir)
20-
from torch.testing._internal.jit_utils import JitTestCase
20+
from torch.testing._internal.jit_utils import JitTestCase, make_global
2121
from torch.testing._internal.common_utils import skipIfTorchDynamo
2222

2323
if __name__ == '__main__':
@@ -2084,6 +2084,62 @@ def forward(self):
20842084
for name in ['a', 'b', 'c']:
20852085
self.assertEqual(getattr(out_loaded, name), getattr(out, name))
20862086

2087+
def test_namedtuple_inside_forwardref(self):
2088+
class FeatureVector(NamedTuple):
2089+
float_features: 'float'
2090+
sequence_features: 'List[float]'
2091+
time_since_first: 'float'
2092+
2093+
@torch.jit.script
2094+
def foo(x) -> float:
2095+
fv = FeatureVector(3.0, [3.0], 3.0)
2096+
rv = fv.float_features
2097+
for val in fv.sequence_features:
2098+
rv += val
2099+
rv *= fv.time_since_first
2100+
return rv
2101+
2102+
self.assertEqual(foo(torch.rand(3, 4)), 18.0)
2103+
2104+
def test_namedtuple_input_forwardref(self):
2105+
class MyNamedTuple(NamedTuple):
2106+
a : int
2107+
b : float
2108+
c : torch.Tensor
2109+
2110+
make_global(MyNamedTuple)
2111+
2112+
nt = MyNamedTuple(4, 2.5, torch.rand((2, 2)))
2113+
2114+
def fn(obj: MyNamedTuple):
2115+
return ((obj.c + obj.b) ** obj.a).sin()
2116+
2117+
expected = fn(nt)
2118+
fn_s = torch.jit.script(fn)
2119+
actual = fn_s(nt)
2120+
self.assertEqual(expected, actual)
2121+
2122+
# see #95858
2123+
@unittest.expectedFailure
2124+
def test_namedtuple_resolution_forwardref(self):
2125+
class TheType(NamedTuple):
2126+
t: 'int'
2127+
2128+
class MyModule(types.ModuleType):
2129+
def __init__(self):
2130+
super().__init__('MyModule')
2131+
2132+
def __getattr__(self, attr):
2133+
return TheType
2134+
2135+
some_module = MyModule()
2136+
2137+
def fn() -> some_module.Type:
2138+
return some_module.Type(1)
2139+
2140+
self.checkScript(fn, [])
2141+
2142+
20872143
class TestScriptDict(JitTestCase):
20882144
"""
20892145
This class contains a suite of tests for torch.jit.script, a

test/jit/test_save_load.py

+18
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,24 @@ def forward(self, x: FooTuple) -> torch.Tensor:
433433
output = m_loaded(FooTuple(a=5))
434434
self.assertEqual(output, torch.tensor(3))
435435

436+
def test_save_namedtuple_input_only_forwardref(self):
437+
"""
438+
Even if a NamedTuple is only used as an input argument, saving and
439+
loading should work correctly.
440+
"""
441+
global FooTuple # see [local resolution in python]
442+
443+
class FooTuple(NamedTuple):
444+
a: 'int'
445+
446+
class MyModule(torch.nn.Module):
447+
def forward(self, x: FooTuple) -> torch.Tensor:
448+
return torch.tensor(3)
449+
450+
m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
451+
output = m_loaded(FooTuple(a=5))
452+
self.assertEqual(output, torch.tensor(3))
453+
436454
def test_save_namedtuple_output_only(self):
437455
"""
438456
Even if a NamedTuple is only used as an output argument, saving and

torch/_jit_internal.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Callable,
2424
Dict,
2525
Final,
26+
ForwardRef,
2627
Generic,
2728
List,
2829
Optional,
@@ -1199,7 +1200,7 @@ def _try_get_dispatched_fn(fn):
11991200

12001201

12011202
def _get_named_tuple_properties(
1202-
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None
1203+
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
12031204
):
12041205
if loc is None:
12051206
loc = fake_range()
@@ -1225,7 +1226,53 @@ def _get_named_tuple_properties(
12251226
annotations = []
12261227
for field in obj._fields:
12271228
if field in obj_annotations:
1228-
the_type = torch.jit.annotations.ann_to_type(obj_annotations[field], loc)
1229+
field_type = obj_annotations[field]
1230+
# [Note: ForwardRef annotations in NamedTuple attributes]
1231+
# NamedTuple types are slightly different from normal types.
1232+
#
1233+
# Normally, annotations are evaluted like this (during jit.script):
1234+
# 1. Load strings of python code into c++ and parse.
1235+
# 2. Get annotations as strings
1236+
# 3. Use the PythonResolver's resolution callback (rcb) to convert
1237+
# the string into a python object
1238+
# 4. We call into annotations.py:ann_to_type to convert python obj
1239+
# from step 3 into a type that torchscript understands.
1240+
#
1241+
# NamedTuples are more complicated, because it has sub-types.
1242+
# Normally, once we have the NamedTuple type object from #3,
1243+
# we can just look at the annotation literal values and use
1244+
# ann_to_type directly on them.
1245+
#
1246+
# But sometimes, users will annotate with string literals, e.g.
1247+
# x: 'int'
1248+
# This also happens with PEP563 (from __forward__ import annotations)
1249+
#
1250+
# These annotations appear in the annotation dict as ForwardRef('int').
1251+
#
1252+
# Then, we need to convert the string into a python object. This
1253+
# requires having local context for custom objects or imported types.
1254+
# rcb() is what gives us this. So, we plumb rcb through the stack so
1255+
# it can be used in this context for the if block below.
1256+
#
1257+
# FAQ:
1258+
# - Why do we need this special handling for NamedTuple but string
1259+
# annotations work fine for normal types? Normally, we parse the
1260+
# string directly and then call rcb() directly from C++.
1261+
# - Why not use ForwardRef._evaluate? For that, we need globals()
1262+
# and locals() for the local context where the NamedTuple was defined.
1263+
# rcb is what lets us look up into these. So, basically rcb does the
1264+
# hard work for us.
1265+
if isinstance(field_type, ForwardRef) and rcb is not None:
1266+
rcb_type = rcb(field_type.__forward_arg__)
1267+
# rcb returns None if it can't find anything.
1268+
if rcb_type is None:
1269+
raise ValueError(
1270+
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
1271+
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
1272+
f" Issue occurred at {loc.highlight()}"
1273+
)
1274+
field_type = rcb_type
1275+
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
12291276
annotations.append(the_type)
12301277
else:
12311278
annotations.append(torch._C.TensorType.getInferred())

torch/csrc/jit/python/pybind_utils.h

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
namespace torch {
5858
namespace jit {
5959

60+
using ResolutionCallback = std::function<py::object(std::string)>;
61+
6062
void clear_registered_instances(void* ptr);
6163

6264
TORCH_PYTHON_API IValue toIValue(

torch/csrc/jit/python/python_sugared_value.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -1006,13 +1006,19 @@ bool isNamedTupleClass(const py::object& obj) {
10061006
return is_tuple_class == 1 && py::hasattr(obj, "_fields");
10071007
}
10081008

1009-
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
1009+
TypePtr registerNamedTuple(
1010+
const py::object& obj,
1011+
const SourceRange& loc,
1012+
const ResolutionCallback& rcb) {
10101013
TORCH_INTERNAL_ASSERT(isNamedTupleClass(obj));
10111014
auto qualifiedName = c10::QualifiedName(py::cast<std::string>(
10121015
py::module::import("torch._jit_internal").attr("_qualified_name")(obj)));
10131016

1014-
py::object props = py::module::import("torch._jit_internal")
1015-
.attr("_get_named_tuple_properties")(obj, loc);
1017+
// Note: we need to pass rcb to resolve ForwardRef annotations. See
1018+
// [Note: ForwardRef annotations in NamedTuple attributes]
1019+
py::object props =
1020+
py::module::import("torch._jit_internal")
1021+
.attr("_get_named_tuple_properties")(obj, loc, py::cpp_function(rcb));
10161022

10171023
std::string unqualName;
10181024
std::vector<std::string> field_names;
@@ -1290,7 +1296,14 @@ std::shared_ptr<SugaredValue> toSugaredValue(
12901296
}
12911297

12921298
if (isNamedTupleClass(obj)) {
1293-
auto tuple_type = registerNamedTuple(obj, loc)->expect<TupleType>();
1299+
// The use of fakeRcb here prevents us from correctly resolving ForwardRef
1300+
// annotations on NamedTuple attributes for instances whose types are
1301+
// inferred. See #95858 for more details, as well as
1302+
// [Note: ForwardRef annotations in NamedTuple attributes]
1303+
auto fakeRcb =
1304+
py::module::import("torch.jit.annotations").attr("_fake_rcb");
1305+
auto tuple_type =
1306+
registerNamedTuple(obj, loc, fakeRcb)->expect<TupleType>();
12941307
return std::make_shared<NamedTupleConstructor>(tuple_type);
12951308
}
12961309

torch/csrc/jit/python/python_sugared_value.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,10 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue {
242242
};
243243

244244
bool isNamedTupleClass(const py::object& obj);
245-
TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc);
245+
TypePtr registerNamedTuple(
246+
const py::object& obj,
247+
const SourceRange& loc,
248+
const ResolutionCallback& rcb);
246249

247250
void recurseThroughNestedModules(
248251
const SourceRange& loc,

torch/csrc/jit/python/script_init.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ namespace torch::jit {
7676
using ::c10::Argument;
7777
using ::c10::FunctionSchema;
7878

79-
using ResolutionCallback = std::function<py::object(std::string)>;
8079
using FunctionDefaults = std::unordered_map<std::string, py::object>;
8180
using ClassMethodDefaults = std::unordered_map<std::string, FunctionDefaults>;
8281

@@ -136,7 +135,7 @@ struct PythonResolver : public Resolver {
136135
}
137136

138137
if (isNamedTupleClass(obj)) {
139-
return registerNamedTuple(obj, loc);
138+
return registerNamedTuple(obj, loc, rcb_);
140139
}
141140

142141
auto qualifiedName = c10::QualifiedName(
@@ -157,8 +156,9 @@ struct PythonResolver : public Resolver {
157156
return nullptr;
158157
}
159158

160-
auto annotation_type = py::module::import("torch.jit.annotations")
161-
.attr("try_ann_to_type")(obj, loc);
159+
auto annotation_type =
160+
py::module::import("torch.jit.annotations")
161+
.attr("try_ann_to_type")(obj, loc, py::cpp_function(rcb_));
162162
if (!annotation_type.is_none()) {
163163
return py::cast<TypePtr>(annotation_type);
164164
}

torch/jit/annotations.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,11 @@ def is_tensor(ann):
315315
return False
316316

317317

318+
def _fake_rcb(inp):
319+
return None
318320

319-
def try_ann_to_type(ann, loc):
321+
322+
def try_ann_to_type(ann, loc, rcb=None):
320323
if ann is inspect.Signature.empty:
321324
return TensorType.getInferred()
322325
if ann is None:
@@ -410,13 +413,13 @@ def try_ann_to_type(ann, loc):
410413
return torch.jit._script._recursive_compile_class(ann, loc)
411414

412415
# Maybe resolve a NamedTuple to a Tuple Type
413-
def fake_rcb(key):
414-
return None
415-
return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
416+
if rcb is None:
417+
rcb = _fake_rcb
418+
return torch._C._resolve_type_from_object(ann, loc, rcb)
416419

417420

418-
def ann_to_type(ann, loc):
419-
the_type = try_ann_to_type(ann, loc)
421+
def ann_to_type(ann, loc, rcb=None):
422+
the_type = try_ann_to_type(ann, loc, rcb)
420423
if the_type is not None:
421424
return the_type
422425
raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")

0 commit comments

Comments
 (0)