Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
gyzhou2000 committed Jul 10, 2024
1 parent 8576663 commit ffde26a
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions gammagl/ops/sparse/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from ._convert import c_ind2ptr, c_ptr2ind
from ._neighbor_sample import c_neighbor_sample, c_hetero_neighbor_sample

from ._sparse_cuda import cuda_torch_ind2ptr, cuda_torch_ptr2ind, cuda_torch_neighbor_sample, cuda_torch_sample_adj
try:
from ._sparse_cuda import (cuda_torch_ind2ptr, cuda_torch_ptr2ind, cuda_torch_neighbor_sample, cuda_torch_sample_adj)
except:
Warning("cuda sparse ops load failed.")

from gammagl.utils.platform_utils import out_tensor_list, Tensor, out_tensor

Expand All @@ -39,10 +42,14 @@ def ind2ptr(
num_worker: int = 0) -> Tensor:
if isinstance(ind, numpy.ndarray):
return c_ind2ptr(ind, M, num_worker)
elif(str(ind.device)=='cpu'):
elif(tlx.BACKEND != "torch" or str(ind.device)=='cpu'):
return c_ind2ptr(ind, M, num_worker)
else:
return cuda_torch_ind2ptr(ind, M)
try:
return cuda_torch_ind2ptr(ind, M)
except Error as e:
print("cuda_torch_ind2ptr error")


# @out_tensor
# def cu_ind2ptr(
Expand All @@ -59,10 +66,13 @@ def ptr2ind(
num_worker: int = 1):
if isinstance(ptr, numpy.ndarray):
return c_ptr2ind(ptr, E, num_worker)
elif(str(ptr.device)=='cpu'):
elif(tlx.BACKEND != "torch" or str(ptr.device)=='cpu'):
return c_ptr2ind(ptr, E, num_worker)
else:
return cuda_torch_ptr2ind(ptr, E)
try:
return cuda_torch_ptr2ind(ptr, E)
except Error as e:
print("cuda_torch_ptr2ind error")


@out_tensor_list
Expand All @@ -73,7 +83,7 @@ def neighbor_sample(
num_neighbors: List,
replace: bool,
directed: bool) -> Tuple[List, List, List, List]:
if(str(colptr.device)=='cpu'):
if(tlx.BACKEND != "torch" or str(colptr.device)=='cpu'):
start = time()
res = c_neighbor_sample(colptr, row, input_node, num_neighbors, replace, directed)
print(f'c算子 cost {time() - start}s')
Expand Down Expand Up @@ -135,7 +145,7 @@ def sample_adj(
idx: Tensor,
num_neighbors: int,
replace: bool = False):
if(str(rowptr.device)=='cpu'):
if(tlx.BACKEND != "torch" or str(rowptr.device)=='cpu'):
# num_neighbors = tlx.convert_to_tensor([num_neighbors], dtype=tlx.int64).to('cpu')
# rowptr = rowptr.to('cuda:2')
# col = col.to('cuda:2')
Expand Down

0 comments on commit ffde26a

Please sign in to comment.