Skip to content

Commit

Permalink
Prevent deepcopy in subscript in for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
advikkabra committed Jul 29, 2024
1 parent 42f385f commit eb59248
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5808,6 +5808,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
std::string loop_src_var_name = "";
ASR::expr_t *loop_end = nullptr, *loop_start = nullptr, *inc = nullptr;
ASR::expr_t *for_iter_type = nullptr;
ASR::expr_t *loop_src_var = nullptr;
if (AST::is_a<AST::Call_t>(*x.m_iter)) {
AST::Call_t *c = AST::down_cast<AST::Call_t>(x.m_iter);
std::string call_name;
Expand Down Expand Up @@ -5852,6 +5853,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {

if (ASR::is_a<ASR::Dict_t>(*loop_src_var_ttype) ||
ASR::is_a<ASR::Set_t>(*loop_src_var_ttype)) {
loop_src_var = ASRUtils::EXPR(
ASR::make_Var_t(al, x.base.base.loc, current_scope->resolve_symbol(loop_src_var_name)));
is_explicit_iterator_required = false;
for_each = true;
} else {
Expand All @@ -5867,34 +5870,35 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
visit_Subscript(*sbt);
ASR::expr_t *target = ASRUtils::EXPR(tmp);
ASR::ttype_t *loop_src_var_ttype = ASRUtils::expr_type(target);
// Create a temporary variable that will contain the evaluated value of Subscript
std::string tmp_assign_name = current_scope->get_unique_name("__tmp_assign_for_loop", false);
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, loop_src_var_ttype);
ASR::asr_t* tmp_assign_variable = ASR::make_Variable_t(al, sbt->base.base.loc, current_scope,
s2c(al, tmp_assign_name), variable_dependencies_vec.p, variable_dependencies_vec.size(),
ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default,
loop_src_var_ttype, nullptr, ASR::abiType::Source, ASR::accessType::Public, ASR::presenceType::Required, false
);
ASR::symbol_t *tmp_assign_variable_sym = ASR::down_cast<ASR::symbol_t>(tmp_assign_variable);
current_scope->add_symbol(tmp_assign_name, tmp_assign_variable_sym);

// Assign the Subscript expr to temporary variable
ASR::asr_t* assign = ASR::make_Assignment_t(al, x.base.base.loc,
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, tmp_assign_variable_sym)),
target, nullptr);
if (current_body != nullptr) {
current_body->push_back(al, ASRUtils::STMT(assign));
} else {
global_init.push_back(al, assign);
}
loop_src_var_name = tmp_assign_name;
if (ASR::is_a<ASR::Dict_t>(*loop_src_var_ttype) ||
ASR::is_a<ASR::Set_t>(*loop_src_var_ttype)) {
loop_src_var = target;
is_explicit_iterator_required = false;
for_each = true;
} else {
// Create a temporary variable that will contain the evaluated value of Subscript
std::string tmp_assign_name = current_scope->get_unique_name("__tmp_assign_for_loop", false);
SetChar variable_dependencies_vec;
variable_dependencies_vec.reserve(al, 1);
ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, loop_src_var_ttype);
ASR::asr_t* tmp_assign_variable = ASR::make_Variable_t(al, sbt->base.base.loc, current_scope,
s2c(al, tmp_assign_name), variable_dependencies_vec.p, variable_dependencies_vec.size(),
ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default,
loop_src_var_ttype, nullptr, ASR::abiType::Source, ASR::accessType::Public, ASR::presenceType::Required, false
);
ASR::symbol_t *tmp_assign_variable_sym = ASR::down_cast<ASR::symbol_t>(tmp_assign_variable);
current_scope->add_symbol(tmp_assign_name, tmp_assign_variable_sym);

// Assign the Subscript expr to temporary variable
ASR::asr_t* assign = ASR::make_Assignment_t(al, x.base.base.loc,
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, tmp_assign_variable_sym)),
target, nullptr);
if (current_body != nullptr) {
current_body->push_back(al, ASRUtils::STMT(assign));
} else {
global_init.push_back(al, assign);
}
loop_src_var_name = tmp_assign_name;
loop_end = for_iterable_helper(loop_src_var_name, x.base.base.loc, explicit_iter_name);
for_iter_type = loop_end;
LCOMPILERS_ASSERT(loop_end);
Expand Down Expand Up @@ -6007,7 +6011,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {

if (for_each) {
current_scope = parent_scope;
ASR::expr_t* loop_src_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, current_scope->resolve_symbol(loop_src_var_name)));
tmp = ASR::make_ForEach_t(al, x.base.base.loc, target, loop_src_var, body.p, body.size());
for_each = false;
return;
Expand Down

0 comments on commit eb59248

Please sign in to comment.