Skip to content

Commit

Permalink
Code clean-up for binary collisions (#4921)
Browse files Browse the repository at this point in the history
* code clean-up for binary collisions

* fix variable shadowing

* pass particle tiles by reference
  • Loading branch information
roelof-groenewald authored May 13, 2024
1 parent ee51424 commit c70a6c5
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 76 deletions.
6 changes: 2 additions & 4 deletions Source/Particles/Collision/BinaryCollision/BinaryCollision.H
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class BinaryCollision final
using ParticleTileType = WarpXParticleContainer::ParticleTileType;
using ParticleTileDataType = ParticleTileType::ParticleTileDataType;
using ParticleBins = amrex::DenseBins<ParticleTileDataType>;
using SoaData_type = WarpXParticleContainer::ParticleTileType::ParticleTileDataType;
using index_type = ParticleBins::index_type;

public:
Expand Down Expand Up @@ -298,7 +297,6 @@ public:
auto dV = geom.CellSize(0) * geom.CellSize(1) * geom.CellSize(2);
#endif


/*
The following calculations are only required when creating product particles
*/
Expand Down Expand Up @@ -431,7 +429,7 @@ public:
// Create the new product particles and define their initial values
// num_added: how many particles of each product species have been created
const amrex::Vector<int> num_added = m_copy_transform_functor(n_total_pairs,
soa_1, soa_1,
ptile_1, ptile_1,
product_species_vector,
tile_products_data,
m1, m1,
Expand Down Expand Up @@ -648,7 +646,7 @@ public:
// Create the new product particles and define their initial values
// num_added: how many particles of each product species have been created
const amrex::Vector<int> num_added = m_copy_transform_functor(n_total_pairs,
soa_1, soa_2,
ptile_1, ptile_2,
product_species_vector,
tile_products_data,
m1, m2,
Expand Down
28 changes: 28 additions & 0 deletions Source/Particles/Collision/BinaryCollision/BinaryCollisionUtils.H
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,34 @@ namespace BinaryCollisionUtils{
// is calculated here.
lab_to_COM_lorentz_factor = g1_star*g2_star/static_cast<amrex::ParticleReal>(g1*g2);
}

/**
* \brief Subtract given weight from particle and set its ID to invalid
* if the weight reaches zero.
*/
AMREX_GPU_HOST_DEVICE AMREX_INLINE
void remove_weight_from_colliding_particle (
amrex::ParticleReal& weight, uint64_t& idcpu,
const amrex::ParticleReal reaction_weight )
{
// Remove weight from given particle
amrex::Gpu::Atomic::AddNoRet(&weight, -reaction_weight);

// If the colliding particle weight decreases to zero, remove particle by
// setting its id to invalid
if (weight <= std::numeric_limits<amrex::ParticleReal>::min())
{
#if defined(AMREX_USE_OMP)
#pragma omp atomic write
idcpu = amrex::ParticleIdCpus::Invalid;
#else
amrex::Gpu::Atomic::Exch(
(unsigned long long *)&idcpu,
(unsigned long long)amrex::ParticleIdCpus::Invalid
);
#endif
}
}
}

#endif // WARPX_BINARY_COLLISION_UTILS_H_
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ class SplitAndScatterFunc
using ParticleTileDataType = typename ParticleTileType::ParticleTileDataType;
using ParticleBins = amrex::DenseBins<ParticleTileDataType>;
using index_type = typename ParticleBins::index_type;
using SoaData_type = typename WarpXParticleContainer::ParticleTileType::ParticleTileDataType;

using SoaData_type = WarpXParticleContainer::ParticleTileType::ParticleTileDataType;

public:
/**
Expand All @@ -53,8 +52,7 @@ public:
AMREX_INLINE
amrex::Vector<int> operator() (
const index_type& n_total_pairs,
// Tile& ptile1, Tile& ptile2,
const SoaData_type& /*soa_1*/, const SoaData_type& /*soa_2*/,
ParticleTileType& ptile1, ParticleTileType& ptile2,
const amrex::Vector<WarpXParticleContainer*>& pc_products,
ParticleTileType** AMREX_RESTRICT tile_products,
const amrex::ParticleReal m1, const amrex::ParticleReal m2,
Expand Down Expand Up @@ -93,9 +91,8 @@ public:
tile_products[i]->resize(products_np[i] + num_added);
}

// this works for DSMC since the colliding particles are also products
const auto soa_1 = tile_products[0]->getParticleTileData();
const auto soa_2 = tile_products[1]->getParticleTileData();
const auto soa_1 = ptile1.getParticleTileData();
const auto soa_2 = ptile2.getParticleTileData();

amrex::ParticleReal* AMREX_RESTRICT w1 = soa_1.m_rdata[PIdx::w];
amrex::ParticleReal* AMREX_RESTRICT w2 = soa_2.m_rdata[PIdx::w];
Expand Down Expand Up @@ -155,36 +152,10 @@ public:
soa_products_data[1].m_rdata[PIdx::w][product2_index] = p_pair_reaction_weight[i];

// Remove p_pair_reaction_weight[i] from the colliding particles' weights
amrex::Gpu::Atomic::AddNoRet(&w1[p_pair_indices_1[i]],
-p_pair_reaction_weight[i]);
amrex::Gpu::Atomic::AddNoRet(&w2[p_pair_indices_2[i]],
-p_pair_reaction_weight[i]);

// Note: Particle::atomicSetID should also be provided as a standalone helper function in AMReX
// to replace the following lambda.
auto const atomicSetIdInvalid = [] AMREX_GPU_DEVICE (uint64_t & idcpu)
{
#if defined(AMREX_USE_OMP)
#pragma omp atomic write
idcpu = amrex::ParticleIdCpus::Invalid;
#else
amrex::Gpu::Atomic::Exch(
(unsigned long long *)&idcpu,
(unsigned long long)amrex::ParticleIdCpus::Invalid
);
#endif
};

// If the colliding particle weight decreases to zero, remove particle by
// setting its id to invalid
if (w1[p_pair_indices_1[i]] <= std::numeric_limits<amrex::ParticleReal>::min())
{
atomicSetIdInvalid(idcpu1[p_pair_indices_1[i]]);
}
if (w2[p_pair_indices_2[i]] <= std::numeric_limits<amrex::ParticleReal>::min())
{
atomicSetIdInvalid(idcpu2[p_pair_indices_2[i]]);
}
BinaryCollisionUtils::remove_weight_from_colliding_particle(
w1[p_pair_indices_1[i]], idcpu1[p_pair_indices_1[i]], p_pair_reaction_weight[i]);
BinaryCollisionUtils::remove_weight_from_colliding_particle(
w2[p_pair_indices_2[i]], idcpu2[p_pair_indices_2[i]], p_pair_reaction_weight[i]);

// Set the child particle properties appropriately
auto& ux1 = soa_products_data[0].m_rdata[PIdx::ux][product1_index];
Expand Down
45 changes: 10 additions & 35 deletions Source/Particles/Collision/BinaryCollision/ParticleCreationFunc.H
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ public:
* function specific to the considered binary collision.
*
* @param[in] n_total_pairs how many binary collisions have been performed in this tile
* @param[in, out] soa_1 struct of array data of the first colliding particle species
* @param[in, out] soa_2 struct of array data of the second colliding particle species
* @param[in, out] ptile1,ptile2 the particle tiles of the two colliding species
* @param[out] tile_products array containing tile data of the product particles.
* @param[in] m1 mass of the first colliding particle species
* @param[in] m2 mass of the second colliding particle species
Expand All @@ -94,7 +93,7 @@ public:
AMREX_INLINE
amrex::Vector<int> operator() (
const index_type& n_total_pairs,
const SoaData_type& soa_1, const SoaData_type& soa_2,
ParticleTileType& ptile1, ParticleTileType& ptile2,
const amrex::Vector<WarpXParticleContainer*>& pc_products,
ParticleTileType** AMREX_RESTRICT tile_products,
const amrex::ParticleReal& m1, const amrex::ParticleReal& m2,
Expand Down Expand Up @@ -129,6 +128,9 @@ public:
tile_products[i]->resize(products_np[i] + num_added);
}

const auto soa_1 = ptile1.getParticleTileData();
const auto soa_2 = ptile2.getParticleTileData();

amrex::ParticleReal* AMREX_RESTRICT w1 = soa_1.m_rdata[PIdx::w];
amrex::ParticleReal* AMREX_RESTRICT w2 = soa_2.m_rdata[PIdx::w];
uint64_t* AMREX_RESTRICT idcpu1 = soa_1.m_idcpu;
Expand Down Expand Up @@ -196,37 +198,10 @@ public:
}

// Remove p_pair_reaction_weight[i] from the colliding particles' weights
amrex::Gpu::Atomic::AddNoRet(&w1[p_pair_indices_1[i]],
-p_pair_reaction_weight[i]);
amrex::Gpu::Atomic::AddNoRet(&w2[p_pair_indices_2[i]],
-p_pair_reaction_weight[i]);

// Note: Particle::atomicSetID should also be provided as a standalone helper function in AMReX
// to replace the following lambda.
auto const atomicSetIdInvalid = [] AMREX_GPU_DEVICE (uint64_t & idcpu)
{
#if defined(AMREX_USE_OMP)
#pragma omp atomic write
idcpu = amrex::ParticleIdCpus::Invalid;
#else
amrex::Gpu::Atomic::Exch(
(unsigned long long *)&idcpu,
(unsigned long long)amrex::ParticleIdCpus::Invalid
);
#endif
};

// If the colliding particle weight decreases to zero, remove particle by
// setting its id to invalid
if (w1[p_pair_indices_1[i]] <= std::numeric_limits<amrex::ParticleReal>::min())
{
atomicSetIdInvalid(idcpu1[p_pair_indices_1[i]]);

}
if (w2[p_pair_indices_2[i]] <= std::numeric_limits<amrex::ParticleReal>::min())
{
atomicSetIdInvalid(idcpu2[p_pair_indices_2[i]]);
}
BinaryCollisionUtils::remove_weight_from_colliding_particle(
w1[p_pair_indices_1[i]], idcpu1[p_pair_indices_1[i]], p_pair_reaction_weight[i]);
BinaryCollisionUtils::remove_weight_from_colliding_particle(
w2[p_pair_indices_2[i]], idcpu2[p_pair_indices_2[i]], p_pair_reaction_weight[i]);

// Initialize the product particles' momentum, using a function depending on the
// specific collision type
Expand Down Expand Up @@ -323,7 +298,7 @@ public:
AMREX_INLINE
amrex::Vector<int> operator() (
const index_type& /*n_total_pairs*/,
const SoaData_type& /*soa_1*/, const SoaData_type& /*soa_2*/,
ParticleTileType& /*ptile1*/, ParticleTileType& /*ptile2*/,
amrex::Vector<WarpXParticleContainer*>& /*pc_products*/,
ParticleTileType** /*tile_products*/,
const amrex::ParticleReal& /*m1*/, const amrex::ParticleReal& /*m2*/,
Expand Down

0 comments on commit c70a6c5

Please sign in to comment.