Skip to content

Commit

Permalink
[BLAS] Add new batch_gemm types (oneapi-src#466)
Browse files Browse the repository at this point in the history
Add support for more batch_gemm types to follow the specification.
Some combination using int8 are disabled on some backends due to precision issue.
  • Loading branch information
AidanBeltonS authored and normallytangent committed Aug 6, 2024
1 parent 7cdd083 commit b08b41f
Show file tree
Hide file tree
Showing 32 changed files with 2,679 additions and 392 deletions.
111 changes: 111 additions & 0 deletions include/oneapi/mkl/blas.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,39 @@ static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose tr
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
}

static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<sycl::half, 1> &b,
std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
}

static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
}

static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<std::int32_t, 1> &c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size) {
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
}

static inline void gemm_bias(sycl::queue &queue, transpose transa, transpose transb,
offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, sycl::buffer<int8_t, 1> &a, std::int64_t lda,
Expand Down Expand Up @@ -2246,6 +2279,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa,
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
std::int64_t *m, std::int64_t *n, std::int64_t *k,
float *alpha, const sycl::half **a, std::int64_t *lda,
const sycl::half **b, std::int64_t *ldb, float *beta,
float **c, std::int64_t *ldc, std::int64_t group_count,
std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {}) {
auto done =
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, group_count, group_size, dependencies);
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
std::int64_t *m, std::int64_t *n, std::int64_t *k,
float *alpha, const std::int8_t **a, std::int64_t *lda,
const std::int8_t **b, std::int64_t *ldb, float *beta,
float **c, std::int64_t *ldc, std::int64_t group_count,
std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {}) {
auto done =
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, group_count, group_size, dependencies);
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb,
std::int64_t *m, std::int64_t *n, std::int64_t *k,
float *alpha, const std::int8_t **a, std::int64_t *lda,
const std::int8_t **b, std::int64_t *ldb, float *beta,
std::int32_t **c, std::int64_t *ldc, std::int64_t group_count,
std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {}) {
auto done =
detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, group_count, group_size, dependencies);
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, const float *a, std::int64_t lda,
Expand Down Expand Up @@ -2312,6 +2384,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, trans
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
const sycl::half *a, std::int64_t lda, std::int64_t stride_a,
const sycl::half *b, std::int64_t ldb, std::int64_t stride_b,
float beta, float *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size,
const std::vector<sycl::event> &dependencies = {}) {
auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
batch_size, dependencies);
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b,
float beta, float *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size,
const std::vector<sycl::event> &dependencies = {}) {
auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
batch_size, dependencies);
return done;
}

static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb,
std::int64_t m, std::int64_t n, std::int64_t k, float alpha,
const std::int8_t *a, std::int64_t lda, std::int64_t stride_a,
const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b,
float beta, std::int32_t *c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event> &dependencies = {}) {
auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
batch_size, dependencies);
return done;
}

static inline sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa,
transpose transb, std::int64_t n, std::int64_t k, float alpha,
const float *a, std::int64_t lda, const float *b,
Expand Down
75 changes: 75 additions & 0 deletions include/oneapi/mkl/blas/detail/blas_ct_backends.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,30 @@ static inline void gemm_batch(backend_selector<backend::BACKEND> selector, trans
sycl::buffer<sycl::half, 1> &c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size);

static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, sycl::buffer<sycl::half, 1> &a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<sycl::half, 1> &b,
std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size);

static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1> &c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size);

static inline void gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n, std::int64_t k,
float alpha, sycl::buffer<std::int8_t, 1> &a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<std::int8_t, 1> &b,
std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<std::int32_t, 1> &c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size);

static inline void spmv(backend_selector<backend::BACKEND> selector, uplo upper_lower,
std::int64_t n, float alpha, sycl::buffer<float, 1> &a,
sycl::buffer<float, 1> &x, std::int64_t incx, float beta,
Expand Down Expand Up @@ -1870,6 +1894,30 @@ static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose *transa,
transpose *transb, std::int64_t *m, std::int64_t *n,
std::int64_t *k, float *alpha, const sycl::half **a,
std::int64_t *lda, const sycl::half **b, std::int64_t *ldb,
float *beta, float **c, std::int64_t *ldc,
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose *transa,
transpose *transb, std::int64_t *m, std::int64_t *n,
std::int64_t *k, float *alpha, const std::int8_t **a,
std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb,
float *beta, float **c, std::int64_t *ldc,
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose *transa,
transpose *transb, std::int64_t *m, std::int64_t *n,
std::int64_t *k, float *alpha, const std::int8_t **a,
std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb,
float *beta, std::int32_t **c, std::int64_t *ldc,
std::int64_t group_count, std::int64_t *group_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector,
transpose transa, transpose transb, std::int64_t m,
std::int64_t n, std::int64_t k, float alpha,
Expand Down Expand Up @@ -1911,6 +1959,33 @@ static inline sycl::event gemm_batch(
sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n,
std::int64_t k, float alpha, const sycl::half *a,
std::int64_t lda, std::int64_t stride_a, const sycl::half *b,
std::int64_t ldb, std::int64_t stride_b, float beta, float *c,
std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n,
std::int64_t k, float alpha, const std::int8_t *a,
std::int64_t lda, std::int64_t stride_a, const std::int8_t *b,
std::int64_t ldb, std::int64_t stride_b, float beta, float *c,
std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event gemm_batch(backend_selector<backend::BACKEND> selector, transpose transa,
transpose transb, std::int64_t m, std::int64_t n,
std::int64_t k, float alpha, const std::int8_t *a,
std::int64_t lda, std::int64_t stride_a, const std::int8_t *b,
std::int64_t ldb, std::int64_t stride_b, float beta,
std::int32_t *c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size,
const std::vector<sycl::event> &dependencies = {});

static inline sycl::event spmv(backend_selector<backend::BACKEND> selector, uplo upper_lower,
std::int64_t n, float alpha, const float *a, const float *x,
std::int64_t incx, float beta, float *y, std::int64_t incy,
Expand Down
Loading

0 comments on commit b08b41f

Please sign in to comment.