Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing gemm_batch data types #446

Open
AidanBeltonS opened this issue Feb 15, 2024 · 5 comments
Open

Missing gemm_batch data types #446

AidanBeltonS opened this issue Feb 15, 2024 · 5 comments

Comments

@AidanBeltonS
Copy link
Contributor

Summary

I believe there are some missing gemm_batch implementations, looking at the oneMKL docs it seems this should support. A gemm_batch with, two half matrices as input, a float matrix out, and float scaling. My reference: https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-dpcpp/2023-0/gemm-batch.html
I run into issues of this overload not being found. Is my documentation correct, or have I misunderstood something?

Version

oneMKL hash: 7d2044e

Environment

oneMKL works with multiple HW and backend libraries and also depends on the
compiler and build environment. Include
the following information to help reproduce the issue:

  • HW: A100 GPU
  • Backend: cuBlas
  • OS: Ubuntu 20.04
  • Compiler version: DPC++ 2024.0.2

Steps to reproduce

Compile with for NVidia GPUs: icpx -fsycl -fsycl-targets=nvptx64-nvidia-cuda reproducer_onemkl_batch.cpp -lonemkl
or for Intel GPUs: icpx -fsycl reproducer_onemkl_batch.cpp -lonemkl

#include <sycl/sycl.hpp>
#include <oneapi/mkl.hpp>

template <class Ta, class Tb, class Tc, class Ts>
void run_gemm(sycl::queue q) {
    // Construct some arbitrary data, error is in compilation, so it does not have to be correct.
    const Ta *a[4] = {nullptr};
    const Tb *b[4] = {nullptr};
    Tc *c[4] = {nullptr};

    int64_t batch_size = 4;

    oneapi::mkl::transpose a_trans = oneapi::mkl::transpose::trans;
    oneapi::mkl::transpose b_trans = oneapi::mkl::transpose::nontrans;

    int64_t m = 10;
    int64_t n = 10;
    int64_t k = 10;

    int64_t lda = 10;
    int64_t ldb = 10;
    int64_t ldc = 10;

    int64_t group_size = 1;

    Ts alpha = 1;
    Ts beta = 0;
    oneapi::mkl::transpose *trans =
        reinterpret_cast<oneapi::mkl::transpose *>( 
            std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      trans[batch + batch_size * 0] = a_trans;
      trans[batch + batch_size * 1] = b_trans;
    }   

    // structured m, n, k, lda, ldb, ldc, group_size
    int64_t *dims = reinterpret_cast<int64_t *>( 
        std::malloc(sizeof(int64_t) * 7 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      dims[batch + batch_size * 0] = m;
      dims[batch + batch_size * 1] = n;
      dims[batch + batch_size * 2] = k;

      dims[batch + batch_size * 3] = lda;
      dims[batch + batch_size * 4] = ldb;
      dims[batch + batch_size * 5] = ldc;

      dims[batch + batch_size * 6] = group_size;
    }   

    // structured alpha, beta
    Ts *coeff =
        reinterpret_cast<Ts *>(std::malloc(sizeof(Ts) * 2 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      coeff[batch + batch_size * 0] = 1;
      coeff[batch + batch_size * 1] = 0;
    }


    oneapi::mkl::blas::column_major::gemm_batch(
        q, trans + batch_size * 0 /*a_trans*/,
        trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/,
        dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/,
        coeff + batch_size * 0 /*alpha*/,
        reinterpret_cast<const Ta **>(a), dims + batch_size * 3 /*lda*/,
        reinterpret_cast<const Tb **>(b), dims + batch_size * 4 /*ldb*/,
        coeff + batch_size * 1 /*beta*/, reinterpret_cast<Tc **>(c),
        dims + batch_size * 5 /*ldc*/, batch_size,
        dims + batch_size * 6 /*group_size*/);
}

int main() {
    sycl::queue q;
    //run_gemm<float, float, float, float>(q); // Compiles
    run_gemm<sycl::half, sycl::half, float, float>(q); // Fails to compile
}

Error:

reproducer_onemkl_batch.cpp:60:5: error: no matching function for call to 'gemm_batch'
   60 |     oneapi::mkl::blas::column_major::gemm_batch(
      |     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
reproducer_onemkl_batch.cpp:75:5: note: in instantiation of function template specialization 'run_gemm<sycl::detail::half_impl::half, sycl::detail::half_impl::half, float, float>' requested here
   75 |     run_gemm<sycl::half, sycl::half, float, float>(q);

Given the documentation I linked to above, I would expect this to compile. As the docs express that this combination of data types are supported.

@mmeterel
Copy link
Contributor

@AidanBeltonS Thanks for reporting this. At this point, this gap is known and expected. The documentation you linked points to oneMKL Product implementation (not oneMKL open source interfaces). Typically, new APIs/features are implemented in oneMKL Product first and then they are ported to oneMKL open source interfaces.

If this use case is critical for your application, please let us know. We also encourage everyone to contribute :)

@AidanBeltonS
Copy link
Contributor Author

Thanks for the response @mmeterel thank you for clarifying the documentation. Yes this would be something that is critical for our application.
It relates to the SYCLomatic translation of llama.cpp and using gemm_batch. I would be happy to help get this working, especially for the CUDA and AMD implementations.

See for our use case: ggerganov/llama.cpp#5591

@mmeterel
Copy link
Contributor

@AidanBeltonS Thanks for your contributions!

@pen-and-papers
Copy link

There seem to be more missing cases from the cublas backend that weren't covered by what llama required. I'm in the process of finding and implementing them.

@Rbiessy
Copy link
Contributor

Rbiessy commented Aug 29, 2024

From what I remember Aidan had issues with the quantized types in particular. He was seeing incorrect results which could be due to the reference that needs to be adjusted or some other issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants