Skip to content

Commit

Permalink
Refactor test_jagged_tensor (pytorch#2241)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2241

# context
* refactor test_jagged_tensor file for better structure of operator test
* update the forward, backward, and device tests.
* **consolidate** _cpu/_gpu/_meta tests
* use a `repeat_test` decorator to iterate a test with a set of arguments

# usages
* list usage
```
    repeat_test(
        ["cpu", 32, [[3, 4], [5, 6, 7], [8]]],
        ["cuda", 128, [[96, 256], [512, 128, 768], [1024]]],
    )
    def test_multi_permute_backward(
        self, device_str: str, batch_size: int, lengths: List[List[int]]
    ) -> None:
        if device_str == "cuda" and not torch.cuda.is_available():
            return
        else:
            device = torch.device(device_str)
```
* dict usage
```
    repeat_test(device_str=["cpu", "cuda"], batch_size=[16, 1024])
    def test_multi_permute_noncontiguous(
        self, device_str: str, batch_size: int
    ) -> None:
```

Differential Revision: D43653576
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 23, 2024
1 parent 9264186 commit bdfdc27
Show file tree
Hide file tree
Showing 2 changed files with 319 additions and 498 deletions.
Loading

0 comments on commit bdfdc27

Please sign in to comment.