Skip to content
This repository was archived by the owner on Jan 13, 2025. It is now read-only.

Fix joint_matrix implementation to match latest api #491

Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cdc3486
Fix joint_matrix_mad api call
muhammad-tanvir-1211 Dec 20, 2023
548321c
Fixed output for matrices with no corner cases
muhammad-tanvir-1211 Dec 26, 2023
65d720b
Added support for corner cases
muhammad-tanvir-1211 Dec 26, 2023
1f1b41c
Fixed the transpose implementation
muhammad-tanvir-1211 Dec 27, 2023
95b11bd
Fixed corner case output
muhammad-tanvir-1211 Dec 28, 2023
e85a2fa
Code working for gemm and gemm_batched
muhammad-tanvir-1211 Jan 1, 2024
6d659af
Unrolled the Global Memory write loops
muhammad-tanvir-1211 Jan 5, 2024
277d787
Reorder computation loops
muhammad-tanvir-1211 Jan 5, 2024
eca4634
Updated comments and removed redundant code
muhammad-tanvir-1211 Jan 8, 2024
37d26e1
Fixed build with release compiler
muhammad-tanvir-1211 Jan 10, 2024
9d2f540
Fix loop bounds
muhammad-tanvir-1211 Jan 16, 2024
5e8d8c6
Add functions for checking lower precision
pgorlani Jan 18, 2024
60fa20a
Fixed synchronization for double buffering
muhammad-tanvir-1211 Jan 25, 2024
ba4cd87
Restriced VectorSize to 1 for joint_matrix
muhammad-tanvir-1211 Jan 31, 2024
78af1da
Fixed race condition
muhammad-tanvir-1211 Jan 31, 2024
9767329
Fixed compilation error with bfloat16 type
muhammad-tanvir-1211 Jan 31, 2024
d1f348b
Increase error margins for trsm tests when joint_matrix is used
muhammad-tanvir-1211 Feb 1, 2024
0c0adaa
Address feedback
muhammad-tanvir-1211 Feb 1, 2024
23dff65
Added joint_matrix tests
muhammad-tanvir-1211 Feb 6, 2024
fe55302
Address test commit feedback
muhammad-tanvir-1211 Feb 22, 2024
50c150b
Reduced the initializer range
muhammad-tanvir-1211 Feb 22, 2024
6b87355
Merge branch 'master' into joint_matrix_fix
muhammad-tanvir-1211 Feb 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Restriced VectorSize to 1 for joint_matrix
* Updated the load/store file with proper vectorized load/store for future use.
muhammad-tanvir-1211 committed Feb 6, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit ba4cd8743346755febb009830169f9da4014ee05
27 changes: 24 additions & 3 deletions src/operations/blas3/gemm_load_store_joint_matrix.hpp
Original file line number Diff line number Diff line change
@@ -86,6 +86,7 @@ struct PacketizeJointMatrix {
*dest = round_to_tf32(val);
}
}

/*! @brief Performs a vectorised load using sycl::vec::load when the current
* block is internal. In the case where k < the
* number of elements being loaded then edge loads will be element wise with
@@ -114,6 +115,7 @@ struct PacketizeJointMatrix {
}
store<trans, ld>(packet, dest);
}

/*! @brief Store a vector packet into local memory when the source is
* transposed. This will untranspose the elements individually when storing so
* the data in local memory is always consistent.
@@ -156,16 +158,35 @@ struct PacketizeJointMatrix {
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(packet[0]);
cl::sycl::vec<dtype, vector_size> new_vec{};
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(packet[0]);
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(packet[0]);
using dtype = float;
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
round_to_tf32(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
}
}
};
41 changes: 18 additions & 23 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
@@ -83,7 +83,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
using value_t = element_t;
using index_t = typename std::make_signed<typename input_t::index_t>::type;
using packetize_t = PacketizeJointMatrix<VectorSize, value_t, index_t>;
using vector_t = typename packetize_t::PacketType;
using address_t = cl::sycl::access::address_space;

// enable easier access to tile dimensions
@@ -156,6 +155,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
static_assert(std::is_same<value_t, float>::value,
"This code is only supported for float data type.");

static_assert(VectorSize == 1,
"Vectorization not supported for joint_matrix.");

//! @brief leading dimension of block of A in local
static constexpr index_t ldsa =
(trans_a ? cl_elems : block_rows) +
@@ -366,7 +368,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
batch_size_);
}
} else {
using address_t = cl::sycl::access::address_space;
auto input_scratch = *reinterpret_cast<cl::sycl::multi_ptr<
typename tile_type::jmInpType, address_t::local_space> *>(&scratch);

@@ -721,25 +722,22 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
index_t item_id, InputPointerType ptr, index_t ld,
ScratchPointerType scratch, RowPredicate in_row, ColPredicate in_col) {
constexpr index_t bs = rows * cols;
constexpr index_t multiplier = internal ? packetize_t::packet_size : 1;
constexpr index_t loop_iterations = (bs - 1) / (wg_size * multiplier) + 1;
constexpr index_t loop_iterations = (bs - 1) / wg_size + 1;
#pragma unroll
for (index_t i = 0; i < loop_iterations; ++i) {
if (!do_check<((bs % (wg_size * multiplier)) != 0)>(
item_id + i * (wg_size * multiplier) < bs))
if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs))
continue;
const index_t col_ofs = i * ((wg_size * multiplier) / rows);
const index_t col_ofs = i * (wg_size / rows);
const bool in_range =
do_check<check_row_limit>(
in_row(((item_id * multiplier) % rows), multiplier - 1)) &&
do_check<check_row_limit>(in_row((item_id % rows), 0)) &&
do_check<check_col_limit>(
in_col((item_id * multiplier / rows), col_ofs));
in_col((item_id / rows), col_ofs));

packetize_t::template load<trans, internal, lds>(
in_range, ptr + col_ofs * ld, scratch + col_ofs * lds,
[&](const index_t &ofs) {
return in_row((item_id * multiplier) % rows, ofs) &&
in_col((item_id * multiplier) / rows, col_ofs);
return in_row(item_id % rows, ofs) &&
in_col(item_id / rows, col_ofs);
});
}
}
@@ -751,24 +749,21 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
index_t item_id, InputPointerType ptr, index_t ld,
ScratchPointerType scratch, RowPredicate in_row, ColPredicate in_col) {
constexpr index_t bs = rows * cols;
constexpr index_t multiplier = internal ? packetize_t::packet_size : 1;
constexpr index_t loop_iterations = (bs - 1) / (wg_size * multiplier) + 1;
constexpr index_t loop_iterations = (bs - 1) / wg_size + 1;
#pragma unroll
for (index_t i = 0; i < loop_iterations; ++i) {
if (!do_check<((bs % (wg_size * multiplier)) != 0)>(
item_id + i * (wg_size * multiplier) < bs))
if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs))
continue;
const index_t row_ofs = i * ((wg_size * multiplier) / cols);
const bool in_range = do_check<check_row_limit>(in_row(
(item_id * multiplier) / cols, row_ofs)) &&
do_check<check_col_limit>(in_col(
(item_id * multiplier) % cols, multiplier - 1));
const index_t row_ofs = i * (wg_size / cols);
const bool in_range =
do_check<check_row_limit>(in_row(item_id / cols, row_ofs)) &&
do_check<check_col_limit>(in_col(item_id % cols, 0));

packetize_t::template load<trans, internal, lds>(
in_range, ptr + row_ofs * ld, scratch + row_ofs * lds,
[&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE {
return in_col((item_id * multiplier) % cols, ofs) &&
in_row((item_id * multiplier) / cols, row_ofs);
return in_col(item_id % cols, ofs) &&
in_row(item_id / cols, row_ofs);
});
}
}