Skip to content

Commit

Permalink
Run TestAOTAutograd test suite with cache (pytorch#128222)
Browse files Browse the repository at this point in the history
This diff introduces AOTAutogradTestWithCache, which runs AOTAutogradTests with both dynamo and AOTAutogradCache.

To do this, for any verify_aot_autograd() calls in the original tests, we run compiled_f an extra time. We also turn on a new strict mode that throws any time a cache is missed due to weird reasons, like BypassAOTAutogradCache or FxGraphCacheMiss.

We use a mocked version of FXGraphCache to decrease the number of variables for these tests. The normal tests in test_aot_autograd_cache.py will still run with FXGraphCache. I might change my mind and unmock these in the future.

In total, 87 of the tests pass naturally. None of the tests fail in non strict cache mode, so the cache never crashes, it just misses more often than we'd like. The remaining 27 tests fail due to relatively simple (though not necessarily easy to fix) reasons. I'll fix the remaining test failures in the next few PRs.

Pull Request resolved: pytorch#128222
Approved by: https://github.com/bdhirsh
  • Loading branch information
jamesjwu authored and pytorchmergebot committed Jun 22, 2024
1 parent c5b9ee7 commit 5b14943
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 32 deletions.
230 changes: 201 additions & 29 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch._dynamo as torchdynamo
import torch.nn as nn
import torch.utils._pytree as pytree

from functorch import grad, jacrev, make_fx, vjp, vmap
from functorch.compile import (
aot_function,
Expand All @@ -33,20 +32,22 @@
default_partition,
get_aot_compilation_context,
make_boxed_compiler,
make_boxed_func,
memory_efficient_fusion,
min_cut_rematerialization_partition,
nnc_jit,
nop,
)
from functorch.experimental import control_flow
from torch._decomp import decomposition_table
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
from torch.nn.utils.rnn import PackedSequence

from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
Expand Down Expand Up @@ -359,7 +360,7 @@ def verify_aot_autograd(
# TODO: probably consolidate all tests to make inp a Callable.
make_inputs_subclasses: bool = False,
):
for keep_input_mutations in [True] if keep_inp_mutations else [True, False]:
def make_inputs(inp_):
# Some tests pass in a callable for inp, to generate the inputs
# (useful if we want to generate complicated aliasing inputs)
if isinstance(inp_, Callable):
Expand All @@ -368,47 +369,45 @@ def verify_aot_autograd(
# (The idea is that we might want to compile a function with the graph inputs,
# but test autograd backprop all the way through the actual inputs)
with TwoTensorMode() if make_inputs_subclasses else nullcontext():
inp_copy, graph_inps_copy = inp_callable()
inp, graph_inps = inp_callable()
else:
inp_copy = []
inp = []
# Our input clones need to mimic when inputs are duplicates of one another
dupes_map = {}
for i, x in enumerate(inp_):
if x in dupes_map:
x_dupe_idx = dupes_map[x]
inp_copy.append(inp_copy[x_dupe_idx])
inp.append(inp[x_dupe_idx])
else:
dupes_map[x] = i
if not isinstance(x, torch.Tensor):
x_copy = x
x_copy2 = x
else:
x_copy = x.clone().detach().requires_grad_(x.requires_grad)
x_copy2 = x.clone().detach().requires_grad_(x.requires_grad)
if x.requires_grad and not x.is_leaf:
x_copy = x_copy.clone()
x_copy2 = x_copy2.clone()
inp_copy.append(x_copy)
inp.append(x_copy2)

inp.append(x_copy)

if test_mutation:
# For graphs where we mutate inputs, need our test to make sure inputs aren't leaves
graph_inps = [x.add(1) for x in inp]
graph_inps_copy = [x.add(1) for x in inp_copy]
else:
graph_inps = inp
graph_inps_copy = inp_copy
fw_graph_cell = [None]
compiled_f = self.run_autograd(
f, fw_graph_cell, decompositions, keep_input_mutations, dynamic
)
ref_out, ref_grad = outs_and_grads(f, graph_inps, inp)
test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy)
self.assertEqual(ref_grad, test_grad)

return inp, graph_inps

def check_results(
ref_results,
test_results,
ref_graph_inps,
test_graph_inps,
ref_inp,
test_inp,
):
ref_out, ref_grad = ref_results
test_out, test_grad = test_results
self.assertEqual(ref_grad, test_grad)
if isinstance(ref_out, torch.Tensor):
self.assertTrue(isinstance(test_out, torch.Tensor))
ref_out, test_out = [ref_out], [test_out]
Expand All @@ -417,10 +416,10 @@ def verify_aot_autograd(
self.assertEqual(ref_o.requires_grad, test_o.requires_grad)
self.assertEqual(ref_o.is_leaf, test_o.is_leaf)
ref_is_view_of_non_interm = is_in_base(
ref_o, graph_inps
ref_o, ref_graph_inps
) or is_in_base(ref_o, ref_out)
test_is_view_of_non_interm = is_in_base(
test_o, graph_inps_copy
test_o, test_graph_inps
) or is_in_base(test_o, test_out)
self.assertEqual(
ref_is_view_of_non_interm, test_is_view_of_non_interm
Expand All @@ -429,13 +428,46 @@ def verify_aot_autograd(
if test_mutation:
# This tests that autograd meta is set properly on the output we can
# mutate it.
ref_o.mul_(2)
test_o.mul_(2)
ref_o.add_(2)
test_o.add_(2)
self.assertEqual(ref_o, test_o)
# Reverse the modification
ref_o.sub_(2)
test_o.sub_(2)
self.assertEqual(ref_o, test_o)
for ref_i, test_i in zip(inp, inp_copy):
for ref_i, test_i in zip(ref_inp, test_inp):
if isinstance(ref_i, torch.Tensor):
self.assertEqual(ref_i.requires_grad, test_i.requires_grad)
self.assertEqual(ref_i, test_i)

for keep_input_mutations in [True] if keep_inp_mutations else [True, False]:
inp, graph_inps = make_inputs(inp_)
test_inp, test_graph_inps = make_inputs(inp_)
fw_graph_cell = [None]
compiled_f = self.run_autograd(
f, fw_graph_cell, decompositions, keep_input_mutations, dynamic
)
ref_results = outs_and_grads(f, graph_inps, inp)
test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp)

check_results(
ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp
)
if isinstance(self, TestAOTAutogradWithCache):
# When testing with cache, run compiled_f a second time
cached_inp, cached_graph_inps = make_inputs(inp_)
cached_results = outs_and_grads(
compiled_f, cached_graph_inps, cached_inp
)
check_results(
ref_results,
cached_results,
graph_inps,
cached_graph_inps,
inp,
cached_inp,
)

return fw_graph_cell[0]

def test_non_tensor_and_none_inputs(self):
Expand Down Expand Up @@ -5962,6 +5994,9 @@ def assertExpectedInline(self, *args, **kwargs):
# only that the outputs match, etc.
pass

def make_compiler(self, graph_cell):
return make_boxed_compiler(partial(extract_graph, graph_cell=graph_cell))

# Compiler to passes to dynamo
def run_autograd(
self,
Expand All @@ -5979,10 +6014,8 @@ def dynamo_compiler(gm, inputs, **kwargs):
result = aot_module_simplified(
gm,
inputs,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
fw_compiler=self.make_compiler(fw_graph_cell),
bw_compiler=self.make_compiler([None]),
decompositions=decompositions,
keep_inference_input_mutations=keep_input_mutations,
# Dynamic is calculated from whether the inputs have fake tensors
Expand All @@ -6002,5 +6035,144 @@ def torch_compile_wrapper(*args, **kwargs):
return torch_compile_wrapper


class MockFXGraphCache:
"""
In memory version of FXGraphCache so we can isolate testing for FXGraphCache
"""

def __init__(self):
self.cache = {}

def save(self, key, gm):
self.cache[key] = gm

def load(self, gm, inputs):
key = compiled_fx_graph_hash(gm, inputs, {}, {})
if key in self.cache:
gm = make_boxed_func(gm)
gm._fx_graph_cache_key = key
return gm
else:
self.save(key, gm)
gm = make_boxed_func(gm)
gm._fx_graph_cache_key = key
return gm

def _lookup_graph(self, key, inputs, local, remote_cache):
gm = self.cache.get(key)
if gm is not None:
gm = make_boxed_func(gm)
return gm


# The following tests fail in strict caching mode (i.e. they bypass or
# cache miss instead of cache hitting). They will be fixed in the PRs above this.
FAILING_CACHE_TESTS = (
# BypassAOTAutogradCache: unsupported nodes
"test_backward_mutation_data",
"test_backward_mutation_metadata",
"test_custom_autograd",
"test_inner_grad",
"test_input_mutation_set__nop",
"test_nonidempotent_amp", # einsum
# Pickle error: OutputAliasInfo/functional tensor
"test_input_aliased_with_mutation_output_alias",
"test_input_data_and_metadata_mutation",
"test_input_mutation_aliases_and_output_alias",
"test_input_mutation_alias_everything",
"test_input_mutation_and_output_view",
"test_input_mutation_output_view_multiple",
"test_input_output_aliase_custom_autograd_function",
"test_input_output_view_metadata_mutate_multiple",
"test_input_output_view_mutate_multiple",
"test_input_output_view_simple",
"test_output_aliases_intermediate_and_returned",
"test_output_aliases_intermediate_and_returned_different_grad",
"test_output_aliases_intermediate_and_returned_flipped",
"test_output_aliases_intermediate_multiple",
"test_output_aliases_intermediate_multiple_mixed",
"test_output_aliases_intermediate_returned_multiple_times",
"test_output_aliases_multiple_inputs_get_correct_one",
"test_output_all_alias_types",
"test_some_outputs_dont_require_grad_view",
"test_view_and_inplace_view",
"test_view_detach",
"test_some_output_requires_grad_input_doesnt",
)


@xfail_inherited_tests(FAILING_CACHE_TESTS)
class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
"""
In memory version of FXGraphCache so we can isolate testing for FXGraphCache
"""

def make_compiler(self, fw_graph_cell):
mock_inductor_cache = self.inductor_cache

def compiler(gm, inputs):
nonlocal mock_inductor_cache, fw_graph_cell
result = mock_inductor_cache.load(gm, inputs)
fw_graph_cell[0] = gm
return result

return compiler

def run_autograd(
self,
f: Callable,
fw_graph_cell: List[Optional[Callable]],
decompositions: Optional[Dict],
keep_input_mutations: bool,
dynamic: bool,
):
return super().run_autograd(
f,
fw_graph_cell,
decompositions,
keep_input_mutations,
dynamic,
)

@torch._functorch.config.patch(
{"enable_autograd_cache": True, "strict_autograd_cache": True}
)
@torch._inductor.config.patch("fx_graph_cache", True)
def verify_aot_autograd(
self,
f,
inp_: Union[Callable, List[Any]],
*,
test_mutation: bool = False,
keep_inp_mutations: bool = False,
decompositions: Optional[Dict] = None,
dynamic: bool = False,
# Only active when inp_ is Callable.
# TODO: probably consolidate all tests to make inp a Callable.
make_inputs_subclasses: bool = False,
):
self.inductor_cache = MockFXGraphCache()
AOTAutogradCache.clear()
with patch(
"torch._inductor.codecache.FxGraphCache._lookup_graph",
new=self.inductor_cache._lookup_graph,
):
return super().verify_aot_autograd(
f,
inp_,
test_mutation=test_mutation,
keep_inp_mutations=keep_inp_mutations,
decompositions=decompositions,
dynamic=dynamic,
make_inputs_subclasses=make_inputs_subclasses,
)

def test_input_mutation_false_aliasing(self):
# This test is disabled because it fails in strict cache mode
# But also can't be xfailed because it causes undefined behavior for
# ASAN
self.skipTest("Skipping because it fails in strict cache mode")


if __name__ == "__main__":
run_tests()
14 changes: 11 additions & 3 deletions torch/_functorch/_aot_autograd/autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,15 @@ def load(
log.info("AOTAutograd cache miss for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_miss"] += 1
# Count missing the FXGraphCache as a miss not a bypass
except FXGraphCacheMiss:
except FXGraphCacheMiss as e:
counters["aot_autograd"]["autograd_cache_miss"] += 1
except BypassAOTAutogradCache:
if config.strict_autograd_cache:
raise e
except BypassAOTAutogradCache as e:
cache_key = None
counters["aot_autograd"]["autograd_cache_bypass"] += 1
if config.strict_autograd_cache:
raise e
if compiled_fn is None:
# Set the cache key so we can save a cache result later
aot_config.cache_key = cache_key
Expand All @@ -470,6 +474,8 @@ def _lookup(key: str) -> Optional[AOTAutogradCacheEntry]:
return entry
except Exception as e:
log.warning("AOTAutograd cache unable to load compiled graph: %s", e)
if config.strict_autograd_cache:
raise e
return None

@staticmethod
Expand All @@ -479,7 +485,9 @@ def save(key: str, entry: AOTAutogradCacheEntry):
content = pickle.dumps(entry)
except Exception as e:
log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e)
raise e
if config.strict_autograd_cache:
raise e
return None
subdir = os.path.join(AOTAutogradCache._get_tmp_dir(), key)
if not os.path.exists(subdir):
os.makedirs(subdir, exist_ok=True)
Expand Down
3 changes: 3 additions & 0 deletions torch/_functorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@

enable_autograd_cache = os.environ.get("ENABLE_AOT_AUTOGRAD_CACHE", "0") == "1"

# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests
strict_autograd_cache = False

if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
Expand Down

0 comments on commit 5b14943

Please sign in to comment.