Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gather_backward op #363

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Conversation

awayzjj
Copy link
Collaborator

@awayzjj awayzjj commented Dec 15, 2024

PR Category

Type of Change

Description

Issue

Closes #317

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

image
image

@0x45f 0x45f self-requested a review December 16, 2024 08:56
@@ -258,3 +260,8 @@ def gather(inp, dim, index, out=None, sparse_grad=False):

_gather_func(inp_strided, out, index, dim, stride_dim, M, N)
return out


def gather_backward(grad, self, dim, index, sparse_grad):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add log here and check whether the relevant code is executed when running the unit test, because the coverage CI does not pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added the log.

@0x45f
Copy link
Collaborator

0x45f commented Dec 17, 2024

It seems that the speedup ratio is not very ideal

@0x45f
Copy link
Collaborator

0x45f commented Dec 19, 2024

It seems that the speedup ratio is not very ideal

I did some profiling work and found some code that can be optimized. I will do some tests locally first.

@0x45f
Copy link
Collaborator

0x45f commented Dec 24, 2024

i modify gather backwar:

  1. Using new_zeros instead of zeros_like
  2. Using inplace scatter

perf in A100

before

Operator: gather_backward backward  Performance Test (dtype=torch.float16, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.021504            0.134144               0.160          [torch.Size([64, 64]), 0, torch.Size([9, 33])]
SUCCESS               0.021504            0.126976               0.169          [torch.Size([256, 256]), 0, torch.Size([254, 231])]
SUCCESS               0.027648            0.135168               0.205          [torch.Size([1024, 1024]), 1, torch.Size([778, 430])]
SUCCESS               0.144384            0.219136               0.659          [torch.Size([4096, 4096]), 0, torch.Size([3997, 233])]
SUCCESS               0.443392            0.652288               0.680          [torch.Size([1024, 65536]), 1, torch.Size([887, 277])]
SUCCESS               0.021504            0.136192               0.158          [torch.Size([1024, 256]), 1, torch.Size([546, 118])]
SUCCESS               0.040960            0.120832               0.339          [torch.Size([1024, 4096]), 0, torch.Size([651, 251])]


Operator: gather_backward backward  Performance Test (dtype=torch.float32, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.014336            0.133120               0.108          [torch.Size([64, 64]), 0, torch.Size([4, 2])]
SUCCESS               0.019456            0.132096               0.147          [torch.Size([256, 256]), 1, torch.Size([111, 217])]
SUCCESS               0.029696            0.132096               0.225          [torch.Size([1024, 1024]), 0, torch.Size([455, 897])]
SUCCESS               0.343040            1.383424               0.248          [torch.Size([4096, 4096]), 1, torch.Size([1910, 2832])]
SUCCESS               5.706752            6.578176               0.868          [torch.Size([1024, 65536]), 0, torch.Size([449, 60242])]
SUCCESS               0.019456            0.130048               0.150          [torch.Size([1024, 256]), 1, torch.Size([45, 214])]
SUCCESS               0.072704            0.167936               0.433          [torch.Size([1024, 4096]), 0, torch.Size([381, 2429])]


Operator: gather_backward backward  Performance Test (dtype=torch.bfloat16, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.032768            0.130048               0.252          [torch.Size([64, 64]), 0, torch.Size([43, 55])]
SUCCESS               0.029696            0.132096               0.225          [torch.Size([256, 256]), 0, torch.Size([156, 146])]
SUCCESS               0.033792            0.132096               0.256          [torch.Size([1024, 1024]), 1, torch.Size([806, 71])]
SUCCESS               0.649216            1.077248               0.603          [torch.Size([4096, 4096]), 1, torch.Size([1989, 3312])]
SUCCESS               1.816576            4.820992               0.377          [torch.Size([1024, 65536]), 1, torch.Size([355, 48120])]
SUCCESS               0.036864            0.125952               0.293          [torch.Size([1024, 256]), 1, torch.Size([178, 225])]
SUCCESS               0.097280            0.138240               0.704          [torch.Size([1024, 4096]), 0, torch.Size([336, 3222])]

after:

Operator: gather_backward backward  Performance Test (dtype=torch.float16, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.022528            0.056320               0.400          [torch.Size([64, 64]), 0, torch.Size([9, 33])]
SUCCESS               0.021504            0.021504               1.000          [torch.Size([256, 256]), 0, torch.Size([254, 231])]
SUCCESS               0.027648            0.046080               0.600          [torch.Size([1024, 1024]), 1, torch.Size([778, 430])]
SUCCESS               0.144384            0.199680               0.723          [torch.Size([4096, 4096]), 0, torch.Size([3997, 233])]
SUCCESS               0.443392            0.447488               0.991          [torch.Size([1024, 65536]), 1, torch.Size([887, 277])]
SUCCESS               0.021504            0.025600               0.840          [torch.Size([1024, 256]), 1, torch.Size([546, 118])]
SUCCESS               0.038912            0.046080               0.844          [torch.Size([1024, 4096]), 0, torch.Size([651, 251])]


Operator: gather_backward backward  Performance Test (dtype=torch.float32, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.013312            0.015360               0.867          [torch.Size([64, 64]), 0, torch.Size([4, 2])]
SUCCESS               0.017408            0.017408               1.000          [torch.Size([256, 256]), 1, torch.Size([111, 217])]
SUCCESS               0.029696            0.052224               0.569          [torch.Size([1024, 1024]), 0, torch.Size([455, 897])]
SUCCESS               0.344064            0.861184               0.400          [torch.Size([4096, 4096]), 1, torch.Size([1910, 2832])]
SUCCESS               5.700608            5.012480               1.137          [torch.Size([1024, 65536]), 0, torch.Size([449, 60242])]
SUCCESS               0.019456            0.018432               1.056          [torch.Size([1024, 256]), 1, torch.Size([45, 214])]
SUCCESS               0.072704            0.122880               0.592          [torch.Size([1024, 4096]), 0, torch.Size([381, 2429])]


Operator: gather_backward backward  Performance Test (dtype=torch.bfloat16, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.031744            0.014336               2.214          [torch.Size([64, 64]), 0, torch.Size([43, 55])]
SUCCESS               0.029696            0.017408               1.706          [torch.Size([256, 256]), 0, torch.Size([156, 146])]
SUCCESS               0.032768            0.024576               1.333          [torch.Size([1024, 1024]), 1, torch.Size([806, 71])]
SUCCESS               0.649216            0.613376               1.058          [torch.Size([4096, 4096]), 1, torch.Size([1989, 3312])]
SUCCESS               1.816576            3.298304               0.551          [torch.Size([1024, 65536]), 1, torch.Size([355, 48120])]
SUCCESS               0.035840            0.020480               1.750          [torch.Size([1024, 256]), 1, torch.Size([178, 225])]
SUCCESS               0.096256            0.102400               0.940          [torch.Size([1024, 4096]), 0, torch.Size([336, 3222])]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Lv3】【Operator Development】gather_backward
3 participants