diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 83175438c5e9e..777c1f3a36421 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -21,7 +21,6 @@ import torch._dynamo as torchdynamo import torch.nn as nn import torch.utils._pytree as pytree - from functorch import grad, jacrev, make_fx, vjp, vmap from functorch.compile import ( aot_function, @@ -33,6 +32,7 @@ default_partition, get_aot_compilation_context, make_boxed_compiler, + make_boxed_func, memory_efficient_fusion, min_cut_rematerialization_partition, nnc_jit, @@ -40,13 +40,14 @@ ) from functorch.experimental import control_flow from torch._decomp import decomposition_table +from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._higher_order_ops.out_dtype import out_dtype +from torch._inductor.codecache import compiled_fx_graph_hash from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv from torch.nn.utils.rnn import PackedSequence - from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, @@ -359,7 +360,7 @@ def verify_aot_autograd( # TODO: probably consolidate all tests to make inp a Callable. make_inputs_subclasses: bool = False, ): - for keep_input_mutations in [True] if keep_inp_mutations else [True, False]: + def make_inputs(inp_): # Some tests pass in a callable for inp, to generate the inputs # (useful if we want to generate complicated aliasing inputs) if isinstance(inp_, Callable): @@ -368,47 +369,45 @@ def verify_aot_autograd( # (The idea is that we might want to compile a function with the graph inputs, # but test autograd backprop all the way through the actual inputs) with TwoTensorMode() if make_inputs_subclasses else nullcontext(): - inp_copy, graph_inps_copy = inp_callable() inp, graph_inps = inp_callable() else: - inp_copy = [] inp = [] # Our input clones need to mimic when inputs are duplicates of one another dupes_map = {} for i, x in enumerate(inp_): if x in dupes_map: x_dupe_idx = dupes_map[x] - inp_copy.append(inp_copy[x_dupe_idx]) inp.append(inp[x_dupe_idx]) else: dupes_map[x] = i if not isinstance(x, torch.Tensor): x_copy = x - x_copy2 = x else: x_copy = x.clone().detach().requires_grad_(x.requires_grad) - x_copy2 = x.clone().detach().requires_grad_(x.requires_grad) if x.requires_grad and not x.is_leaf: x_copy = x_copy.clone() - x_copy2 = x_copy2.clone() - inp_copy.append(x_copy) - inp.append(x_copy2) + + inp.append(x_copy) if test_mutation: # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves graph_inps = [x.add(1) for x in inp] - graph_inps_copy = [x.add(1) for x in inp_copy] else: graph_inps = inp - graph_inps_copy = inp_copy - fw_graph_cell = [None] - compiled_f = self.run_autograd( - f, fw_graph_cell, decompositions, keep_input_mutations, dynamic - ) - ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) - test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) - self.assertEqual(ref_grad, test_grad) + return inp, graph_inps + + def check_results( + ref_results, + test_results, + ref_graph_inps, + test_graph_inps, + ref_inp, + test_inp, + ): + ref_out, ref_grad = ref_results + test_out, test_grad = test_results + self.assertEqual(ref_grad, test_grad) if isinstance(ref_out, torch.Tensor): self.assertTrue(isinstance(test_out, torch.Tensor)) ref_out, test_out = [ref_out], [test_out] @@ -417,10 +416,10 @@ def verify_aot_autograd( self.assertEqual(ref_o.requires_grad, test_o.requires_grad) self.assertEqual(ref_o.is_leaf, test_o.is_leaf) ref_is_view_of_non_interm = is_in_base( - ref_o, graph_inps + ref_o, ref_graph_inps ) or is_in_base(ref_o, ref_out) test_is_view_of_non_interm = is_in_base( - test_o, graph_inps_copy + test_o, test_graph_inps ) or is_in_base(test_o, test_out) self.assertEqual( ref_is_view_of_non_interm, test_is_view_of_non_interm @@ -429,13 +428,46 @@ def verify_aot_autograd( if test_mutation: # This tests that autograd meta is set properly on the output we can # mutate it. - ref_o.mul_(2) - test_o.mul_(2) + ref_o.add_(2) + test_o.add_(2) + self.assertEqual(ref_o, test_o) + # Reverse the modification + ref_o.sub_(2) + test_o.sub_(2) self.assertEqual(ref_o, test_o) - for ref_i, test_i in zip(inp, inp_copy): + for ref_i, test_i in zip(ref_inp, test_inp): if isinstance(ref_i, torch.Tensor): self.assertEqual(ref_i.requires_grad, test_i.requires_grad) self.assertEqual(ref_i, test_i) + + for keep_input_mutations in [True] if keep_inp_mutations else [True, False]: + inp, graph_inps = make_inputs(inp_) + test_inp, test_graph_inps = make_inputs(inp_) + fw_graph_cell = [None] + compiled_f = self.run_autograd( + f, fw_graph_cell, decompositions, keep_input_mutations, dynamic + ) + ref_results = outs_and_grads(f, graph_inps, inp) + test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp) + + check_results( + ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp + ) + if isinstance(self, TestAOTAutogradWithCache): + # When testing with cache, run compiled_f a second time + cached_inp, cached_graph_inps = make_inputs(inp_) + cached_results = outs_and_grads( + compiled_f, cached_graph_inps, cached_inp + ) + check_results( + ref_results, + cached_results, + graph_inps, + cached_graph_inps, + inp, + cached_inp, + ) + return fw_graph_cell[0] def test_non_tensor_and_none_inputs(self): @@ -5962,6 +5994,9 @@ def assertExpectedInline(self, *args, **kwargs): # only that the outputs match, etc. pass + def make_compiler(self, graph_cell): + return make_boxed_compiler(partial(extract_graph, graph_cell=graph_cell)) + # Compiler to passes to dynamo def run_autograd( self, @@ -5979,10 +6014,8 @@ def dynamo_compiler(gm, inputs, **kwargs): result = aot_module_simplified( gm, inputs, - fw_compiler=make_boxed_compiler( - partial(extract_graph, graph_cell=fw_graph_cell) - ), - bw_compiler=nop, + fw_compiler=self.make_compiler(fw_graph_cell), + bw_compiler=self.make_compiler([None]), decompositions=decompositions, keep_inference_input_mutations=keep_input_mutations, # Dynamic is calculated from whether the inputs have fake tensors @@ -6002,5 +6035,144 @@ def torch_compile_wrapper(*args, **kwargs): return torch_compile_wrapper +class MockFXGraphCache: + """ + In memory version of FXGraphCache so we can isolate testing for FXGraphCache + """ + + def __init__(self): + self.cache = {} + + def save(self, key, gm): + self.cache[key] = gm + + def load(self, gm, inputs): + key = compiled_fx_graph_hash(gm, inputs, {}, {}) + if key in self.cache: + gm = make_boxed_func(gm) + gm._fx_graph_cache_key = key + return gm + else: + self.save(key, gm) + gm = make_boxed_func(gm) + gm._fx_graph_cache_key = key + return gm + + def _lookup_graph(self, key, inputs, local, remote_cache): + gm = self.cache.get(key) + if gm is not None: + gm = make_boxed_func(gm) + return gm + + +# The following tests fail in strict caching mode (i.e. they bypass or +# cache miss instead of cache hitting). They will be fixed in the PRs above this. +FAILING_CACHE_TESTS = ( + # BypassAOTAutogradCache: unsupported nodes + "test_backward_mutation_data", + "test_backward_mutation_metadata", + "test_custom_autograd", + "test_inner_grad", + "test_input_mutation_set__nop", + "test_nonidempotent_amp", # einsum + # Pickle error: OutputAliasInfo/functional tensor + "test_input_aliased_with_mutation_output_alias", + "test_input_data_and_metadata_mutation", + "test_input_mutation_aliases_and_output_alias", + "test_input_mutation_alias_everything", + "test_input_mutation_and_output_view", + "test_input_mutation_output_view_multiple", + "test_input_output_aliase_custom_autograd_function", + "test_input_output_view_metadata_mutate_multiple", + "test_input_output_view_mutate_multiple", + "test_input_output_view_simple", + "test_output_aliases_intermediate_and_returned", + "test_output_aliases_intermediate_and_returned_different_grad", + "test_output_aliases_intermediate_and_returned_flipped", + "test_output_aliases_intermediate_multiple", + "test_output_aliases_intermediate_multiple_mixed", + "test_output_aliases_intermediate_returned_multiple_times", + "test_output_aliases_multiple_inputs_get_correct_one", + "test_output_all_alias_types", + "test_some_outputs_dont_require_grad_view", + "test_view_and_inplace_view", + "test_view_detach", + "test_some_output_requires_grad_input_doesnt", +) + + +@xfail_inherited_tests(FAILING_CACHE_TESTS) +class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): + """ + In memory version of FXGraphCache so we can isolate testing for FXGraphCache + """ + + def make_compiler(self, fw_graph_cell): + mock_inductor_cache = self.inductor_cache + + def compiler(gm, inputs): + nonlocal mock_inductor_cache, fw_graph_cell + result = mock_inductor_cache.load(gm, inputs) + fw_graph_cell[0] = gm + return result + + return compiler + + def run_autograd( + self, + f: Callable, + fw_graph_cell: List[Optional[Callable]], + decompositions: Optional[Dict], + keep_input_mutations: bool, + dynamic: bool, + ): + return super().run_autograd( + f, + fw_graph_cell, + decompositions, + keep_input_mutations, + dynamic, + ) + + @torch._functorch.config.patch( + {"enable_autograd_cache": True, "strict_autograd_cache": True} + ) + @torch._inductor.config.patch("fx_graph_cache", True) + def verify_aot_autograd( + self, + f, + inp_: Union[Callable, List[Any]], + *, + test_mutation: bool = False, + keep_inp_mutations: bool = False, + decompositions: Optional[Dict] = None, + dynamic: bool = False, + # Only active when inp_ is Callable. + # TODO: probably consolidate all tests to make inp a Callable. + make_inputs_subclasses: bool = False, + ): + self.inductor_cache = MockFXGraphCache() + AOTAutogradCache.clear() + with patch( + "torch._inductor.codecache.FxGraphCache._lookup_graph", + new=self.inductor_cache._lookup_graph, + ): + return super().verify_aot_autograd( + f, + inp_, + test_mutation=test_mutation, + keep_inp_mutations=keep_inp_mutations, + decompositions=decompositions, + dynamic=dynamic, + make_inputs_subclasses=make_inputs_subclasses, + ) + + def test_input_mutation_false_aliasing(self): + # This test is disabled because it fails in strict cache mode + # But also can't be xfailed because it causes undefined behavior for + # ASAN + self.skipTest("Skipping because it fails in strict cache mode") + + if __name__ == "__main__": run_tests() diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 030fa0495c753..c58fd7dafd5de 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -439,11 +439,15 @@ def load( log.info("AOTAutograd cache miss for key %s", cache_key) counters["aot_autograd"]["autograd_cache_miss"] += 1 # Count missing the FXGraphCache as a miss not a bypass - except FXGraphCacheMiss: + except FXGraphCacheMiss as e: counters["aot_autograd"]["autograd_cache_miss"] += 1 - except BypassAOTAutogradCache: + if config.strict_autograd_cache: + raise e + except BypassAOTAutogradCache as e: cache_key = None counters["aot_autograd"]["autograd_cache_bypass"] += 1 + if config.strict_autograd_cache: + raise e if compiled_fn is None: # Set the cache key so we can save a cache result later aot_config.cache_key = cache_key @@ -470,6 +474,8 @@ def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]: return entry except Exception as e: log.warning("AOTAutograd cache unable to load compiled graph: %s", e) + if config.strict_autograd_cache: + raise e return None @staticmethod @@ -479,7 +485,9 @@ def save(key: str, entry: AOTAutogradCacheEntry): content = pickle.dumps(entry) except Exception as e: log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) - raise e + if config.strict_autograd_cache: + raise e + return None subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key) if not os.path.exists(subdir): os.makedirs(subdir, exist_ok=True) diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index ab768dc13d4a0..9f741cf96170d 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -178,6 +178,9 @@ enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1" +# Error on BypassAOTAutogradCache instead of just a warning +# Used for tests +strict_autograd_cache = False if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403