-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for aten::record_stream (#1058)
Currently there's no available UT in pytorch to test record_stream. These two tests are adapted from corresponding [cuda tests](https://github.com/pytorch/pytorch/blob/a7479fa2828ad55a056a60a629dff6b7a0cb6b98/test/test_cuda.py#L703). The only difference is I use an actual expensive kernel in place of `torch.cuda._sleep` to create delay in one stream. The add kernel here would create sufficient delay based on max memory bandwidth among current supported gpus.
- Loading branch information
Showing
2 changed files
with
75 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import TestCase | ||
|
||
|
||
class TestTorchMethod(TestCase): | ||
def test_record_stream(self): | ||
t = torch.FloatTensor([1, 2, 3, 4]).pin_memory() | ||
result = torch.FloatTensor(t.size()).to("xpu") | ||
stream = torch.xpu.Stream() | ||
ptr = [None] | ||
|
||
# Performs the CPU->GPU copy in a background stream | ||
def perform_copy(): | ||
x = torch.randn(256, 1024, 1024, device="xpu") | ||
y = torch.randn(256, 1024, 1024, device="xpu") | ||
with torch.xpu.stream(stream): | ||
tmp = t.xpu(non_blocking=True) | ||
ptr[0] = tmp.data_ptr() | ||
torch.xpu.current_stream().wait_stream(stream) | ||
tmp.record_stream(torch.xpu.current_stream()) | ||
for i in range(30): # delay the copy | ||
z = x + y | ||
result.copy_(tmp) | ||
|
||
perform_copy() | ||
with torch.xpu.stream(stream): | ||
tmp2 = torch.FloatTensor(t.size()).to("xpu") | ||
tmp2.zero_() | ||
self.assertNotEqual( | ||
tmp2.data_ptr(), ptr[0], msg="allocation re-used to soon" | ||
) | ||
|
||
self.assertEqual(result.tolist(), [1, 2, 3, 4]) | ||
|
||
# In the native allocator, we expect "tmp"'s side-stream-tagged block will be reused | ||
# in that side stream after result.copy_(tmp) in the main stream finishes. | ||
torch.xpu.current_stream().synchronize() | ||
with torch.xpu.stream(stream): | ||
tmp3 = torch.FloatTensor(t.size()).to("xpu") | ||
self.assertEqual(tmp3.data_ptr(), ptr[0], msg="allocation not re-used") | ||
|
||
def test_record_stream_on_shifted_view(self): | ||
# See PyTorch issue #27366 | ||
# This test detects unexpected block reallocation. For reliable test, | ||
# the stream to allocate tensors is isolated. The allocator will not | ||
# reuse free blocks which were allocated from another stream. | ||
x = torch.randn(256, 1024, 1024, device="xpu") | ||
y = torch.randn(256, 1024, 1024, device="xpu") | ||
|
||
stream_alloc = torch.xpu.Stream() | ||
with torch.xpu.stream(stream_alloc): | ||
base = torch.FloatTensor([10, 10]).xpu() | ||
|
||
# Record another stream on a shifted view tensor. | ||
view = base[5:] | ||
self.assertTrue(view.storage_offset() > 0) | ||
|
||
stream_record = torch.xpu.Stream() | ||
with torch.xpu.stream(stream_record): | ||
for i in range(30): | ||
z = x+y | ||
|
||
view.record_stream(stream_record) | ||
|
||
# Delete those tensors to make the block free soon. | ||
data_ptr = base.data_ptr() | ||
del base, view | ||
|
||
# A new tensor should not be allocated to the block above. | ||
stream_alloc.synchronize() | ||
|
||
with torch.xpu.stream(stream_alloc): | ||
try_realloc = torch.FloatTensor([10, 10]).xpu() | ||
|
||
self.assertNotEqual(try_realloc.data_ptr(), data_ptr) |