Skip to content

Commit

Permalink
[GPU] Use onednn gemm instead of inner product (#27628)
Browse files Browse the repository at this point in the history
SD3 dynamic has FC unfusion pattern, and needs sub graphs. It caused bad
performance.


### Tickets:
 - *152851, 157100*

---------

Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback authored Dec 31, 2024
1 parent 517ad68 commit ca501ca
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 158 deletions.
14 changes: 12 additions & 2 deletions src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,15 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
if (fc_layout.is_dynamic() || data_layout.is_dynamic())
continue;

auto same_spatial = [](layout a, layout b) {
if (a.get_spatial_rank() != b.get_spatial_rank())
return false;
for (size_t i = 0; i < a.get_spatial_rank(); i++) {
if (a.spatial(i) != b.spatial(i))
return false;
}
return true;
};
// fc_b | fc_f | data_b | data_f | broadcast condition
// ---------+-----------+-----------+-----------+--------------------
// 1 | 1 | 1 | 1 | no broadcast
Expand All @@ -1043,11 +1052,12 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
// N | 1 | N | 1 | no broadcast
// N | 1 | N | N | N/A
// N | N | 1 | 1 | implicit broadcast
// N | N | 1 | N | explicit broadcast
// N | N | N | 1 | explicit broadcast
// N | N | 1 | N | explicit broadcast when spatial different
// N | N | N | 1 | explicit broadcast when spatial different
// N | N | N | N | no broadcast
if ((fc_layout.batch() == 1 || fc_layout.feature() == 1) ||
(data_layout.batch() == 1 && data_layout.feature() == 1) ||
((data_layout.batch() == 1 || data_layout.feature() == 1) && same_spatial(fc_layout, data_layout)) ||
(fc_layout.count() == data_layout.count())) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,44 +98,6 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
return args;
}

static std::shared_ptr<WeightsReorderParams> get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd) {
auto input_layout = impl_params.get_input_layout(0);
auto source_weights_layout = impl_params.get_input_layout(1);
auto cldnn_prim = impl_params.typed_desc<fully_connected>();

auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = source_weights_layout.get_partial_shape();

int64_t feature = input_pshape[std::min(cldnn_prim->input_size, static_cast<size_t>(4)) - 1].get_length();
if (cldnn_prim->input_size == 3) {
feature = std::max({input_layout.spatial(0), input_layout.spatial(1), input_layout.spatial(2)});
}
auto target_weights_layout = source_weights_layout;
if (weights_pshape.size() != 2) {
target_weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
}

auto target_weights_desc = pd.weights_desc(0);

auto shape_consistent = onednn::keep_weights_reorder_shape_consistent(source_weights_layout, target_weights_desc);
OPENVINO_ASSERT(shape_consistent, "[GPU] Input shape and output shape of weight reorder should be same.");

auto source_weights_desc = onednn::layout_to_memory_desc(source_weights_layout);

const bool weights_format = true;
const bool grouped = false;

auto traits = convert_memory_desc_to_traits(target_weights_desc, weights_format, grouped);

target_weights_layout.format = format(traits);

return std::make_shared<WeightsReorderParamsOneDNN>(source_weights_layout,
target_weights_layout,
source_weights_desc,
target_weights_desc,
false);
}

static void transform_layouts(layout& input_layout, layout& weights_layout, layout& output_layout, size_t prim_input_size) {
auto input_pshape = input_layout.get_partial_shape();
auto weights_pshape = weights_layout.get_partial_shape();
Expand Down Expand Up @@ -164,43 +126,6 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
}
}

static std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
get_inner_product_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine,
size_t prim_input_size,
bool has_bias,
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
auto input_layout = impl_params.get_input_layout(0);
auto weights_layout = impl_params.get_input_layout(1);
auto output_layout = impl_params.get_output_layout();

transform_layouts(input_layout, weights_layout, output_layout, prim_input_size);

auto input_md = onednn::layout_to_memory_desc(input_layout, dnnl::memory::format_tag::undef, false);
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::any);
auto output_md = onednn::layout_to_memory_desc(output_layout, dnnl::memory::format_tag::ab, false);

if (has_bias) {
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), dnnl::memory::format_tag::any, true);
return std::make_shared<dnnl::inner_product_forward::primitive_desc>(
engine.get_onednn_engine(),
dnnl::prop_kind::forward_inference,
input_md,
weights_md,
bias_md,
output_md,
attr);
} else {
return std::make_shared<dnnl::inner_product_forward::primitive_desc>(
engine.get_onednn_engine(),
dnnl::prop_kind::forward_inference,
input_md,
weights_md,
output_md,
attr);
}
}

static std::shared_ptr<dnnl::matmul::primitive_desc>
get_matmul_primitive_descriptor(const kernel_impl_params& impl_params,
cldnn::engine& engine,
Expand All @@ -219,7 +144,11 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
auto output_md = onednn::layout_to_memory_desc(output_layout, dnnl::memory::format_tag::ab, false);

if (has_bias) {
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), dnnl::memory::format_tag::ab, false);
dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::ab;
auto bias_l = impl_params.get_input_layout(2);
if (bias_l.get_shape().size() == 1)
target_fmt = dnnl::memory::format_tag::ba;
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), target_fmt, false);
return std::make_shared<dnnl::matmul::primitive_desc>(
engine.get_onednn_engine(),
input_md,
Expand Down Expand Up @@ -335,13 +264,8 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
_attrs->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
}

if (is_compressed) {
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
} else {
auto prim_desc = get_inner_product_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;
}
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
_pd = *prim_desc;

std::vector<uint8_t> prim_cache;
ib >> prim_cache;
Expand Down Expand Up @@ -426,10 +350,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
prim_onednn->_dzp_data_type = dzp_data_type;
return prim_onednn;
} else {
auto prim_desc = get_inner_product_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);
auto prim_desc = get_matmul_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
prim->input_size, !prim->bias.empty(), *attr);

return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc, get_weights_reorder(impl_params, *prim_desc));
return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc);
}
}
};
Expand Down
12 changes: 11 additions & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2671,7 +2671,6 @@ bool primitive_inst::is_valid_fusion() const {
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
cldnn::format::dimension(gemm_layout.format),
false);

auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
cldnn::format::dimension(data_layout.format),
false);
Expand All @@ -2685,8 +2684,19 @@ bool primitive_inst::is_valid_fusion() const {
const auto fc_dims = fc_layout.get_dims();
const auto data_dims = data_layout.get_dims();

auto same_spatial = [](layout a, layout b) {
if (a.get_spatial_rank() != b.get_spatial_rank())
return false;
for (size_t i = 0; i < a.get_spatial_rank(); i++) {
if (a.spatial(i) != b.spatial(i))
return false;
}
return true;
};

if (!(fc_dims[0] == 1 || fc_dims[1] == 1) &&
!(data_dims[0] == 1 && data_dims[1] == 1) &&
!((data_dims[0] == 1 || data_dims[1] == 1) && same_spatial(fc_layout, data_layout)) &&
!(fc_layout.count() == data_layout.count())) {
return false;
}
Expand Down
72 changes: 52 additions & 20 deletions src/plugins/intel_gpu/src/graph/program_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1558,30 +1558,62 @@ void program_node::create_onednn_primitive_attributes(
} else if (desc.is_type<eltwise>()) {
auto dep_idx = desc.outer_dep_start_idx;
auto in = get_input_layout(dep_idx);
auto in_origin = in;
auto set_binary_op = [&](dnnl::algorithm alg, onednn_post_op_type op_type) {
if (is_type<fully_connected>()) {
auto prim = this->as<fully_connected>().get_primitive();
if (prim->input_size == 3) {
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
auto fc_needs_full_tensor = [&]() {
for (size_t i = 0; i < cldnn_post_ops.size(); i++) {
auto& desc = cldnn_post_ops[i];
if (desc.is_type<eltwise>()) {
auto prim = this->as<fully_connected>().get_primitive();
auto dep_idx = desc.outer_dep_start_idx;
auto in = get_input_layout(dep_idx);
if (prim->input_size == 3 && in.batch() > 1 && in.feature() > 1)
return true;
}
auto mem_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab);
post_ops.append_binary(alg, mem_desc);
update_onednn_post_op_list(op_type, dep_idx, dnnl::memory::format_tag::ab, false,
mem_desc.get_dims(), mem_desc.get_data_type());
} else if (is_type<gemm>()) {
}
return false;
};
auto set_binary_op = [&](dnnl::algorithm alg, onednn_post_op_type op_type) {
if (is_type<fully_connected>() || is_type<gemm>()) {
size_t rank = cldnn::format::dimension(in.format);
auto in_pshape = in.get_partial_shape();
auto out_pshape = get_output_layout().get_partial_shape();
size_t ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
if (ones_to_add > 0) {
layout new_layout = in;
ov::PartialShape new_input_pshape;
std::vector<ov::Dimension> dims(in_pshape.begin(), in_pshape.begin() + in_pshape.size());
new_input_pshape = ov::PartialShape(dims);
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
new_layout.set_partial_shape(new_input_pshape);
in = new_layout;
size_t ones_to_add = 0;

if (is_type<fully_connected>()) {
auto prim = this->as<fully_connected>().get_primitive();
if (prim->input_size == in_pshape.size()) {
if (prim->input_size == 3 && !fc_needs_full_tensor()) {
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
in_pshape = in.get_partial_shape();
}
ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
} else {
if (prim->input_size == 3)
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
ones_to_add = std::max(in_pshape.size(), prim->input_size) - std::min(in_pshape.size(), prim->input_size);
}
if (ones_to_add > 0) {
layout new_layout = in;
ov::PartialShape new_input_pshape;
auto last = in_pshape.begin() + in_pshape.size();
if (in_pshape.size() > prim->input_size)
last -= ones_to_add;
std::vector<ov::Dimension> dims(in_pshape.begin(), last);
new_input_pshape = ov::PartialShape(dims);
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
new_layout.set_partial_shape(new_input_pshape);
in = new_layout;
}
} else {
ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
if (ones_to_add > 0) {
layout new_layout = in;
ov::PartialShape new_input_pshape;
std::vector<ov::Dimension> dims(in_pshape.begin(), in_pshape.begin() + in_pshape.size());
new_input_pshape = ov::PartialShape(dims);
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
new_layout.set_partial_shape(new_input_pshape);
in = new_layout;
}
}
size_t in_batched_size = in.count() / (in.spatial(0) * in.spatial(1));
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in_batched_size == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ class FullyConnectedFusingTestOneDNN : public BaseFusingTest<fully_connected_tes
#define CASE_FC_FP32_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::f32, format::yxfb, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3 { 2, 32 }, { 2, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx

#define DYN_CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
#define DYN_CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx

#define CASE_FC_U8S8_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
#define CASE_FC_U8S8_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx
Expand Down
Loading

0 comments on commit ca501ca

Please sign in to comment.