Skip to content

Commit

Permalink
MAINT: flatten _fft/_ifft for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Dec 17, 2024
1 parent 1d5a7fc commit 81c49a3
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 71 deletions.
61 changes: 61 additions & 0 deletions pmd_beamphysics/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import logging
from typing import Sequence

import h5py
import numpy as np
import scipy.fft

logger = logging.getLogger(__name__)
_global_fft_workers = -1
Expand Down Expand Up @@ -211,3 +213,62 @@ def set_num_fft_workers(workers: int):
_global_fft_workers = workers

logger.info(f"Set number of FFT workers to: {workers}")


def fft_phased(
array: np.ndarray,
axes: Sequence[int],
phasors,
workers=-1,
) -> np.ndarray:
"""
Compute the N-D discrete Fourier Transform with phasors applied.
Ortho normalization is used.
Parameters
----------
array : np.ndarray
Input array which can be complex.
axes : sequence of int
Axis indices to apply the FFT to.
phasors : np.ndarray
Apply these per-dimension phasors after performing the FFT.
workers : int, default=-1
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from ``os.cpu_count()``.
"""
array_fft = scipy.fft.fftn(array, axes=axes, workers=workers, norm="ortho")
for phasor in phasors:
array_fft *= phasor
return scipy.fft.fftshift(array_fft, axes=axes)


def ifft_phased(
array: np.ndarray,
axes: Sequence[int],
phasors,
workers=-1,
) -> np.ndarray:
"""
Compute the N-D inverse discrete Fourier Transform with phasors applied.
Ortho normalization is used.
Parameters
----------
array : np.ndarray
Input array which can be complex.
axes : Sequence[int]
Axis indices to apply the FFT to.
phasors : np.ndarray
Apply the complex conjugate of these per-dimension phasors after the
inverse FFT.
workers : int, default=-1
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from ``os.cpu_count()``.
"""
array_fft = scipy.fft.ifftn(array, axes=axes, workers=workers, norm="ortho")
for phasor in phasors:
array_fft *= np.conj(phasor)
return scipy.fft.ifftshift(array_fft, axes=axes)
133 changes: 62 additions & 71 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,65 +146,6 @@ def from_domain(cls, domain: tuple[np.ndarray, ...]) -> _NiceXYZ:
)


def fft_phased(
array: np.ndarray,
axes: Sequence[int],
phasors,
workers=-1,
) -> np.ndarray:
"""
Compute the N-D discrete Fourier Transform with phasors applied.
Ortho normalization is used.
Parameters
----------
array : np.ndarray
Input array which can be complex.
axes : sequence of int
Axis indices to apply the FFT to.
phasors : np.ndarray
Apply these per-dimension phasors after performing the FFT.
workers : int, default=-1
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from ``os.cpu_count()``.
"""
array_fft = scipy.fft.fftn(array, axes=axes, workers=workers, norm="ortho")
for phasor in phasors:
array_fft *= phasor
return scipy.fft.fftshift(array_fft, axes=axes)


def ifft_phased(
array: np.ndarray,
axes: Sequence[int],
phasors,
workers=-1,
) -> np.ndarray:
"""
Compute the N-D inverse discrete Fourier Transform with phasors applied.
Ortho normalization is used.
Parameters
----------
array : np.ndarray
Input array which can be complex.
axes : Sequence[int]
Axis indices to apply the FFT to.
phasors : np.ndarray
Apply the complex conjugate of these per-dimension phasors after the
inverse FFT.
workers : int, default=-1
Maximum number of workers to use for parallel computation. If negative,
the value wraps around from ``os.cpu_count()``.
"""
array_fft = scipy.fft.ifftn(array, axes=axes, workers=workers, norm="ortho")
for phasor in phasors:
array_fft *= np.conj(phasor)
return scipy.fft.ifftshift(array_fft, axes=axes)


def nd_kspace_domains(
coeffs: Sequence[float],
sizes: Sequence[int],
Expand Down Expand Up @@ -1024,7 +965,21 @@ def with_rmesh(
pad: Sequence[int] | int | None = None,
fix_pad: bool = True,
) -> Wavefront:
"""Create a new Wavefront instance, replacing the `rmesh`."""
"""
Create a new Wavefront instance, replacing the `rmesh`.
Parameters
----------
rmesh : np.ndarray
pad : int or sequence of int, optional
New padding settings.
fix : bool, optional
Fix padding for efficiency.
Returns
-------
Wavefront
"""
if pad is None:
pad = self.pad
return Wavefront(
Expand All @@ -1040,6 +995,20 @@ def with_padding(
pad: int | Sequence[int],
fix: bool = True,
) -> Wavefront:
"""
Get a new Wavefront instance with adjusted padding settings.
Parameters
----------
pad : int or sequence of int
The new padding settings.
fix : bool, optional
Fix padding for efficiency.
Returns
-------
Wavefront
"""
ndim = len(self._grid)
if isinstance(pad, int):
pad = (pad,) * ndim
Expand All @@ -1050,14 +1019,12 @@ def with_padding(
f"Got {len(pad)} but expected {ndim}"
)

if fix:
pad = fix_padding(self._grid, pad=pad)

return Wavefront(
rmesh=self.rmesh,
wavelength=self.wavelength,
metadata=self.metadata,
pad=pad,
fix_pad=fix,
)

def with_padding_divergence(
Expand All @@ -1067,6 +1034,28 @@ def with_padding_divergence(
beam_size: float = 1e-4,
fix: bool = True,
) -> Wavefront:
"""
Using divergence settings, create a new Wavefront with padding adjusted by the factor:
2.0 * (theta_max * drift_distance) / beam_size
Parameters
----------
theta_max : float, optional
Maximum divergence angle in radians. Default is 5e-5.
drift_distance : float, optional
Maximum distance over which the beam will be drifted in meters.
Default is 1.0.
beam_size : float, optional
Size of the beam in meters. Default is 1e-4.
fix : bool, optional
Fix padding for efficiency.
Returns
-------
Wavefront
The new wavefront instance with the padding settings.
"""
pad_factor = transverse_divergence_padding_factor(
theta_max=theta_max,
drift_distance=drift_distance,
Expand Down Expand Up @@ -1200,12 +1189,11 @@ def _fft(self):
rmesh_grid=self._grid,
pad=self._padding,
)
return fft_phased(
dfl_pad,
axes=(0, 1, 2),
phasors=phasors,
workers=workers,
)
axes = (0, 1, 2)
array_fft = scipy.fft.fftn(dfl_pad, axes=axes, workers=workers, norm="ortho")
for phasor in phasors:
array_fft *= phasor
return scipy.fft.fftshift(array_fft, axes=axes)

def _ifft(self):
"""
Expand All @@ -1226,12 +1214,15 @@ def _ifft(self):
pad=self._padding,
)

full_ifft = ifft_phased(
array_fft = scipy.fft.ifftn(
self._kmesh,
axes=(0, 1, 2),
phasors=phasors,
workers=workers,
norm="ortho",
)
for phasor in phasors:
array_fft *= np.conj(phasor)
full_ifft = scipy.fft.ifftshift(array_fft, axes=(0, 1, 2))

# Remove padding from the inverse fft result:
ifft_slices = tuple(slice(pad, -pad) for pad in self._padding)
Expand Down

0 comments on commit 81c49a3

Please sign in to comment.