From 5b1494321341a5abd72076d8e984f0f9ff3bc69e Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 21 Jun 2024 08:12:44 -0700 Subject: [PATCH] Run TestAOTAutograd test suite with cache (#128222) This diff introduces AOTAutogradTestWithCache, which runs AOTAutogradTests with both dynamo and AOTAutogradCache. To do this, for any verify_aot_autograd() calls in the original tests, we run compiled_f an extra time. We also turn on a new strict mode that throws any time a cache is missed due to weird reasons, like BypassAOTAutogradCache or FxGraphCacheMiss. We use a mocked version of FXGraphCache to decrease the number of variables for these tests. The normal tests in test_aot_autograd_cache.py will still run with FXGraphCache. I might change my mind and unmock these in the future. In total, 87 of the tests pass naturally. None of the tests fail in non strict cache mode, so the cache never crashes, it just misses more often than we'd like. The remaining 27 tests fail due to relatively simple (though not necessarily easy to fix) reasons. I'll fix the remaining test failures in the next few PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128222 Approved by: https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 230 +++++++++++++++--- .../_aot_autograd/autograd_cache.py | 14 +- torch/_functorch/config.py | 3 + 3 files changed, 215 insertions(+), 32 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 83175438c5e9ed..777c1f3a364215 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -21,7 +21,6 @@ import torch._dynamo as torchdynamo import torch.nn as nn import torch.utils._pytree as pytree - from functorch import grad, jacrev, make_fx, vjp, vmap from functorch.compile import ( aot_function, @@ -33,6 +32,7 @@ default_partition, get_aot_compilation_context, make_boxed_compiler, + make_boxed_func, memory_efficient_fusion, min_cut_rematerialization_partition, nnc_jit, @@ -40,13 +40,14 @@ ) from functorch.experimental import control_flow from torch._decomp import decomposition_table +from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._higher_order_ops.out_dtype import out_dtype +from torch._inductor.codecache import compiled_fx_graph_hash from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv from torch.nn.utils.rnn import PackedSequence - from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, @@ -359,7 +360,7 @@ def verify_aot_autograd( # TODO: probably consolidate all tests to make inp a Callable. make_inputs_subclasses: bool = False, ): - for keep_input_mutations in [True] if keep_inp_mutations else [True, False]: + def make_inputs(inp_): # Some tests pass in a callable for inp, to generate the inputs # (useful if we want to generate complicated aliasing inputs) if isinstance(inp_, Callable): @@ -368,47 +369,45 @@ def verify_aot_autograd( # (The idea is that we might want to compile a function with the graph inputs, # but test autograd backprop all the way through the actual inputs) with TwoTensorMode() if make_inputs_subclasses else nullcontext(): - inp_copy, graph_inps_copy = inp_callable() inp, graph_inps = inp_callable() else: - inp_copy = [] inp = [] # Our input clones need to mimic when inputs are duplicates of one another dupes_map = {} for i, x in enumerate(inp_): if x in dupes_map: x_dupe_idx = dupes_map[x] - inp_copy.append(inp_copy[x_dupe_idx]) inp.append(inp[x_dupe_idx]) else: dupes_map[x] = i if not isinstance(x, torch.Tensor): x_copy = x - x_copy2 = x else: x_copy = x.clone().detach().requires_grad_(x.requires_grad) - x_copy2 = x.clone().detach().requires_grad_(x.requires_grad) if x.requires_grad and not x.is_leaf: x_copy = x_copy.clone() - x_copy2 = x_copy2.clone() - inp_copy.append(x_copy) - inp.append(x_copy2) + + inp.append(x_copy) if test_mutation: # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves graph_inps = [x.add(1) for x in inp] - graph_inps_copy = [x.add(1) for x in inp_copy] else: graph_inps = inp - graph_inps_copy = inp_copy - fw_graph_cell = [None] - compiled_f = self.run_autograd( - f, fw_graph_cell, decompositions, keep_input_mutations, dynamic - ) - ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) - test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) - self.assertEqual(ref_grad, test_grad) + return inp, graph_inps + + def check_results( + ref_results, + test_results, + ref_graph_inps, + test_graph_inps, + ref_inp, + test_inp, + ): + ref_out, ref_grad = ref_results + test_out, test_grad = test_results + self.assertEqual(ref_grad, test_grad) if isinstance(ref_out, torch.Tensor): self.assertTrue(isinstance(test_out, torch.Tensor)) ref_out, test_out = [ref_out], [test_out] @@ -417,10 +416,10 @@ def verify_aot_autograd( self.assertEqual(ref_o.requires_grad, test_o.requires_grad) self.assertEqual(ref_o.is_leaf, test_o.is_leaf) ref_is_view_of_non_interm = is_in_base( - ref_o, graph_inps + ref_o, ref_graph_inps ) or is_in_base(ref_o, ref_out) test_is_view_of_non_interm = is_in_base( - test_o, graph_inps_copy + test_o, test_graph_inps ) or is_in_base(test_o, test_out) self.assertEqual( ref_is_view_of_non_interm, test_is_view_of_non_interm @@ -429,13 +428,46 @@ def verify_aot_autograd( if test_mutation: # This tests that autograd meta is set properly on the output we can # mutate it. - ref_o.mul_(2) - test_o.mul_(2) + ref_o.add_(2) + test_o.add_(2) + self.assertEqual(ref_o, test_o) + # Reverse the modification + ref_o.sub_(2) + test_o.sub_(2) self.assertEqual(ref_o, test_o) - for ref_i, test_i in zip(inp, inp_copy): + for ref_i, test_i in zip(ref_inp, test_inp): if isinstance(ref_i, torch.Tensor): self.assertEqual(ref_i.requires_grad, test_i.requires_grad) self.assertEqual(ref_i, test_i) + + for keep_input_mutations in [True] if keep_inp_mutations else [True, False]: + inp, graph_inps = make_inputs(inp_) + test_inp, test_graph_inps = make_inputs(inp_) + fw_graph_cell = [None] + compiled_f = self.run_autograd( + f, fw_graph_cell, decompositions, keep_input_mutations, dynamic + ) + ref_results = outs_and_grads(f, graph_inps, inp) + test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp) + + check_results( + ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp + ) + if isinstance(self, TestAOTAutogradWithCache): + # When testing with cache, run compiled_f a second time + cached_inp, cached_graph_inps = make_inputs(inp_) + cached_results = outs_and_grads( + compiled_f, cached_graph_inps, cached_inp + ) + check_results( + ref_results, + cached_results, + graph_inps, + cached_graph_inps, + inp, + cached_inp, + ) + return fw_graph_cell[0] def test_non_tensor_and_none_inputs(self): @@ -5962,6 +5994,9 @@ def assertExpectedInline(self, *args, **kwargs): # only that the outputs match, etc. pass + def make_compiler(self, graph_cell): + return make_boxed_compiler(partial(extract_graph, graph_cell=graph_cell)) + # Compiler to passes to dynamo def run_autograd( self, @@ -5979,10 +6014,8 @@ def dynamo_compiler(gm, inputs, **kwargs): result = aot_module_simplified( gm, inputs, - fw_compiler=make_boxed_compiler( - partial(extract_graph, graph_cell=fw_graph_cell) - ), - bw_compiler=nop, + fw_compiler=self.make_compiler(fw_graph_cell), + bw_compiler=self.make_compiler([None]), decompositions=decompositions, keep_inference_input_mutations=keep_input_mutations, # Dynamic is calculated from whether the inputs have fake tensors @@ -6002,5 +6035,144 @@ def torch_compile_wrapper(*args, **kwargs): return torch_compile_wrapper +class MockFXGraphCache: + """ + In memory version of FXGraphCache so we can isolate testing for FXGraphCache + """ + + def __init__(self): + self.cache = {} + + def save(self, key, gm): + self.cache[key] = gm + + def load(self, gm, inputs): + key = compiled_fx_graph_hash(gm, inputs, {}, {}) + if key in self.cache: + gm = make_boxed_func(gm) + gm._fx_graph_cache_key = key + return gm + else: + self.save(key, gm) + gm = make_boxed_func(gm) + gm._fx_graph_cache_key = key + return gm + + def _lookup_graph(self, key, inputs, local, remote_cache): + gm = self.cache.get(key) + if gm is not None: + gm = make_boxed_func(gm) + return gm + + +# The following tests fail in strict caching mode (i.e. they bypass or +# cache miss instead of cache hitting). They will be fixed in the PRs above this. +FAILING_CACHE_TESTS = ( + # BypassAOTAutogradCache: unsupported nodes + "test_backward_mutation_data", + "test_backward_mutation_metadata", + "test_custom_autograd", + "test_inner_grad", + "test_input_mutation_set__nop", + "test_nonidempotent_amp", # einsum + # Pickle error: OutputAliasInfo/functional tensor + "test_input_aliased_with_mutation_output_alias", + "test_input_data_and_metadata_mutation", + "test_input_mutation_aliases_and_output_alias", + "test_input_mutation_alias_everything", + "test_input_mutation_and_output_view", + "test_input_mutation_output_view_multiple", + "test_input_output_aliase_custom_autograd_function", + "test_input_output_view_metadata_mutate_multiple", + "test_input_output_view_mutate_multiple", + "test_input_output_view_simple", + "test_output_aliases_intermediate_and_returned", + "test_output_aliases_intermediate_and_returned_different_grad", + "test_output_aliases_intermediate_and_returned_flipped", + "test_output_aliases_intermediate_multiple", + "test_output_aliases_intermediate_multiple_mixed", + "test_output_aliases_intermediate_returned_multiple_times", + "test_output_aliases_multiple_inputs_get_correct_one", + "test_output_all_alias_types", + "test_some_outputs_dont_require_grad_view", + "test_view_and_inplace_view", + "test_view_detach", + "test_some_output_requires_grad_input_doesnt", +) + + +@xfail_inherited_tests(FAILING_CACHE_TESTS) +class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): + """ + In memory version of FXGraphCache so we can isolate testing for FXGraphCache + """ + + def make_compiler(self, fw_graph_cell): + mock_inductor_cache = self.inductor_cache + + def compiler(gm, inputs): + nonlocal mock_inductor_cache, fw_graph_cell + result = mock_inductor_cache.load(gm, inputs) + fw_graph_cell[0] = gm + return result + + return compiler + + def run_autograd( + self, + f: Callable, + fw_graph_cell: List[Optional[Callable]], + decompositions: Optional[Dict], + keep_input_mutations: bool, + dynamic: bool, + ): + return super().run_autograd( + f, + fw_graph_cell, + decompositions, + keep_input_mutations, + dynamic, + ) + + @torch._functorch.config.patch( + {"enable_autograd_cache": True, "strict_autograd_cache": True} + ) + @torch._inductor.config.patch("fx_graph_cache", True) + def verify_aot_autograd( + self, + f, + inp_: Union[Callable, List[Any]], + *, + test_mutation: bool = False, + keep_inp_mutations: bool = False, + decompositions: Optional[Dict] = None, + dynamic: bool = False, + # Only active when inp_ is Callable. + # TODO: probably consolidate all tests to make inp a Callable. + make_inputs_subclasses: bool = False, + ): + self.inductor_cache = MockFXGraphCache() + AOTAutogradCache.clear() + with patch( + "torch._inductor.codecache.FxGraphCache._lookup_graph", + new=self.inductor_cache._lookup_graph, + ): + return super().verify_aot_autograd( + f, + inp_, + test_mutation=test_mutation, + keep_inp_mutations=keep_inp_mutations, + decompositions=decompositions, + dynamic=dynamic, + make_inputs_subclasses=make_inputs_subclasses, + ) + + def test_input_mutation_false_aliasing(self): + # This test is disabled because it fails in strict cache mode + # But also can't be xfailed because it causes undefined behavior for + # ASAN + self.skipTest("Skipping because it fails in strict cache mode") + + if __name__ == "__main__": run_tests() diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 030fa0495c753d..c58fd7dafd5dea 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -439,11 +439,15 @@ def load( log.info("AOTAutograd cache miss for key %s", cache_key) counters["aot_autograd"]["autograd_cache_miss"] += 1 # Count missing the FXGraphCache as a miss not a bypass - except FXGraphCacheMiss: + except FXGraphCacheMiss as e: counters["aot_autograd"]["autograd_cache_miss"] += 1 - except BypassAOTAutogradCache: + if config.strict_autograd_cache: + raise e + except BypassAOTAutogradCache as e: cache_key = None counters["aot_autograd"]["autograd_cache_bypass"] += 1 + if config.strict_autograd_cache: + raise e if compiled_fn is None: # Set the cache key so we can save a cache result later aot_config.cache_key = cache_key @@ -470,6 +474,8 @@ def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]: return entry except Exception as e: log.warning("AOTAutograd cache unable to load compiled graph: %s", e) + if config.strict_autograd_cache: + raise e return None @staticmethod @@ -479,7 +485,9 @@ def save(key: str, entry: AOTAutogradCacheEntry): content = pickle.dumps(entry) except Exception as e: log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) - raise e + if config.strict_autograd_cache: + raise e + return None subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index ab768dc13d4a0c..9f741cf96170d7 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -178,6 +178,9 @@ enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1" +# Error on BypassAOTAutogradCache instead of just a warning +# Used for tests +strict_autograd_cache = False if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403