Skip to content

Commit

Permalink
add inference_mode test coverage on the KT.regroup function (#2243)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2243

# context
* help MRS to analyze the root cause and prevention regarding S422627
> we want to try to avoid getting auto_grad versions of kernels into preditctor; we want to try to raise Developer awareness to steps needed to successful launch new kernels.   ST solution is just ensure link_whole=True + some cogswell tests on CPU model.   Our contribution could be help develop awareness / add work to autoswap kernels to inference version in model processing workflows.
* usually we'll need to use "with torch.inference_mode()" to avoid the operater being dispatched to the autograd backend.
* dstaay-fb raises a question "if i jit.script the module what kernel is recorded?"

# setup
* a very simple module
```
        class MyModule(torch.nn.Module):
            def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]:
                # user provided, not model input
                groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]
                return regroup_func(inputs, groups)
```
* run the script module
```
        m = MyModule()
        script_model = torch.jit.script(m)
        with torch.inference_mode():
            outputs = script_model(inputs)
```
* the script_model only contains the function's name, not the dispatched implementation
```
(Pdb) print(script_model.graph)
graph(%self : __torch__.torchrec.sparse.tests.test_jagged_tensor.___torch_mangle_27.MyModule,
      %inputs.1 : __torch__.torchrec.sparse.jagged_tensor.KeyedTensor[]):
  %12 : Function = prim::Constant[name="regroup_kts"]()
  %7 : str = prim::Constant[value="sparse_0"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:74
  %6 : str = prim::Constant[value="dense_1"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:63
  %4 : str = prim::Constant[value="dense_2"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:50
  %3 : str = prim::Constant[value="sparse_1"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:38
  %2 : str = prim::Constant[value="dense_0"]() # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2491:27
  %5 : str[] = prim::ListConstruct(%2, %3, %4)
  %8 : str[] = prim::ListConstruct(%6, %7)
  %groups.1 : str[][] = prim::ListConstruct(%5, %8)
  %13 : Tensor[] = prim::CallFunction(%12, %inputs.1, %groups.1) # /data/sandcastle/boxes/fbsource/buck-out/v2/gen/fbcode/1587bf3ddf9259d6/torchrec/sparse/tests/__test_jagged_tensor__/test_jagged_tensor#link-tree/torchrec/sparse/tests/test_jagged_tensor.py:2492:23
  return (%13)
```
# results
* log: P1492020919
* without inference_mode: dispatch to autograd very single time before dispatching to the corresponding backend
```
running test_regroup_scriptable () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cpu'}
regroup_keyed_tensor_autograd
kt_regroup_arguments_cpu
kt_regroup_arguments_impl
PermuteMultiEmbeddingOp::forward
permute_multi_embedding_function_cpu
running test_regroup_scriptable () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cuda'}
regroup_keyed_tensor_autograd
kt_regroup_arguments_gpu
kt_regroup_arguments_impl
PermuteMultiEmbeddingOp::forward
permute_multi_embedding_function_gpu
running test_regroup_scriptable () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'meta'}
regroup_keyed_tensor_autograd
kt_regroup_arguments_meta
PermuteMultiEmbeddingOp::forward
permute_multi_embedding_function_meta
running test_regroup_scriptable () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cpu'}
kt_regroup_arguments_cpu
kt_regroup_arguments_impl
permute_multi_embedding_autograd
PermuteMultiEmbeddingOp::forward
permute_multi_embedding_function_cpu
running test_regroup_scriptable () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cuda'}
kt_regroup_arguments_gpu
kt_regroup_arguments_impl
permute_multi_embedding_autograd
PermuteMultiEmbeddingOp::forward
permute_multi_embedding_function_gpu
running test_regroup_scriptable () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'meta'}
kt_regroup_arguments_meta
permute_multi_embedding_autograd
PermuteMultiEmbeddingOp::forward
permute_multi_embedding_function_meta
```
* with inference_mode: directly dispatch the op to the corresponding backend
```
running test_regroup_scriptable_inference () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cpu'}
regroup_keyed_tensor_cpu
kt_regroup_arguments_cpu
kt_regroup_arguments_impl
permute_multi_embedding_function_cpu
running test_regroup_scriptable_inference () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'cuda'}
regroup_keyed_tensor_gpu
kt_regroup_arguments_gpu
kt_regroup_arguments_impl
permute_multi_embedding_function_gpu
running test_regroup_scriptable_inference () {'regroup_func': <function regroup_kts at 0x7f38a0aba320>, 'device_str': 'meta'}
regroup_keyed_tensor_meta
kt_regroup_arguments_meta
permute_multi_embedding_function_meta
running test_regroup_scriptable_inference () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cpu'}
kt_regroup_arguments_cpu
kt_regroup_arguments_impl
permute_multi_embedding_cpu
permute_multi_embedding_function_cpu
running test_regroup_scriptable_inference () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'cuda'}
kt_regroup_arguments_gpu
kt_regroup_arguments_impl
permute_multi_embedding_gpu
permute_multi_embedding_function_gpu
running test_regroup_scriptable_inference () {'regroup_func': <function permute_multi_embedding at 0x7f38a0aba290>, 'device_str': 'meta'}
kt_regroup_arguments_meta
permute_multi_embedding_meta
permute_multi_embedding_function_meta
```

Reviewed By: dstaay-fb

Differential Revision: D48359504

fbshipit-source-id: 7758c8fe1552a7ee8d8e7da740c9c1a4f74953bd
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 24, 2024
1 parent 60b0347 commit abd692a
Showing 1 changed file with 82 additions and 3 deletions.
85 changes: 82 additions & 3 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,15 +2425,94 @@ def test_regroup_multiple_kt_duplicate_keys(self) -> None:
)
)

def test_regroup_scriptable(self) -> None:
@repeat_test(
regroup_func=[
KeyedTensor.regroup,
regroup_kts,
permute_multi_embedding,
],
device_str=["cpu", "cuda", "meta"],
)
def test_regroup_scriptable(
self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str
) -> None:
if device_str == "cuda" and not torch.cuda.is_available():
return
else:
device = torch.device(device_str)

class MyModule(torch.nn.Module):
def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]:
# user provided, not model input
groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]
return KeyedTensor.regroup(inputs, groups)
return regroup_func(inputs, groups)

m = MyModule()
torch.jit.script(m)
script_model = torch.jit.script(m)
# input
key_dim = 1
tensor_list_1 = [torch.randn(2, 3, device=device) for i in range(3)]
keys_1 = ["dense_0", "dense_1", "dense_2"]
kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim)
tensor_list_2 = [torch.randn(2, 3, device=device) for i in range(2)]
keys_2 = ["sparse_0", "sparse_1"]
kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim)
inputs = [kt_1, kt_2]
outputs = script_model(inputs)
refs = _regroup_keyed_tensors(
inputs, [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]
)
for ref, output in zip(refs, outputs):
self.assertEqual(ref.device, output.device)
if device_str == "meta":
self.assertEqual(ref.shape, output.shape)
else:
torch.testing.assert_close(ref, output)

@repeat_test(
regroup_func=[
KeyedTensor.regroup,
regroup_kts,
permute_multi_embedding,
],
device_str=["cpu", "cuda", "meta"],
)
def test_regroup_scriptable_inference(
self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str
) -> None:
if device_str == "cuda" and not torch.cuda.is_available():
return
else:
device = torch.device(device_str)

class MyModule(torch.nn.Module):
def forward(self, inputs: List[KeyedTensor]) -> List[torch.Tensor]:
# user provided, not model input
groups = [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]
return regroup_func(inputs, groups)

m = MyModule()
script_model = torch.jit.script(m)
with torch.inference_mode():
# input
key_dim = 1
tensor_list_1 = [torch.randn(2, 3, device=device) for i in range(3)]
keys_1 = ["dense_0", "dense_1", "dense_2"]
kt_1 = KeyedTensor.from_tensor_list(keys_1, tensor_list_1, key_dim)
tensor_list_2 = [torch.randn(2, 3, device=device) for i in range(2)]
keys_2 = ["sparse_0", "sparse_1"]
kt_2 = KeyedTensor.from_tensor_list(keys_2, tensor_list_2, key_dim)
inputs = [kt_1, kt_2]
outputs = script_model(inputs)
refs = _regroup_keyed_tensors(
inputs, [["dense_0", "sparse_1", "dense_2"], ["dense_1", "sparse_0"]]
)
for ref, output in zip(refs, outputs):
self.assertEqual(ref.device, output.device)
if device_str == "meta":
self.assertEqual(ref.shape, output.shape)
else:
torch.testing.assert_close(ref, output)

def test_regroup_fxable(self) -> None:
class MyModule(torch.nn.Module):
Expand Down

0 comments on commit abd692a

Please sign in to comment.