Skip to content

Commit

Permalink
Overload the intersection and union operators
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiLehe committed Oct 22, 2024
1 parent 4aac40b commit eee9ad6
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions openpmd_viewer/openpmd_timeseries/particle_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ParticleTracker( object ):
to be stored in the openPMD files.
"""

def __init__(self, ts, species=None, t=None,
def __init__(self, ts=None, species=None, t=None,
iteration=None, select=None, preserve_particle_index=False):
"""
Initialize an instance of `ParticleTracker`: select particles at
Expand Down Expand Up @@ -69,8 +69,9 @@ def __init__(self, ts, species=None, t=None,
'x' : [-4., 10.] (Particles having x between -4 and 10)
'ux' : [-0.1, 0.1] (Particles having ux between -0.1 and 0.1 mc)
'uz' : [5., None] (Particles with uz above 5 mc).
Can also be a 1d array of interegers corresponding to the
selected particles `id`
Can also be a 1d array of integers corresponding to the
selected particles `id`. In this case, the arguments `ts`, `t`
and `iteration` do not need to be passed.
preserve_particle_index: bool, optional
When retrieving particles at a several iterations,
Expand Down Expand Up @@ -105,6 +106,37 @@ def __init__(self, ts, species=None, t=None,
self.species = species
self.preserve_particle_index = preserve_particle_index

def __and__(self, other):
"""
Define the intersection of two ParticleTracker instances.
This selects the particles that are present in both instances.
"""
# Check that both instances are consistent
assert self.species == other.species
assert self.preserve_particle_index == other.preserve_particle_index

# Find the intersection of the selected particles
pid = np.intersect1d( self.selected_pid, other.selected_pid )
pt = ParticleTracker( species=self.species, select=pid,
preserve_particle_index=self.preserve_particle_index )
return pt

def __or__(self, other):
"""
Define the union of two ParticleTracker instances.
This selects the particles that are present in at least one of the instances.
"""
# Check that both instances are consistent
assert self.species == other.species
assert self.preserve_particle_index == other.preserve_particle_index

# Find the union of the selected particles
pid = np.union1d( self.selected_pid, other.selected_pid )
pt = ParticleTracker( species=self.species, select=pid,
preserve_particle_index=self.preserve_particle_index )
return pt

def extract_tracked_particles( self, iteration, data_reader, data_list,
species, extensions ):
Expand Down

0 comments on commit eee9ad6

Please sign in to comment.