1
1
#include " binbcast.hpp"
2
+ #include < cstddef>
3
+ #include < cstdint>
2
4
#include < sycl/sycl.hpp>
3
5
#include " ggml.h"
4
6
@@ -85,15 +87,14 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
85
87
template <float (*bin_op)(const float , const float )>
86
88
struct bin_bcast_sycl {
87
89
template <typename src0_t , typename src1_t , typename dst_t >
88
- void operator ()(ggml_backend_sycl_context & ctx,
89
- const struct ggml_tensor *src0,
90
- const struct ggml_tensor *src1, struct ggml_tensor *dst,
91
- const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
92
- queue_ptr stream) {
93
-
94
- GGML_TENSOR_BINARY_OP_LOCALS
95
-
96
- int nr0 = ne10/ne0;
90
+ void operator ()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
91
+ const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
92
+ const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,
93
+ const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
94
+ const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
95
+ const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguos,
96
+ const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
97
+ int nr0 = ne10 / ne0;
97
98
int nr1 = ne11/ne1;
98
99
int nr2 = ne12/ne2;
99
100
int nr3 = ne13/ne3;
@@ -120,7 +121,7 @@ struct bin_bcast_sycl {
120
121
cnb[3 ] *= cne[3 ];
121
122
};
122
123
123
- if (ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && ggml_is_contiguous (dst) ) {
124
+ if (src0_is_contiguos && src1_is_contiguous && dst_is_contiguous ) {
124
125
for (int i = 0 ; i < 4 ; i++) {
125
126
if (nr[i] != 1 ) {
126
127
break ;
@@ -253,32 +254,39 @@ struct bin_bcast_sycl {
253
254
});
254
255
}
255
256
}
256
- GGML_UNUSED (ctx);
257
257
}
258
258
};
259
259
260
260
template <class op >
261
- inline void ggml_sycl_op_bin_bcast (ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
262
- const ggml_tensor *src1, ggml_tensor * dst) {
261
+ inline void ggml_sycl_op_bin_bcast (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1 ,
262
+ ggml_tensor * dst) {
263
263
dpct::queue_ptr main_stream = ctx.stream ();
264
+ GGML_TENSOR_BINARY_OP_LOCALS
264
265
265
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
266
- op ()(ctx, src0, src1, dst, (const float *)src0->data , (const float *)src1->data , (float *)dst->data , main_stream);
266
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
267
+ op ()((const float *) src0->data , (const float *) src1->data , (float *) dst->data , ne00, ne01, ne02, ne03, ne10,
268
+ ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
269
+ ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
267
270
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
268
- op ()(ctx, src0, src1, dst, (const sycl::half *)src0->data , (const sycl::half *)src1->data ,
269
- (sycl::half *)dst->data , main_stream);
270
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
271
- op ()(ctx, src0, src1, dst, (const sycl::half *)src0->data , (const float *)src1->data , (sycl::half *)dst->data ,
271
+ op ()((const sycl::half *) src0->data , (const sycl::half *) src1->data , (sycl::half *) dst->data , ne00, ne01,
272
+ ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
273
+ nb0, nb1, nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst),
272
274
main_stream);
275
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
276
+ op ()((const sycl::half *) src0->data , (const float *) src1->data , (sycl::half *) dst->data , ne00, ne01, ne02,
277
+ ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
278
+ nb2, nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
273
279
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
274
- op ()(ctx, src0, src1, dst, (const int32_t *)src0->data , (const int32_t *)src1->data , (int32_t *)dst->data ,
275
- main_stream);
280
+ op ()((const int32_t *) src0->data , (const int32_t *) src1->data , (int32_t *) dst->data , ne00, ne01, ne02, ne03,
281
+ ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
282
+ nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
276
283
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
277
- op ()(ctx, src0, src1, dst, (const int16_t *)src0->data , (const int16_t *)src1->data , (int16_t *)dst->data ,
278
- main_stream);
284
+ op ()((const int16_t *) src0->data , (const int16_t *) src1->data , (int16_t *) dst->data , ne00, ne01, ne02, ne03,
285
+ ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
286
+ nb3, ggml_is_contiguous (src0), ggml_is_contiguous (src1), ggml_is_contiguous (dst), main_stream);
279
287
} else {
280
- fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__,
281
- ggml_type_name (dst-> type ), ggml_type_name (src0->type ), ggml_type_name (src1->type ));
288
+ fprintf (stderr, " %s: unsupported types: dst: %s, src0: %s, src1: %s\n " , __func__, ggml_type_name (dst-> type ),
289
+ ggml_type_name (src0->type ), ggml_type_name (src1->type ));
282
290
GGML_ABORT (" fatal error" );
283
291
}
284
292
}
0 commit comments