Skip to content

Commit

Permalink
[Bugfix] Replace global cudaStream in Filter with runtime calls (fix d…
Browse files Browse the repository at this point in the history
…mlc#5153) (dmlc#5157)

* Add failing unit test

* Add fix

* Remove extra newline

* skip cpu test

Co-authored-by: Xin Yao <[email protected]>
  • Loading branch information
nv-dlasalle and yaox12 authored Jan 12, 2023
1 parent 84e4d02 commit 751b4c2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/array/cuda/cuda_filter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ namespace array {

namespace {

cudaStream_t cudaStream = runtime::getCurrentCUDAStream();

template <typename IdType, bool include>
__global__ void _IsInKernel(
DeviceOrderedHashTable<IdType> table, const IdType* const array,
Expand All @@ -46,6 +44,7 @@ IdArray _PerformFilter(const OrderedHashTable<IdType>& table, IdArray test) {
const auto& ctx = test->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
const int64_t size = test->shape[0];
cudaStream_t cudaStream = runtime::getCurrentCUDAStream();

if (size == 0) {
return test;
Expand Down Expand Up @@ -108,7 +107,8 @@ template <typename IdType>
class CudaFilterSet : public Filter {
public:
explicit CudaFilterSet(IdArray array)
: table_(array->shape[0], array->ctx, cudaStream) {
: table_(array->shape[0], array->ctx, runtime::getCurrentCUDAStream()) {
cudaStream_t cudaStream = runtime::getCurrentCUDAStream();
table_.FillWithUnique(
static_cast<const IdType*>(array->data), array->shape[0], cudaStream);
}
Expand Down
26 changes: 24 additions & 2 deletions tests/compute/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import unittest

import backend as F
import numpy as np
from test_utils import parametrize_idtype

import dgl
import numpy as np
from dgl.utils import Filter
from test_utils import parametrize_idtype


def test_graph_filter():
Expand Down Expand Up @@ -71,6 +71,28 @@ def test_array_filter(idtype):
assert F.array_equal(ye_act, ye_exp)


@unittest.skipIf(
dgl.backend.backend_name != "pytorch",
reason="Multiple streams are only supported by pytorch backend",
)
@unittest.skipIf(
F._default_context_str == "cpu", reason="CPU not yet supported"
)
@parametrize_idtype
def test_filter_multistream(idtype):
# this is a smoke test to ensure we do not trip any internal assertions
import torch

s = torch.cuda.Stream(device=F.ctx())
with torch.cuda.stream(s):
# we must do multiple runs such that the stream is busy as we launch
# work
for i in range(10):
f = Filter(F.arange(1000, 4000, dtype=idtype, ctx=F.ctx()))
x = F.randint([30000], dtype=idtype, ctx=F.ctx(), low=0, high=50000)
xi = f.find_included_indices(x)


if __name__ == "__main__":
test_graph_filter()
test_array_filter()

0 comments on commit 751b4c2

Please sign in to comment.