Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/develop' into einsum_support
Browse files Browse the repository at this point in the history
  • Loading branch information
mirza-halilcevic committed May 23, 2024
2 parents 6813ef2 + 76d1c14 commit bc7d5cb
Show file tree
Hide file tree
Showing 20 changed files with 315 additions and 67 deletions.
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,14 @@ if(MIGRAPHX_USE_ROCBLAS)
list(APPEND PACKAGE_DEPENDS rocblas)
endif()

rocm_package_add_deb_dependencies(SHARED_DEPENDS "hip-dev")
rocm_package_add_rpm_dependencies(SHARED_DEPENDS "hip-devel")

rocm_create_package(
NAME MIGraphX
DESCRIPTION "AMD's graph optimizer"
MAINTAINER "AMDMIGraphX Maintainer <[email protected]>"
LDCONFIG
PTH
DEPENDS miopen-hip ${DEPENDS_HIP_RUNTIME} hip-base half ${PACKAGE_DEPENDS}
DEPENDS miopen-hip ${DEPENDS_HIP_RUNTIME} half ${PACKAGE_DEPENDS}
)
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libpython3.8 \
wget \
rocm-device-libs \
hip-base \
hip-dev \
libnuma-dev \
miopen-hip \
rocblas \
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx/requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
rocm-docs-core==1.1.1
rocm-docs-core==1.1.2
sphinx-collapse
4 changes: 2 additions & 2 deletions docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ pyyaml==6.0
# myst-parser
# rocm-docs-core
# sphinx-external-toc
requests==2.31.0
requests==2.32.0
# via
# pygithub
# sphinx
rocm-docs-core==1.1.1
rocm-docs-core==1.1.2
# via -r requirements.in
smmap==5.0.0
# via gitdb
Expand Down
2 changes: 1 addition & 1 deletion hip-clang.docker
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
software-properties-common \
wget \
rocm-device-libs \
hip-base \
hip-dev \
libnuma-dev \
miopen-hip \
rocblas \
Expand Down
11 changes: 10 additions & 1 deletion src/include/migraphx/tf.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -45,6 +45,15 @@ struct tf_options
MIGRAPHX_TF_EXPORT program parse_tf(const std::string& name,
const tf_options& options = tf_options{});

/// Create a program from an tf buffer
MIGRAPHX_TF_EXPORT program parse_tf_buffer(const std::string& buffer,
const tf_options& options = tf_options{});

/// Create a program from tf buffer
MIGRAPHX_TF_EXPORT program parse_tf_buffer(const void* data,
std::size_t size,
const tf_options& options = tf_options{});

MIGRAPHX_TF_EXPORT std::vector<std::string> get_tf_operators();

} // namespace MIGRAPHX_INLINE_NS
Expand Down
73 changes: 59 additions & 14 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ struct find_nested_slice
}
};

/**
* Example case
* From:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* mb0: multibroadcast(param0, output_lens = [2, 3, 4])
* mb1: multibroadcast(param1, output_lens = [2, 3, 4])
* concat(mb0, mb1, axis = 2)
*
* To:
* param0: lens = [3, 4], strides = [4, 1]
* param1: lens = [3, 4], strides = [4, 1]
* con0: concat(param0, param1, axis = 1)
* multibroadcast(con0, lens = [2, 3, 4])
*/
struct find_concat_multibroadcasts
{
auto matcher() const
Expand All @@ -253,32 +268,62 @@ struct find_concat_multibroadcasts

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto op = any_cast<op::concat>(ins->get_operator());
auto out_lens = ins->get_shape().lens();
auto inputs = ins->inputs();
auto in_strides = inputs.front()->get_shape().strides();
auto concat_ins = mr.result;
auto concat_op = any_cast<op::concat>(concat_ins->get_operator());
auto concat_out_lens = concat_ins->get_shape().lens();
auto concat_inputs = concat_ins->inputs();
auto front_mb_strides = concat_inputs.front()->get_shape().strides();
assert(concat_op.axis >= 0);

// Only apply when concat axis is not a broadcasted dimension
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape().strides()[op.axis] == 0;
if(std::any_of(concat_inputs.begin(), concat_inputs.end(), [&](auto i) {
return i->get_shape().strides()[concat_op.axis] == 0;
}))
{
return;
}

// Use inputs of multibroadcast ops as inputs to new concat op
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [](auto i) {
// Get the inputs of multibroadcast ops. Will be used as inputs to new concat op
std::vector<instruction_ref> mb_inputs(concat_inputs.size());
std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) {
return i->inputs().front();
});

// Check that the inputs into the multibroadcasts have the same rank
const auto& first_shape = mb_inputs.front()->get_shape();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) {
return mb_in->get_shape().ndim() == first_shape.ndim();
}))
{
return;
}

// Reduce axis by number of leading broadcasted dimensions
if(inputs.front()->get_shape().lens().size() < out_lens.size())
op.axis -= std::count(in_strides.begin(), in_strides.begin() + op.axis, 0);
if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size())
{
concat_op.axis -=
std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0);
}

auto concat = m.insert_instruction(ins, op, inputs);
m.replace_instruction(
ins, migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), concat);
// Inputs to multibroadcasts should have the same dimensions except for the axis to
// concatenate over
const auto& front_in_lens = mb_inputs.front()->get_shape().lens();
if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) {
const auto& lens = input_to_mb->get_shape().lens();
return std::equal(
lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and
std::equal(lens.begin() + concat_op.axis + 1,
lens.end(),
front_in_lens.begin() + concat_op.axis + 1);
}))
{
return;
}

auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs);
m.replace_instruction(concat_ins,
migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}),
new_concat_ins);
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/rocblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_h

rocblas_handle_ptr create_rocblas_handle_ptr();
rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s);

#endif
struct context;

MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();

MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();
#endif

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
54 changes: 33 additions & 21 deletions src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,6 +30,18 @@

namespace migraphx {

template <typename T>
struct acc_type
{
using type = float;
};

template <>
struct acc_type<double>
{
using type = double;
};

template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
Expand All @@ -50,33 +62,33 @@ __device__ void generic_binary_layernorm(
using reduce_output = reduce::with_axis<Input1, Axis>;

block::template run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
using value_type = typename Input1::type;
using vec_value_type = vec_type<value_type>;
using value_type = typename Input1::type;
using vec_value_type = typename acc_type<vec_type<value_type>>::type;

auto input = r.inner([&](auto x1, auto x2) {
return migraphx::convert<vec_value_type>(op(x1, x2));
})(input1, input2);

constexpr auto relements = r.template elements<Input1>();
constexpr auto relements_r = vec_value_type{1.0 / relements};
auto relements_rsqrt = sqrt(relements_r);

auto means = r.reduce(op::sum{},
make_array<vec_value_type>(vec_value_type{0}, vec_value_type{0}),
[&](auto x) {
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing
// higher values before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);
auto means = r.reduce(op::sum{}, make_array<vec_value_type>(0, 0), [&](auto x) {
auto x_out = x * relements_r;
// dividing x by sqrt(relements) before squaring allows computing
// higher values before overflow in low precision
auto x2_sqrt = x * relements_rsqrt;
return make_array(x_out, x2_sqrt * x2_sqrt);
})(input);

auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = implicit_conversion(eps);
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
vec_value_type eps_val = implicit_conversion(eps);
auto rsqrt_val = rsqrt(variance + eps_val);

r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;

// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
y = compute(migraphx::convert<vec_type<value_type>>((x - mean_x) * rsqrt_val), xs...);
})(output, input, inputs...);
});
}
Expand Down
8 changes: 6 additions & 2 deletions src/targets/gpu/rocblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
rocblas_set_stream(rb.get(), s);
return rb;
}

#endif
bool get_compute_fp32_flag()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
Expand All @@ -57,13 +57,17 @@ bool get_compute_fp32_flag()

bool rocblas_fp8_available()
{
#if MIGRAPHX_USE_ROCBLAS
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
return gfx_has_fp8_intrinsics();
#endif
}
#else
return false;
#endif
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 0 additions & 2 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {};
#if MIGRAPHX_USE_ROCBLAS
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
unsupported_fp8_ops.insert("quant_dot");
}
#endif
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
Expand Down
26 changes: 21 additions & 5 deletions src/tf/tf.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,9 +37,9 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

program parse_tf(const std::string& name, const tf_options& options)
template <class... Ts>
program parse_tf_from(const tf_options& options, Ts&&... xs)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
tf::tf_parser parser;
parser.is_nhwc = options.is_nhwc;
parser.batch_size = options.batch_size;
Expand All @@ -50,19 +50,35 @@ program parse_tf(const std::string& name, const tf_options& options)
// Log the program when it can't be parsed
try
{
parser.parse_from(input);
parser.parse_from(std::forward<Ts>(xs)...);
}
catch(...)
{
std::cerr << parser.prog << std::endl;
throw;
}
#else
parser.parse_from(input);
parser.parse_from(std::forward<Ts>(xs)...);
#endif
return std::move(parser.prog);
}

program parse_tf(const std::string& name, const tf_options& options)
{
std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
return parse_tf_from(options, input);
}

program parse_tf_buffer(const std::string& buffer, const tf_options& options)
{
return parse_tf_from(options, buffer.data(), buffer.size());
}

program parse_tf_buffer(const void* data, std::size_t size, const tf_options& options)
{
return parse_tf_from(options, data, size);
}

std::vector<std::string> get_tf_operators() { return tf::get_op_parsers(); }

} // namespace MIGRAPHX_INLINE_NS
Expand Down
13 changes: 13 additions & 0 deletions src/tf/tf_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,19 @@ void tf_parser::parse_from(std::istream& is)
}
}

void tf_parser::parse_from(const void* data, std::size_t size)
{
tensorflow::GraphDef graph;
if(graph.ParseFromArray(data, size))
{
this->parse_graph(graph);
}
else
{
throw std::runtime_error("Failed reading tf buffer array");
}
}

shape::type_t tf_parser::parse_type(const tensorflow::DataType t) const
{
shape::type_t shape_type{};
Expand Down
Loading

0 comments on commit bc7d5cb

Please sign in to comment.