You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Our current format annotations do not efficiently handle branches in the axes dependency tree, for example, the [irregular batched-GEMM operator] has the following dependency trees:
""" B / | \ I J KX: (B, I, K)Y: (B, K, J)Z: (B, I, J)"""B=T.dense_fixed(batch_size, "int32")
I=T.dense_variable(B, (m, nnz_I), indptr_I, "int32")
J=T.dense_variable(B, (n, nnz_J), indptr_J, "int32")
K=T.dense_variable(B, (k, nnz_K), indptr_K, "int32")
X=T.match_sparse_buffer(x, (B, I, K), "float32")
Y=T.match_sparse_buffer(y, (B, K, J), "float32")
Z=T.match_sparse_buffer(z, (B, I, J), "float32")
withT.iter([B, I, J, K], "SSSR", "irregular-batched-gemm") as [b, i, j, k]:
withT.init():
Z[b, i, j] =T.float32(0)
Z[b, i, j] =X[b, i, k] *Y[b, k, j]
The efficient indexing of X/Y/Z requires auxiliary buffers such as indptr_IK, indptr_KJ and indptr_IJ. But currently, SparseTIR does not provide such an interface.
Proposals
Let take B: (B, I, K) as an example:
Alternative 1: Create a new axis IK that follows I to replace K
IK=T.dense_variable(I, ...)
# before loweringX[i, k]
# after lowering:x[indptr_ik[indptr_i[b] +i] +k]
Alternative 2: Insert a bridge axis IK that flattens I and K
Problem
Our current format annotations do not efficiently handle branches in the axes dependency tree, for example, the [irregular batched-GEMM operator] has the following dependency trees:
The efficient indexing of
X
/Y
/Z
requires auxiliary buffers such asindptr_IK
,indptr_KJ
andindptr_IJ
. But currently, SparseTIR does not provide such an interface.Proposals
Let take
B: (B, I, K)
as an example:Alternative 1: Create a new axis
IK
that followsI
to replaceK
Alternative 2: Insert a bridge axis
IK
that flattensI
andK
The text was updated successfully, but these errors were encountered: