Skip to content

Commit

Permalink
is measure
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 2, 2023
1 parent 3055096 commit cf85c76
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions encodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ static struct ggml_tensor * forward_pass_lstm_unilayer(
struct ggml_tensor * weight_ih,
struct ggml_tensor * weight_hh,
struct ggml_tensor * bias_ih,
struct ggml_tensor * bias_hh) {
struct ggml_tensor * bias_hh,
bool is_measure) {

const int input_dim = inp->ne[1];
const int hidden_dim = weight_ih->ne[1]/4;
Expand All @@ -151,8 +152,10 @@ static struct ggml_tensor * forward_pass_lstm_unilayer(
struct ggml_tensor * c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim);
struct ggml_tensor * h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim);

h_t = ggml_set_zero(h_t);
c_t = ggml_set_zero(c_t);
if (is_measure) {
h_t = ggml_set_zero(h_t);
c_t = ggml_set_zero(c_t);
}

struct ggml_tensor * current = ggml_cont(ctx0, ggml_transpose(ctx0, inp));

Expand All @@ -169,7 +172,7 @@ static struct ggml_tensor * forward_pass_lstm_unilayer(

struct ggml_tensor * i_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0*sizeof(float)*hidden_dim));
struct ggml_tensor * f_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1*sizeof(float)*hidden_dim));
struct ggml_tensor * g_t = ggml_tanh (ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 2*sizeof(float)*hidden_dim));
struct ggml_tensor * g_t = ggml_tanh (ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 2*sizeof(float)*hidden_dim));
struct ggml_tensor * o_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3*sizeof(float)*hidden_dim));

c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t));
Expand Down Expand Up @@ -612,11 +615,13 @@ static struct ggml_cgraph * encodec_build_graph(

// first lstm layer
struct ggml_tensor * hs1 = forward_pass_lstm_unilayer(
ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b);
ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b,
ggml_allocr_is_measure(ectx.allocr));

// second lstm layer
struct ggml_tensor * out = forward_pass_lstm_unilayer(
ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b);
ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b,
ggml_allocr_is_measure(ectx.allocr));

inpL = ggml_add(ctx0, inpL, out);
}
Expand Down Expand Up @@ -718,7 +723,8 @@ static struct ggml_cgraph * encodec_build_graph(
const int stride = hparams.stride;

struct ggml_tensor * inpL = strided_conv_1d(
ctx0, quantized_out, model.decoder.init_conv_w, model.decoder.init_conv_b, stride);
ctx0, quantized_out, model.decoder.init_conv_w,
model.decoder.init_conv_b, stride);

// lstm
{
Expand All @@ -728,11 +734,13 @@ static struct ggml_cgraph * encodec_build_graph(

// first lstm layer
struct ggml_tensor * hs1 = forward_pass_lstm_unilayer(
ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b);
ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b,
ggml_allocr_is_measure(ectx.allocr));

// second lstm layer
struct ggml_tensor * out = forward_pass_lstm_unilayer(
ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b);
ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b,
ggml_allocr_is_measure(ectx.allocr));

inpL = ggml_add(ctx0, inpL, out);
}
Expand Down

0 comments on commit cf85c76

Please sign in to comment.