Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gather_backward op #363

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 34 additions & 21 deletions benchmark/test_select_and_slice_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,36 +99,49 @@ def scatter_input_fn(shape, dtype, device):
bench.run()


@pytest.mark.gather
def test_perf_gather():
def gather_input_fn(shape, dtype, device):
inp = torch.randn(shape, dtype=dtype, device=device)
def gather_input_fn(shape, dtype, device):
inp = torch.randn(shape, dtype=dtype, device=device)

dim = random.choice([0, 1])
size_dim = shape[dim]
index_shape = [
random.randint(1, shape[0]),
random.randint(1, shape[1]),
]
index = torch.empty(tuple(index_shape), dtype=torch.long, device=device)

m, n = index_shape

index_size_dim = index_shape[dim]
# make unique indices
for i in range(1 if dim == 0 else m):
for j in range(1 if dim == 1 else n):
ii = [i, j]
ii[dim] = slice(0, index.size(dim) + 1)
index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim]

dim = random.choice([0, 1])
size_dim = shape[dim]
index_shape = [
random.randint(1, shape[0]),
random.randint(1, shape[1]),
]
index = torch.empty(tuple(index_shape), dtype=torch.long, device=device)
yield inp, dim, index

m, n = index_shape

index_size_dim = index_shape[dim]
# make unique indices
for i in range(1 if dim == 0 else m):
for j in range(1 if dim == 1 else n):
ii = [i, j]
ii[dim] = slice(0, index.size(dim) + 1)
index[tuple(ii)] = torch.randperm(size_dim)[0:index_size_dim]
@pytest.mark.gather
def test_perf_gather():
bench = TensorSelectBenchmark(
op_name="gather",
torch_op=torch.gather,
input_fn=gather_input_fn,
dtypes=FLOAT_DTYPES,
)
bench.run()

yield inp, dim, index

@pytest.mark.gather_backward
def test_perf_gather_backward():
bench = TensorSelectBenchmark(
op_name="gather",
op_name="gather_backward",
torch_op=torch.gather,
input_fn=gather_input_fn,
dtypes=FLOAT_DTYPES,
is_backward=True,
)
bench.run()

Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def enable(lib=aten_lib, unused=None):
("scatter.src", scatter, Autograd.disable),
("scatter.reduce", scatter, Autograd.disable),
("gather", gather, Autograd.disable),
("gather_backward", gather_backward, Autograd.disable),
("isclose", isclose, Autograd.disable),
("allclose", allclose, Autograd.disable),
("fill.Scalar", fill_scalar, Autograd.disable),
Expand Down
3 changes: 2 additions & 1 deletion src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .flip import flip
from .full import full
from .full_like import full_like
from .gather import gather
from .gather import gather, gather_backward
from .ge import ge, ge_scalar
from .gelu import gelu
from .groupnorm import group_norm
Expand Down Expand Up @@ -174,6 +174,7 @@
"fill_tensor",
"exponential_",
"gather",
"gather_backward",
"flip",
"ones_like",
"full_like",
Expand Down
7 changes: 7 additions & 0 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from flag_gems.utils.code_utils import IndentedBuffer, NameSpace
from flag_gems.utils.shape_utils import restride_dim

from .scatter import scatter


def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
code.writeline("import torch")
Expand Down Expand Up @@ -258,3 +260,8 @@ def gather(inp, dim, index, out=None, sparse_grad=False):

_gather_func(inp_strided, out, index, dim, stride_dim, M, N)
return out


def gather_backward(grad, self, dim, index, sparse_grad):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add log here and check whether the relevant code is executed when running the unit test, because the coverage CI does not pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added the log.

result = torch.zeros_like(self)
return scatter(result, dim, index, grad, reduce="add")
13 changes: 12 additions & 1 deletion tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ def test_accuracy_scatter_mul(src_shape, inp_shape, dim, dtype):
@pytest.mark.parametrize("dim", [0, 1, 2])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_gather(inp_shape, dim, dtype):
inp = torch.randn(inp_shape, dtype=dtype, device=flag_gems.device)
inp = torch.randn(
inp_shape, dtype=dtype, device=flag_gems.device, requires_grad=True
)
size_dim = inp_shape[dim]

import random
Expand Down Expand Up @@ -540,6 +542,15 @@ def test_accuracy_gather(inp_shape, dim, dtype):

gems_assert_equal(res_out, ref_out)

out_grad = torch.randn_like(res_out)
ref_grad = to_reference(out_grad)

(ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad)
with flag_gems.use_gems():
(res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad)
res_in_grad = to_reference(res_in_grad)
gems_assert_equal(res_in_grad, ref_in_grad)


@pytest.mark.select_scatter
@pytest.mark.parametrize("shape", REDUCTION_SHAPES)
Expand Down
Loading