Skip to content

Commit 18a0364

Browse files
authored
Pass backend-related ctx to TorchDynamo Optimize Context (#201)
* Pass backend-related ctx to TorchDynamo Optimize Context * Reinit the backend ctx for every frame * Doc
1 parent 962f893 commit 18a0364

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

torchdynamo/eval_frame.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import functools
23
import logging
34
import threading
@@ -24,26 +25,32 @@ def nothing():
2425
pass
2526

2627

28+
null_context = contextlib.nullcontext
29+
2730
unset = object()
2831

2932
compile_lock = threading.Lock()
3033

3134

3235
class _TorchDynamoContext:
33-
def __init__(self, callback, on_enter=nothing):
36+
def __init__(self, callback, on_enter=nothing, backend_ctx_ctor=null_context):
3437
super().__init__()
3538
assert callable(callback) or callback is False or callback is None
3639
self.callback = callback
3740
self.prior = unset
3841
self.on_enter = on_enter
42+
self.extra_ctx_ctor = backend_ctx_ctor
3943

4044
def __enter__(self):
4145
self.on_enter()
4246
self.prior = set_eval_frame(self.callback)
47+
self.backend_ctx = self.extra_ctx_ctor()
48+
self.backend_ctx.__enter__()
4349

4450
def __exit__(self, exc_type, exc_val, exc_tb):
4551
set_eval_frame(self.prior)
4652
self.prior = unset
53+
self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)
4754

4855
def __call__(self, fn):
4956
assert callable(fn)
@@ -69,8 +76,12 @@ def _fn(*args, **kwargs):
6976

7077

7178
class OptimizeContext(_TorchDynamoContext):
72-
def __init__(self, callback):
73-
super().__init__(callback=callback, on_enter=install_generation_tagging_new)
79+
def __init__(self, callback, backend_ctx_ctor):
80+
super().__init__(
81+
callback=callback,
82+
on_enter=install_generation_tagging_new,
83+
backend_ctx_ctor=backend_ctx_ctor,
84+
)
7485

7586

7687
class RunOnlyContext(_TorchDynamoContext):
@@ -107,8 +118,10 @@ def catch_errors(frame, cache_size):
107118
return catch_errors
108119

109120

110-
def _optimize_catch_errors(compile_fn):
111-
return OptimizeContext(catch_errors_wrapper(compile_fn))
121+
def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context):
122+
return OptimizeContext(
123+
catch_errors_wrapper(compile_fn), backend_ctx_ctor=backend_ctx_ctor
124+
)
112125

113126

114127
def optimize(backend, nopython=False):
@@ -117,10 +130,13 @@ def optimize(backend, nopython=False):
117130
backend() to optimize extracted graphs.
118131
119132
Args:
120-
backend: One of two things:
121-
- Either, a function taking a torch.fx.GraphModule and
133+
backend: One of the two things:
134+
- Either, a function/callable taking a torch.fx.GraphModule and
122135
example_inputs and returning a python callable that runs the
123136
graph faster.
137+
One can also provide additional context for the backend, like
138+
torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
139+
See AOTAutogradMemoryEfficientFusionWithContext for the usage.
124140
- Or, a string backend name in `torchdynamo.list_backends()`
125141
nopython: If True, graph breaks will be errors and there will
126142
be a single whole-program graph.
@@ -136,16 +152,25 @@ def toy_example(a, b):
136152
with torchdynamo.optimize(my_compiler):
137153
...
138154
"""
155+
156+
backend_ctx_ctor = null_context
157+
if hasattr(backend, "backend_ctx_ctor"):
158+
backend_ctx_ctor = getattr(backend, "backend_ctx_ctor")
159+
139160
if nopython:
140-
return optimize_assert(backend)
141-
return _optimize_catch_errors(convert_frame.convert_frame(backend))
161+
return optimize_assert(backend, backend_ctx_ctor)
162+
return _optimize_catch_errors(
163+
convert_frame.convert_frame(backend), backend_ctx_ctor
164+
)
142165

143166

144-
def optimize_assert(backend):
167+
def optimize_assert(backend, backend_ctx_ctor=null_context):
145168
"""
146169
The same as `torchdynamo.optimize(backend, nopython=True)`
147170
"""
148-
return _optimize_catch_errors(convert_frame.convert_frame_assert(backend))
171+
return _optimize_catch_errors(
172+
convert_frame.convert_frame_assert(backend), backend_ctx_ctor
173+
)
149174

150175

151176
def run(fn=None):

torchdynamo/optimizations/training.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,14 @@ def candidate(self):
143143
return BACKENDS["aot_autograd"](self.gm, self.example_inputs)
144144

145145

146-
aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusion.compile_fn
146+
class AOTAutogradMemoryEfficientFusionWithContext:
147+
"""Pass nvfuser context to TorchDynamo"""
148+
149+
def __init__(self):
150+
self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2")
151+
152+
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
153+
return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs)
154+
155+
156+
aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext()

0 commit comments

Comments
 (0)