diff --git a/src/tike/opt.py b/src/tike/opt.py index 7125e899..1a921dd2 100644 --- a/src/tike/opt.py +++ b/src/tike/opt.py @@ -259,12 +259,18 @@ def line_search( # Decrease the step length while the step increases the cost function step_count = 0 first_step = step_length + step_is_decreasing = False while True: xsd = update_multi(x, step_length, d) fxsd = f(xsd) if fxsd <= fx + step_shrink * m: - break - step_length *= step_shrink + if step_is_decreasing: + break + step_length /= step_shrink + else: + step_length *= step_shrink + step_is_decreasing = True + if step_length < 1e-32: warnings.warn("Line search failed for conjugate gradient.") step_length, fxsd, xsd = 0, fx, x diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 415f6aae..9c1a68d0 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -8,6 +8,7 @@ import dataclasses import logging import typing +import copy import cupy as cp import cupyx.scipy.ndimage @@ -25,6 +26,16 @@ class ObjectOptions: """Manage data and setting related to object correction.""" + convergence_tolerance: float = 0 + """Terminate reconstruction early when the mnorm of the object update is + less than this value.""" + + update_mnorm: typing.List[float] = dataclasses.field( + init=False, + default_factory=list, + ) + """A record of the previous mnorms of the object update.""" + positivity_constraint: float = 0 """This value is passed to the tike.ptycho.object.positivity_constraint function.""" @@ -66,32 +77,47 @@ class ObjectOptions: ) """Used for compact batch updates.""" + multigrid_update: typing.Union[npt.NDArray, None] = dataclasses.field( + init=False, + default_factory=lambda: None, + ) + """Used for multigrid updates.""" + clip_magnitude: bool = False """Whether to force the object magnitude to remain <= 1.""" def copy_to_device(self, comm): """Copy to the current GPU memory.""" + options = copy.copy(self) + options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: - self.v = cp.asarray(self.v) + options.v = cp.asarray(self.v) if self.m is not None: - self.m = cp.asarray(self.m) + options.m = cp.asarray(self.m) if self.preconditioner is not None: - self.preconditioner = comm.pool.bcast([self.preconditioner]) - return self + options.preconditioner = comm.pool.bcast([self.preconditioner]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asarray(self.multigrid_update) + return options def copy_to_host(self): """Copy to the host CPU memory.""" + options = copy.copy(self) + options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: - self.v = cp.asnumpy(self.v) + options.v = cp.asnumpy(self.v) if self.m is not None: - self.m = cp.asnumpy(self.m) + options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - self.preconditioner = cp.asnumpy(self.preconditioner[0]) - return self + options.preconditioner = cp.asnumpy(self.preconditioner[0]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asnumpy(self.multigrid_update) + return options - def resample(self, factor: float) -> ObjectOptions: + def resample(self, factor: float, interp) -> ObjectOptions: """Return a new `ObjectOptions` with the parameters rescaled.""" - return ObjectOptions( + options = ObjectOptions( + convergence_tolerance=self.convergence_tolerance, positivity_constraint=self.positivity_constraint, smoothness_constraint=self.smoothness_constraint, use_adaptive_moment=self.use_adaptive_moment, @@ -99,6 +125,10 @@ def resample(self, factor: float) -> ObjectOptions: mdecay=self.mdecay, clip_magnitude=self.clip_magnitude, ) + options.update_mnorm = copy.copy(self.update_mnorm) + if self.multigrid_update is not None: + options.multigrid_update = interp(self.multigrid_update, factor) + return options # Momentum reset to zero when grid scale changes diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 937b8254..c8b7e084 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -119,6 +119,7 @@ import dataclasses import logging import typing +import copy import cupy as cp import cupyx.scipy.ndimage @@ -445,21 +446,23 @@ def join(self, other, indices): def copy_to_device(self): """Copy to the current GPU memory.""" - self.initial_scan = cp.asarray(self.initial_scan) + options = copy.copy(self) + options.initial_scan = cp.asarray(self.initial_scan) if self.confidence is not None: - self.confidence = cp.asarray(self.confidence) + options.confidence = cp.asarray(self.confidence) if self.use_adaptive_moment: - self._momentum = cp.asarray(self._momentum) - return self + options._momentum = cp.asarray(self._momentum) + return options def copy_to_host(self): """Copy to the host CPU memory.""" - self.initial_scan = cp.asnumpy(self.initial_scan) + options = copy.copy(self) + options.initial_scan = cp.asnumpy(self.initial_scan) if self.confidence is not None: - self.confidence = cp.asnumpy(self.confidence) + options.confidence = cp.asnumpy(self.confidence) if self.use_adaptive_moment: - self._momentum = cp.asnumpy(self._momentum) - return self + options._momentum = cp.asnumpy(self._momentum) + return options def resample(self, factor: float) -> PositionOptions: """Return a new `PositionOptions` with the parameters scaled.""" diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index ef2d64c7..60a507d4 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -36,6 +36,7 @@ """ from __future__ import annotations +import copy import dataclasses import logging import typing @@ -135,29 +136,41 @@ class ProbeOptions: ) """The power of the primary probe modes at each iteration.""" + multigrid_update: typing.Union[npt.NDArray, None] = dataclasses.field( + init=False, + default_factory=lambda: None, + ) + """Used for multigrid updates.""" + def copy_to_device(self, comm): """Copy to the current GPU memory.""" + options = copy.copy(self) if self.v is not None: - self.v = cp.asarray(self.v) + options.v = cp.asarray(self.v) if self.m is not None: - self.m = cp.asarray(self.m) + options.m = cp.asarray(self.m) if self.preconditioner is not None: - self.preconditioner = comm.pool.bcast([self.preconditioner]) - return self + options.preconditioner = comm.pool.bcast([self.preconditioner]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asarray(self.multigrid_update) + return options def copy_to_host(self): """Copy to the host CPU memory.""" + options = copy.copy(self) if self.v is not None: - self.v = cp.asnumpy(self.v) + options.v = cp.asnumpy(self.v) if self.m is not None: - self.m = cp.asnumpy(self.m) + options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - self.preconditioner = cp.asnumpy(self.preconditioner[0]) - return self + options.preconditioner = cp.asnumpy(self.preconditioner[0]) + if self.multigrid_update is not None: + options.multigrid_update = cp.asnumpy(self.multigrid_update) + return options - def resample(self, factor: float) -> ProbeOptions: + def resample(self, factor: float, interp) -> ProbeOptions: """Return a new `ProbeOptions` with the parameters rescaled.""" - return ProbeOptions( + options = ProbeOptions( force_orthogonality=self.force_orthogonality, force_centered_intensity=self.force_centered_intensity, force_sparsity=self.force_sparsity, @@ -168,6 +181,9 @@ def resample(self, factor: float) -> ProbeOptions: probe_support_degree=self.probe_support_degree, probe_support_radius=self.probe_support_radius, ) + if self.multigrid_update is not None: + options.multigrid_update = interp(self.multigrid_update, factor) + return options # Momentum reset to zero when grid scale changes diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index de904e8d..9a668812 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -54,6 +54,7 @@ "simulate", "Reconstruction", "reconstruct_multigrid", + "reconstruct_multigrid_new", ] import copy @@ -317,8 +318,7 @@ def __init__( mpi = tike.communicators.NoMPIComm self.data = data - self.parameters = parameters - self._device_parameters = copy.deepcopy(parameters) + self.parameters = copy.deepcopy(parameters) self.device = cp.cuda.Device( num_gpu[0] if isinstance(num_gpu, tuple) else None) self.operator = tike.operators.Ptycho( @@ -343,9 +343,9 @@ def __enter__(self): odd_pool = self.comm.pool.num_workers % 2 ( self.comm.order, - self._device_parameters.scan, + self.parameters.scan, self.data, - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, ) = tike.cluster.by_scan_grid( self.comm.pool, ( @@ -356,55 +356,53 @@ def __enter__(self): (tike.precision.floating, tike.precision.floating if self.data.itemsize > 2 else self.data.dtype, tike.precision.floating), - self._device_parameters.scan, + self.parameters.scan, self.data, - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, ) - self._device_parameters.psi = self.comm.pool.bcast( - [self._device_parameters.psi.astype(tike.precision.cfloating)]) + self.parameters.psi = self.comm.pool.bcast( + [self.parameters.psi.astype(tike.precision.cfloating)]) - self._device_parameters.probe = self.comm.pool.bcast( - [self._device_parameters.probe.astype(tike.precision.cfloating)]) + self.parameters.probe = self.comm.pool.bcast( + [self.parameters.probe.astype(tike.precision.cfloating)]) - if self._device_parameters.probe_options is not None: - self._device_parameters.probe_options = self._device_parameters.probe_options.copy_to_device( + if self.parameters.probe_options is not None: + self.parameters.probe_options = self.parameters.probe_options.copy_to_device( self.comm,) - if self._device_parameters.object_options is not None: - self._device_parameters.object_options = self._device_parameters.object_options.copy_to_device( + if self.parameters.object_options is not None: + self.parameters.object_options = self.parameters.object_options.copy_to_device( self.comm,) - if self._device_parameters.eigen_probe is not None: - self._device_parameters.eigen_probe = self.comm.pool.bcast([ - self._device_parameters.eigen_probe.astype( - tike.precision.cfloating) - ]) + if self.parameters.eigen_probe is not None: + self.parameters.eigen_probe = self.comm.pool.bcast( + [self.parameters.eigen_probe.astype(tike.precision.cfloating)]) - if self._device_parameters.position_options is not None: + if self.parameters.position_options is not None: # TODO: Consider combining put/split, get/join operations? - self._device_parameters.position_options = self.comm.pool.map( + self.parameters.position_options = self.comm.pool.map( PositionOptions.copy_to_device, - (self._device_parameters.position_options.split(x) + (self.parameters.position_options.split(x) for x in self.comm.order), ) # Unique batch for each device self.batches = self.comm.pool.map( getattr(tike.cluster, - self._device_parameters.algorithm_options.batch_method), - self._device_parameters.scan, - num_cluster=self._device_parameters.algorithm_options.num_batch, + self.parameters.algorithm_options.batch_method), + self.parameters.scan, + num_cluster=self.parameters.algorithm_options.num_batch, ) - self._device_parameters.probe = _rescale_probe( + self.parameters.probe = _rescale_probe( self.operator, self.comm, self.data, - self._device_parameters.psi, - self._device_parameters.scan, - self._device_parameters.probe, - num_batch=self._device_parameters.algorithm_options.num_batch, + self.parameters.psi, + self.parameters.scan, + self.parameters.probe, + num_batch=self.parameters.algorithm_options.num_batch, ) return self @@ -412,119 +410,126 @@ def __enter__(self): def iterate(self, num_iter: int) -> None: """Advance the reconstruction by num_iter epochs.""" start = time.perf_counter() + psi_previous = self.parameters.psi[0].copy() for i in range(num_iter): - logger.info( - f"{self._device_parameters.algorithm_options.name} epoch " - f"{len(self._device_parameters.algorithm_options.times):,d}") + logger.info(f"{self.parameters.algorithm_options.name} epoch " + f"{len(self.parameters.algorithm_options.times):,d}") - if self._device_parameters.probe_options is not None: - if self._device_parameters.probe_options.force_centered_intensity: - self._device_parameters.probe = self.comm.pool.map( + if self.parameters.probe_options is not None: + if self.parameters.probe_options.force_centered_intensity: + self.parameters.probe = self.comm.pool.map( constrain_center_peak, - self._device_parameters.probe, + self.parameters.probe, ) - if self._device_parameters.probe_options.force_sparsity < 1: - self._device_parameters.probe = self.comm.pool.map( + if self.parameters.probe_options.force_sparsity < 1: + self.parameters.probe = self.comm.pool.map( constrain_probe_sparsity, - self._device_parameters.probe, - f=self._device_parameters.probe_options - .force_sparsity, + self.parameters.probe, + f=self.parameters.probe_options.force_sparsity, ) - if self._device_parameters.probe_options.force_orthogonality: + if self.parameters.probe_options.force_orthogonality: ( - self._device_parameters.probe, + self.parameters.probe, power, ) = (list(a) for a in zip(*self.comm.pool.map( tike.ptycho.probe.orthogonalize_eig, - self._device_parameters.probe, + self.parameters.probe, ))) - self._device_parameters.probe_options.power.append(power[0].get()) + self.parameters.probe_options.power.append(power[0].get()) - self._device_parameters = getattr( + self.parameters = getattr( solvers, - self._device_parameters.algorithm_options.name, + self.parameters.algorithm_options.name, )( self.operator, self.comm, data=self.data, batches=self.batches, - parameters=self._device_parameters, + parameters=self.parameters, ) - if self._device_parameters.object_options.clip_magnitude: - self._device_parameters.psi = self.comm.pool.map( + if self.parameters.object_options.clip_magnitude: + self.parameters.psi = self.comm.pool.map( _clip_magnitude, - self._device_parameters.psi, + self.parameters.psi, a_max=1.0, ) - if (self._device_parameters.position_options - and self._device_parameters.position_options[0] - .use_position_regularization): + if (self.parameters.position_options and self.parameters + .position_options[0].use_position_regularization): - (self._device_parameters.position_options + (self.parameters.position_options ) = affine_position_regularization( self.comm, - updated=self._device_parameters.scan, - position_options=self._device_parameters.position_options, + updated=self.parameters.scan, + position_options=self.parameters.position_options, ) - self._device_parameters.algorithm_options.times.append( - time.perf_counter() - start) + self.parameters.algorithm_options.times.append(time.perf_counter() - + start) start = time.perf_counter() - if tike.opt.is_converged(self._device_parameters.algorithm_options): + update_norm = tike.linalg.mnorm(self.parameters.psi[0] - + psi_previous) + self.parameters.object_options.update_mnorm.append( + update_norm.get()) + logger.info(f"The object update mean-norm is {update_norm:.3e}") + if (np.mean(self.parameters.object_options.update_mnorm[-5:]) < + self.parameters.object_options.convergence_tolerance): + logger.info( + f"The object seems converged. {update_norm:.3e} < " + f"{self.parameters.object_options.convergence_tolerance:.3e}" + ) break - def _get_result(self): + def get_result(self): """Return the current parameter estimates.""" - self.parameters.probe = self._device_parameters.probe[0].get() + reorder = np.argsort(np.concatenate(self.comm.order)) + parameters = solvers.PtychoParameters( + probe=self.parameters.probe[0].get(), + psi=self.parameters.psi[0].get(), + scan=self.comm.pool.gather_host( + self.parameters.scan, + axis=-2, + )[reorder], + algorithm_options=self.parameters.algorithm_options, + ) - self.parameters.psi = self._device_parameters.psi[0].get() + if self.parameters.eigen_probe is not None: + parameters.eigen_probe = self.parameters.eigen_probe[0].get() - reorder = np.argsort(np.concatenate(self.comm.order)) - self.parameters.scan = self.comm.pool.gather_host( - self._device_parameters.scan, - axis=-2, - )[reorder] - - if self._device_parameters.eigen_probe is not None: - self.parameters.eigen_probe = self._device_parameters.eigen_probe[ - 0].get() - - if self._device_parameters.eigen_weights is not None: - self.parameters.eigen_weights = self.comm.pool.gather( - self._device_parameters.eigen_weights, + if self.parameters.eigen_weights is not None: + parameters.eigen_weights = self.comm.pool.gather( + self.parameters.eigen_weights, axis=-3, )[reorder].get() - self.parameters.algorithm_options = self._device_parameters.algorithm_options - - if self._device_parameters.probe_options is not None: - self.parameters.probe_options = self._device_parameters.probe_options.copy_to_host( + if self.parameters.probe_options is not None: + parameters.probe_options = self.parameters.probe_options.copy_to_host( ) - if self._device_parameters.object_options is not None: - self.parameters.object_options = self._device_parameters.object_options.copy_to_host( + if self.parameters.object_options is not None: + parameters.object_options = self.parameters.object_options.copy_to_host( ) - if self._device_parameters.position_options is not None: - host_position_options = self._device_parameters.position_options[ - 0].empty() + if self.parameters.position_options is not None: + host_position_options = self.parameters.position_options[0].empty() for x, o in zip( self.comm.pool.map( PositionOptions.copy_to_host, - self._device_parameters.position_options, + self.parameters.position_options, ), self.comm.order, ): host_position_options = host_position_options.join(x, o) - self.parameters.position_options = host_position_options + parameters.position_options = host_position_options + + return parameters def __exit__(self, type, value, traceback): - self._get_result() + self.parameters = self.get_result() self.comm.__exit__(type, value, traceback) self.operator.__exit__(type, value, traceback) self.device.__exit__(type, value, traceback) @@ -534,29 +539,29 @@ def get_convergence( ) -> typing.Tuple[typing.List[typing.List[float]], typing.List[float]]: """Return the cost function values and times as a tuple.""" return ( - self._device_parameters.algorithm_options.costs, - self._device_parameters.algorithm_options.times, + self.parameters.algorithm_options.costs, + self.parameters.algorithm_options.times, ) def get_psi(self) -> np.array: """Return the current object estimate as a numpy array.""" - return self._device_parameters.psi[0].get() + return self.parameters.psi[0].get() def get_probe(self) -> typing.Tuple[np.array, np.array, np.array]: """Return the current probe, eigen_probe, weights as numpy arrays.""" reorder = np.argsort(np.concatenate(self.comm.order)) - if self._device_parameters.eigen_probe is None: + if self.parameters.eigen_probe is None: eigen_probe = None else: - eigen_probe = self._device_parameters.eigen_probe[0].get() - if self._device_parameters.eigen_weights is None: + eigen_probe = self.parameters.eigen_probe[0].get() + if self.parameters.eigen_weights is None: eigen_weights = None else: eigen_weights = self.comm.pool.gather( - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, axis=-3, )[reorder].get() - probe = self._device_parameters.probe[0].get() + probe = self.parameters.probe[0].get() return probe, eigen_probe, eigen_weights def peek(self) -> typing.Tuple[np.array, np.array, np.array, np.array]: @@ -592,7 +597,7 @@ def append_new_data( if odd_pool else self.comm.pool.num_workers // 2, 1 if odd_pool else 2, ), - (self._device_parameters.scan[0].dtype, self.data[0].dtype), + (self.parameters.scan[0].dtype, self.data[0].dtype), new_scan, new_data, ) @@ -604,9 +609,9 @@ def append_new_data( new_data, axis=0, ) - self._device_parameters.scan = self.comm.pool.map( + self.parameters.scan = self.comm.pool.map( cp.append, - self._device_parameters.scan, + self.parameters.scan, new_scan, axis=0, ) @@ -619,15 +624,15 @@ def append_new_data( # Rebatch on each device self.batches = self.comm.pool.map( getattr(tike.cluster, - self._device_parameters.algorithm_options.batch_method), - self._device_parameters.scan, - num_cluster=self._device_parameters.algorithm_options.num_batch, + self.parameters.algorithm_options.batch_method), + self.parameters.scan, + num_cluster=self.parameters.algorithm_options.num_batch, ) - if self._device_parameters.eigen_weights is not None: - self._device_parameters.eigen_weights = self.comm.pool.map( + if self.parameters.eigen_weights is not None: + self.parameters.eigen_weights = self.comm.pool.map( cp.pad, - self._device_parameters.eigen_weights, + self.parameters.eigen_weights, pad_width=( (0, len(new_scan)), # position (0, 0), # eigen @@ -636,10 +641,10 @@ def append_new_data( mode='mean', ) - if self._device_parameters.position_options is not None: - self._device_parameters.position_options = self.comm.pool.map( + if self.parameters.position_options is not None: + self.parameters.position_options = self.comm.pool.map( PositionOptions.append, - self._device_parameters.position_options, + self.parameters.position_options, new_scan, ) @@ -709,7 +714,7 @@ def reconstruct_multigrid( num_gpu: typing.Union[int, typing.Tuple[int, ...]] = 1, use_mpi: bool = False, num_levels: int = 3, - interp=None, + interp: typing.Callable = solvers.options._resize_fft, ) -> solvers.PtychoParameters: """Solve the ptychography problem using a multi-grid method. @@ -758,3 +763,196 @@ def reconstruct_multigrid( resampled_parameters = context.parameters.resample(2.0, interp) raise RuntimeError('This should not happen.') + + +def reconstruct_multigrid_new( + data: npt.NDArray, + parameters: solvers.PtychoParameters, + model: str = 'gaussian', + num_gpu: typing.Union[int, typing.Tuple[int, ...]] = 1, + use_mpi: bool = False, + num_levels: int = 3, + level: int = 0, + interp: typing.Callable = solvers.options._resize_mean, +) -> solvers.PtychoParameters: + """Solve the ptychography problem using a multi-grid method. + + .. versionadded:: 0.23.2 + + Uses the same parameters as the functional reconstruct API. This function + applies a multi-grid approach to the problem by downsampling the real-space + input parameters and cropping the diffraction patterns to reduce the + computational cost of early iterations. + + Parameters + ---------- + num_levels : int > 0 + The number of times to reduce the problem by a factor of two. + + + .. seealso:: :py:func:`tike.ptycho.ptycho.reconstruct` + """ + + if level == 0 and (data.shape[-1] * 0.5**(num_levels - 1) < 64): + warnings.warn('Cropping diffraction patterns to less than 64 pixels ' + 'wide is not recommended because the full doughnut' + ' may not be visible.') + + with tike.ptycho.Reconstruction( + data=data, + parameters=parameters, + model=model, + num_gpu=num_gpu, + use_mpi=use_mpi, + ) as context: + + if context.parameters.object_options.multigrid_update is not None: + grad_psi = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_psi = context.comm.pool.map( + context.operator.grad_psi, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_psi += context.comm.Allreduce_reduce_gpu(_grad_psi)[0] + context.parameters.object_options.multigrid_update += -grad_psi + + if context.parameters.probe_options.multigrid_update is not None: + grad_probe = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_probe = context.comm.pool.map( + context.operator.grad_probe, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_probe += context.comm.Allreduce_reduce_gpu(_grad_probe)[0] + context.parameters.probe_options.multigrid_update += -grad_probe + + logging.info(f'Multigrid level {level} pre-smoothing') + + # pre-smoothing + context.iterate(4) + + if level + 1 < num_levels: + + # coarse-grid correction + parameters_coarser = context.get_result().resample(0.5, interp) + data_coarser = solvers.crop_fourier_space( + data, + data.shape[-1] // 2, + ) + + grad_psi = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_psi = context.comm.pool.map( + context.operator.grad_psi, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_psi += context.comm.Allreduce_reduce_cpu(_grad_psi) + if context.parameters.object_options.multigrid_update is None: + parameters_coarser.object_options.multigrid_update = interp( + grad_psi, 0.5) + else: + parameters_coarser.object_options.multigrid_update += interp( + grad_psi, 0.5) + + grad_probe = 0 + for n in range(context.parameters.algorithm_options.num_batch): + _grad_probe = context.comm.pool.map( + context.operator.grad_probe, + context.comm.pool.map(tike.opt.get_batch, context.data, context.batches, n=n), + context.parameters.psi, + context.comm.pool.map(tike.opt.get_batch, context.parameters.scan, context.batches, n=n), + context.parameters.probe, + ) + grad_probe += context.comm.Allreduce_reduce_cpu(_grad_probe) + if context.parameters.probe_options.multigrid_update is None: + parameters_coarser.probe_options.multigrid_update = interp( + grad_probe, 0.5) + else: + parameters_coarser.probe_options.multigrid_update += interp( + grad_probe, 0.5) + + parameters_coarser_updated = reconstruct_multigrid_new( + data=data_coarser, + parameters=parameters_coarser, + num_gpu=num_gpu, + model=model, + use_mpi=use_mpi, + num_levels=num_levels, + level=level + 1, + interp=interp, + ) + + context.parameters.algorithm_options.times = parameters_coarser_updated.algorithm_options.times + context.parameters.algorithm_options.costs = parameters_coarser_updated.algorithm_options.costs + context.parameters.object_options.update_mnorm = parameters_coarser_updated.object_options.update_mnorm + + def update_multi(x, gamma, dir): + + def f(x, dir): + return x + gamma * dir + + return list(context.comm.pool.map(f, x, dir)) + + def cost_function_psi(psi, **kwargs): + cost_out = context.comm.pool.map( + context.operator.cost, + context.data, + psi, + context.parameters.scan, + context.parameters.probe, + ) + return context.comm.Allreduce_mean(cost_out, axis=None).get() + + def cost_function_probe(probe, **kwargs): + cost_out = context.comm.pool.map( + context.operator.cost, + context.data, + context.parameters.psi, + context.parameters.scan, + probe, + ) + return context.comm.Allreduce_mean(cost_out, axis=None).get() + + logging.info(f'Multigrid level {level} upsample update') + + _, _, context.parameters.psi = tike.opt.line_search( + f=cost_function_psi, + x=context.parameters.psi, + d=context.comm.pool.bcast([ + interp( + parameters_coarser_updated.psi - parameters_coarser.psi, + 2.0, + ) + ]), + update_multi=update_multi, + ) + + _, _, context.parameters.probe = tike.opt.line_search( + f=cost_function_probe, + x=context.parameters.probe, + d=context.comm.pool.bcast([ + interp( + parameters_coarser_updated.probe - + parameters_coarser.probe, + 2.0, + ) + ]), + update_multi=update_multi, + ) + + logging.info(f'Multigrid level {level} post-smoothing') + + # post-smoothing + context.iterate(parameters.algorithm_options.num_iter) + + print(f"Return level {level}") + return context.parameters diff --git a/src/tike/ptycho/solvers/adam.py b/src/tike/ptycho/solvers/adam.py index feab4135..c814ca12 100644 --- a/src/tike/ptycho/solvers/adam.py +++ b/src/tike/ptycho/solvers/adam.py @@ -224,6 +224,8 @@ def _update_all( mdecay=object_options.mdecay, ) psi[0] = psi[0] - algorithm_options.step_length * dpsi / deno + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update / deno psi = comm.pool.bcast([psi[0]]) if probe_options: @@ -243,10 +245,12 @@ def _update_all( g=dprobe, v=probe_options.v, m=probe_options.m, - vdecay=object_options.vdecay, - mdecay=object_options.mdecay, + vdecay=probe_options.vdecay, + mdecay=probe_options.mdecay, ) probe[0] = probe[0] - algorithm_options.step_length * dprobe + if probe_options.multigrid_update is not None: + probe[0] = probe[0] + probe_options.multigrid_update / deno probe = comm.pool.bcast([probe[0]]) return psi, probe diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index b8932088..a962ab60 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -279,6 +279,8 @@ def _apply_update( mdecay=object_options.mdecay, ) new_psi = dpsi + psi[0] + if object_options.multigrid_update is not None: + new_psi = new_psi + object_options.multigrid_update psi = comm.pool.bcast([new_psi]) if recover_probe: @@ -301,6 +303,8 @@ def _apply_update( mdecay=probe_options.mdecay, ) new_probe = dprobe + probe[0] + if probe_options.multigrid_update is not None: + new_probe = new_probe + probe_options.multigrid_update probe = comm.pool.bcast([new_probe]) return psi, probe diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index eebbfd9b..ff121b38 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -236,6 +236,9 @@ def lstsq_grad( dpsi = beta_object * object_update_precond psi[0] = psi[0] + dpsi + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update + if object_options.use_adaptive_moment: ( dpsi, @@ -557,6 +560,10 @@ def _update_nearplane( mdecay=object_options.mdecay, ) psi[0] = psi[0] + dpsi + + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update + psi = comm.pool.bcast([psi[0]]) else: object_options.combined_update += object_upd_sum[0] diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index 0f86c26c..c8589d15 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -49,6 +49,7 @@ class IterativeOptions(abc.ABC): """The number of epochs to consider for convergence monitoring. Set to any value less than 2 to disable.""" + @dataclasses.dataclass class AdamOptions(IterativeOptions): name: str = dataclasses.field(default='adam_grad', init=False) @@ -178,9 +179,9 @@ def resample( if self.eigen_probe is not None else None, eigen_weights=self.eigen_weights, algorithm_options=self.algorithm_options, - probe_options=self.probe_options.resample(factor) + probe_options=self.probe_options.resample(factor, interp) if self.probe_options is not None else None, - object_options=self.object_options.resample(factor) + object_options=self.object_options.resample(factor, interp) if self.object_options is not None else None, position_options=self.position_options.resample(factor) if self.position_options is not None else None, @@ -264,3 +265,24 @@ def _resize_fft(x: np.ndarray, f: float) -> np.ndarray: norm='ortho', axes=(-2, -1), ) + + +def _resize_mean(x: np.ndarray, f: float) -> np.ndarray: + """Use an averaging filter to resize/resample the last 2 dimensions of x""" + if f == 1: + return x + if f < 1: + new_shape = ( + *x.shape[:-2], + int(x.shape[-2] * f), + int(1.0 / f), + int(x.shape[-1] * f), + int(1.0 / f), + ) + return np.sum(x.reshape(new_shape), axis=(-1, -3)) * (f * f) + else: + return np.repeat( + np.repeat(x, repeats=f, axis=-2), + repeats=int(f), + axis=-1, + ) * (f * f) diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index ba646464..e42b0ad9 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -354,6 +354,8 @@ def _update( mdecay=object_options.mdecay, ) psi[0] = psi[0] + dpsi / deno + if object_options.multigrid_update is not None: + psi[0] = psi[0] + object_options.multigrid_update / deno psi = comm.pool.bcast([psi[0]]) if probe_options: @@ -398,10 +400,12 @@ def _update( g=(dprobe)[0, 0, mode, :, :], v=probe_options.v, m=probe_options.m, - vdecay=object_options.vdecay, - mdecay=object_options.mdecay, + vdecay=probe_options.vdecay, + mdecay=probe_options.mdecay, ) probe[0] = probe[0] + dprobe / deno + if probe_options.multigrid_update is not None: + probe[0] = probe[0] + probe_options.multigrid_update / deno probe = comm.pool.bcast([probe[0]]) return psi, probe diff --git a/tests/ptycho/test_multigrid.py b/tests/ptycho/test_multigrid.py index c55e693e..9b3f25c9 100644 --- a/tests/ptycho/test_multigrid.py +++ b/tests/ptycho/test_multigrid.py @@ -14,6 +14,7 @@ _resize_cubic, _resize_lanczos, _resize_linear, + _resize_mean, ) from .templates import _mpi_size @@ -23,12 +24,36 @@ output_folder = os.path.join(result_dir, 'multigrid') +def test_resize_mean(): + x0 = np.array([[ + [0, 1], + [5, 7], + ]]) + x3 = np.array([[ + [0, 1*9], + [5*9, 7*9], + ]]) + x = np.array([[ + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [5, 5, 5, 7, 7, 7], + [5, 5, 5, 7, 7, 7], + [5, 5, 5, 7, 7, 7], + ]]) * 9 + x1 = _resize_mean(x0, 3.0) + np.testing.assert_equal(x1, x) + x2 = _resize_mean(x, 1.0/3.0) + np.testing.assert_equal(x2, x3) + + @pytest.mark.parametrize("function", [ _resize_fft, _resize_spline, _resize_linear, _resize_cubic, _resize_lanczos, + _resize_mean, ]) def test_resample(function, filename='siemens-star-small.npz.bz2'): @@ -86,6 +111,49 @@ def template_consistent_algorithm(self, *, data, params): return parameters +@unittest.skipIf( + _mpi_size > 1, + reason="MPI not implemented for multi-grid.", +) +class ReconMultiGridNew(): + """Test ptychography multi-grid reconstruction method.""" + + def interp(self, x, f): + pass + + def template_consistent_algorithm(self, *, data, params): + """Check ptycho.solver.algorithm for consistency.""" + if _mpi_size > 1: + raise NotImplementedError() + + with cp.cuda.Device(self.gpu_indices[0]): + parameters = tike.ptycho.reconstruct_multigrid_new( + parameters=params, + data=self.data, + num_gpu=self.gpu_indices, + use_mpi=self.mpi_size > 1, + num_levels=2, + interp=self.interp, + ) + + print() + print('\n'.join( + f'{c[0]:1.3e}' for c in parameters.algorithm_options.costs)) + return parameters + + +class TestPtychoReconMultiGridMean( + ReconMultiGridNew, + PtychoRecon, + unittest.TestCase, +): + + post_name = '-multigrid-mean' + + def interp(self, x, f): + return _resize_mean(x, f) + + class TestPtychoReconMultiGridFFT( ReconMultiGrid, PtychoRecon,