diff --git a/src/tike/opt.py b/src/tike/opt.py index 7bcea03c..348bf173 100644 --- a/src/tike/opt.py +++ b/src/tike/opt.py @@ -12,8 +12,9 @@ import numpy as np +from tike.random import randomizer + logger = logging.getLogger(__name__) -randomizer = np.random.default_rng() def batch_indicies(n, m=1, use_random=True): diff --git a/src/tike/pca.py b/src/tike/pca.py new file mode 100644 index 00000000..6cf4819b --- /dev/null +++ b/src/tike/pca.py @@ -0,0 +1,163 @@ +import unittest +from unittest.case import skip + +import numpy as np + +from tike.linalg import hermitian as _hermitian +from tike.linalg import pca_eig + + +def pca_incremental(data, k, S=None, U=None): + """Principal component analysis with method IIIB from Arora et al (2012). + + This method iteratively updates a current guess for the principal + components of a population. + + Parameters + ---------- + data (..., N, D) + Array of N observations of a D dimensional space. + k : int + The desired number of principal components + + Returns + ------- + S (..., k) + The singular values corresponding to the current principal components + sorted largest to smallest. + U (..., D, k) + The current best principal components of the population. + + References + ---------- + Arora, R., Cotter, A., Livescu, K., & Srebro, N. (2012). Stochastic + optimization for PCA and PLS. 2012 50th Annual Allerton Conference on + Communication, Control, and Computing (Allerton), 863. + https://doi.org/10.1109/Allerton.2012.6483308 + + """ + ndim = data.shape[-1] + nsamples = data.shape[-2] + lead = data.shape[:-2] + if S is None or U is None: + # TODO: Better inital guess for one sample? + S = np.ones((*lead, k), dtype=data.dtype) + U = np.zeros((*lead, ndim, k), dtype=data.dtype) + U[..., list(range(k)), list(range(k))] = 1 + if S.shape != (*lead, k): + raise ValueError('S is the wrong shape', S.shape) + if U.shape != (*lead, ndim, k): + raise ValueError('U is the wrong shape', U.shape) + + for m in range(nsamples): + + x = data[..., m, :, None] + + # (k, d) x (d, 1) + xy = _hermitian(U) @ x + + # (d, 1) - (d, d) x (d, 1) + xp = x - U @ _hermitian(U) @ x + + # (..., 1, 1) + norm_xp = np.linalg.norm(xp, axis=-2, keepdims=True) + + # [ + # (k, k), (k, 1) + # (1, k), (1, 1) + # ] + Q = np.empty(shape=(*lead, k + 1, k + 1), dtype=x.dtype) + Q[..., :-1, :-1] = xy @ _hermitian(xy) + Q[..., -1:, -1:] = norm_xp * norm_xp + for i in range(k): + Q[..., i, i] = S[..., i] + Q[..., :-1, -1:] = norm_xp * xy + # Skip one assignment because matrix is conjugate symmetric + # Q[..., -1:, :-1] = norm_xp * _hermitian(xy) + S1, U1 = np.linalg.eigh(Q, UPLO='U') + + # [(d, k), (d, 1)] x (k + 1, k + 1) + Utilde = np.concatenate([U, xp / norm_xp], axis=-1) @ U1 + Stilde = S1 + + # Skip sorting because eigh() guarantees vectors already sorted + # order = np.argsort(Stilde, axis=-1) + # Stilde = np.take_along_axis(Stilde, order, axis=-1) + # Utilde = np.take_along_axis(Utilde, order[..., None, :], axis=-1) + S, U = Stilde[..., -1:-(k + 1):-1], Utilde[..., -1:-(k + 1):-1] + + return S, U + + +def pca_svd(data, k): + """Return k principal components via singular value decomposition. + + Parameters + ---------- + data (..., N, D) + Array of N observations of a D dimensional space. + + Returns + ------- + W (..., N, k) + The weights projecting the original observations onto k-fold subspace + C (..., k, D) + The k principal components sorted largest to smallest. + + """ + U, S, Vh = np.linalg.svd(data, full_matrices=False, compute_uv=True) + assert data.shape == ((U * S[..., None, :]) @ Vh).shape + # svd API states that values returned in descending order. i.e. + # the best vectors are first. + U = U[..., :k] + S = S[..., None, :k] + Vh = Vh[..., :k, :] + return U * S, Vh + + +class TestPrincipalComponentAnalysis(unittest.TestCase): + + def setUp(self, batch=2, sample=100, dimensions=4): + # generates some random uncentered data that is strongly biased towards + # having principal components. The first batch is flipped so the + # principal components start at the last dimension. + np.random.seed(0) + self.data = np.random.normal( + np.random.rand(dimensions), + 10 / (np.arange(dimensions) + 1), + size=[batch, sample, dimensions], + ) + self.data[0] = self.data[0, ..., ::-1] + + def print_metrics(self, W, C, k): + I = C @ _hermitian(C) + np.testing.assert_allclose( + I, + np.tile(np.eye(k), (C.shape[0], 1, 1)), + atol=1e-12, + ) + print( + 'reconstruction error: ', + np.linalg.norm(W @ C - self.data, axis=(1, 2)), + ) + @unittest.skip("Broken due to tsting API change.") + def test_numpy_eig(self, k=2): + S, U = pca_eig(self.data, k) + print('EIG COV principal components\n', U) + self.print_metrics(U, k) + + def test_numpy_svd(self, k=2): + W, C = pca_svd(self.data, k) + print('SVD principal components\n', C) + self.print_metrics(W, C, k) + + @unittest.skip("Broken due to tsting API change.") + def test_incremental_pca(self, k=2): + S, U = pca_incremental(self.data, k=k) + + print('INCREMENTAL principal components\n', U) + self.print_metrics(U, k) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 93b7515d..05506a53 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -29,9 +29,37 @@ class ObjectOptions: """ - def __init__(self, positivity_constraint=0, smoothness_constraint=0): + def __init__( + self, + positivity_constraint=0, + smoothness_constraint=0, + use_adaptive_moment=False, + vdecay=0.6, + mdecay=0.666, + ): self.positivity_constraint = positivity_constraint self.smoothness_constraint = smoothness_constraint + self.use_adaptive_moment = use_adaptive_moment + self.vdecay = vdecay + self.mdecay = mdecay + self.v = None + self.m = None + + def put(self): + """Copy to the current GPU memory.""" + if self.v is not None: + self.v = cp.asarray(self.v) + if self.m is not None: + self.m = cp.asarray(self.m) + return self + + def get(self): + """Copy to the host CPU memory.""" + if self.v is not None: + self.v = cp.asnumpy(self.v) + if self.m is not None: + self.m = cp.asnumpy(self.m) + return self def positivity_constraint(x, r): diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 214d4799..7c5b9675 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -38,9 +38,11 @@ import logging import cupy as cp +import cupyx.scipy.ndimage import numpy as np import tike.random +import tike.linalg logger = logging.getLogger(__name__) @@ -56,17 +58,49 @@ class ProbeOptions: The number of eigen probes/components. """ - def __init__(self, num_eigen_probes=0, orthogonality_constraint=True): + def __init__( + self, + num_eigen_probes=0, + orthogonality_constraint=True, + use_adaptive_moment=False, + vdecay=0.6, + mdecay=0.666, + centered_intensity_constraint=True, + sparsity_constraint=1, + ): self.orthogonality_constraint = orthogonality_constraint self._weights = None self._eigen_probes = None if num_eigen_probes > 0: pass + self.use_adaptive_moment = use_adaptive_moment + self.vdecay = vdecay + self.mdecay = mdecay + self.v = None + self.m = None + self.centered_intensity_constraint = centered_intensity_constraint + self.sparsity_constraint = sparsity_constraint @property def num_eigen_probes(self): return 0 if self._weights is None else self._weights.shape[-2] + def put(self): + """Copy to the current GPU memory.""" + if self.v is not None: + self.v = cp.asarray(self.v) + if self.m is not None: + self.m = cp.asarray(self.m) + return self + + def get(self): + """Copy to the host CPU memory.""" + if self.v is not None: + self.v = cp.asnumpy(self.v) + if self.m is not None: + self.m = cp.asnumpy(self.m) + return self + def get_varying_probe(shared_probe, eigen_probe=None, weights=None): """Construct the varying probes. @@ -434,9 +468,144 @@ def gaussian(size, rin=0.8, rout=1.0): return img +def opr(residual_gradient, eigen_probe, eigen_weights, n, alpha=0.5): + """Regularize multiple probes with orthogonal probe relaxation (OPR). + + Corrects for variable illumination across scan positions by regularizing + the probe at each position with the first `n` principal components of + the probes at all positions. + + The probes are regularized using a weighted sum as below: + + .. math:: + $$P_{1} = \alpha P_{0} + (1 - \alpha) \sum_{c=0}^{n}P_0^c$$ + + where `alpha` is the weighting paramters which determines the relative + importance of the regulariation. + + Parameters + ---------- + residual_gradient : (..., POSI, 1, SHARED, WIDE, HIGH) complex64 + The residual gradients for each position. + eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64 + The eigen probes for all positions. + eigen_weights : (..., POSI, EIGEN, SHARED) float32 + The relative intensity of the eigen probes at each position. + n : int + The number of components to use for regularization. + + Returns + ------- + eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64 + The eigen probes for all positions. + eigen_weights : (..., POSI, EIGEN, SHARED) float32 + The relative intensity of the eigen probes at each position. + + References + ---------- + Odstrcil, M., P. Baksh, S. A. Boden, R. Card, J. E. Chad, J. G. Frey, and + W. S. Brocklesby. 2016. “Ptychographic Coherent Diffractive Imaging with + Orthogonal Probe Relaxation.” Optics Express. + https://doi.org/10.1364/oe.24.008360. + + """ + # Flatten the last dimension of the probe and move incoherent mode + # dimension so the position dimensions are in the last two dimensions + residual_gradient = cp.moveaxis(residual_gradient, -5, -3) + shape = residual_gradient.shape # (..., 1, SHARED, POSI, WIDE, HIGH) + residual_gradient = residual_gradient.reshape( + *residual_gradient.shape[:-2], + residual_gradient.shape[-2] * residual_gradient.shape[-1], + ) + # residual_grad is now shape (..., 1, SHARED, POSI, WIDE*HIGH) + assert residual_gradient.shape[ + -2] > n, 'There cannot be more modes than positions.' + + _, U = tike.linalg.pca_eig(residual_gradient, k=n) + # U is shape (..., 1, SHARED, WIDE*HIGH, EIGEN) + weights = residual_gradient @ U + # weights is shape (..., 1, SHARED, POSI, EIGEN) + weights -= cp.mean(weights, axis=-2, keepdims=True) + + U = cp.moveaxis(U, -1, -3) + # U is shape (..., 1, EIGEN, SHARED, WIDE*HIGH) + U = U.reshape(*eigen_probe.shape) + weights = cp.moveaxis(weights[..., 0, :, :, :], -3, -1) + assert weights.shape == eigen_weights.shape + + # eigen_probe = (1 - alpha) * eigen_probe + alpha * U + # equivalent expression as above + eigen_probe += alpha * (U - eigen_probe) + + return eigen_probe, weights + + +def constrain_center_peak(probe): + """Force the peak illumination intensity to the center of the probe grid. + + After smoothing the intensity of the combined illumination with a gaussian + filter with standard deviation sigma, the probe is shifted such that the + maximum intensity is centered. + """ + half = probe.shape[-2] // 2, probe.shape[-1] // 2 + logger.info("Constrained probe intensity to center with sigma=%f", half[0]) + # First reshape the probe to 3D so it is a single stack of 2D images. + stack = probe.reshape((-1, *probe.shape[-2:])) + intensity = cupyx.scipy.ndimage.gaussian_filter( + input=np.sum(np.square(np.abs(stack)), axis=0), + sigma=half, + mode='wrap', + ) + # Find the maximum intensity in 2D. + center = np.argmax(intensity) + # Find the 2D coordinates of the maximum. + coords = cp.unravel_index(center, dims=probe.shape[-2:]) + # Shift each of the probes so the max is in the center. + p = np.roll(stack, half[0] - coords[0], axis=-2) + stack = np.roll(p, half[1] - coords[1], axis=-1) + # Reform to the original shape; make contiguous. + probe = stack.reshape(probe.shape) + return probe + + +def constrain_probe_sparsity(probe, f): + """Constrain the probe intensity so no more than f/1 elements are nonzero.""" + if f == 1: + return probe + logger.info("Constrained probe intensity spasity to %f", f) + # First reshape the probe to 3D so it is a single stack of 2D images. + stack = probe.reshape((-1, *probe.shape[-2:])) + intensity = np.sum(np.square(np.abs(stack)), axis=0) + sigma = probe.shape[-2] / 8, probe.shape[-1] / 8 + intensity = cupyx.scipy.ndimage.gaussian_filter( + input=intensity, + sigma=sigma, + mode='wrap', + ) + # Get the coordinates of the smallest k values + k = int((1 - f) * probe.shape[-1] * probe.shape[-2]) + smallest = np.argpartition(intensity, k, axis=None)[:k] + coords = cp.unravel_index(smallest, dims=probe.shape[-2:]) + # Set these k smallest values to zero in all probes + probe[..., coords[0], coords[1]] = 0 + return probe + + if __name__ == "__main__": - cp.random.seed(0) + cp.random.seed() x = (cp.random.rand(7, 1, 9, 3, 3) + 1j * cp.random.rand(7, 1, 9, 3, 3)).astype('complex64') x1 = orthogonalize_eig(x) assert x1.shape == x.shape, x1.shape + + probe_shape = (1, 52, 3, 5, 8, 8) + probe = cp.random.rand(*probe_shape) + 1j * cp.random.rand(*probe_shape) + probe = probe.astype('complex64') + reg_probes = opr(probe, n=7) + assert reg_probes.shape == probe_shape + + p = (cp.random.rand(3, 7, 7) * 100).astype(int) + p1 = constrain_center_peak(p) + print(p1) + p2 = constrain_probe_sparsity(p1, 0.6) + print(p2) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index be32b2bc..045bd66e 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -69,7 +69,7 @@ from .object import get_padded_object from .position import (PositionOptions, check_allowed_positions, affine_position_regularization) -from .probe import get_varying_probe +from .probe import constrain_center_peak, constrain_probe_sparsity, get_varying_probe logger = logging.getLogger(__name__) @@ -173,6 +173,8 @@ def reconstruct( batch_size=None, initial_scan=None, position_options=None, + probe_options=None, + object_options=None, **kwargs ): # yapf: disable """Solve the ptychography problem using the given `algorithm`. @@ -205,6 +207,8 @@ def reconstruct( simultaneously per view. position_options : PositionOptions A class containing settings related to position correction. + probe_options : ProbeOptions + A class containing settings related to probe updates. num_gpu : int, tuple(int) The number of GPUs to use or a tuple of the device numbers of the GPUs to use. If the number of GPUs is less than the requested number, only @@ -280,6 +284,10 @@ def reconstruct( PositionOptions.put, result['position_options'], ) + if probe_options: + result['probe_options'] = probe_options.put() + if object_options: + result['object_options'] = object_options.put() for key, value in kwargs.items(): if np.ndim(value) > 0: kwargs[key] = comm.pool.bcast([value]) @@ -303,6 +311,19 @@ def reconstruct( logger.info(f"{algorithm} epoch {i:,d}") + if probe_options is not None: + if probe_options.centered_intensity_constraint: + result['probe'] = comm.pool.map( + constrain_center_peak, + result['probe'], + ) + if probe_options.sparsity_constraint < 1: + result['probe'] = comm.pool.map( + constrain_probe_sparsity, + result['probe'], + f=probe_options.sparsity_constraint, + ) + kwargs.update(result) result = getattr(solvers, algorithm)( operator, @@ -352,7 +373,10 @@ def reconstruct( for x, o in zip(result['position_options'], order) ] result['position_options'] = position_options - + if probe_options: + result['probe_options'] = result['probe_options'].get() + if object_options: + result['object_options'] = result['object_options'].get() if 'eigen_weights' in result: result['eigen_weights'] = comm.pool.gather( eigen_weights, diff --git a/src/tike/ptycho/solvers/__init__.py b/src/tike/ptycho/solvers/__init__.py index 95fa8fa3..2bd0c9c1 100644 --- a/src/tike/ptycho/solvers/__init__.py +++ b/src/tike/ptycho/solvers/__init__.py @@ -1,9 +1,11 @@ """Contains different solver implementations.""" +from .adam import adam_grad from .combined import cgrad from .divided import lstsq_grad __all__ = [ + 'adam_grad', 'cgrad', 'lstsq_grad', ] diff --git a/src/tike/ptycho/solvers/adam.py b/src/tike/ptycho/solvers/adam.py new file mode 100644 index 00000000..cb24af84 --- /dev/null +++ b/src/tike/ptycho/solvers/adam.py @@ -0,0 +1,366 @@ +import logging + +import cupy as cp +import numpy as np + +import tike.linalg +from tike.opt import batch_indicies, get_batch, adam, put_batch +from ..position import update_positions_pd, PositionOptions +from ..object import positivity_constraint, smoothness_constraint +from ..probe import constrain_variable_probe, get_varying_probe +from tike.pca import pca_svd + +logger = logging.getLogger(__name__) + + +def adam_grad( + op, comm, + data, probe, scan, psi, + cost=None, + eigen_probe=None, + eigen_weights=None, + num_batch=1, + subset_is_random=True, + probe_options=None, + position_options=None, + object_options=None, +): # yapf: disable + """Solve the ptychography problem using ADAptive Moment gradient descent. + + Parameters + ---------- + op : :py:class:`tike.operators.Ptycho` + A ptychography operator. + comm : :py:class:`tike.communicators.Comm` + An object which manages communications between both + GPUs and nodes. + + + .. seealso:: :py:mod:`tike.ptycho` + + """ + cost = np.inf + # Unique batch for each device + batches = [ + batch_indicies(s.shape[-2], num_batch, subset_is_random) for s in scan + ] + for n in range(num_batch): + + bdata = comm.pool.map(get_batch, data, batches, n=n) + bscan = comm.pool.map(get_batch, scan, batches, n=n) + + if isinstance(eigen_probe, list): + beigen_weights = comm.pool.map( + get_batch, + eigen_weights, + batches, + n=n, + ) + beigen_probe = eigen_probe + else: + beigen_probe = [None] * comm.pool.num_workers + beigen_weights = [None] * comm.pool.num_workers + + if position_options: + bposition_options = comm.pool.map(PositionOptions.split, + position_options, + [b[n] for b in batches]) + else: + bposition_options = None + + if object_options: + psi, cost, object_options = _update_object( + op, + comm, + bdata, + psi, + bscan, + probe, + eigen_probe=beigen_probe, + eigen_weights=beigen_weights, + object_options=object_options, + ) + psi = comm.pool.map(positivity_constraint, + psi, + r=object_options.positivity_constraint) + psi = comm.pool.map(smoothness_constraint, + psi, + a=object_options.smoothness_constraint) + + if probe_options: + for m in list(range(probe[0].shape[-3])): + probe, cost, probe_options = _update_probe( + op, + comm, + bdata, + psi, + bscan, + probe, + mode=[m], + probe_options=probe_options, + eigen_probe=beigen_probe, + eigen_weights=beigen_weights, + ) + + if position_options and comm.pool.num_workers == 1: + bscan, cost = update_positions_pd( + op, + comm.pool.gather(bdata, axis=-3), + psi[0], + probe[0], + comm.pool.gather(bscan, axis=-2), + ) + bscan = comm.pool.bcast([bscan]) + # TODO: Assign bscan into scan when positions are updated + + if isinstance(eigen_probe, list): + comm.pool.map( + put_batch, + beigen_weights, + eigen_weights, + batches, + n=n, + ) + + if isinstance(eigen_probe, list): + eigen_probe, eigen_weights = (list(a) for a in zip(*comm.pool.map( + constrain_variable_probe, + eigen_probe, + eigen_weights, + ))) + + result = { + 'psi': psi, + 'probe': probe, + 'cost': cost, + 'scan': scan, + 'probe_options': probe_options, + 'object_options': object_options, + 'position_options': position_options, + } + if isinstance(eigen_probe, list): + result['eigen_probe'] = eigen_probe + result['eigen_weights'] = eigen_weights + return result + + +def grad_probe(data, psi, scan, probe, mode=None, op=None): + """Compute the gradient with respect to the probe(s). + + Parameters + ---------- + mode : list(int) + Only return the gradient with resepect to these probes. + + """ + self = op + mode = list(range(probe.shape[-3])) if mode is None else mode + intensity, farplane = self._compute_intensity(data, psi, scan, probe) + # Use the average gradient for all probe positions + gradient = self.adj_probe( + farplane=self.propagation.grad( + data, + farplane[..., mode, :, :], + intensity, + ), + psi=psi, + scan=scan, + overwrite=True, + ) + mean_grad = self.xp.mean( + gradient, + axis=0, + keepdims=True, + ) + return mean_grad, gradient + + +def _update_probe( + op, + comm, + data, + psi, + scan, + probe, + mode, + probe_options, + eigen_probe, + eigen_weights, + step_length=0.1, +): + """Solve the probe recovery problem.""" + + def cost_function(probe): + unique_probe = comm.pool.map( + get_varying_probe, + probe, + eigen_probe, + eigen_weights, + ) + cost_out = comm.pool.map(op.cost, data, psi, scan, unique_probe) + if comm.use_mpi: + return comm.Allreduce_reduce(cost_out, 'cpu') + else: + return comm.reduce(cost_out, 'cpu') + + def grad(probe): + unique_probe = comm.pool.map( + get_varying_probe, + probe, + eigen_probe, + eigen_weights, + ) + mgrad_list, grad_list = zip(*comm.pool.map( + grad_probe, + data, + psi, + scan, + unique_probe, + mode=mode, + op=op, + )) + if comm.use_mpi: + return comm.Allreduce_reduce(mgrad_list, 'gpu'), grad_list + else: + return comm.reduce(mgrad_list, 'gpu'), grad_list + + def dir_multi(dir): + """Scatter dir to all GPUs""" + return comm.pool.bcast(dir) + + def update_multi(x, gamma, d): + + def f(x, d): + x[..., mode, :, :] = x[..., mode, :, :] - gamma * d + return x + + return comm.pool.map(f, x, d) + + d, gradient = grad(probe) + + probe_options.use_adaptive_moment = True + if probe_options.v is None or probe_options.m is None: + probe_options.v = cp.zeros_like(probe[0]) + probe_options.m = cp.zeros_like(probe[0]) + ( + d, + probe_options.v[..., mode, :, :], + probe_options.m[..., mode, :, :], + ) = adam( + g=d[0], + v=probe_options.v[..., mode, :, :], + m=probe_options.m[..., mode, :, :], + vdecay=probe_options.vdecay, + mdecay=probe_options.mdecay, + ) + d = [d] + + probe = update_multi( + probe, + gamma=step_length, + d=dir_multi(d), + ) + + if eigen_probe is not None: + residuals = comm.pool.map(cp.subtract, gradient, d) + comm.pool.map( + _update_eigen_modes, + residuals, + eigen_probe, + eigen_weights, + m=mode, + ) + + if probe[0].shape[-3] > 1 and probe_options.orthogonality_constraint: + probe = comm.pool.map(tike.linalg.orthogonalize_gs, + probe, + axis=(-2, -1)) + + cost = cost_function(probe) + + logger.info('%10s cost is %+12.5e', 'probe', cost) + return probe, cost, probe_options + + +def _update_eigen_modes(residuals, eigen_probe, eigen_weights, m, alpha=0.5): + residuals = cp.moveaxis(residuals, -5, -3) + residuals = residuals.reshape(*residuals.shape[:-2], -1) + W, C = pca_svd(residuals, k=eigen_probe.shape[-4]) + C = cp.moveaxis(C, -2, -3) + C = C.reshape(*C.shape[:-1], *eigen_probe.shape[-2:]) + W = cp.moveaxis(W[0], -2, -3) + W = cp.moveaxis(W, -1, -2) + eigen_probe[..., + m, :, :] = (1 - alpha) * eigen_probe[..., m, :, :] + alpha * C + eigen_weights[..., 1:, + m] = (1 - alpha) * eigen_weights[..., 1:, m] + alpha * W + + +def _update_object( + op, + comm, + data, + psi, + scan, + probe, + object_options, + eigen_probe, + eigen_weights, + step_length=0.1, +): + """Solve the object recovery problem.""" + + unique_probe = comm.pool.map( + get_varying_probe, + probe, + eigen_probe, + eigen_weights, + ) + + def cost_function_multi(psi, **kwargs): + cost_out = comm.pool.map(op.cost, data, psi, scan, unique_probe) + if comm.use_mpi: + return comm.Allreduce_reduce(cost_out, 'cpu') + else: + return comm.reduce(cost_out, 'cpu') + + def grad_multi(psi): + grad_list = comm.pool.map(op.grad_psi, data, psi, scan, unique_probe) + if comm.use_mpi: + return comm.Allreduce_reduce(grad_list, 'gpu') + else: + return comm.reduce(grad_list, 'gpu') + + def dir_multi(dir): + """Scatter dir to all GPUs""" + return comm.pool.bcast(dir) + + def update_multi(psi, gamma, dir): + + def f(psi, dir): + return psi + gamma * dir + + return list(comm.pool.map(f, psi, dir)) + + d = -grad_multi(psi)[0] + + object_options.use_adaptive_moment = True + d, object_options.v, object_options.m = adam( + g=d, + v=object_options.v, + m=object_options.m, + vdecay=object_options.vdecay, + mdecay=object_options.mdecay, + ) + d = [d] + + psi = update_multi( + psi, + gamma=step_length, + dir=dir_multi(d), + ) + + cost = cost_function_multi(psi) + + logger.info('%10s cost is %+12.5e', 'object', cost) + return psi, cost, object_options diff --git a/src/tike/ptycho/solvers/combined.py b/src/tike/ptycho/solvers/combined.py index 6ecf413e..2363b726 100644 --- a/src/tike/ptycho/solvers/combined.py +++ b/src/tike/ptycho/solvers/combined.py @@ -3,8 +3,9 @@ import numpy as np from tike.linalg import orthogonalize_gs -from tike.opt import conjugate_gradient, batch_indicies, get_batch +from tike.opt import conjugate_gradient, batch_indicies, get_batch, put_batch from ..position import update_positions_pd, PositionOptions +from ..probe import get_varying_probe, opr from ..object import positivity_constraint, smoothness_constraint logger = logging.getLogger(__name__) @@ -47,6 +48,18 @@ def cgrad( bdata = comm.pool.map(get_batch, data, batches, n=n) bscan = comm.pool.map(get_batch, scan, batches, n=n) + if isinstance(eigen_probe, list): + beigen_weights = comm.pool.map( + get_batch, + eigen_weights, + batches, + n=n, + ) + beigen_probe = eigen_probe + else: + beigen_probe = [None] * comm.pool.num_workers + beigen_weights = [None] * comm.pool.num_workers + if position_options: bposition_options = comm.pool.map(PositionOptions.split, position_options, @@ -55,15 +68,18 @@ def cgrad( bposition_options = None if object_options: - psi, cost = _update_object( + psi, cost, object_options = _update_object( op, comm, bdata, psi, bscan, probe, + eigen_probe=beigen_probe, + eigen_weights=beigen_weights, num_iter=cg_iter, step_length=step_length, + object_options=object_options, ) psi = comm.pool.map(positivity_constraint, psi, @@ -73,17 +89,19 @@ def cgrad( a=object_options.smoothness_constraint) if probe_options: - probe, cost = _update_probe( + probe, cost, probe_options = _update_probe( op, comm, bdata, psi, bscan, probe, + eigen_probe=beigen_probe, + eigen_weights=beigen_weights, num_iter=cg_iter, step_length=step_length, - probe_is_orthogonal=probe_options.orthogonality_constraint, mode=list(range(probe[0].shape[-3])), + probe_options=probe_options, ) if position_options and comm.pool.num_workers == 1: @@ -97,27 +115,72 @@ def cgrad( bscan = comm.pool.bcast([bscan]) # TODO: Assign bscan into scan when positions are updated - return {'psi': psi, 'probe': probe, 'cost': cost, 'scan': scan} - + if isinstance(eigen_probe, list): + comm.pool.map( + put_batch, + beigen_weights, + eigen_weights, + batches, + n=n, + ) -def _update_probe(op, comm, data, psi, scan, probe, num_iter, step_length, - probe_is_orthogonal, mode): + result = { + 'psi': psi, + 'probe': probe, + 'cost': cost, + 'scan': scan, + 'probe_options': probe_options, + 'object_options': object_options, + 'position_options': position_options, + } + if isinstance(eigen_probe, list): + result['eigen_probe'] = eigen_probe + result['eigen_weights'] = eigen_weights + return result + + +def _update_probe( + op, + comm, + data, + psi, + scan, + probe, + num_iter, + step_length, + mode, + probe_options, + eigen_probe, + eigen_weights, +): """Solve the probe recovery problem.""" def cost_function(probe): - cost_out = comm.pool.map(op.cost, data, psi, scan, probe) + unique_probe = comm.pool.map( + get_varying_probe, + probe, + eigen_probe, + eigen_weights, + ) + cost_out = comm.pool.map(op.cost, data, psi, scan, unique_probe) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad(probe): + unique_probe = comm.pool.map( + get_varying_probe, + probe, + eigen_probe, + eigen_weights, + ) grad_list = comm.pool.map( op.grad_probe, data, psi, scan, - probe, + unique_probe, mode=mode, ) if comm.use_mpi: @@ -147,25 +210,44 @@ def f(x, d): step_length=step_length, ) - if probe[0].shape[-3] > 1 and probe_is_orthogonal: + if probe[0].shape[-3] > 1 and probe_options.orthogonality_constraint: probe = comm.pool.map(orthogonalize_gs, probe, axis=(-2, -1)) logger.info('%10s cost is %+12.5e', 'probe', cost) - return probe, cost - - -def _update_object(op, comm, data, psi, scan, probe, num_iter, step_length): + return probe, cost, probe_options + + +def _update_object( + op, + comm, + data, + psi, + scan, + probe, + num_iter, + step_length, + object_options, + eigen_probe, + eigen_weights, +): """Solve the object recovery problem.""" + unique_probe = comm.pool.map( + get_varying_probe, + probe, + eigen_probe, + eigen_weights, + ) + def cost_function_multi(psi, **kwargs): - cost_out = comm.pool.map(op.cost, data, psi, scan, probe) + cost_out = comm.pool.map(op.cost, data, psi, scan, unique_probe) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad_multi(psi): - grad_list = comm.pool.map(op.grad_psi, data, psi, scan, probe) + grad_list = comm.pool.map(op.grad_psi, data, psi, scan, unique_probe) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: @@ -194,4 +276,43 @@ def f(psi, dir): ) logger.info('%10s cost is %+12.5e', 'object', cost) - return psi, cost + return psi, cost, object_options + + +def _update_eigen_probe(op, comm, data, psi, scan, probe, eigen_probe, + eigen_weights, alpha): + """Update the eigen probes and weights.""" + unique_probe = get_varying_probe( + probe, + eigen_probe, + eigen_weights, + ) + # Compute the gradient for each probe positions + intensity, farplane = op._compute_intensity(data, psi, scan, unique_probe) + gradients = op.adj_probe( + farplane=op.propagation.grad( + data, + farplane, + intensity, + ), + psi=psi, + scan=scan, + overwrite=True, + ) + + # Get the residual gradient for each probe position + # TODO: Maybe subtracting this mean is not necessary because we already + # updated the main probe. Or maybe it is because it keeps the residuals + # zero-mean + residuals = gradients - np.mean(gradients, axis=-5, keepdims=True) + + # Perform principal component analysis on the residual gradients + eigen_probe, eigen_weights = opr( + residuals, + eigen_probe, + eigen_weights, + eigen_weights.shape[-2], + alpha=alpha, + ) + + return eigen_probe, eigen_weights diff --git a/src/tike/ptycho/solvers/divided.py b/src/tike/ptycho/solvers/divided.py index 460c4209..a422b353 100644 --- a/src/tike/ptycho/solvers/divided.py +++ b/src/tike/ptycho/solvers/divided.py @@ -169,6 +169,10 @@ def lstsq_grad( result['eigen_weights'] = eigen_weights if position_options: result['position_options'] = position_options + if probe_options: + result['probe_options'] = probe_options + if object_options: + result['object_options'] = object_options return result diff --git a/src/tike/random.py b/src/tike/random.py index 2ba0d9f4..2c7941e0 100644 --- a/src/tike/random.py +++ b/src/tike/random.py @@ -3,10 +3,12 @@ import cupy as cp import numpy as np +randomizer = np.random.default_rng() + def numpy_complex(*shape): """Return a complex random array in the range [-0.5, 0.5).""" - return (np.random.rand(*shape, 2) - 0.5).view('complex')[..., 0] + return (randomizer.random((*shape, 2)) - 0.5).view('complex')[..., 0] def cupy_complex(*shape): diff --git a/tests/test_ptycho.py b/tests/test_ptycho.py index c91c21dd..0bb6e3ab 100644 --- a/tests/test_ptycho.py +++ b/tests/test_ptycho.py @@ -51,6 +51,7 @@ import os import pickle import unittest +import warnings import numpy as np from mpi4py import MPI @@ -272,6 +273,27 @@ def template_consistent_algorithm(self, algorithm, params={}): print('\n'.join(f'{c:1.3e}' for c in result['costs'])) return result + def test_consistent_adam_grad(self): + """Check ptycho.solver.cgrad for consistency.""" + + eigen_probe, weights = tike.ptycho.probe.init_varying_probe( + self.scan, self.probe, 4) + + _save_ptycho_result( + self.template_consistent_algorithm( + 'adam_grad', + params={ + 'subset_is_random': True, + 'batch_size': int(self.data.shape[-3] / 1), + 'num_gpu': 2, + 'probe_options': ProbeOptions(), + 'object_options': ObjectOptions(), + 'use_mpi': _mpi_size > 1, + 'eigen_probe': eigen_probe, + 'eigen_weights': weights, + }, + ), f"{'mpi-' if _mpi_size > 1 else ''}adam_grad") + def test_consistent_cgrad(self): """Check ptycho.solver.cgrad for consistency.""" _save_ptycho_result( @@ -317,7 +339,7 @@ def test_consistent_lstsq_grad_variable_probe(self): """Check ptycho.solver.lstsq_grad for consistency.""" eigen_probe, weights = tike.ptycho.probe.init_varying_probe( - self.scan, self.probe, 3) + self.scan, self.probe, 5) _save_ptycho_result( self.template_consistent_algorithm( @@ -347,6 +369,26 @@ def test_consistent_lstsq_grad_variable_probe(self): }, ), f"{'mpi-' if _mpi_size > 1 else ''}lstsq_grad-variable-probe") + def test_consistent_cgrad_variable_probe(self): + """Check ptycho.solver.cgrad for consistency.""" + eigen_probe, weights = tike.ptycho.probe.init_varying_probe( + self.scan, self.probe, 3) + + _save_ptycho_result( + self.template_consistent_algorithm( + 'cgrad', + params={ + 'subset_is_random': True, + 'batch_size': int(self.data.shape[-3] / 3), + 'num_gpu': 2, + 'probe_options': ProbeOptions(), + 'object_options': ObjectOptions(), + 'use_mpi': _mpi_size > 1, + 'eigen_probe': eigen_probe, + 'eigen_weights': weights, + }, + ), f"{'mpi-' if _mpi_size > 1 else ''}cgrad-variable-probe") + def test_invaid_algorithm_name(self): """Check that wrong names are handled gracefully.""" with self.assertRaises(ValueError): @@ -387,6 +429,55 @@ def test_eigen_probe(self): assert eigen_probe[0].shape == new_probe[0].shape +def _save_eigen_probe(output_folder, eigen_probe): + import matplotlib.pyplot as plt + flattened = [] + for i in range(eigen_probe.shape[-4]): + probe = eigen_probe[..., i, :, :, :] + flattened.append( + np.concatenate( + probe.reshape((-1, *probe.shape[-2:])), + axis=1, + )) + flattened = np.concatenate(flattened, axis=0) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + plt.imsave( + f'{output_folder}/eigen-phase.png', + np.angle(flattened), + # The output of np.angle is locked to (-pi, pi] + cmap=plt.cm.twilight, + vmin=-np.pi, + vmax=np.pi, + ) + plt.imsave( + f'{output_folder}/eigen-ampli.png', + np.abs(flattened), + ) + + +def _save_probe(output_folder, probe): + import matplotlib.pyplot as plt + flattened = np.concatenate( + probe.reshape((-1, *probe.shape[-2:])), + axis=-1, + ) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + plt.imsave( + f'{output_folder}/probe-phase.png', + np.angle(flattened), + # The output of np.angle is locked to (-pi, pi] + cmap=plt.cm.twilight, + vmin=-np.pi, + vmax=np.pi, + ) + plt.imsave( + f'{output_folder}/probe-ampli.png', + np.abs(flattened), + ) + + def _save_ptycho_result(result, algorithm): try: import matplotlib.pyplot as plt @@ -411,19 +502,9 @@ def _save_ptycho_result(result, algorithm): f'{fname}/{0}-ampli.png', np.abs(result['psi']).astype('float32'), ) - for i in range(result['probe'].shape[-3]): - plt.imsave( - f'{fname}/{i}-probe-phase.png', - np.angle(result['probe'][0, 0, i]), - # The output of np.angle is locked to (-pi, pi] - cmap=plt.cm.twilight, - vmin=-np.pi, - vmax=np.pi, - ) - plt.imsave( - f'{fname}/{i}-probe-ampli.png', - np.abs(result['probe'][0, 0, i]), - ) + _save_probe(fname, result['probe']) + if 'eigen_probe' in result and result['eigen_probe'] is not None: + _save_eigen_probe(fname, result['eigen_probe']) except ImportError: pass