Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add inference_mode test coverage on the KT.regroup function (#2243)
Summary: Pull Request resolved: #2243 # context * help MRS to analyze the root cause and prevention regarding S422627 > we want to try to avoid getting auto_grad versions of kernels into preditctor; we want to try to raise Developer awareness to steps needed to successful launch new kernels. ST solution is just ensure link_whole=True + some cogswell tests on CPU model. Our contribution could be help develop awareness / add work to autoswap kernels to inference version in model processing workflows. * usually we'll need to use "with torch.inference_mode()" to avoid the operater being dispatched to the autograd backend. * dstaay-fb raises a question "if i jit.script the module what kernel is recorded?" # setup * a very simple module ``` class MyModule(torch.nn.Module): def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]: # user provided, not model input groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]] return regroup_func(inputs, groups) ``` * run the script module ``` m = MyModule() script_model = torch.jit.script(m) with torch.inference_mode(): outputs = script_model(inputs) ``` * the script_model only contains the function's name, not the dispatched implementation ``` (Pdb) print(script_model.graph) graph(%self : __torch__.torchrec.sparse.tests.test_jagged_tensor.___torch_mangle_27.MyModule, %inputs.1 : __torch__.torchrec.sparse.jagged_tensor.KeyedTensor[]): %12 : Function = prim::Constant[name="regroup_kts"]() %7 : str = prim::Constant[value="sparse_0"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:74 %6 : str = prim::Constant[value="dense_1"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:63 %4 : str = prim::Constant[value="dense_2"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:50 %3 : str = prim::Constant[value="sparse_1"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:38 %2 : str = prim::Constant[value="dense_0"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:27 %5 : str[] = prim::ListConstruct(%2, %3, %4) %8 : str[] = prim::ListConstruct(%6, %7) %groups.1 : str[][] = prim::ListConstruct(%5, %8) %13 : Tensor[] = prim::CallFunction(%12, %inputs.1, %groups.1) # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2492:23 return (%13) ``` # results * log: P1492020919 * without inference_mode: dispatch to autograd very single time before dispatching to the corresponding backend ``` running test_regroup_scriptable () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cpu'} regroup_keyed_tensor_autograd kt_regroup_arguments_cpu kt_regroup_arguments_impl PermuteMultiEmbeddingOp::forward permute_multi_embedding_function_cpu running test_regroup_scriptable () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cuda'} regroup_keyed_tensor_autograd kt_regroup_arguments_gpu kt_regroup_arguments_impl PermuteMultiEmbeddingOp::forward permute_multi_embedding_function_gpu running test_regroup_scriptable () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'meta'} regroup_keyed_tensor_autograd kt_regroup_arguments_meta PermuteMultiEmbeddingOp::forward permute_multi_embedding_function_meta running test_regroup_scriptable () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cpu'} kt_regroup_arguments_cpu kt_regroup_arguments_impl permute_multi_embedding_autograd PermuteMultiEmbeddingOp::forward permute_multi_embedding_function_cpu running test_regroup_scriptable () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cuda'} kt_regroup_arguments_gpu kt_regroup_arguments_impl permute_multi_embedding_autograd PermuteMultiEmbeddingOp::forward permute_multi_embedding_function_gpu running test_regroup_scriptable () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'meta'} kt_regroup_arguments_meta permute_multi_embedding_autograd PermuteMultiEmbeddingOp::forward permute_multi_embedding_function_meta ``` * with inference_mode: directly dispatch the op to the corresponding backend ``` running test_regroup_scriptable_inference () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cpu'} regroup_keyed_tensor_cpu kt_regroup_arguments_cpu kt_regroup_arguments_impl permute_multi_embedding_function_cpu running test_regroup_scriptable_inference () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cuda'} regroup_keyed_tensor_gpu kt_regroup_arguments_gpu kt_regroup_arguments_impl permute_multi_embedding_function_gpu running test_regroup_scriptable_inference () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'meta'} regroup_keyed_tensor_meta kt_regroup_arguments_meta permute_multi_embedding_function_meta running test_regroup_scriptable_inference () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cpu'} kt_regroup_arguments_cpu kt_regroup_arguments_impl permute_multi_embedding_cpu permute_multi_embedding_function_cpu running test_regroup_scriptable_inference () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cuda'} kt_regroup_arguments_gpu kt_regroup_arguments_impl permute_multi_embedding_gpu permute_multi_embedding_function_gpu running test_regroup_scriptable_inference () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'meta'} kt_regroup_arguments_meta permute_multi_embedding_meta permute_multi_embedding_function_meta ``` Reviewed By: dstaay-fb Differential Revision: D48359504 fbshipit-source-id: 7758c8fe1552a7ee8d8e7da740c9c1a4f74953bd
- Loading branch information