Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to bind fp8 gemm to python? #82

Open
RuiWang1998 opened this issue Jan 15, 2025 · 0 comments
Open

How to bind fp8 gemm to python? #82

RuiWang1998 opened this issue Jan 15, 2025 · 0 comments

Comments

@RuiWang1998
Copy link

RuiWang1998 commented Jan 15, 2025

I tried with the following code

#include "pyutils/torch_helpers.cuh"
#include <ATen/cuda/CUDAContext.h>
#include <iostream>


template<typename mmt>
void dispatch_fp8_gemm(fp8e4m3 *d_A, fp8e4m3 *d_B, fp8e4m3 *d_C, size_t M, size_t N, size_t K) 
{
    using a_layout = typename mmt::layout::a_layout;
    using b_layout = typename mmt::layout::b_layout;
    using c_layout = typename mmt::layout::c_layout;
    using globals  = typename mmt::layout::globals;
    a_layout Ag{d_A, nullptr, nullptr, M, K};
    b_layout Bg{d_B, nullptr, nullptr, N, K};
    c_layout Cg{d_C, nullptr, nullptr, M, N};
    globals G{Ag, Bg, Cg};
    dim3 grid = mmt::grid(M, N, K);
    dim3 block = kittens::prototype::detail::NUM_THREADS_v<mmt>;
    prototype::lcf::kernel<mmt><<<grid, block, MAX_SHARED_MEMORY-1024>>>(G);
}

/*
A: MxK
B: NxK
C: MxN
*/
std::vector<torch::Tensor> 
fp8_gemm(torch::Tensor A, torch::Tensor B, torch::Tensor C)
{
    CHECK_INPUT(A);
    CHECK_INPUT(B);
    CHECK_INPUT(C);

    auto M = A.size(0);
    auto N = B.size(0);
    auto K = A.size(1);

    TORCH_CHECK(B.size(1) == K);
    TORCH_CHECK(C.size(0) == M);
    TORCH_CHECK(C.size(1) == N);

    TORCH_CHECK(M % 16 == 0, "Invalid number of elements, must be a multiple of 16.");
    TORCH_CHECK(N % 16 == 0, "Invalid number of elements, must be a multiple of 16.");
    TORCH_CHECK(K % 16 == 0, "Invalid number of elements, must be a multiple of 16.");

    TORCH_CHECK(A.dtype() == at::kFloat8_e4m3fn);
    TORCH_CHECK(B.dtype() == at::kFloat8_e4m3fn);
    TORCH_CHECK(C.dtype() == at::kFloat8_e4m3fn);

    c10::Float8_e4m3fn *A_ptr = A.data_ptr<c10::Float8_e4m3fn>();
    c10::Float8_e4m3fn *B_ptr = B.data_ptr<c10::Float8_e4m3fn>();
    c10::Float8_e4m3fn *C_ptr = C.data_ptr<c10::Float8_e4m3fn>();

    fp8e4m3 *d_A = reinterpret_cast<fp8e4m3*>(A_ptr);
    fp8e4m3 *d_B = reinterpret_cast<fp8e4m3*>(B_ptr);
    fp8e4m3 *d_C = reinterpret_cast<fp8e4m3*>(C_ptr);

    dispatch_fp8_gemm<matmul_template<8>>(d_A, d_B, d_C, M, N, K);

    return {C};
}

and testing with

import torch
import _C

hidden_states = torch.randn(4096, 4096, device="cuda").to(torch.float8_e4m3fn)
weights = torch.randn(4096, 4096, device="cuda").to(torch.float8_e4m3fn)
output = torch.empty(4096, 4096, device="cuda", dtype=torch.float8_e4m3fn)
_C.fp8_gemm(hidden_states, weights, output)

But it's giving me segmentation fault (core dumped). I believe this is happening on the host code but yet to figure out why, could anyone help me with this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant