-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add cov op #276
Open
RubiaCx
wants to merge
9
commits into
FlagOpen:master
Choose a base branch
from
RubiaCx:develop
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
be80e23
[Operator] Add cov op
RubiaCx dfc6237
[Operator] Add cov op
RubiaCx 571308b
[Operator] Add covariance (cov) op & Remove unused code
RubiaCx 620d842
[Operator] Correct covariance (cov) op indexing issue
RubiaCx 5e581b6
Merge remote-tracking branch 'upstream/master' into develop
RubiaCx 8d36687
[Operator] Fix covariance (cov) op shape issues and handle edge cases
RubiaCx 69f1eda
[Operator] Reduce MAX_GRID_NUM and implement sub-block handling for l…
RubiaCx 831fec0
Merge to keep branches in sync
RubiaCx cb44c17
Merge master into COV op with major updates
RubiaCx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import logging | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
import triton.testing | ||
|
||
from ..utils import libentry | ||
|
||
@libentry() | ||
@triton.jit | ||
def mean_kernel( | ||
X, | ||
mean, | ||
M, | ||
N, | ||
weights, | ||
row_offset: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr | ||
): | ||
row = tl.program_id(0) + row_offset | ||
if row >= M: | ||
return | ||
|
||
acc = 0.0 | ||
for block_start in range(0, N, BLOCK_SIZE): | ||
cols = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
x = tl.load(X + row * N + cols, mask=mask, other=0.0) | ||
w = tl.load(weights + cols, mask=mask, other=0.0) | ||
acc += tl.sum(x * w, axis=0) | ||
tl.atomic_add(mean + row, acc) | ||
|
||
@libentry() | ||
@triton.jit | ||
def covariance_kernel( | ||
X, | ||
cov_matrix, | ||
mean, | ||
M, | ||
N, | ||
weights, | ||
row_offset: tl.constexpr, | ||
col_offset: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr | ||
): | ||
row = tl.program_id(0) + row_offset | ||
col = tl.program_id(1) + col_offset | ||
if row >= M or col >= M: | ||
return | ||
|
||
acc = 0.0 | ||
mean_row = tl.load(mean + row) | ||
mean_col = tl.load(mean + col) | ||
|
||
for block_start in range(0, N, BLOCK_SIZE): | ||
cols = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
x_row = tl.load(X + row * N + cols, mask=mask, other=0.0) | ||
x_col = tl.load(X + col * N + cols, mask=mask, other=0.0) | ||
w = tl.load(weights + cols, mask=mask, other=0.0) | ||
x_row_centered = x_row - mean_row | ||
x_col_centered = x_col - mean_col | ||
acc += tl.sum(w * x_row_centered * x_col_centered, axis=0) | ||
tl.atomic_add(cov_matrix + row * M + col, acc) | ||
|
||
def cov(X, correction=1, fweights=None, aweights=None): | ||
logging.debug("GEMS COV") | ||
M, N = X.shape | ||
MAX_GRID_NUM = 2048 | ||
BLOCK_SIZE = min(128, triton.next_power_of_2(N)) | ||
|
||
if fweights is None: | ||
fweights = torch.ones(N, device=X.device, dtype=X.dtype) | ||
else: | ||
fweights = fweights.to(device=X.device, dtype=X.dtype) | ||
if aweights is None: | ||
aweights = torch.ones(N, device=X.device, dtype=X.dtype) | ||
else: | ||
aweights = aweights.to(device=X.device, dtype=X.dtype) | ||
weights = fweights * aweights | ||
sum_weights = weights.sum() | ||
sum_wa = (weights * aweights).sum() | ||
|
||
adjustment = (sum_wa / sum_weights) * correction if correction != 0 else 0 | ||
denominator = torch.clamp(sum_weights - adjustment, min=0) | ||
if denominator <= 0: | ||
raise ValueError("Non-positive denominator in covariance calculation.") | ||
|
||
mean = torch.zeros(M, device=X.device, dtype=X.dtype) | ||
cov_matrix = torch.zeros((M, M), device=X.device, dtype=X.dtype) | ||
|
||
num_row_chunks = (M + MAX_GRID_NUM - 1) // MAX_GRID_NUM | ||
for i in range(num_row_chunks): | ||
row_offset = i * MAX_GRID_NUM | ||
current_M = min(MAX_GRID_NUM, M - row_offset) | ||
grid = (current_M,) | ||
mean_kernel[grid](X, mean, M, N, weights, row_offset=row_offset, BLOCK_SIZE=BLOCK_SIZE) | ||
mean = mean / sum_weights | ||
|
||
for i in range(num_row_chunks): | ||
row_offset = i * MAX_GRID_NUM | ||
current_rows = min(MAX_GRID_NUM, M - row_offset) | ||
for j in range(num_row_chunks): | ||
col_offset = j * MAX_GRID_NUM | ||
current_cols = min(MAX_GRID_NUM, M - col_offset) | ||
grid = (current_rows, current_cols) | ||
covariance_kernel[grid](X, cov_matrix, mean, M, N, weights, row_offset=row_offset, col_offset=col_offset, BLOCK_SIZE=BLOCK_SIZE) | ||
cov_matrix = cov_matrix / denominator | ||
return cov_matrix |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not gsl style kernel as I previously mentioned. Multiple kernel invocations should be avoided as much as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I found #91 and will refactor the
cov
function accordingly to adopt the GSL-style kernel as suggested. Thanks for pointing this out!