diff --git a/src/tensor/untyped_tensor.cxx b/src/tensor/untyped_tensor.cxx index 8d3ed7cd..bced3ac2 100644 --- a/src/tensor/untyped_tensor.cxx +++ b/src/tensor/untyped_tensor.cxx @@ -737,9 +737,10 @@ namespace CTF_int { this->data = sr->alloc(this->size); //CTF_int::alloc_ptr(this->size*sr->el_size, (void**)&this->data); #endif - #if DEBUG >= 2 + #if DEBUG >= 1 if (wrld->rank == 0) printf("New tensor %s defined of size %ld elms (%ld bytes):\n",name, this->size,this->size*sr->el_size); + this->print_lens(); this->print_map(stdout); #endif sr->init(this->size, this->data); @@ -1789,6 +1790,7 @@ namespace CTF_int { } bool can_merge_split(tensor const * input, tensor const * output){ + if ((output->order == 1 || input->order == 1) && input->wrld->np > 1) return false; if (output->order != input->order){ if (output->is_sparse || input->is_sparse || output->has_symmetry() || input->has_symmetry()){ return false; @@ -1802,94 +1804,81 @@ namespace CTF_int { } return true; } - - int tensor::reshape(tensor const * old_tsr, char const * alpha, char const * beta){ + + static int reshape_tensor(tensor * new_tsr, tensor const * old_tsr, char const * alpha, char const * beta){ char * pairs; int64_t n; -#if DEBUG >=1 - if (this->wrld->rank == 0){ - printf("CTF: Performing reshape from shape "); - for (int i=0; iorder; i++){ - printf("%ld ",old_tsr->lens[i]); - } - printf("to shape "); - for (int i=0; iorder; i++){ - printf("%ld ",this->lens[i]); - } - printf("\n"); - } -#endif tensor * talias; - if (!this->is_sparse && !old_tsr->is_sparse){ - talias = this->get_no_unit_len_alias(); - if (talias != this){ + if (!new_tsr->is_sparse && !old_tsr->is_sparse){ + talias = new_tsr->get_no_unit_len_alias(); + if (talias != new_tsr){ int stat = talias->reshape(old_tsr, alpha, beta); delete talias; return stat; } talias = ((tensor*)old_tsr)->get_no_unit_len_alias(); if (talias != old_tsr){ - int stat = this->reshape(talias, alpha, beta); + int stat = new_tsr->reshape(talias, alpha, beta); delete talias; return stat; } } - if (beta == NULL || this->sr->isequal(beta,this->sr->addid())){ - bool did_lens_change = this->order != old_tsr->order; + if (beta == NULL || new_tsr->sr->isequal(beta,new_tsr->sr->addid())){ + bool did_lens_change = new_tsr->order != old_tsr->order; if (!did_lens_change){ for (int i=0; iorder; i++){ - if (old_tsr->lens[i] != this->lens[i]) + if (old_tsr->lens[i] != new_tsr->lens[i]) did_lens_change = true; } } if (!did_lens_change){ bool is_map_changed = false; - if (topo != old_tsr->topo) is_map_changed = true; + if (new_tsr->topo != old_tsr->topo) is_map_changed = true; //topo = old_tsr->topo; - for (int i=0; iedge_map+i)){ + for (int i=0; iorder; i++){ + if (!comp_dim_map(new_tsr->edge_map+i, old_tsr->edge_map+i)){ //edge_map[i].clear(); //copy_mapping(1, old_tsr->edge_map+i, edge_map+i); is_map_changed = true; } } if (!is_map_changed){ - if (!this->is_sparse){ + if (!new_tsr->is_sparse){ IASSERT(!old_tsr->is_sparse); - memcpy(this->data, old_tsr->data, this->sr->el_size*this->size); + memcpy(new_tsr->data, old_tsr->data, new_tsr->sr->el_size*new_tsr->size); } else { IASSERT(old_tsr->is_sparse); - this->set_zero(); - this->data = this->sr->pair_alloc(old_tsr->nnz_loc); - this->sr->copy_pairs(this->data, old_tsr->data, old_tsr->nnz_loc); - memcpy(this->nnz_blk, old_tsr->nnz_blk, old_tsr->calc_nvirt()*sizeof(int64_t)); - this->set_new_nnz_glb(this->nnz_blk); + new_tsr->set_zero(); + new_tsr->data = new_tsr->sr->pair_alloc(old_tsr->nnz_loc); + new_tsr->sr->copy_pairs(new_tsr->data, old_tsr->data, old_tsr->nnz_loc); + memcpy(new_tsr->nnz_blk, old_tsr->nnz_blk, old_tsr->calc_nvirt()*sizeof(int64_t)); + new_tsr->set_new_nnz_glb(new_tsr->nnz_blk); } return SUCCESS; } } } - if (can_merge_split(old_tsr, this)){ + if (can_merge_split(old_tsr, new_tsr)){ bool is_mode_merge = true; bool is_mode_split = true; int i=0,j=0; - while (iorder && jorder){ - if (this->lens[i] == old_tsr->lens[j]){ + while (iorder && jorder){ + if (new_tsr->lens[i] == old_tsr->lens[j]){ i++; j++; continue; } - int64_t sm_len = std::min(this->lens[i], old_tsr->lens[j]); - //printf("HERE [%d] %ld; [%d] %ld \n",i,this->lens[i], j,old_tsr->lens[j]); + int64_t sm_len = std::min(new_tsr->lens[i], old_tsr->lens[j]); + //printf("HERE [%d] %ld; [%d] %ld \n",i,new_tsr->lens[i], j,old_tsr->lens[j]); if (sm_len < old_tsr->lens[j]){ while (sm_len < old_tsr->lens[j]){ i++; - assert(iorder); - sm_len *= this->lens[i]; + assert(iorder); + sm_len *= new_tsr->lens[i]; is_mode_merge = false; - //printf("HERE2 [%d] %ld; [%d] %ld \n",i,this->lens[i], j,old_tsr->lens[j]); + //printf("HERE2 [%d] %ld; [%d] %ld \n",i,new_tsr->lens[i], j,old_tsr->lens[j]); } if (sm_len > old_tsr->lens[j]){ is_mode_merge = false; @@ -1897,14 +1886,14 @@ namespace CTF_int { break; } } else { - while (sm_len < this->lens[i]){ + while (sm_len < new_tsr->lens[i]){ j++; assert(jorder); sm_len *= old_tsr->lens[j]; is_mode_split = false; - //printf("HERE3 [%d] %ld; [%d] %ld \n",i,this->lens[i], j,old_tsr->lens[j]); + //printf("HERE3 [%d] %ld; [%d] %ld \n",i,new_tsr->lens[i], j,old_tsr->lens[j]); } - if (sm_len > this->lens[i]){ + if (sm_len > new_tsr->lens[i]){ is_mode_merge = false; is_mode_split = false; break; @@ -1914,38 +1903,62 @@ namespace CTF_int { j++; } if (is_mode_merge && !is_mode_split){ - return this->merge_modes((tensor*)old_tsr, alpha, beta); + return new_tsr->merge_modes((tensor*)old_tsr, alpha, beta); } if (!is_mode_merge && is_mode_split){ - return this->split_modes((tensor*)old_tsr, alpha, beta); - // tensor * copy_tsr = new tensor(this, 1, 1); - // tensor * res_tsr = new tensor(this, 0, 1); - // tensor * zero_tsr = new tensor(this->sr, 0, (int64_t*)NULL, NULL, this->wrld); + return new_tsr->split_modes((tensor*)old_tsr, alpha, beta); + // tensor * copy_tsr = new tensor(new_tsr, 1, 1); + // tensor * res_tsr = new tensor(new_tsr, 0, 1); + // tensor * zero_tsr = new tensor(new_tsr->sr, 0, (int64_t*)NULL, NULL, new_tsr->wrld); // int stat2 = copy_tsr->split_modes((tensor*)old_tsr, alpha, beta); - // if (beta == NULL || this->sr->isequal(beta,this->sr->addid())) - // this->set_zero(); + // if (beta == NULL || new_tsr->sr->isequal(beta,new_tsr->sr->addid())) + // new_tsr->set_zero(); // int stat = old_tsr->read_local_nnz(&n, &pairs, true); // //if (stat != SUCCESS) return stat; - // stat = this->write(n, alpha, beta, pairs, 'w'); - // this->sr->pair_dealloc(pairs); - // zero_tsr->operator[]("") = this->operator[](get_default_inds(this->order))-copy_tsr->operator[](get_default_inds(this->order)); - // res_tsr->operator[](get_default_inds(this->order)) = this->operator[](get_default_inds(this->order))-copy_tsr->operator[](get_default_inds(this->order)); + // stat = new_tsr->write(n, alpha, beta, pairs, 'w'); + // new_tsr->sr->pair_dealloc(pairs); + // zero_tsr->operator[]("") = new_tsr->operator[](get_default_inds(new_tsr->order))-copy_tsr->operator[](get_default_inds(new_tsr->order)); + // res_tsr->operator[](get_default_inds(new_tsr->order)) = new_tsr->operator[](get_default_inds(new_tsr->order))-copy_tsr->operator[](get_default_inds(new_tsr->order)); // char * val = new char[sr->el_size]; // res_tsr->reduce_sum(val); // printf("sum is ..."); // sr->print(val); // zero_tsr->print(); - // this->compare(copy_tsr, stdout, sr->addid()); + // new_tsr->compare(copy_tsr, stdout, sr->addid()); // printf("\n"); // return stat; } } - if (beta == NULL || this->sr->isequal(beta,this->sr->addid())) - this->set_zero(); + if (beta == NULL || new_tsr->sr->isequal(beta,new_tsr->sr->addid())) + new_tsr->set_zero(); int stat = old_tsr->read_local_nnz(&n, &pairs, true); if (stat != SUCCESS) return stat; - stat = this->write(n, alpha, beta, pairs, 'w'); - this->sr->pair_dealloc(pairs); + stat = new_tsr->write(n, alpha, beta, pairs, 'w'); + new_tsr->sr->pair_dealloc(pairs); + return stat; + } + + int tensor::reshape(tensor const * old_tsr, char const * alpha, char const * beta){ +#if DEBUG >=1 + if (this->wrld->rank == 0){ + printf("CTF: Performing reshape from shape "); + for (int i=0; iorder; i++){ + printf("%ld ",old_tsr->lens[i]); + } + printf("to shape "); + for (int i=0; iorder; i++){ + printf("%ld ",this->lens[i]); + } + printf("\n"); + } +#endif + #ifdef PROFILE_MEMORY + start_memprof(this->wrld->rank); + #endif + int stat = reshape_tensor(this,old_tsr,alpha,beta); + #ifdef PROFILE_MEMORY + stop_memprof(); + #endif return stat; }