Skip to content

Commit

Permalink
idiomatic variable names + memset for dst->data
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 22, 2023
1 parent 9d10023 commit d5b778d
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14395,37 +14395,47 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));

const int64_t K = ne00;
const int64_t OC = ne01;
const int64_t IC = ne02;

const int64_t L = ne10;
GGML_ASSERT(IC == ne11);

if (params->type == GGML_TASK_INIT) {
memset(params->wdata, 0, params->wsize);

// permute kernel data (src0) from [K, Cout, Cin] to [Cin, K, Cout]
// permute kernel data (src0) from [K, OC, IC] to [IC, K, OC]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;

for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
for (int64_t i00 = 0; i00 < ne00; i00++) {
dst_data[i00*ne02 + i02] = src[i00];
for (int64_t ic = 0; ic < IC; ic++) {
for (int64_t oc = 0; oc < OC; oc++) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + ic*nb02 + oc*nb01);
ggml_fp16_t * dst_data = wdata + oc*K*IC;
for (int64_t k = 0; k < K; k++) {
dst_data[k*IC + ic] = src[k];
}
}
}
}

// permute source data (src1) from [L, Cin] to [Cin, L]
// permute source data (src1) from [L, IC] to [IC, L]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne00*ne01*ne02;
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + K*OC*IC;
ggml_fp16_t * dst_data = wdata;

for (int64_t i11 = 0; i11 < ne11; i11++) {
const float * const src = (float *)((char *) src1->data + i11*nb11);
for (int64_t i10 = 0; i10 < ne10; i10++) {
dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
for (int64_t ic = 0; ic < IC; ic++) {
const float * const src = (float *)((char *) src1->data + ic*nb11);
for (int64_t l = 0; l < L; l++) {
dst_data[l*IC + ic] = GGML_FP32_TO_FP16(src[l]);
}
}
}

// need to zero dst since we are accumulating into it
memset(dst->data, 0, ggml_nbytes(dst));

return;
}

Expand All @@ -14436,7 +14446,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];

// total rows in dst
const int nr = ne1;
const int nr = OC;

// rows per thread
const int dr = (nr + nth - 1)/nth;
Expand All @@ -14445,19 +14455,17 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
ggml_fp16_t * const wdata_src = wdata + ne00*ne01*ne02;
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;

for (int i1 = ir0; i1 < ir1; i1++) {
float * dst_data = (float *)((char *) dst->data + i1*nb1);
ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
for (int i10 = 0; i10 < ne10; i10++) {
for (int i00 = 0; i00 < ne00; i00++) {
for (int oc = ir0; oc < ir1; oc++) {
float * dst_data = (float *)((char *) dst->data + oc*nb1);
for (int l = 0; l < L; l++) {
for (int k = 0; k < K; k++) {
float v = 0;
ggml_vec_dot_f16(ne02, &v,
(ggml_fp16_t *) wdata_src + i10*ne11,
(ggml_fp16_t *) wdata_kernel + i00*ne02);
dst_data[i10*s0 + i00] += v;
ggml_vec_dot_f16(IC, &v,
wdata + oc*K*IC + k*IC,
wdata + IC*K*OC + l*IC);
dst_data[l*s0 + k] += v;
}
}
}
Expand Down

0 comments on commit d5b778d

Please sign in to comment.