-
Notifications
You must be signed in to change notification settings - Fork 11.6k
sycl : Implemented reorder Q4_K mmvq #13109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -357,6 +357,31 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 | |||||||||||||||||
} | ||||||||||||||||||
#endif | ||||||||||||||||||
|
||||||||||||||||||
template <typename dst_t> | ||||||||||||||||||
inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall, | ||||||||||||||||||
const float dmin, uint8_t * __restrict__ scales_local, | ||||||||||||||||||
const sycl::nd_item<3> & item_ct1, int il, int ir) { | ||||||||||||||||||
const int is = 2 * il; | ||||||||||||||||||
const int n = 4; | ||||||||||||||||||
|
||||||||||||||||||
item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this barrier outside this function, before the callsite. |
||||||||||||||||||
|
||||||||||||||||||
uint8_t sc, m; | ||||||||||||||||||
get_scale_min_k4(is + 0, scales_local, sc, m); | ||||||||||||||||||
const float d1 = dall * sc; | ||||||||||||||||||
const float m1 = dmin * m; | ||||||||||||||||||
|
||||||||||||||||||
get_scale_min_k4(is + 1, scales_local, sc, m); | ||||||||||||||||||
const float d2 = dall * sc; | ||||||||||||||||||
const float m2 = dmin * m; | ||||||||||||||||||
|
||||||||||||||||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir); | ||||||||||||||||||
for (int l = 0; l < n; ++l) { | ||||||||||||||||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; | ||||||||||||||||||
y[l + 32] = d2 * (q_vec[l] >> 4) - m2; | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
template<typename dst_t> | ||||||||||||||||||
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, | ||||||||||||||||||
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { | ||||||||||||||||||
|
@@ -365,36 +390,21 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri | |||||||||||||||||
const int64_t i = item_ct1.get_group(2); | ||||||||||||||||||
|
||||||||||||||||||
#if QK_K == 256 | ||||||||||||||||||
// assume 32 threads | ||||||||||||||||||
const int64_t tid = item_ct1.get_local_id(2); | ||||||||||||||||||
const int64_t il = tid/8; | ||||||||||||||||||
const int64_t ir = tid%8; | ||||||||||||||||||
const int64_t is = 2*il; | ||||||||||||||||||
const int64_t n = 4; | ||||||||||||||||||
const int64_t il = tid / 8; | ||||||||||||||||||
const int64_t ir = tid % 8; | ||||||||||||||||||
|
||||||||||||||||||
dst_t * y = yy + i*QK_K + 64*il + n*ir; | ||||||||||||||||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; | ||||||||||||||||||
|
||||||||||||||||||
const sycl::half2 dm = x[i].dm; | ||||||||||||||||||
const float dall = dm[0]; | ||||||||||||||||||
const float dmin = dm[1]; | ||||||||||||||||||
|
||||||||||||||||||
if (tid < 12) | ||||||||||||||||||
if (tid < 12) { | ||||||||||||||||||
scales_local[tid] = x[i].scales[tid]; | ||||||||||||||||||
item_ct1.barrier(sycl::access::fence_space::local_space); | ||||||||||||||||||
|
||||||||||||||||||
uint8_t sc, m; | ||||||||||||||||||
get_scale_min_k4(is + 0, scales_local, sc, m); | ||||||||||||||||||
const float d1 = dall * sc; | ||||||||||||||||||
const float m1 = dmin * m; | ||||||||||||||||||
get_scale_min_k4(is + 1, scales_local, sc, m); | ||||||||||||||||||
const float d2 = dall * sc; | ||||||||||||||||||
const float m2 = dmin * m; | ||||||||||||||||||
|
||||||||||||||||||
sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir); | ||||||||||||||||||
for (int l = 0; l < n; ++l) { | ||||||||||||||||||
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; | ||||||||||||||||||
y[l +32] = d2 * (q_vec[l] >> 4) - m2; | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, item_ct1, il, ir); | ||||||||||||||||||
#else | ||||||||||||||||||
const int64_t tid = item_ct1.get_local_id(2); | ||||||||||||||||||
const uint8_t * q = x[i].qs; | ||||||||||||||||||
|
@@ -406,6 +416,35 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri | |||||||||||||||||
#endif | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
template <typename dst_t> | ||||||||||||||||||
static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local, | ||||||||||||||||||
const sycl::nd_item<3> & item_ct1, int64_t nb) { | ||||||||||||||||||
Comment on lines
+420
to
+421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also just another comment to cover removing all other usages of restrict as well, instead of highlighting each one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I think we can start using 1D kernel launches as we start adding kernels manually. 3 Dimensional kernels are not a deliberate choice the backend makes but rather is a result of direct translation of dim3 in the cuda backend. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly for all other usages of 3D nd_item as well |
||||||||||||||||||
const int64_t i = item_ct1.get_group(2); | ||||||||||||||||||
const int64_t tid = item_ct1.get_local_id(2); | ||||||||||||||||||
const int64_t il = tid / 8; | ||||||||||||||||||
const int64_t ir = tid % 8; | ||||||||||||||||||
|
||||||||||||||||||
dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; | ||||||||||||||||||
|
||||||||||||||||||
const uint8_t * base = static_cast<const uint8_t *>(vx); | ||||||||||||||||||
const size_t qs_offset = i * (QK_K / 2); | ||||||||||||||||||
const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE; | ||||||||||||||||||
const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2); | ||||||||||||||||||
|
||||||||||||||||||
const uint8_t * qs_ptr = base + qs_offset; | ||||||||||||||||||
const uint8_t * scales_ptr = base + scales_offset; | ||||||||||||||||||
const ggml_half2 * dm_ptr = reinterpret_cast<const ggml_half2 *>(base + dm_offset); | ||||||||||||||||||
|
||||||||||||||||||
const float dall = dm_ptr->x(); | ||||||||||||||||||
const float dmin = dm_ptr->y(); | ||||||||||||||||||
Comment on lines
+436
to
+439
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
You can fetch the data you need in one read and avoid making twice the number of trips to memory |
||||||||||||||||||
|
||||||||||||||||||
if (tid < 12) { | ||||||||||||||||||
scales_local[tid] = scales_ptr[tid]; | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, item_ct1, il, ir); | ||||||||||||||||||
} | ||||||||||||||||||
|
||||||||||||||||||
template<typename dst_t> | ||||||||||||||||||
static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy, | ||||||||||||||||||
const sycl::nd_item<3> &item_ct1) { | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.