diff --git a/src/ParticleActions.cxx b/src/ParticleActions.cxx index 83cf8b6..0894e0d 100644 --- a/src/ParticleActions.cxx +++ b/src/ParticleActions.cxx @@ -27,11 +27,27 @@ float FGridEvalPoly(float r2) namespace HACCabana { -ParticleActions::ParticleActions() {}; +ParticleActions::ParticleActions() +{ + aosoa_device = aosoa_device_type("aosoa_device", 0); +}; -ParticleActions::ParticleActions(Particles *P_) : P(P_) +ParticleActions::ParticleActions(Particles *P_, const float cm_size, const float min_pos, const float max_pos ) : P(P_) { - ; + aosoa_device = aosoa_device_type("aosoa_device", P->num_p); + auto position = Cabana::slice(aosoa_device, "position"); + + // create the cell list on the GPU + // NOTE: fuzz particles (outside of overload) are not included + float dx = cm_size; + float x_min = min_pos; + float x_max = max_pos; + + float grid_delta[3] = {dx, dx, dx}; + float grid_min[3] = {x_min, x_min, x_min}; + float grid_max[3] = {x_max, x_max, x_max}; + + cell_list = neighbor_type(position, P->begin, P->end, grid_delta, grid_min, grid_max); }; ParticleActions::~ParticleActions() @@ -136,25 +152,14 @@ void ParticleActions::updateVel(\ Kokkos::fence(); } -void ParticleActions::subCycle(TimeStepper &ts, const int nsub, const float gpscal, const float rmax2, const float rsm2, - const float cm_size, const float min_pos, const float max_pos) +void ParticleActions::subCycle(TimeStepper &ts, const int nsub, const float gpscal, const float rmax2, const float rsm2) { // copy particles to GPU - Cabana::AoSoA aosoa_device("aosoa_device", P->num_p); + aosoa_device.resize(P->num_p); Cabana::deep_copy(aosoa_device, P->aosoa_host); - // create the cell list on the GPU - // NOTE: fuzz particles (outside of overload) are not included - float dx = cm_size; - float x_min = min_pos; - float x_max = max_pos; - - float grid_delta[3] = {dx, dx, dx}; - float grid_min[3] = {x_min, x_min, x_min}; - float grid_max[3] = {x_max, x_max, x_max}; - auto position = Cabana::slice(aosoa_device, "position"); - Cabana::LinkedCellList cell_list(position, P->begin, P->end, grid_delta, grid_min, grid_max); + cell_list.build(position); Cabana::permute(cell_list, aosoa_device); Kokkos::fence(); diff --git a/src/ParticleActions.h b/src/ParticleActions.h index 0b1d45a..c0fc430 100644 --- a/src/ParticleActions.h +++ b/src/ParticleActions.h @@ -18,19 +18,23 @@ namespace HACCabana { private: Particles *P; + aosoa_device_type aosoa_device; public: using device_exec = Kokkos::DefaultExecutionSpace::execution_space; using device_mem = Kokkos::DefaultExecutionSpace::memory_space; using device_type = Kokkos::Device; + using aosoa_device_type = Cabana::AoSoA; + using neighbor_type = Cabana::LinkedCellList; //using device_scratch = Kokkos::ScratchMemorySpace; + neighbor_type cell_list; + ParticleActions(); - ParticleActions(Particles *P_); + ParticleActions(Particles *P_, const float cm_size, const float min_pos, const float max_pos); ~ParticleActions(); void setParticles(Particles *P_); - void subCycle(TimeStepper &ts, const int nsub, const float gpscal, const float rmax2, const float rsm2,\ - const float cm_size, const float min_pos, const float max_pos); + void subCycle(TimeStepper &ts, const int nsub, const float gpscal, const float rmax2, const float rsm2); void updatePos(Cabana::AoSoA aosoa_device,\ float prefactor); void updateVel(Cabana::AoSoA aosoa_device,\ diff --git a/src/driver_gpu.cxx b/src/driver_gpu.cxx index b7edcae..fed5c91 100644 --- a/src/driver_gpu.cxx +++ b/src/driver_gpu.cxx @@ -152,8 +152,8 @@ int main( int argc, char* argv[] ) P.reorder(min_alive_pos, max_alive_pos); // TODO:assumes local extent equals the global extent cout << "\t" << P.end-P.begin << " particles in [" << min_alive_pos << "," << max_alive_pos << "]" << endl; - HACCabana::ParticleActions PA(&P); - PA.subCycle(ts, Params.nsub, Params.gpscal, Params.rmax*Params.rmax, Params.rsm*Params.rsm, Params.cm_size, Params.oL, Params.rL+Params.oL); + HACCabana::ParticleActions PA(&P, Params.cm_size, Params.oL, Params.rL+Params.oL); + PA.subCycle(ts, Params.nsub, Params.gpscal, Params.rmax*Params.rmax, Params.rsm*Params.rsm); // verify against the answer from the simulation // --------------------------------------------------------------------------------------------------------------------------