Skip to content

Commit 18a7c17

Browse files
committed
new "seed" keyword argument in Lattice.track
1 parent 836663b commit 18a7c17

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

pyat/at/tracking/track.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,11 @@ def _atpass_spawn(ring, seed, rank, rin, **kwargs):
3535
return rin, result
3636

3737

38-
def _pass(ring, r_in, pool_size, start_method, **kwargs):
38+
def _pass(ring, r_in, pool_size, start_method, seed, **kwargs):
3939
ctx = multiprocessing.get_context(start_method)
4040
# Split input in as many slices as processes
4141
args = enumerate(numpy.array_split(r_in, pool_size, axis=1))
4242
# Generate a new starting point for C RNGs
43-
seed = random.common.integers(0, high=_imax, dtype=int)
4443
global _globring
4544
_globring = ring
4645
if ctx.get_start_method() == 'fork':
@@ -63,34 +62,40 @@ def _element_pass(element: Element, r_in, **kwargs):
6362

6463
@fortran_align
6564
def _lattice_pass(lattice: list[Element], r_in, nturns: int = 1,
66-
refpts: Refpts = End, no_varelem=True, **kwargs):
65+
refpts: Refpts = End, no_varelem=True, seed: int | None = None, **kwargs):
6766
kwargs['reuse'] = kwargs.pop('keep_lattice', False)
6867
if no_varelem:
6968
lattice = disable_varelem(lattice)
7069
else:
7170
if sum(variable_refs(lattice)) > 0:
7271
kwargs['reuse'] = False
7372
refs = get_uint32_index(lattice, refpts)
73+
if seed is not None:
74+
reset_rng(seed=seed)
7475
return _atpass(lattice, r_in, nturns, refpts=refs, **kwargs)
7576

7677

7778
@fortran_align
7879
def _plattice_pass(lattice: list[Element], r_in, nturns: int = 1,
79-
refpts: Refpts = End, pool_size: int = None,
80+
refpts: Refpts = End, seed: int | None = None, pool_size: int = None,
8081
start_method: str = None, **kwargs):
8182
refpts = get_uint32_index(lattice, refpts)
8283
any_collective = has_collective(lattice)
8384
kwargs['reuse'] = kwargs.pop('keep_lattice', False)
8485
rshape = r_in.shape
8586
if len(rshape) >= 2 and rshape[1] > 1 and not any_collective:
87+
if seed is None:
88+
seed = random.common.integers(0, high=_imax, dtype=int)
8689
if pool_size is None:
8790
pool_size = min(len(r_in[0]), multiprocessing.cpu_count(),
8891
DConstant.patpass_poolsize)
8992
if start_method is None:
9093
start_method = DConstant.patpass_startmethod
91-
return _pass(lattice, r_in, pool_size, start_method, nturns=nturns,
94+
return _pass(lattice, r_in, pool_size, start_method, seed=seed, nturns=nturns,
9295
refpts=refpts, **kwargs)
9396
else:
97+
if seed is not None:
98+
reset_rng(seed=seed)
9499
if any_collective:
95100
warn(AtWarning('Collective PassMethod found: use single process'))
96101
else:
@@ -125,6 +130,8 @@ def lattice_track(lattice: Iterable[Element], r_in,
125130
in_place (bool): If True *r_in* is modified in-place and
126131
reports the coordinates at the end of the element.
127132
(default: False)
133+
seed (int | None): Seed for the random generators. If None (default)
134+
continue the sequence
128135
keep_lattice (bool): Use elements persisted from a previous
129136
call. If :py:obj:`True`, assume that the lattice has not changed
130137
since the previous call.

0 commit comments

Comments
 (0)