Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add cov op #276

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .cat import cat
from .clamp import clamp, clamp_tensor
from .cos import cos
from .cov import cov
from .cross_entropy_loss import cross_entropy_loss
from .cumsum import cumsum, normed_cumsum
from .diag import diag
Expand Down
109 changes: 109 additions & 0 deletions src/flag_gems/ops/cov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging
import torch
import triton
import triton.language as tl
import triton.testing

from ..utils import libentry

@libentry()
@triton.jit
def mean_kernel(
X,
mean,
M,
N,
weights,
row_offset: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
row = tl.program_id(0) + row_offset
if row >= M:
return

acc = 0.0
for block_start in range(0, N, BLOCK_SIZE):
cols = block_start + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X + row * N + cols, mask=mask, other=0.0)
w = tl.load(weights + cols, mask=mask, other=0.0)
acc += tl.sum(x * w, axis=0)
tl.atomic_add(mean + row, acc)

@libentry()
@triton.jit
def covariance_kernel(
X,
cov_matrix,
mean,
M,
N,
weights,
row_offset: tl.constexpr,
col_offset: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
row = tl.program_id(0) + row_offset
col = tl.program_id(1) + col_offset
if row >= M or col >= M:
return

acc = 0.0
mean_row = tl.load(mean + row)
mean_col = tl.load(mean + col)

for block_start in range(0, N, BLOCK_SIZE):
cols = block_start + tl.arange(0, BLOCK_SIZE)
mask = cols < N
x_row = tl.load(X + row * N + cols, mask=mask, other=0.0)
x_col = tl.load(X + col * N + cols, mask=mask, other=0.0)
w = tl.load(weights + cols, mask=mask, other=0.0)
x_row_centered = x_row - mean_row
x_col_centered = x_col - mean_col
acc += tl.sum(w * x_row_centered * x_col_centered, axis=0)
tl.atomic_add(cov_matrix + row * M + col, acc)

def cov(X, correction=1, fweights=None, aweights=None):
logging.debug("GEMS COV")
M, N = X.shape
MAX_GRID_NUM = 2048
BLOCK_SIZE = min(128, triton.next_power_of_2(N))

if fweights is None:
fweights = torch.ones(N, device=X.device, dtype=X.dtype)
else:
fweights = fweights.to(device=X.device, dtype=X.dtype)
if aweights is None:
aweights = torch.ones(N, device=X.device, dtype=X.dtype)
else:
aweights = aweights.to(device=X.device, dtype=X.dtype)
weights = fweights * aweights
sum_weights = weights.sum()
sum_wa = (weights * aweights).sum()

adjustment = (sum_wa / sum_weights) * correction if correction != 0 else 0
denominator = torch.clamp(sum_weights - adjustment, min=0)
if denominator <= 0:
raise ValueError("Non-positive denominator in covariance calculation.")

mean = torch.zeros(M, device=X.device, dtype=X.dtype)
cov_matrix = torch.zeros((M, M), device=X.device, dtype=X.dtype)

num_row_chunks = (M + MAX_GRID_NUM - 1) // MAX_GRID_NUM
for i in range(num_row_chunks):
row_offset = i * MAX_GRID_NUM
current_M = min(MAX_GRID_NUM, M - row_offset)
grid = (current_M,)
mean_kernel[grid](X, mean, M, N, weights, row_offset=row_offset, BLOCK_SIZE=BLOCK_SIZE)
mean = mean / sum_weights

for i in range(num_row_chunks):
row_offset = i * MAX_GRID_NUM
current_rows = min(MAX_GRID_NUM, M - row_offset)
for j in range(num_row_chunks):
col_offset = j * MAX_GRID_NUM
current_cols = min(MAX_GRID_NUM, M - col_offset)
grid = (current_rows, current_cols)
covariance_kernel[grid](X, cov_matrix, mean, M, N, weights, row_offset=row_offset, col_offset=col_offset, BLOCK_SIZE=BLOCK_SIZE)
Comment on lines +93 to +107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not gsl style kernel as I previously mentioned. Multiple kernel invocations should be avoided as much as possible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I found #91 and will refactor the cov function accordingly to adopt the GSL-style kernel as suggested. Thanks for pointing this out!

cov_matrix = cov_matrix / denominator
return cov_matrix
Loading