diff --git a/CHANGELOG.md b/CHANGELOG.md index f4d8b7fc4..e4fb911a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.4.0] - 2023-MM-DD ### Added +- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272)) ### Changed - Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267)) - Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270)) diff --git a/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp b/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp index e9d15b299..06a020086 100644 --- a/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp +++ b/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp @@ -297,7 +297,8 @@ std::vector grouped_matmul_kernel(const at::TensorList input, {input_contig[i].size(0), other_contig[i].size(-1)})); } - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, input_contig.front().scalar_type(), "grouped_matmul_kernel", [&] { if (mkl_path_available() && mkl_path_possible(input_contig, other_contig)) { @@ -413,7 +414,8 @@ at::Tensor segment_matmul_kernel(const at::Tensor& input, const auto other_contig = other.contiguous(); auto out = input_contig.new_empty({input.size(0), other.size(-1)}); - AT_DISPATCH_ALL_TYPES( + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, input_contig.scalar_type(), "segment_matmul_kernel", [&] { const auto n = other_contig.size(-1); const auto k = input_contig.size(-1); diff --git a/pyg_lib/testing.py b/pyg_lib/testing.py index 66e673079..d55c66c00 100644 --- a/pyg_lib/testing.py +++ b/pyg_lib/testing.py @@ -37,12 +37,13 @@ def onlyTriton(func: Callable) -> Callable: def withCUDA(func: Callable) -> Callable: - def wrapper(*args, **kwargs): - func(*args, device=torch.device('cpu'), **kwargs) - if torch.cuda.is_available(): - func(*args, device=torch.device('cuda:0'), **kwargs) + import pytest - return wrapper + devices = [torch.device('cpu')] + if torch.cuda.is_available(): + devices.append(torch.device('cuda:0')) + + return pytest.mark.parametrize('device', devices)(func) def withDataset(group: str, name: str) -> Callable: diff --git a/test/ops/test_matmul.py b/test/ops/test_matmul.py index 3fa74ed97..b9001d3fc 100644 --- a/test/ops/test_matmul.py +++ b/test/ops/test_matmul.py @@ -1,5 +1,6 @@ import os +import pytest import torch import pyg_lib @@ -11,11 +12,17 @@ @withCUDA -def test_segment_matmul_autograd(device): - inputs = torch.randn((8, 16), requires_grad=True, device=device) +@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16]) +def test_segment_matmul_autograd(dtype, device): + if device.type == 'cuda' and dtype == torch.bfloat16: + pytest.skip('CUDA does not support bfloat16') + + inputs = torch.randn((8, 16), requires_grad=True, device=device, + dtype=dtype) ptr = torch.tensor([0, 5, 8]).to(torch.device(device)) - other = torch.randn((2, 16, 32), requires_grad=True, device=device) - bias = torch.randn((2, 32), requires_grad=True, device=device) + other = torch.randn((2, 16, 32), requires_grad=True, device=device, + dtype=dtype) + bias = torch.randn((2, 32), requires_grad=True, device=device, dtype=dtype) out = pyg_lib.ops.segment_matmul(inputs, ptr, other, bias) assert out.size() == (8, 32) @@ -31,7 +38,11 @@ def test_segment_matmul_autograd(device): @withCUDA -def test_grouped_matmul_autograd(device): +@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16]) +def test_grouped_matmul_autograd(dtype, device): + if device.type == 'cuda' and dtype == torch.bfloat16: + pytest.skip('CUDA does not support bfloat16') + inputs = [ torch.randn(5, 16, device=device, requires_grad=True), torch.randn(6, 9, device=device, requires_grad=True),