diff --git a/torchdynamo/eval_frame.py b/torchdynamo/eval_frame.py index fca066eb18..809002cb07 100644 --- a/torchdynamo/eval_frame.py +++ b/torchdynamo/eval_frame.py @@ -1,3 +1,4 @@ +import contextlib import functools import logging import threading @@ -24,26 +25,32 @@ def nothing(): pass +null_context = contextlib.nullcontext + unset = object() compile_lock = threading.Lock() class _TorchDynamoContext: - def __init__(self, callback, on_enter=nothing): + def __init__(self, callback, on_enter=nothing, backend_ctx_ctor=null_context): super().__init__() assert callable(callback) or callback is False or callback is None self.callback = callback self.prior = unset self.on_enter = on_enter + self.extra_ctx_ctor = backend_ctx_ctor def __enter__(self): self.on_enter() self.prior = set_eval_frame(self.callback) + self.backend_ctx = self.extra_ctx_ctor() + self.backend_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): set_eval_frame(self.prior) self.prior = unset + self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) def __call__(self, fn): assert callable(fn) @@ -69,8 +76,12 @@ def _fn(*args, **kwargs): class OptimizeContext(_TorchDynamoContext): - def __init__(self, callback): - super().__init__(callback=callback, on_enter=install_generation_tagging_new) + def __init__(self, callback, backend_ctx_ctor): + super().__init__( + callback=callback, + on_enter=install_generation_tagging_new, + backend_ctx_ctor=backend_ctx_ctor, + ) class RunOnlyContext(_TorchDynamoContext): @@ -107,8 +118,10 @@ def catch_errors(frame, cache_size): return catch_errors -def _optimize_catch_errors(compile_fn): - return OptimizeContext(catch_errors_wrapper(compile_fn)) +def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context): + return OptimizeContext( + catch_errors_wrapper(compile_fn), backend_ctx_ctor=backend_ctx_ctor + ) def optimize(backend, nopython=False): @@ -117,10 +130,13 @@ def optimize(backend, nopython=False): backend() to optimize extracted graphs. Args: - backend: One of two things: - - Either, a function taking a torch.fx.GraphModule and + backend: One of the two things: + - Either, a function/callable taking a torch.fx.GraphModule and example_inputs and returning a python callable that runs the graph faster. + One can also provide additional context for the backend, like + torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. + See AOTAutogradMemoryEfficientFusionWithContext for the usage. - Or, a string backend name in `torchdynamo.list_backends()` nopython: If True, graph breaks will be errors and there will be a single whole-program graph. @@ -136,16 +152,25 @@ def toy_example(a, b): with torchdynamo.optimize(my_compiler): ... """ + + backend_ctx_ctor = null_context + if hasattr(backend, "backend_ctx_ctor"): + backend_ctx_ctor = getattr(backend, "backend_ctx_ctor") + if nopython: - return optimize_assert(backend) - return _optimize_catch_errors(convert_frame.convert_frame(backend)) + return optimize_assert(backend, backend_ctx_ctor) + return _optimize_catch_errors( + convert_frame.convert_frame(backend), backend_ctx_ctor + ) -def optimize_assert(backend): +def optimize_assert(backend, backend_ctx_ctor=null_context): """ The same as `torchdynamo.optimize(backend, nopython=True)` """ - return _optimize_catch_errors(convert_frame.convert_frame_assert(backend)) + return _optimize_catch_errors( + convert_frame.convert_frame_assert(backend), backend_ctx_ctor + ) def run(fn=None): diff --git a/torchdynamo/optimizations/training.py b/torchdynamo/optimizations/training.py index f03ea787dd..d7de789268 100644 --- a/torchdynamo/optimizations/training.py +++ b/torchdynamo/optimizations/training.py @@ -143,4 +143,14 @@ def candidate(self): return BACKENDS["aot_autograd"](self.gm, self.example_inputs) -aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusion.compile_fn +class AOTAutogradMemoryEfficientFusionWithContext: + """Pass nvfuser context to TorchDynamo""" + + def __init__(self): + self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + return AOTAutogradMemoryEfficientFusion.compile_fn(gm, example_inputs) + + +aot_autograd_speedup_strategy = AOTAutogradMemoryEfficientFusionWithContext()