Skip to content

Commit

Permalink
Add OCP FP8 formats (#3399)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 authored Sep 13, 2024
1 parent 8ff42e8 commit 9dcea5c
Show file tree
Hide file tree
Showing 156 changed files with 959 additions and 121 deletions.
4 changes: 3 additions & 1 deletion src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2)
// clang-format on

#ifdef __cplusplus
Expand Down
5 changes: 3 additions & 2 deletions src/autocast_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/fp8_types.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -37,7 +38,7 @@ void autocast_fp8_pass::apply(module& m) const
for(auto ins : iterator_for(m))
{
const auto& ins_name = ins->name();
if(ins_name == "@param" and contains(fp8_types, ins->get_shape().type()))
if(ins_name == "@param" and contains(fp8_types{}.get(), ins->get_shape().type()))
{
shape::type_t fp8_type = ins->get_shape().type();
migraphx::shape new_shape = ins->get_shape().with_type(target_type);
Expand All @@ -58,7 +59,7 @@ void autocast_fp8_pass::apply(module& m) const
std::vector<instruction_ref> new_inputs;
std::transform(
inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [&](auto i) {
if(contains(fp8_types, i->get_shape().type()))
if(contains(fp8_types{}.get(), i->get_shape().type()))
{
return m.insert_instruction(
ins,
Expand Down
2 changes: 1 addition & 1 deletion src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ struct compiler
ap.set_value(true));
ap(to_fp16, {"--fp16"}, ap.help("Quantize for fp16"), ap.set_value(true));
ap(to_int8, {"--int8"}, ap.help("Quantize for int8"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8e4m3fnuz type"), ap.set_value(true));
ap(to_fp8, {"--fp8"}, ap.help("Quantize for fp8"), ap.set_value(true));
}

auto params(const program& p)
Expand Down
1 change: 0 additions & 1 deletion src/include/migraphx/autocast_fp8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ This pass will convert model with fp8 input parameter to model with fp32
input parameter and internally add casts to fp8 for those converted params.*/
struct MIGRAPHX_EXPORT autocast_fp8_pass
{
std::set<shape::type_t> fp8_types = {migraphx::shape::fp8e4m3fnuz_type};
shape::type_t target_type = migraphx::shape::float_type;
std::string name() const { return "autocast_fp8_pass"; }
void apply(module& m) const;
Expand Down
21 changes: 20 additions & 1 deletion src/include/migraphx/float8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include <string>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/float8_impl.hpp>

namespace migraphx {
Expand Down Expand Up @@ -392,7 +393,7 @@ namespace std {
{ \
}; \
template <class U> \
struct common_type<U, T> : std::common_type<float, U> \
struct common_type<U, T> : std::common_type<U, float> \
{ \
}; \
template <> \
Expand All @@ -405,6 +406,24 @@ MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz)

// needed to resolve between multiple ambiguous definition from previous templates
#define MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(T, U) \
template <> \
struct common_type<T, U> : std::common_type<float, float> \
{ \
}; \
template <> \
struct common_type<U, T> : std::common_type<float, float> \
{ \
};

MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fn, migraphx::fp8::fp8e5m2fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e5m2, migraphx::fp8::fp8e5m2fnuz)
MIGRAPHX_FP8_COMMON_TYPE_OVERLOAD_RESOLUTION(migraphx::fp8::fp8e4m3fnuz, migraphx::fp8::fp8e5m2fnuz)
} // namespace std
// NOLINTEND
// =================================================================================================
Expand Down
38 changes: 38 additions & 0 deletions src/include/migraphx/fp8_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FP8_TYPES_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FP8_TYPES_HPP
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct fp8_types
{
const std::set<shape::type_t> types = {
shape::fp8e4m3fnuz_type, shape::fp8e4m3fn_type, shape::fp8e5m2_type};

std::set<shape::type_t> get() const { return types; }
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FP8_TYPES_HPP
24 changes: 24 additions & 0 deletions src/include/migraphx/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,30 @@ struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fnuz>
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e4m3fn, migraphx::half>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e4m3fn>
{
using type = float;
};

template <>
struct common_type<migraphx::fp8::fp8e5m2, migraphx::half>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::fp8::fp8e5m2>
{
using type = float;
};

template <>
struct common_type<migraphx::half, migraphx::half>
{
Expand Down
10 changes: 5 additions & 5 deletions src/include/migraphx/op/quant_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ struct quant_convolution
MIGRAPHX_THROW("quant_convolution: input k-dims does not match attribute size");
}

// all input type must be int8_type and output is float_type
std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
// all input type must be int8_type or fp8 types
// output should be float_type
std::set<migraphx::shape::type_t> supported_types = {
shape::int8_type, shape::fp8e4m3fnuz_type, shape::fp8e4m3fn_type, shape::fp8e5m2_type};
if(not contains(supported_types, t))
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8 or fp8");
}

std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
Expand Down
15 changes: 9 additions & 6 deletions src/include/migraphx/op/quant_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/value.hpp>
#include <migraphx/fp8_types.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -44,12 +45,14 @@ struct quant_dot
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
std::set<migraphx::shape::type_t> suppported_types = {
shape::int8_type, shape::uint8_type, shape::fp8e4m3fnuz_type};
if(not contains(suppported_types, t))
std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::uint8_type,
shape::fp8e4m3fnuz_type,
shape::fp8e4m3fn_type,
shape::fp8e5m2_type};
if(not contains(supported_types, t))
{
MIGRAPHX_THROW(
"QUANT_DOT: only support data type int8_t, uint8_t and fp8e4m3fnuz_type");
MIGRAPHX_THROW("QUANT_DOT: only support data type int8_t, uint8_t and fp8 types");
}

if(not std::all_of(
Expand All @@ -76,7 +79,7 @@ struct quant_dot

auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
if(t == shape::fp8e4m3fnuz_type)
if(contains(fp8_types{}.get(), t))
{
return {shape::float_type, out_lens};
} // else int8 gemm
Expand Down
6 changes: 4 additions & 2 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ struct MIGRAPHX_EXPORT shape
m(int64_type, int64_t) \
m(uint32_type, uint32_t) \
m(uint64_type, uint64_t) \
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz)
// clang-format on
m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) \
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2)
// clang-format on

#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
Expand Down
8 changes: 8 additions & 0 deletions src/include/migraphx/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fn)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fn)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fn)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e5m2)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e5m2)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e5m2)

template <class T>
using accumulator_type =
std::conditional_t<is_floating_point<T>{},
Expand Down
5 changes: 3 additions & 2 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <migraphx/param_utils.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/fp8_types.hpp>
#include <iostream>
#include <sstream>
#include <algorithm>
Expand Down Expand Up @@ -814,8 +815,8 @@ void module::finalize(std::vector<context>& contexts)
}
}
#ifndef BUILD_DEV
if(std::any_of(this->begin(), this->end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
if(std::any_of(this->begin(), this->end(), [&](const auto i) {
return contains(fp8_types{}.get(), i.get_shape().type());
}))
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
Expand Down
25 changes: 23 additions & 2 deletions src/py/migraphx_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,34 @@ struct npy_format_descriptor<migraphx::fp8::fp8e4m3fnuz>
{
static std::string format()
{
// following: https://docs.python.org/3/library/struct.html#format-characters
// TODO: need to figure out correct encoding
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e4m3fnuz"); }
};

template <>
struct npy_format_descriptor<migraphx::fp8::fp8e4m3fn>
{
static std::string format()
{
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e4m3fn"); }
};

template <>
struct npy_format_descriptor<migraphx::fp8::fp8e5m2>
{
static std::string format()
{
// TODO: no standard format in numpy for fp8
return "z";
}
static constexpr auto name() { return _("fp8e5m2"); }
};

} // namespace detail
} // namespace pybind11

Expand Down
25 changes: 22 additions & 3 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/normalize_ops.hpp>
#include <set>
#include <map>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -75,8 +76,10 @@ void quantize_8bits(program& prog,
std::shared_ptr<std::vector<std::pair<float, float>>> quant_8bit_params =
std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();

float quantized_range = (precision == shape::type_t::int8_type) ? 127.0 : 240.0;
std::map<shape::type_t, float> type_ranges = {{shape::type_t::int8_type, 127.0},
{shape::type_t::fp8e4m3fnuz_type, 240.0},
{shape::type_t::fp8e4m3fn_type, 448.0}};
float quantized_range = type_ranges.at(precision);
auto calc_quant_params = [&](std::size_t ins_index, std::vector<argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
Expand Down Expand Up @@ -180,7 +183,23 @@ void quantize_fp8(program& prog, const target& t, const std::vector<parameter_ma
supported_ins_names.insert(ins->name());
}
}
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
auto gfx_has_fp8fnuz = [&]() {
if(t.name() == "gpu")
{
auto context_value = t.get_context().to_value();
auto device_name = context_value["gfx_name"].to<std::string>();
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
return false;
};
if(gfx_has_fp8fnuz())
{
quantize_8bits(prog, t, shape::fp8e4m3fnuz_type, calibration, supported_ins_names);
}
else
{
quantize_8bits(prog, t, shape::fp8e4m3fn_type, calibration, supported_ins_names);
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
3 changes: 2 additions & 1 deletion src/rewrite_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ struct find_reduce_mean

auto n = input->get_shape().elements() / ins->get_shape().elements();

if(n >= max_n / 4 and size < 3)
// Convert accumulator to float if <= 8bit type or if < 3 bytes and n >= max_n /4
if(size == 1 or (n >= max_n / 4 and size < 3))
{
shape::type_t t = is_integral ? shape::int32_type : shape::float_type;
input = m.insert_instruction(ins, make_op("convert", {{"target_type", t}}), input);
Expand Down
2 changes: 2 additions & 0 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ struct match_find_quantizable_ops
auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
migraphx::shape::fp8e4m3fn_type,
migraphx::shape::fp8e5m2_type,
migraphx::shape::int8_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
Expand Down
4 changes: 2 additions & 2 deletions src/targets/cpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ struct cpu_apply
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
return contains(fp8_types{}.get(), i->get_shape().type());
}))
continue;
if(it->name() == "pow")
Expand All @@ -390,7 +390,7 @@ struct cpu_apply
// skip lowering if input has fp8 as one of the inputs since oneDNN doesn't have fp8
// supported yet.
if(std::any_of(it->inputs().begin(), it->inputs().end(), [](const auto& i) {
return i->get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
return contains(fp8_types{}.get(), i->get_shape().type());
}))
continue;
if(it->name() == "pooling")
Expand Down
Loading

0 comments on commit 9dcea5c

Please sign in to comment.