Skip to content

Commit de71881

Browse files
committed
Remove duplicates
1 parent 603fdbf commit de71881

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

ggml/src/ggml-sycl/binbcast.cpp

+33-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include "binbcast.hpp"
2+
#include <cstddef>
3+
#include <cstdint>
24
#include <sycl/sycl.hpp>
35
#include "ggml.h"
46

@@ -85,15 +87,14 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
8587
template<float (*bin_op)(const float, const float)>
8688
struct bin_bcast_sycl {
8789
template <typename src0_t, typename src1_t, typename dst_t>
88-
void operator()(ggml_backend_sycl_context & ctx,
89-
const struct ggml_tensor *src0,
90-
const struct ggml_tensor *src1, struct ggml_tensor *dst,
91-
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
92-
queue_ptr stream) {
93-
94-
GGML_TENSOR_BINARY_OP_LOCALS
95-
96-
int nr0 = ne10/ne0;
90+
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
91+
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
92+
const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,
93+
const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
94+
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
95+
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguos,
96+
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
97+
int nr0 = ne10 / ne0;
9798
int nr1 = ne11/ne1;
9899
int nr2 = ne12/ne2;
99100
int nr3 = ne13/ne3;
@@ -120,7 +121,7 @@ struct bin_bcast_sycl {
120121
cnb[3] *= cne[3];
121122
};
122123

123-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
124+
if (src0_is_contiguos && src1_is_contiguous && dst_is_contiguous) {
124125
for (int i = 0; i < 4; i++) {
125126
if (nr[i] != 1) {
126127
break;
@@ -253,32 +254,39 @@ struct bin_bcast_sycl {
253254
});
254255
}
255256
}
256-
GGML_UNUSED(ctx);
257257
}
258258
};
259259

260260
template <class op>
261-
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
262-
const ggml_tensor *src1, ggml_tensor *dst) {
261+
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
262+
ggml_tensor * dst) {
263263
dpct::queue_ptr main_stream = ctx.stream();
264+
GGML_TENSOR_BINARY_OP_LOCALS
264265

265-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
266-
op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
266+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
267+
op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
268+
ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
269+
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
267270
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
268-
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const sycl::half *)src1->data,
269-
(sycl::half *)dst->data, main_stream);
270-
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
271-
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (sycl::half *)dst->data,
271+
op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
272+
ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
273+
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
272274
main_stream);
275+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
276+
op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
277+
ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
278+
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
273279
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
274-
op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
275-
main_stream);
280+
op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
281+
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
282+
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
276283
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
277-
op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
278-
main_stream);
284+
op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
285+
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
286+
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
279287
} else {
280-
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
281-
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
288+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
289+
ggml_type_name(src0->type), ggml_type_name(src1->type));
282290
GGML_ABORT("fatal error");
283291
}
284292
}

0 commit comments

Comments
 (0)