Skip to content

Commit

Permalink
better variable namings
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 21, 2023
1 parent a7e0350 commit 9d10023
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14392,15 +14392,13 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
const int ith = params->ith;
const int nth = params->nth;

const int nk = ne00*ne01*ne02;

GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));

if (params->type == GGML_TASK_INIT) {
memset(params->wdata, 0, params->wsize);

// permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
// permute kernel data (src0) from [K, Cout, Cin] to [Cin, K, Cout]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;

Expand All @@ -14415,9 +14413,9 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
}
}

// permute source data (src1) from (L x Cin) to (Cin x L)
// permute source data (src1) from [L, Cin] to [Cin, L]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + ne00*ne01*ne02;
ggml_fp16_t * dst_data = wdata;

for (int64_t i11 = 0; i11 < ne11; i11++) {
Expand All @@ -14435,7 +14433,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
return;
}

const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];

// total rows in dst
const int nr = ne1;
Expand All @@ -14448,17 +14446,16 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
const int ir1 = MIN(ir0 + dr, nr);

ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
ggml_fp16_t * const wdata_src = wdata + nk;
ggml_fp16_t * const wdata_src = wdata + ne00*ne01*ne02;

for (int i1 = ir0; i1 < ir1; i1++) {
float * dst_data = (float *)((char *) dst->data + i1*nb1);
ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
for (int i10 = 0; i10 < ne10; i10++) {
const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
ggml_vec_dot_f16(ne02, &v,
(ggml_fp16_t *) wdata_src + i1n,
(ggml_fp16_t *) wdata_src + i10*ne11,
(ggml_fp16_t *) wdata_kernel + i00*ne02);
dst_data[i10*s0 + i00] += v;
}
Expand All @@ -14483,38 +14480,43 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
const int ith = params->ith;
const int nth = params->nth;

const int nk = ne00*ne01*ne02;

GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));

const int64_t K = ne00;
const int64_t OC = ne01;
const int64_t IC = ne02;

const int64_t L = ne10;
GGML_ASSERT(IC == ne11);

if (params->type == GGML_TASK_INIT) {
memset(params->wdata, 0, params->wsize);

// prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
// reshape kernel data (src0) from [K, OC, IC] to [IC, K, OC]
{
float * const wdata = (float *) params->wdata + 0;

for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
float * dst_data = wdata + i01*ne00*ne02;
for (int64_t i00 = 0; i00 < ne00; i00++) {
dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
for (int64_t ic = 0; ic < IC; ic++) {
for (int64_t oc = 0; oc < OC; oc++) {
const float * const src = (float *)((char *) src0->data + ic*nb02 + oc*nb01);
float * dst_data = wdata + oc*K*IC;
for (int64_t k = 0; k < K; k++) {
dst_data[k*IC + ic] = src[k];
}
}
}
}

// prepare source data (src1)
// reshape source data (src1) from [L, IC] to [IC, L]
{
float * const wdata = (float *) params->wdata + nk;
float * const wdata = (float *) params->wdata + IC*K*OC;
float * dst_data = wdata;

for (int64_t i11 = 0; i11 < ne11; i11++) {
const float * const src = (float *)((char *) src1->data + i11*nb11);
for (int64_t i10 = 0; i10 < ne10; i10++) {
dst_data[i10*ne11 + i11] = src[i10];
for (int64_t ic = 0; ic < IC; ic++) {
const float * const src = (float *)((char *) src1->data + ic*nb11);
for (int64_t l = 0; l < L; l++) {
dst_data[l*IC + ic] = src[l];
}
}
}
Expand All @@ -14526,10 +14528,10 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
return;
}

const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];

// total rows in dst
const int nr = ne1;
const int nr = OC;

// rows per thread
const int dr = (nr + nth - 1)/nth;
Expand All @@ -14538,20 +14540,17 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

float * const wdata = (float *) params->wdata + 0;
float * const wdata_src = wdata + nk;
float * const wdata = (float *) params->wdata;

for (int i1 = ir0; i1 < ir1; i1++) {
float * dst_data = (float *)((char *) dst->data + i1*nb1);
float * wdata_kernel = wdata + i1*ne02*ne00;
for (int i10 = 0; i10 < ne10; i10++) {
const int i1n = i10*ne11;
for (int i00 = 0; i00 < ne00; i00++) {
for (int oc = ir0; oc < ir1; oc++) {
float * dst_data = (float *) ((char *) dst->data + oc*nb1);
for (int l = 0; l < L; l++) {
for (int k = 0; k < K; k++) {
float v = 0;
ggml_vec_dot_f32(ne02, &v,
wdata_src + i1n,
wdata_kernel + i00*ne02);
dst_data[i10*s0 + i00] += v;
ggml_vec_dot_f32(IC, &v,
wdata + oc*K*IC + k*IC,
wdata + IC*K*OC + l*IC);
dst_data[l*s0 + k] += v;
}
}
}
Expand Down

0 comments on commit 9d10023

Please sign in to comment.