Skip to content

Commit

Permalink
fixed nnz value (#17)
Browse files Browse the repository at this point in the history
fixes for the rspmm kernels
  • Loading branch information
migalkin authored Mar 30, 2024
1 parent 871add0 commit c414f83
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
5 changes: 4 additions & 1 deletion ultra/rspmm/rspmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="
def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
if extra_cflags is None:
extra_cflags = ["-Ofast"]
if torch.backends.openmp.is_available():
# PyTorch 2.2.1+ on Apple Silicon is now compiled by default with OpenMP
# However, installing OpenMP on macs properly and wiring it together to the compiler is tedious
# So on macs we turn off OpenMP (as the default behavior in all torch < 2.2.1 versions)
if torch.backends.openmp.is_available() and not sys.platform.startswith('darwin'):
extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
else:
extra_cflags.append("-DAT_PARALLEL_NATIVE")
Expand Down
4 changes: 2 additions & 2 deletions ultra/rspmm/source/rspmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ Tensor rspmm_forward_cpu(const Tensor &edge_index_, const Tensor &edge_type_, co
const Tensor relation = relation_.contiguous();
const Tensor input = input_.contiguous();

int64_t nnz = edge_index.size(0);
int64_t nnz = edge_index.size(1);
int64_t num_row = input.size(0);
int64_t dim = input.size(1);
Tensor output = at::empty({num_row, dim}, input.options());
Expand Down Expand Up @@ -183,7 +183,7 @@ std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cpu(
const Tensor output = output_.contiguous();
const Tensor output_grad = output_grad_.contiguous();

int64_t nnz = edge_index.size(0);
int64_t nnz = edge_index.size(1);
int64_t num_row = input.size(0);
int64_t dim = input.size(1);
Tensor weight_grad = at::zeros_like(edge_weight);
Expand Down
4 changes: 2 additions & 2 deletions ultra/rspmm/source/rspmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Tensor rspmm_forward_cuda(const Tensor &edge_index_, const Tensor &edge_type_, c
const Tensor relation = relation_.contiguous();
const Tensor input = input_.contiguous();

int64_t nnz = edge_index.size(0);
int64_t nnz = edge_index.size(1);
int64_t num_row = input.size(0);
int64_t dim = input.size(1);
Tensor output = at::empty({num_row, dim}, input.options());
Expand Down Expand Up @@ -289,7 +289,7 @@ std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cuda(
const Tensor output = output_.contiguous();
const Tensor output_grad = output_grad_.contiguous();

int64_t nnz = edge_index.size(0);
int64_t nnz = edge_index.size(1);
int64_t num_row = input.size(0);
int64_t dim = input.size(1);
Tensor weight_grad = at::zeros_like(edge_weight);
Expand Down

0 comments on commit c414f83

Please sign in to comment.