1
+ import contextlib
1
2
import functools
2
3
import logging
3
4
import threading
@@ -24,26 +25,32 @@ def nothing():
24
25
pass
25
26
26
27
28
+ null_context = contextlib .nullcontext
29
+
27
30
unset = object ()
28
31
29
32
compile_lock = threading .Lock ()
30
33
31
34
32
35
class _TorchDynamoContext :
33
- def __init__ (self , callback , on_enter = nothing ):
36
+ def __init__ (self , callback , on_enter = nothing , backend_ctx_ctor = null_context ):
34
37
super ().__init__ ()
35
38
assert callable (callback ) or callback is False or callback is None
36
39
self .callback = callback
37
40
self .prior = unset
38
41
self .on_enter = on_enter
42
+ self .extra_ctx_ctor = backend_ctx_ctor
39
43
40
44
def __enter__ (self ):
41
45
self .on_enter ()
42
46
self .prior = set_eval_frame (self .callback )
47
+ self .backend_ctx = self .extra_ctx_ctor ()
48
+ self .backend_ctx .__enter__ ()
43
49
44
50
def __exit__ (self , exc_type , exc_val , exc_tb ):
45
51
set_eval_frame (self .prior )
46
52
self .prior = unset
53
+ self .backend_ctx .__exit__ (exc_type , exc_val , exc_tb )
47
54
48
55
def __call__ (self , fn ):
49
56
assert callable (fn )
@@ -69,8 +76,12 @@ def _fn(*args, **kwargs):
69
76
70
77
71
78
class OptimizeContext (_TorchDynamoContext ):
72
- def __init__ (self , callback ):
73
- super ().__init__ (callback = callback , on_enter = install_generation_tagging_new )
79
+ def __init__ (self , callback , backend_ctx_ctor ):
80
+ super ().__init__ (
81
+ callback = callback ,
82
+ on_enter = install_generation_tagging_new ,
83
+ backend_ctx_ctor = backend_ctx_ctor ,
84
+ )
74
85
75
86
76
87
class RunOnlyContext (_TorchDynamoContext ):
@@ -107,8 +118,10 @@ def catch_errors(frame, cache_size):
107
118
return catch_errors
108
119
109
120
110
- def _optimize_catch_errors (compile_fn ):
111
- return OptimizeContext (catch_errors_wrapper (compile_fn ))
121
+ def _optimize_catch_errors (compile_fn , backend_ctx_ctor = null_context ):
122
+ return OptimizeContext (
123
+ catch_errors_wrapper (compile_fn ), backend_ctx_ctor = backend_ctx_ctor
124
+ )
112
125
113
126
114
127
def optimize (backend , nopython = False ):
@@ -117,10 +130,13 @@ def optimize(backend, nopython=False):
117
130
backend() to optimize extracted graphs.
118
131
119
132
Args:
120
- backend: One of two things:
121
- - Either, a function taking a torch.fx.GraphModule and
133
+ backend: One of the two things:
134
+ - Either, a function/callable taking a torch.fx.GraphModule and
122
135
example_inputs and returning a python callable that runs the
123
136
graph faster.
137
+ One can also provide additional context for the backend, like
138
+ torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
139
+ See AOTAutogradMemoryEfficientFusionWithContext for the usage.
124
140
- Or, a string backend name in `torchdynamo.list_backends()`
125
141
nopython: If True, graph breaks will be errors and there will
126
142
be a single whole-program graph.
@@ -136,16 +152,25 @@ def toy_example(a, b):
136
152
with torchdynamo.optimize(my_compiler):
137
153
...
138
154
"""
155
+
156
+ backend_ctx_ctor = null_context
157
+ if hasattr (backend , "backend_ctx_ctor" ):
158
+ backend_ctx_ctor = getattr (backend , "backend_ctx_ctor" )
159
+
139
160
if nopython :
140
- return optimize_assert (backend )
141
- return _optimize_catch_errors (convert_frame .convert_frame (backend ))
161
+ return optimize_assert (backend , backend_ctx_ctor )
162
+ return _optimize_catch_errors (
163
+ convert_frame .convert_frame (backend ), backend_ctx_ctor
164
+ )
142
165
143
166
144
- def optimize_assert (backend ):
167
+ def optimize_assert (backend , backend_ctx_ctor = null_context ):
145
168
"""
146
169
The same as `torchdynamo.optimize(backend, nopython=True)`
147
170
"""
148
- return _optimize_catch_errors (convert_frame .convert_frame_assert (backend ))
171
+ return _optimize_catch_errors (
172
+ convert_frame .convert_frame_assert (backend ), backend_ctx_ctor
173
+ )
149
174
150
175
151
176
def run (fn = None ):
0 commit comments