@@ -35,12 +35,11 @@ def _atpass_spawn(ring, seed, rank, rin, **kwargs):
35
35
return rin , result
36
36
37
37
38
- def _pass (ring , r_in , pool_size , start_method , ** kwargs ):
38
+ def _pass (ring , r_in , pool_size , start_method , seed , ** kwargs ):
39
39
ctx = multiprocessing .get_context (start_method )
40
40
# Split input in as many slices as processes
41
41
args = enumerate (numpy .array_split (r_in , pool_size , axis = 1 ))
42
42
# Generate a new starting point for C RNGs
43
- seed = random .common .integers (0 , high = _imax , dtype = int )
44
43
global _globring
45
44
_globring = ring
46
45
if ctx .get_start_method () == 'fork' :
@@ -63,34 +62,40 @@ def _element_pass(element: Element, r_in, **kwargs):
63
62
64
63
@fortran_align
65
64
def _lattice_pass (lattice : list [Element ], r_in , nturns : int = 1 ,
66
- refpts : Refpts = End , no_varelem = True , ** kwargs ):
65
+ refpts : Refpts = End , no_varelem = True , seed : int | None = None , ** kwargs ):
67
66
kwargs ['reuse' ] = kwargs .pop ('keep_lattice' , False )
68
67
if no_varelem :
69
68
lattice = disable_varelem (lattice )
70
69
else :
71
70
if sum (variable_refs (lattice )) > 0 :
72
71
kwargs ['reuse' ] = False
73
72
refs = get_uint32_index (lattice , refpts )
73
+ if seed is not None :
74
+ reset_rng (seed = seed )
74
75
return _atpass (lattice , r_in , nturns , refpts = refs , ** kwargs )
75
76
76
77
77
78
@fortran_align
78
79
def _plattice_pass (lattice : list [Element ], r_in , nturns : int = 1 ,
79
- refpts : Refpts = End , pool_size : int = None ,
80
+ refpts : Refpts = End , seed : int | None = None , pool_size : int = None ,
80
81
start_method : str = None , ** kwargs ):
81
82
refpts = get_uint32_index (lattice , refpts )
82
83
any_collective = has_collective (lattice )
83
84
kwargs ['reuse' ] = kwargs .pop ('keep_lattice' , False )
84
85
rshape = r_in .shape
85
86
if len (rshape ) >= 2 and rshape [1 ] > 1 and not any_collective :
87
+ if seed is None :
88
+ seed = random .common .integers (0 , high = _imax , dtype = int )
86
89
if pool_size is None :
87
90
pool_size = min (len (r_in [0 ]), multiprocessing .cpu_count (),
88
91
DConstant .patpass_poolsize )
89
92
if start_method is None :
90
93
start_method = DConstant .patpass_startmethod
91
- return _pass (lattice , r_in , pool_size , start_method , nturns = nturns ,
94
+ return _pass (lattice , r_in , pool_size , start_method , seed = seed , nturns = nturns ,
92
95
refpts = refpts , ** kwargs )
93
96
else :
97
+ if seed is not None :
98
+ reset_rng (seed = seed )
94
99
if any_collective :
95
100
warn (AtWarning ('Collective PassMethod found: use single process' ))
96
101
else :
@@ -125,6 +130,8 @@ def lattice_track(lattice: Iterable[Element], r_in,
125
130
in_place (bool): If True *r_in* is modified in-place and
126
131
reports the coordinates at the end of the element.
127
132
(default: False)
133
+ seed (int | None): Seed for the random generators. If None (default)
134
+ continue the sequence
128
135
keep_lattice (bool): Use elements persisted from a previous
129
136
call. If :py:obj:`True`, assume that the lattice has not changed
130
137
since the previous call.
0 commit comments