Skip to content

Commit f364e18

Browse files
[Matrix][SYCL] Rename wi_slice with wi_data (#5728)
1 parent 5023657 commit f364e18

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_matrix.asciidoc

+4-4
Original file line numberDiff line numberDiff line change
@@ -548,19 +548,19 @@ While this provides fast element indexing on the GPU compared to the non-restric
548548
However using the `mma` ptx instructions as opposed to the `wmma` ptx instructions the mapping is known. Knowing this mapping is important for the user to implement new operations like sum of rows of a matrix for quantized algorithms.
549549

550550
#### proposal: Explicit conversion in the interface from SIMD to SPMD
551-
We introduce a new function `get_wi_slice` that provides any portion of the matrix that the user wants but in a SPMD array object:.
551+
We introduce a new function `get_wi_data` that provides any portion of the matrix that the user wants but in a SPMD array object:.
552552

553553
```c++
554554
namespace sycl::ext::oneapi::experimental::matrix {
555555
template <typename Group, typename T, size_t NumRows, size_t NumCols, matrix_layout L>
556-
marray<T, n_rows * n_cols> get_wi_slice(joint_matrix<T, NumRows, NumCols, L, Group> &m, size_t row_index,
556+
marray<T, n_rows * n_cols> get_wi_data(joint_matrix<T, NumRows, NumCols, L, Group> &m, size_t row_index,
557557
size_t col_index, size_t n_rows, size_t n_cols);
558558
}
559559
```
560560

561561
Example where each WI gets 1 column:
562562
```c++
563-
marray<T,msize> wi_C = get_wi_slice(C, 0, wi_idx, msize, 1, matrix_layout::row_major);
563+
marray<T,msize> wi_C = get_wi_data(C, 0, wi_idx, msize, 1, matrix_layout::row_major);
564564
for (int i = 0; i < msize; i++)
565565
row_sum += wi_C[i];
566566
```
@@ -582,7 +582,7 @@ We did not utilize this extension for this matrix API version because sub-group
582582
-- Yes, this will be addressed in the next revision where `use` argument will be introduced to distinguish between right (B) , left (A), and accumulator matrix.
583583
- Ronan Keryell: "It would be interesting to investigate whether providing also member functions would simplify the API. Provide both so it is possible to use the best one for each use case, while waiting for https://en.wikipedia.org/wiki/Uniform_Function_Call_Syntax to land into C++?"
584584

585-
- In the future looking APIs, `get_wi_slice` (that is currently under design) returns an owned object. Should this return a view object to make sure the original matrix C is changed after its slices are modified.
585+
- In the future looking APIs, `get_wi_data` (that is currently under design) returns an owned object. Should this return a view object to make sure the original matrix C is changed after its slices are modified.
586586

587587
## TODO List
588588
- Add support for fill matrix and element-wise operations features

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ template <int D> struct spv_scope_traits<sycl::group<D>> {
4747
template <typename T, size_t NumRows, size_t NumCols,
4848
matrix_layout Layout = matrix_layout::row_major,
4949
typename Group = sycl::sub_group>
50-
class wi_slice;
50+
class wi_data;
5151

5252
template <typename T, size_t NumRows, size_t NumCols,
5353
matrix_layout Layout = matrix_layout::row_major,
@@ -64,9 +64,9 @@ struct joint_matrix {
6464
#endif // __SYCL_DEVICE_ONLY__
6565
}
6666

67-
inline __SYCL_ALWAYS_INLINE wi_slice<T, NumRows, NumCols, Layout, Group>
67+
inline __SYCL_ALWAYS_INLINE wi_data<T, NumRows, NumCols, Layout, Group>
6868
get_wi_data() {
69-
return wi_slice<T, NumRows, NumCols, Layout, Group>(*this);
69+
return wi_data<T, NumRows, NumCols, Layout, Group>(*this);
7070
}
7171
};
7272

@@ -455,11 +455,11 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {
455455

456456
template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
457457
typename Group>
458-
class wi_slice {
458+
class wi_data {
459459
joint_matrix<T, NumRows, NumCols, Layout, Group> &M;
460460

461461
public:
462-
wi_slice(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat) : M(Mat) {}
462+
wi_data(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat) : M(Mat) {}
463463
size_t length() {
464464
#ifdef __SYCL_DEVICE_ONLY__
465465
return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm);

sycl/test/matrix/matrix-elemwise-ops.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,9 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
8787
N * 4, matrix_layout::packed_b);
8888
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
8989
}
90-
auto wi_slice_c = sub_c.get_wi_data();
91-
for (int i = 0; i < wi_slice_c.length(); i++) {
92-
wi_slice_c[i] *= 2;
90+
auto wi_data_c = sub_c.get_wi_data();
91+
for (int i = 0; i < wi_data_c.length(); i++) {
92+
wi_data_c[i] *= 2;
9393
}
9494
joint_matrix_store(sg, sub_c,
9595
accC.get_pointer() + (sg_startx * TM) * N +

0 commit comments

Comments
 (0)