Skip to content

Commit

Permalink
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
Browse files Browse the repository at this point in the history
…e-horizontal-fusion-opportunities
  • Loading branch information
aarushjain29 authored Jan 15, 2025
2 parents 8bd694e + 22e323c commit c8290eb
Show file tree
Hide file tree
Showing 102 changed files with 3,830 additions and 167 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,14 @@ jobs:
docker-images: true

- uses: actions/[email protected]
with:
fetch-depth: 0 # Fetch the entire repository history and all branches
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: run License Check
run: python3 tools/check_stamped.py ${GITHUB_HEAD_REF:-${GITHUB_REF#refs/heads/}}
- name: Run License Check
run: python3 tools/check_stamped.py origin/${{ github.event_name == 'pull_request' && github.base_ref || 'develop' }}

linux:

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/config.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#=====ROCM INFO=====
ROCM_VERSION : '6.0.2'
ROCM_VERSION : '6.3.1'
#default ROCm version to be used
ROCM_BASE_IMAGE : 'rocm/dev-ubuntu-20.04'
ROCM_BASE_IMAGE : 'rocm/dev-ubuntu-22.04'
#base image from dockerhub to be used
ROCM_BUILT_IMAGE : 'rocm-migraphx'
#name of the docker image built upon ROCm base
Expand All @@ -26,4 +26,4 @@ PERFORMANCE_TEST_TIMEOUT : '30m'

#===== W A R N I N G =====
#VARIABLE NAMES NOT TO BE CHANGED, VALUES ONLY!
#VALUES MUST BE ENGLOSED IN SINGLE QUOTES!
#VALUES MUST BE ENGLOSED IN SINGLE QUOTES!
4 changes: 2 additions & 2 deletions .github/workflows/performance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
rocm_release:
description: ROCm Version
required: true
default: '6.0.2'
default: '6.3.1'
performance_reports_repo:
description: Repository where performance reports are stored
required: true
Expand Down Expand Up @@ -96,4 +96,4 @@ jobs:
secrets:
gh_token: ${{ secrets.MIGRAPHX_BOT_TOKEN }}
mail_user: ${{ secrets.MAIL_USERNAME }}
mail_pass: ${{ secrets.MAIL_PASSWORD }}
mail_pass: ${{ secrets.MAIL_PASSWORD }}
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ rocm_enable_clang_tidy(
-clang-diagnostic-disabled-macro-expansion
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-avoid-capture-default-when-capturing-this
-cppcoreguidelines-avoid-const-or-ref-data-members
-cppcoreguidelines-avoid-do-while
-cppcoreguidelines-explicit-virtual-functions
Expand All @@ -222,7 +221,6 @@ rocm_enable_clang_tidy(
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-rvalue-reference-param-not-moved
-cppcoreguidelines-special-member-functions
-cppcoreguidelines-use-default-member-init
-cppcoreguidelines-virtual-class-destructor
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
The MIT License (MIT)

Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 4 additions & 0 deletions docs/driver/read.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ Print out the program as cpp program.

Print out program as json.

.. option:: --netron

Print out program as a Netron viewable json file.

.. option:: --text

Print out program in text format.
Expand Down
2 changes: 2 additions & 0 deletions docs/migraphx-driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ To learn which options can be used with which commands, see the :ref:`MIGraphX d
- Prints the program in .txt format
* - --binary
- Prints the program in binary format
* - --netron
- Prints the program in Netron viewable JSON format
* - --output | -o
- Writes output in a file
* - --fill0
Expand Down
22 changes: 11 additions & 11 deletions docs/reference/driver-options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ read

Loads and prints input graph.

.. include:: ./driver/read.rst
.. include:: ../driver/read.rst

compile
-------
Expand All @@ -25,8 +25,8 @@ compile

Compiles and prints input graph.

.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst
.. include:: ../driver/read.rst
.. include:: ../driver/compile.rst

run
---
Expand All @@ -35,8 +35,8 @@ run

Loads and prints input graph.

.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst
.. include:: ../driver/read.rst
.. include:: ../driver/compile.rst

perf
----
Expand All @@ -45,8 +45,8 @@ perf

Compiles and runs input graph then prints performance report.

.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst
.. include:: ../driver/read.rst
.. include:: ../driver/compile.rst

.. option:: --iterations, -n [unsigned int]

Expand All @@ -59,8 +59,8 @@ verify

Runs reference and CPU or GPU implementations and checks outputs for consistency.

.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst
.. include:: ../driver/read.rst
.. include:: ../driver/compile.rst

.. option:: --rms-tol [double]

Expand Down Expand Up @@ -104,5 +104,5 @@ Here is how you can use ``roctx`` combined with :doc:`rocprof <rocprofiler:rocpr
Running :doc:`rocprof <rocprofiler:rocprofv1>` generates trace information for HIP, HCC and ROCTX in separate ``.txt`` files.
To understand the interactions between API calls, utilize the :ref:`roctx.py <tools>` helper script.

.. include:: ./driver/read.rst
.. include:: ./driver/compile.rst
.. include:: ../driver/read.rst
.. include:: ../driver/compile.rst
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ add_library(migraphx
file_buffer.cpp
fileutils.cpp
fp_to_double.cpp
fp8_ocp_to_fnuz.cpp
fuse_concat.cpp
fuse_pointwise.cpp
fuse_pointwise_reduce.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/cpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m,
string_literal = "-__builtin_huge_val()";
}
else if(std::isnan(static_cast<double>(x)))
string_literal = "__builtin_nan()";
string_literal = "__builtin_nan(\"0\")";
else
string_literal = ins->get_literal().to_string();
});
Expand Down
178 changes: 178 additions & 0 deletions src/fp8_ocp_to_fnuz.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fp8_ocp_to_fnuz.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/match/dq_helpers.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {

using fp8::fp8e4m3fnuz;

std::unordered_set<std::string> get_quantizable_op_names()
{
static std::unordered_set<std::string> s = {"convolution", "dot"};
return s;
}

struct match_fp8ocp_convert_to_fp8fnuz
{
auto matcher() const
{
auto dq1 = match::arg(0)(
skip_post_dq_ops(match::dequantizelinear_op("scale1", "zp1").bind("dq1")));
auto dq2 = match::arg(1)(
skip_post_dq_ops(match::dequantizelinear_op("scale2", "zp2").bind("dq2")));
return match::name(get_quantizable_op_names())(dq1, dq2);
}

static auto bit_cast_and_handle_specials(module& m,
const instruction_ref dq,
const instruction_ref x,
const instruction_ref bits_0x80_lit,
const instruction_ref bits_0x7f_lit,
const instruction_ref bits_0xff_lit,
const instruction_ref bits_0x00_lit)
{
auto x_lens = x->get_shape().lens();
auto cast_input = m.insert_instruction(
dq, make_op("bit_cast", {{"target_type", shape::fp8e4m3fnuz_type}}), x);
auto mb_bits_0x80_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x80_lit);
auto mb_bits_0x7f_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x7f_lit);
auto mb_bits_0xff_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0xff_lit);
auto mb_zero_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", x_lens}}), bits_0x00_lit);
// negative zero in fp8e4m3fn to zero in fp8e4m3fnuz
// a == 0x80 ? 0x0 : a
auto is_neg_zero = m.insert_instruction(dq, make_op("equal"), cast_input, mb_bits_0x80_lit);
auto ret = m.insert_instruction(dq, make_op("where"), is_neg_zero, mb_zero_lit, cast_input);

// positive and negative NaN in fp8e4m3fn to NaN in fp8e4m3fnuz
// (a == 0x7f or a == 0xff) ? 0x80 : a
auto eq_0x7f = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0x7f_lit);

auto eq_0xff = m.insert_instruction(dq, make_op("equal"), ret, mb_bits_0xff_lit);

auto cond = m.insert_instruction(dq, make_op("logical_or"), eq_0x7f, eq_0xff);
ret = m.insert_instruction(dq, make_op("where"), cond, mb_bits_0x80_lit, ret);
return ret;
}

// Add the same broadcast instructions after adjusted scales or
// adjusted zero points from after the originals. Similar to
// propagate_quantized_ins in simplify_qdq.
static auto propagate_broadcasts(module& m,
const instruction_ref adj,
const instruction_ref ori,
const instruction_ref start,
const instruction_ref insert_pt)
{
auto prev_ins = start;
std::vector<instruction_ref> ins_between;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
// instructions
while(prev_ins != ori)
{
ins_between.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
}
auto ret = adj;
for(auto ins : reverse_iterator_for(ins_between))
{
ret = m.insert_instruction(insert_pt, (*ins)->get_operator(), {ret});
}
return ret;
}

static auto cast_to_fnuz(module& m,
const instruction_ref dq,
const instruction_ref input,
const instruction_ref dq_scale,
const instruction_ref dq_zp)
{
auto x = input;
std::vector<fp8e4m3fnuz> bits_0x80 = {fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits())};
auto bits_0x80_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x80);

std::vector<fp8e4m3fnuz> bits_0x7f = {fp8e4m3fnuz(0x7f, fp8e4m3fnuz::from_bits())};
auto bits_0x7f_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x7f);

std::vector<fp8e4m3fnuz> bits_0xff = {fp8e4m3fnuz(0xff, fp8e4m3fnuz::from_bits())};
auto bits_0xff_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0xff);

std::vector<fp8e4m3fnuz> bits_0x00 = {fp8e4m3fnuz(0x00, fp8e4m3fnuz::from_bits())};
auto bits_0x00_lit = m.add_literal(shape{shape::fp8e4m3fnuz_type, {1}, {0}}, bits_0x00);

x = bit_cast_and_handle_specials(
m, dq, x, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);
auto adj_dq_zp = bit_cast_and_handle_specials(
m, dq, dq_zp, bits_0x80_lit, bits_0x7f_lit, bits_0xff_lit, bits_0x00_lit);

// adj_scale = 2 * scale
auto two_lit = m.add_literal(literal{shape{dq_scale->get_shape().type()}, {2}});
two_lit = m.insert_instruction(
dq, make_op("multibroadcast", {{"out_lens", dq_scale->get_shape().lens()}}), two_lit);
auto adj_dq_scale = m.insert_instruction(dq, make_op("mul"), dq_scale, two_lit);

adj_dq_scale = propagate_broadcasts(m, adj_dq_scale, dq_scale, dq->inputs().at(1), dq);
adj_dq_zp = propagate_broadcasts(m, adj_dq_zp, dq_zp, dq->inputs().at(2), dq);
m.replace_instruction(dq, make_op("dequantizelinear"), x, adj_dq_scale, adj_dq_zp);
}

auto apply(module& m, const match::matcher_result& r) const
{
auto dq1 = r.instructions["dq1"];
auto dq2 = r.instructions["dq2"];
auto scale1 = r.instructions["scale1"];
auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"];

std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fn_type};
if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return;

cast_to_fnuz(m, dq1, dq1->inputs().front(), scale1, zp1);
cast_to_fnuz(m, dq2, dq2->inputs().front(), scale2, zp2);
}
};

} // namespace

void fp8_ocp_to_fnuz::apply(module_pass_manager& mpm) const
{
module_ref mm = &mpm.get_module();
match::find_matches(*mm, match_fp8ocp_convert_to_fp8fnuz{});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 1 addition & 1 deletion src/include/migraphx/base64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

/// encode string to base64
std::string base64_encode(const std::string& str);
std::string MIGRAPHX_EXPORT base64_encode(const std::string& str);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
Loading

0 comments on commit c8290eb

Please sign in to comment.