Skip to content

Commit

Permalink
FP8 OCP to FP8 FNUZ on hardware with only FP8 FNUZ support (#3684)
Browse files Browse the repository at this point in the history
NANOO is short for NAN On Overflow, the data type comes from this paper: https://arxiv.org/pdf/2206.02915
Implements the method written about in Convert OCP FP8 model to FNUZ model inside MIGraphX #2717
This pass must run before simplify_qdq so that the adjusted scales and zero points are propagated to after the quantized operator.
The test in test/fp8_ocp_to_nanoo.cpp checks the pass works with simplify_qdq and does the expected operations
The test in test/ref/fp8_ocp_to_nanoo.cpp checks the pass produces the same result before and after
I will make a separate PR that removes the gpu context changes to get the gfx number
Fixed the cpp_generator that was using __builtin_nan incorrectly
  • Loading branch information
CharlieL7 authored Jan 13, 2025
1 parent 6d02806 commit eb1717d
Show file tree
Hide file tree
Showing 13 changed files with 887 additions and 84 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ add_library(migraphx
file_buffer.cpp
fileutils.cpp
fp_to_double.cpp
fp8_ocp_to_fnuz.cpp
fuse_concat.cpp
fuse_pointwise.cpp
fuse_pointwise_reduce.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/cpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(static_cast<double>(x)))
string_literal = "__builtin_nan()";
string_literal = "__builtin_nan(\"0\")";
else
string_literal = ins->get_literal().to_string();
});
Expand Down
178 changes: 178 additions & 0 deletions src/fp8_ocp_to_fnuz.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fp8_ocp_to_fnuz.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/match/dq_helpers.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {

using fp8::fp8e4m3fnuz;

std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}

struct match_fp8ocp_convert_to_fp8fnuz
{
auto matcher() const
{
auto dq1 = match::arg(0)(
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
auto dq2 = match::arg(1)(
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
return match::name(get_quantizable_op_names())(dq1, dq2);
}

static auto bit_cast_and_handle_specials(module& m,
const instruction_ref dq,
const instruction_ref x,
const instruction_ref bits_0x80_lit,
const instruction_ref bits_0x7f_lit,
const instruction_ref bits_0xff_lit,
const instruction_ref bits_0x00_lit)
{
auto x_lens = x->get_shape().lens();
auto cast_input = m.insert_instruction(
dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x);
auto mb_bits_0x80_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit);
auto mb_bits_0x7f_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit);
auto mb_bits_0xff_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit);
auto mb_zero_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit);
// negative zero in fp8e4m3fn to zero in fp8e4m3fnuz
// a == 0x80 ? 0x0 : a
auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit);
auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input);

// positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz
// (a == 0x7f or a == 0xff) ? 0x80 : a
auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit);

auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit);

auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff);
ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret);
return ret;
}

// Add the same broadcast instructions after adjusted scales or
// adjusted zero points from after the originals. Similar to
// propagate_quantized_ins in simplify_qdq.
static auto propagate_broadcasts(module& m,
const instruction_ref adj,
const instruction_ref ori,
const instruction_ref start,
const instruction_ref insert_pt)
{
auto prev_ins = start;
std::vector<instruction_ref> ins_between;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != ori)
{
ins_between.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto ret = adj;
for(auto ins : reverse_iterator_for(ins_between))
{
ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret});
}
return ret;
}

static auto cast_to_fnuz(module& m,
const instruction_ref dq,
const instruction_ref input,
const instruction_ref dq_scale,
const instruction_ref dq_zp)
{
auto x = input;
std::vector<fp8e4m3fnuz> bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())};
auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80);

std::vector<fp8e4m3fnuz> bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())};
auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f);

std::vector<fp8e4m3fnuz> bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())};
auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff);

std::vector<fp8e4m3fnuz> bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())};
auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00);

x = bit_cast_and_handle_specials(
m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);
auto adj_dq_zp = bit_cast_and_handle_specials(
m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);

// adj_scale = 2 * scale
auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}});
two_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit);
auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit);

adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq);
adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq);
m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp);
}

auto apply(module& m, const match::matcher_result& r) const
{
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];

std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fn_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;

cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1);
cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2);
}
};

} // namespace

void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
49 changes: 49 additions & 0 deletions src/include/migraphx/fp8_ocp_to_fnuz.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP
#define MIGRAPHX_GUARD_RTGLIB_FP8_OCP_TO_FNUZ_HPP

#include <migraphx/config.hpp>
#include <migraphx/pass_manager.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

/**
* Convert fp8e4m3fn to fp8e4m3fnuz for hardware that only supports fp8e4m3fnuz data types
* intrinsically. Conversion uses the same bit representation and adjusts scaling factors at the
* dequantization. Using the same bit representation from fp8e4m3fn to fp8e4m3fnuz halves the
* floating point representation. This pass should run before simplify_qdq so that the scales and
* zero points calculated by simplify_qdq have the correct adjusted scaling factors
*/
struct MIGRAPHX_EXPORT fp8_ocp_to_fnuz
{
std::string name() const { return "fp8_ocp_to_fnuz"; }
void apply(module_pass_manager& mpm) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
62 changes: 62 additions & 0 deletions src/include/migraphx/match/dq_helpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MATCH_DQ_HELPERS_HPP
#define MIGRAPHX_GUARD_MATCH_DQ_HELPERS_HPP

#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {

/**
* Find dequantizelinear (DQ) instruction with constant scale and zero point input
* while skipping broadcast instructions between DQ and scale/zero point. Used
* in simplify_qdq and fp8_ocp_to_fnuz.
*/
inline auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}

/**
* Skip certain operators after DQ instruction.
* Used in simplify_qdq and fp8_ocp_to_fnuz.
*/
template <class... Ms>
auto skip_post_dq_ops(Ms... ms)
{
return match::skip(match::name(
"broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...);
}

} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
1 change: 1 addition & 0 deletions src/include/migraphx/op/bit_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ struct bit_cast : unary<bit_cast>
args[0].visit([&](auto input) {
using itype = typename decltype(input)::value_type;
if constexpr(sizeof(otype) == sizeof(itype))

{
par_transform(input.begin(), input.end(), output.begin(), [&](auto x) {
return migraphx::bit_cast<otype>(x);
Expand Down
26 changes: 7 additions & 19 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,12 @@
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/fp8_types.hpp>
#include <migraphx/match/dq_helpers.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {

template <class... Ms>
auto skip_post_dq_ops(Ms... ms)
{
return match::skip(match::name(
"broadcast", "multibroadcast", "contiguous", "transpose", "reshape", "convert"))(ms...);
}

std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
Expand Down Expand Up @@ -117,20 +111,12 @@ struct match_find_quantizable_ops
return qinp;
}

static auto dequantizelinear_op(const std::string& scale, const std::string& zp)
{
return match::name("dequantizelinear")(
match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())),
match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))),
match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp))));
}

auto matcher() const
{
auto dq1 =
match::arg(0)(skip_post_dq_ops(dequantizelinear_op("scale1", "zp1").bind("dq1")));
auto dq2 =
match::arg(1)(skip_post_dq_ops(dequantizelinear_op("scale2", "zp2").bind("dq2")));
auto dq1 = match::arg(0)(
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
auto dq2 = match::arg(1)(
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
return match::name(get_quantizable_op_names())(dq1, dq2);
}

Expand Down Expand Up @@ -231,7 +217,9 @@ struct match_find_quantizable_ops
is_valid_qparam(zp1, out_lens, out_lens.size() - 2) and
is_valid_qparam(scale2, out_lens, out_lens.size() - 1) and
is_valid_qparam(zp2, out_lens, out_lens.size() - 1)))
{
return;
}

// This implementation supports both arguments being per-axis affine quantized
// In practice, inputs are per-tensor affine and weights are per-axis symmetric
Expand Down
3 changes: 3 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fp8_ocp_to_fnuz.hpp>
#include <migraphx/fuse_concat.hpp>
#include <migraphx/fuse_pointwise_reduce.hpp>
#include <migraphx/inline_module.hpp>
Expand Down Expand Up @@ -179,6 +180,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
eliminate_identity{},
dead_code_elimination{},
enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}),
enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}),
simplify_qdq{},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
Expand Down
Loading

0 comments on commit eb1717d

Please sign in to comment.