Skip to content

Commit

Permalink
Remove internal ml argument and move GPU code to Acc visitor (#919)
Browse files Browse the repository at this point in the history
  • Loading branch information
olupton authored Aug 31, 2022
1 parent aa5046e commit 60249f1
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 56 deletions.
42 changes: 36 additions & 6 deletions src/codegen/codegen_acc_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,26 @@ void CodegenAccVisitor::print_newtonspace_transfer_to_device() const {
}


void CodegenAccVisitor::print_instance_variable_transfer_to_device(
std::vector<std::string> const& ptr_members) const {
void CodegenAccVisitor::print_instance_struct_transfer_routine_declarations() {
if (info.artificial_cell) {
return;
}
printer->fmt_line(
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst);",
instance_struct());
printer->fmt_line("static inline void delete_instance_from_device({}* inst);",
instance_struct());
}


void CodegenAccVisitor::print_instance_struct_transfer_routines(
std::vector<std::string> const& ptr_members) {
if (info.artificial_cell) {
return;
}
printer->fmt_start_block(
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst)",
instance_struct());
printer->start_block("if (!nt->compute_gpu)");
printer->add_line("return;");
printer->end_block(1);
Expand All @@ -285,18 +300,33 @@ void CodegenAccVisitor::print_instance_variable_transfer_to_device(
printer->add_line("auto* d_ml = cnrn_target_deviceptr(ml);");
printer->add_line("void* d_inst_void = d_inst;");
printer->add_line("cnrn_target_memcpy_to_device(&(d_ml->instance), &d_inst_void);");
printer->end_block(2); // copy_instance_to_device

printer->fmt_start_block("static inline void delete_instance_from_device({}* inst)",
instance_struct());
printer->start_block("if (cnrn_target_is_present(inst))");
printer->add_line("cnrn_target_delete(inst);");
printer->end_block(1);
printer->end_block(2); // delete_instance_from_device
}


void CodegenAccVisitor::print_instance_variable_deletion_from_device() const {
void CodegenAccVisitor::print_instance_struct_copy_to_device() {
if (info.artificial_cell) {
return;
}
printer->start_block("if (cnrn_target_is_present(&inst))");
printer->add_line("cnrn_target_delete(&inst);");
printer->end_block(1);
printer->add_line("copy_instance_to_device(nt, ml, inst);");
}


void CodegenAccVisitor::print_instance_struct_delete_from_device() {
if (info.artificial_cell) {
return;
}
printer->add_line("delete_instance_from_device(inst);");
}


void CodegenAccVisitor::print_deriv_advance_flag_transfer_to_device() const {
printer->add_line("nrn_pragma_acc(update device (deriv_advance_flag) if(nt->compute_gpu))");
printer->add_line("nrn_pragma_omp(target update to(deriv_advance_flag) if(nt->compute_gpu))");
Expand Down
15 changes: 10 additions & 5 deletions src/codegen/codegen_acc_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,17 @@ class CodegenAccVisitor: public CodegenCVisitor {
/// transfer newtonspace structure to device
void print_newtonspace_transfer_to_device() const override;

/// copy the instance struct to the device
void print_instance_variable_transfer_to_device(
std::vector<std::string> const& ptr_members) const override;
/// declare helper functions for copying the instance struct to the device
void print_instance_struct_transfer_routine_declarations() override;

/// delete the instance struct from the device
void print_instance_variable_deletion_from_device() const override;
/// define helper functions for copying the instance struct to the device
void print_instance_struct_transfer_routines(std::vector<std::string> const&) override;

/// call helper function for copying the instance struct to the device
void print_instance_struct_copy_to_device() override;

/// call helper function that deletes the instance struct from the device
void print_instance_struct_delete_from_device() override;

// update derivimplicit advance flag on the gpu device
void print_deriv_advance_flag_transfer_to_device() const override;
Expand Down
45 changes: 10 additions & 35 deletions src/codegen/codegen_c_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1041,17 +1041,6 @@ void CodegenCVisitor::print_channel_iteration_tiling_block_end() {
}


void CodegenCVisitor::print_instance_variable_transfer_to_device(
std::vector<std::string> const& ptr_members) const {
// backend specific, do nothing
}


void CodegenCVisitor::print_instance_variable_deletion_from_device() const {
// backend specific, do nothing
}


void CodegenCVisitor::print_deriv_advance_flag_transfer_to_device() const {
// backend specific, do nothing
}
Expand Down Expand Up @@ -1892,9 +1881,9 @@ void CodegenCVisitor::print_eigen_linear_solver(const std::string& float_type, i

std::string CodegenCVisitor::internal_method_arguments() {
if (ion_variable_struct_required()) {
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, ml, v";
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
}
return "id, pnodecount, inst, data, indexes, thread, nt, ml, v";
return "id, pnodecount, inst, data, indexes, thread, nt, v";
}


Expand Down Expand Up @@ -1926,7 +1915,6 @@ CodegenCVisitor::ParamVector CodegenCVisitor::internal_method_parameters() {
params.emplace_back("const ", "Datum*", "", "indexes");
params.emplace_back(param_type_qualifier(), "ThreadDatum*", "", "thread");
params.emplace_back(param_type_qualifier(), "NrnThread*", param_ptr_qualifier(), "nt");
params.emplace_back(param_type_qualifier(), "Memb_list*", param_ptr_qualifier(), "ml");
params.emplace_back("", "double", "", "v");
return params;
}
Expand Down Expand Up @@ -1961,9 +1949,9 @@ std::string CodegenCVisitor::nrn_thread_arguments() {
*/
std::string CodegenCVisitor::nrn_thread_internal_arguments() {
if (ion_variable_struct_required()) {
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, ml, v";
return "id, pnodecount, inst, ionvar, data, indexes, thread, nt, v";
}
return "id, pnodecount, inst, data, indexes, thread, nt, ml, v";
return "id, pnodecount, inst, data, indexes, thread, nt, v";
}


Expand Down Expand Up @@ -3200,18 +3188,15 @@ void CodegenCVisitor::print_instance_variable_setup() {
printer->fmt_line("assert(ml->global_variables_size == sizeof({}));", global_struct());
};

printer->fmt_line(
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst);",
instance_struct());
printer->fmt_line("static inline void delete_instance_from_device({}& inst);",
instance_struct());
printer->add_newline();
// Must come before print_instance_struct_copy_to_device and
// print_instance_struct_delete_from_device
print_instance_struct_transfer_routine_declarations();

printer->add_line("// Deallocate the instance structure");
printer->fmt_start_block("static void {}(NrnThread* nt, Memb_list* ml, int type)",
method_name(naming::NRN_PRIVATE_DESTRUCTOR_METHOD));
cast_inst_and_assert_validity();
printer->add_line("delete_instance_from_device(*inst);");
print_instance_struct_delete_from_device();
printer->add_line("delete inst;");
printer->add_line("ml->instance = nullptr;");
printer->add_line("ml->global_variables = nullptr;");
Expand Down Expand Up @@ -3269,20 +3254,10 @@ void CodegenCVisitor::print_instance_variable_setup() {
printer->fmt_line("inst->{} = {};", name, variable);
ptr_members.push_back(std::move(name));
}
printer->add_line("copy_instance_to_device(nt, ml, inst);");
print_instance_struct_copy_to_device();
printer->end_block(2); // setup_instance

printer->add_line("// Set up the device-side copy of the instance structure");
printer->fmt_start_block(
"static inline void copy_instance_to_device(NrnThread* nt, Memb_list* ml, {} const* inst)",
instance_struct());
print_instance_variable_transfer_to_device(ptr_members);
printer->end_block(2); // copy_instance_to_device

printer->fmt_start_block("static inline void delete_instance_from_device({}& inst)",
instance_struct());
print_instance_variable_deletion_from_device();
printer->end_block(2); // delete_instance_from_device
print_instance_struct_transfer_routines(ptr_members);
}


Expand Down
34 changes: 27 additions & 7 deletions src/codegen/codegen_c_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,19 +1080,39 @@ class CodegenCVisitor: public visitor::ConstAstVisitor {


/**
* Print the code to copy instance struct members to the device,
* substituting host pointers for device ones.
* Print declarations of the functions used by \ref
* print_instance_struct_copy_to_device and \ref
* print_instance_struct_delete_from_device.
*/
virtual void print_instance_struct_transfer_routine_declarations() {}

/**
* Print the definitions of the functions used by \ref
* print_instance_struct_copy_to_device and \ref
* print_instance_struct_delete_from_device. Declarations of these functions
* are printed by \ref print_instance_struct_transfer_routine_declarations.
*
* This updates the (pointer) member variables in the device copy of the
* instance struct to contain device pointers, which is why you must pass a
* list of names of those member variables.
*
* \param ptr_members Members to update.
* \param ptr_members List of instance struct member names.
*/
virtual void print_instance_variable_transfer_to_device(
std::vector<std::string> const& ptr_members) const;
virtual void print_instance_struct_transfer_routines(
std::vector<std::string> const& /* ptr_members */) {}


/**
* Print the code to delete the instance structure from the device.
* Transfer the instance struct to the device. This calls a function
* declared by \ref print_instance_struct_transfer_routine_declarations.
*/
virtual void print_instance_struct_copy_to_device() {}

/**
* Delete the instance struct from the device. This calls a function
* declared by \ref print_instance_struct_transfer_routine_declarations.
*/
virtual void print_instance_variable_deletion_from_device() const;
virtual void print_instance_struct_delete_from_device() {}


/**
Expand Down
3 changes: 0 additions & 3 deletions test/unit/codegen/codegen_c_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
inst->ion_cao = nt->_data;
inst->ion_ica = nt->_data;
inst->ion_dicadv = nt->_data;
copy_instance_to_device(nt, ml, inst);
}
)";
auto const expected = reindent_text(generated_code);
Expand Down Expand Up @@ -151,7 +150,6 @@ SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
inst->v_unused = ml->data+4*pnodecount;
inst->ion_cai = nt->_data;
inst->ion_cao = nt->_data;
copy_instance_to_device(nt, ml, inst);
}
)";

Expand Down Expand Up @@ -231,7 +229,6 @@ SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
inst->ion_ilca = nt->_data;
inst->ion_elca = nt->_data;
inst->style_lca = ml->pdata;
copy_instance_to_device(nt, ml, inst);
}
)";

Expand Down

0 comments on commit 60249f1

Please sign in to comment.