Skip to content

Commit

Permalink
optimize memory
Browse files Browse the repository at this point in the history
  • Loading branch information
liruipeng committed Mar 28, 2024
1 parent 2462a81 commit 4a651ca
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
34 changes: 25 additions & 9 deletions bbb/jaccalc.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ jac_calc_seq_c_(Int *neq,
return 0;
}

#define UEDGE_WITH_OMP 1

#if defined(UEDGE_WITH_OMP)

#define USE_OMP_VERSION 1
Expand All @@ -54,6 +56,8 @@ typedef struct
Int num_threads;
Int neq;
Int nnzmx;
Int nnzmx_t;
real nnzmx_f;
real *yl;
real *yldot00;
real *wk;
Expand Down Expand Up @@ -92,26 +96,33 @@ int
jac_calc_omp_data_init(Int num_threads,
Int neq,
Int nnzmx,
real nnzmx_f,
jac_calc_omp_data *data)
{
if (num_threads == data -> num_threads &&
neq == data -> neq &&
nnzmx == data -> nnzmx)
nnzmx == data -> nnzmx &&
nnzmx_f == data -> nnzmx_f)
{
return 0;
}

const Int nnzmx_0 = (nnzmx + num_threads - 1) / num_threads;
const Int nnzmx_t = (Int) (nnzmx * nnzmx_f + nnzmx_0 * (1.0 - nnzmx_f));

data -> num_threads = num_threads;
data -> neq = neq;
data -> nnzmx = nnzmx;
data -> nnzmx_t = nnzmx_t;
data -> nnzmx_f = nnzmx_f;

Int nt = num_threads - 1;

data -> yl = (real *) realloc(data -> yl, nt * (neq + 2) * sizeof(real));
data -> yldot00 = (real *) realloc(data -> yldot00, nt * (neq + 2) * sizeof(real));
data -> wk = (real *) realloc(data -> wk, nt * neq * sizeof(real));
data -> jac = (real *) realloc(data -> jac, nt * nnzmx * sizeof(real));
data -> ja = (Int *) realloc(data -> ja, nt * nnzmx * sizeof(Int));
data -> jac = (real *) realloc(data -> jac, nt * nnzmx_t * sizeof(real));
data -> ja = (Int *) realloc(data -> ja, nt * nnzmx_t * sizeof(Int));
data -> yldot_pert = (real *) realloc(data -> yldot_pert, nt * neq * sizeof(real));

return 0;
Expand Down Expand Up @@ -149,15 +160,16 @@ jac_calc_omp_get_thread_data(Int thread_id,
thread_data -> num_threads = data -> num_threads;
thread_data -> neq = data -> neq;
thread_data -> nnzmx = data -> nnzmx;
thread_data -> nnzmx_t = data -> nnzmx_t;

if (thread_id)
{
Int j, tid = thread_id - 1;
thread_data -> yl = data -> yl + tid * (data -> neq + 2);
thread_data -> yldot00 = data -> yldot00 + tid * (data -> neq + 2);
thread_data -> wk = data -> wk + tid * (data -> neq);
thread_data -> jac = data -> jac + tid * (data -> nnzmx);
thread_data -> ja = data -> ja + tid * (data -> nnzmx);
thread_data -> jac = data -> jac + tid * (data -> nnzmx_t);
thread_data -> ja = data -> ja + tid * (data -> nnzmx_t);
thread_data -> yldot_pert = data -> yldot_pert + tid * (data -> neq);

for (j = 0; j < data -> neq + 2; j++)
Expand Down Expand Up @@ -197,6 +209,7 @@ jac_calc_omp_c_(Int *neq,
Int *mu,
real *wk,
Int *nnzmx,
real nnzmx_f,
real *jac,
Int *ja,
Int *ia,
Expand All @@ -209,10 +222,12 @@ jac_calc_omp_c_(Int *neq,
{
printf(" =============================================\n"
" Jac_calc OpenMP C version, Num. Threads = %ld\n"
" ** n = %ld, nnzmx = %ld, nnzmx_f = %.2f ** \n"
#if SEQ_CHECK
" ** Check with serial version is ON ** \n"
#endif
" =============================================\n", num_threads);
" =============================================\n",
num_threads, *neq, *nnzmx, nnzmx_f);
}

#if SEQ_CHECK
Expand All @@ -230,7 +245,7 @@ jac_calc_omp_c_(Int *neq,
{
jcod = jac_calc_omp_data_create();
}
jac_calc_omp_data_init(num_threads, *neq, *nnzmx, jcod);
jac_calc_omp_data_init(num_threads, *neq, *nnzmx, nnzmx_f, jcod);

Int *neq_all = (Int *) malloc((num_threads + 1) * sizeof(Int));
Int *nnz_all = (Int *) malloc((num_threads + 1) * sizeof(Int));
Expand All @@ -248,7 +263,7 @@ jac_calc_omp_c_(Int *neq,
for (iv = iv_start + 1; iv <= iv_end; iv++)
{
jac_calc_iv_(&iv, &jcod_t.neq, &t_t, jcod_t.yl, jcod_t.yldot00, &ml_t, &mu_t, jcod_t.wk,
&jcod_t.nnzmx, jcod_t.jac, jcod_t.ja, ia, jcod_t.yldot_pert, &nnz_t);
&jcod_t.nnzmx_t, jcod_t.jac, jcod_t.ja, ia, jcod_t.yldot_pert, &nnz_t);
}

if (PRINT_LEVEL > 1)
Expand Down Expand Up @@ -321,6 +336,7 @@ jac_calc_c_(Int *neq,
Int *mu,
real *wk,
Int *nnzmx,
real *nnzmx_f,
real *jac,
Int *ja,
Int *ia,
Expand All @@ -333,7 +349,7 @@ jac_calc_c_(Int *neq,
#endif

#if USE_OMP_VERSION
int ret = jac_calc_omp_c_(neq, t, yl, yldot00, ml, mu, wk, nnzmx, jac, ja, ia, yldot_pert, nnz);
int ret = jac_calc_omp_c_(neq, t, yl, yldot00, ml, mu, wk, nnzmx, *nnzmx_f, jac, ja, ia, yldot_pert, nnz);
#else
int ret = jac_calc_seq_c_(neq, t, yl, yldot00, ml, mu, wk, nnzmx, jac, ja, ia, yldot_pert, nnz);
#endif
Expand Down
2 changes: 1 addition & 1 deletion bbb/oderhs.m
Original file line number Diff line number Diff line change
Expand Up @@ -8792,7 +8792,7 @@ c jcsc(neq+1) = nnz
c##############################################################

call jac_calc_c(neq, t, yl, yldot00, ml, mu, wk,
. nnzmx, rcsc, icsc, jcsc, yldot_pert, nnz)
. nnzmx, 0.0, rcsc, icsc, jcsc, yldot_pert, nnz)

c ... Convert Jacobian from compressed sparse column to compressed
c sparse row format.
Expand Down

0 comments on commit 4a651ca

Please sign in to comment.