Skip to content

Commit

Permalink
Merge branch 'xsuite:main' into qgaussian_derivative
Browse files Browse the repository at this point in the history
  • Loading branch information
ewaagaard authored Jul 15, 2024
2 parents 38aae6b + 25a44ec commit f760b68
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 17 deletions.
28 changes: 14 additions & 14 deletions tests/test_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,48 +298,48 @@ def test_slicer_moments_single_bunch(test_context):
atol=1e-12)
xo.assert_allclose(sl.num_particles[0, :], [0, 0, p.weight.sum()], rtol=0,
atol=1e-12)
xo.assert_allclose(sl.sum(c1_name), [0, 0, (c1 * p.weight).sum()],
xo.assert_allclose(sl.sum(c1_name)[0], [0, 0, (c1 * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum(c2_name), [0, 0, (c2 * p.weight).sum()],
xo.assert_allclose(sl.sum(c2_name)[0], [0, 0, (c2 * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum('zeta'),
xo.assert_allclose(sl.sum('zeta')[0],
[0, 0, (p.zeta * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum(c1_name + c1_name),
xo.assert_allclose(sl.sum(c1_name + c1_name)[0],
[0, 0, (c1 ** 2 * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum(c2_name + c2_name),
xo.assert_allclose(sl.sum(c2_name + c2_name)[0],
[0, 0, (c2 ** 2 * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum(c1_name + c2_name),
xo.assert_allclose(sl.sum(c1_name + c2_name)[0],
[0, 0, (c1 * c2 * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum('zetazeta'),
xo.assert_allclose(sl.sum('zetazeta')[0],
[0, 0, (p.zeta ** 2 * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.sum(c1_name + 'zeta'),
xo.assert_allclose(sl.sum(c1_name + 'zeta')[0],
[0, 0, (c1 * p.zeta * p.weight).sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.mean(c1_name),
xo.assert_allclose(sl.mean(c1_name)[0],
[0, 0, (c1 * p.weight).sum() / p.weight.sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.mean(c2_name),
xo.assert_allclose(sl.mean(c2_name)[0],
[0, 0, (c2 * p.weight).sum() / p.weight.sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.mean(c1_name + c1_name),
xo.assert_allclose(sl.mean(c1_name + c1_name)[0],
[0, 0,
(c1 ** 2 * p.weight).sum() / p.weight.sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.mean(c2_name + c2_name),
xo.assert_allclose(sl.mean(c2_name + c2_name)[0],
[0, 0,
(c2 ** 2 * p.weight).sum() / p.weight.sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(sl.mean(c1_name + c2_name),
xo.assert_allclose(sl.mean(c1_name + c2_name)[0],
[0, 0,
(c1 * c2 * p.weight).sum() / p.weight.sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(
sl.mean(c1_name + 'zeta'),
sl.mean(c1_name + 'zeta')[0],
[0, 0, (c1 * p.zeta * p.weight).sum() / p.weight.sum()],
rtol=0, atol=1e-12)
xo.assert_allclose(
Expand Down
2 changes: 1 addition & 1 deletion xfields/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.20.0"
__version__ = "0.20.2"
12 changes: 10 additions & 2 deletions xfields/solvers/fftsolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def solve(self, rho):
_workspace_dev.T[:,:,:] *= (
self._gint_rep_transf_dev.T) # phi_rep_hat
except Exception: # pyopencl does not support array broadcasting (used in 2.5D)
# Check F-contiguity
assert _workspace_dev.strides[0] == 16
assert self._gint_rep_transf_dev.strides[0] == 16

for ii in range(self.nz):
_workspace_dev.T[ii,:,:] *= (
self._gint_rep_transf_dev.T[0, :, :]) # phi_rep_hat
Expand Down Expand Up @@ -192,7 +196,9 @@ def __init__(self, dx, dy, dz, nx, ny, nz, context=None, fftplan=None):
del(temp_dev)

# Transform the green function
gint_rep_transf = np.fft.fftn(gint_rep, axes=(0,1))
gint_rep_transf = np.zeros((2*nx, 2*ny),
dtype=np.complex128, order='F')
gint_rep_transf[:, :] = np.fft.fftn(gint_rep, axes=(0,1))

# Transfer to GPU (if needed)
gint_rep_transf_dev = context.nparray_to_context_array(
Expand Down Expand Up @@ -247,7 +253,9 @@ def __init__(self, dx, dy, dz, nx, ny, nz, context=None, fftplan=None):
del(temp_dev)

# Transform the green function
gint_rep_transf = np.fft.fftn(gint_rep, axes=(0,1))
gint_rep_transf = np.zeros((2*nx, 2*ny),
dtype=np.complex128, order='F')
gint_rep_transf[:, :] = np.fft.fftn(gint_rep, axes=(0,1))

# Transfer to GPU (if needed)
gint_rep_transf_dev = context.nparray_to_context_array(
Expand Down

0 comments on commit f760b68

Please sign in to comment.