Skip to content

Commit

Permalink
FP8 QuantDot operation (#2506)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Dec 12, 2023
1 parent 9d2003a commit aac4e95
Show file tree
Hide file tree
Showing 22 changed files with 150 additions and 286 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#####################################################################################
google/[email protected] -DCMAKE_POSITION_INDEPENDENT_CODE=On -X subdir -Dprotobuf_BUILD_TESTS=Off
nlohmann/[email protected]
live-clones/[email protected] -X header -DHEADER_DIR=blaze -H sha256:d0ff011f47538285178908ea5f2cab46bb6a8f55b1edb6e03224a82dbc1a3212
ROCmSoftwarePlatform/[email protected]
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
Expand Down
7 changes: 4 additions & 3 deletions src/include/migraphx/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class T, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
Expand All @@ -52,7 +52,8 @@ void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
Expand Down
10 changes: 8 additions & 2 deletions src/include/migraphx/op/quant_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(t != shape::int8_type)
std::set<migraphx::shape::type_t> suppported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
{
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t and fp8e4m3fnuz_type");
}

if(not std::all_of(
Expand All @@ -73,6 +75,10 @@ struct quant_dot

auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
{
return {shape::float_type, out_lens};
} // else int8 gemm
return {shape::int32_type, out_lens};
}
};
Expand Down
5 changes: 5 additions & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ struct find_nested_convert
auto x = ins->inputs().front();
auto input = x->inputs().front();

while(input->name() == "convert")
{
input = input->inputs().front();
}

if(ins->get_shape() != input->get_shape())
return;

Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ struct gemm_impl
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;

arg_type = get_type(input_shapes[0].type());
output_type = arg_type;
output_type = get_type(input_shapes[2].type());
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/include/migraphx/gpu/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ struct rocblas_gemm
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{
if(this->name() == "gpu::gemm")
if(this->name() == "gpu::gemm" or output_shape.type() == migraphx::shape::float_type)
{
gemm_compute(ctx, output_shape, args, alpha, beta, compute_fp32, solution_idx);
}
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
unsupported_fp8_ops.insert("quant_dot");
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
Expand Down
5 changes: 0 additions & 5 deletions src/targets/ref/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,13 @@
add_library(migraphx_ref
target.cpp
lowering.cpp
gemm.cpp
)
set_target_properties(migraphx_ref PROPERTIES EXPORT_NAME ref)
rocm_set_soversion(migraphx_ref ${MIGRAPHX_SO_VERSION})

find_path(BLAZE_INCLUDE blaze/Blaze.h)

rocm_clang_tidy_check(migraphx_ref)
target_link_libraries(migraphx_ref PRIVATE Threads::Threads)
target_link_libraries(migraphx_ref PUBLIC migraphx)
target_include_directories(migraphx_ref SYSTEM PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)

migraphx_generate_export_header(migraphx_ref)

Expand Down
157 changes: 0 additions & 157 deletions src/targets/ref/gemm.cpp

This file was deleted.

46 changes: 0 additions & 46 deletions src/targets/ref/include/migraphx/ref/gemm.hpp

This file was deleted.

23 changes: 6 additions & 17 deletions src/targets/ref/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/ref/gemm.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
Expand Down Expand Up @@ -283,8 +282,8 @@ struct ref_gemm
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f);

visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
}
};
Expand All @@ -306,24 +305,14 @@ struct ref_quant_gemm
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// first, convert the args[0] and args[1] from int8_t to int32_t
argument arg_0{{shape::int32_type, {args.at(0).get_shape().lens()}}};
argument arg_1{{shape::int32_type, {args.at(1).get_shape().lens()}}};
arg_0.visit([&](auto output) {
args.at(0).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
});

arg_1.visit([&](auto output) {
args.at(1).visit(
[&](auto input) { std::copy(input.begin(), input.end(), output.begin()); });
result.visit([&](auto cmat) {
visit_all(args.at(0), args.at(1))(
[&](auto amat, auto bmat) { return gemm(cmat, amat, bmat, 1.0f, 0.0f); });
});

migemm(result, arg_0, arg_1, int32_t{1}, int32_t{0});

return result;
}
};

MIGRAPHX_REGISTER_OP(ref_gemm)

template <class Op>
Expand Down
Loading

0 comments on commit aac4e95

Please sign in to comment.