Skip to content

Multi grid #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/tike/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,18 @@ def line_search(
# Decrease the step length while the step increases the cost function
step_count = 0
first_step = step_length
step_is_decreasing = False
while True:
xsd = update_multi(x, step_length, d)
fxsd = f(xsd)
if fxsd <= fx + step_shrink * m:
break
step_length *= step_shrink
if step_is_decreasing:
break
step_length /= step_shrink
else:
step_length *= step_shrink
step_is_decreasing = True

if step_length < 1e-32:
warnings.warn("Line search failed for conjugate gradient.")
step_length, fxsd, xsd = 0, fx, x
Expand Down
50 changes: 40 additions & 10 deletions src/tike/ptycho/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dataclasses
import logging
import typing
import copy

import cupy as cp
import cupyx.scipy.ndimage
Expand All @@ -25,6 +26,16 @@
class ObjectOptions:
"""Manage data and setting related to object correction."""

convergence_tolerance: float = 0
"""Terminate reconstruction early when the mnorm of the object update is
less than this value."""

update_mnorm: typing.List[float] = dataclasses.field(
init=False,
default_factory=list,
)
"""A record of the previous mnorms of the object update."""

positivity_constraint: float = 0
"""This value is passed to the tike.ptycho.object.positivity_constraint
function."""
Expand Down Expand Up @@ -66,39 +77,58 @@ class ObjectOptions:
)
"""Used for compact batch updates."""

multigrid_update: typing.Union[npt.NDArray, None] = dataclasses.field(
init=False,
default_factory=lambda: None,
)
"""Used for multigrid updates."""

clip_magnitude: bool = False
"""Whether to force the object magnitude to remain <= 1."""

def copy_to_device(self, comm):
"""Copy to the current GPU memory."""
options = copy.copy(self)
options.update_mnorm = copy.copy(self.update_mnorm)
if self.v is not None:
self.v = cp.asarray(self.v)
options.v = cp.asarray(self.v)
if self.m is not None:
self.m = cp.asarray(self.m)
options.m = cp.asarray(self.m)
if self.preconditioner is not None:
self.preconditioner = comm.pool.bcast([self.preconditioner])
return self
options.preconditioner = comm.pool.bcast([self.preconditioner])
if self.multigrid_update is not None:
options.multigrid_update = cp.asarray(self.multigrid_update)
return options

def copy_to_host(self):
"""Copy to the host CPU memory."""
options = copy.copy(self)
options.update_mnorm = copy.copy(self.update_mnorm)
if self.v is not None:
self.v = cp.asnumpy(self.v)
options.v = cp.asnumpy(self.v)
if self.m is not None:
self.m = cp.asnumpy(self.m)
options.m = cp.asnumpy(self.m)
if self.preconditioner is not None:
self.preconditioner = cp.asnumpy(self.preconditioner[0])
return self
options.preconditioner = cp.asnumpy(self.preconditioner[0])
if self.multigrid_update is not None:
options.multigrid_update = cp.asnumpy(self.multigrid_update)
return options

def resample(self, factor: float) -> ObjectOptions:
def resample(self, factor: float, interp) -> ObjectOptions:
"""Return a new `ObjectOptions` with the parameters rescaled."""
return ObjectOptions(
options = ObjectOptions(
convergence_tolerance=self.convergence_tolerance,
positivity_constraint=self.positivity_constraint,
smoothness_constraint=self.smoothness_constraint,
use_adaptive_moment=self.use_adaptive_moment,
vdecay=self.vdecay,
mdecay=self.mdecay,
clip_magnitude=self.clip_magnitude,
)
options.update_mnorm = copy.copy(self.update_mnorm)
if self.multigrid_update is not None:
options.multigrid_update = interp(self.multigrid_update, factor)
return options
# Momentum reset to zero when grid scale changes


Expand Down
19 changes: 11 additions & 8 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import dataclasses
import logging
import typing
import copy

import cupy as cp
import cupyx.scipy.ndimage
Expand Down Expand Up @@ -445,21 +446,23 @@ def join(self, other, indices):

def copy_to_device(self):
"""Copy to the current GPU memory."""
self.initial_scan = cp.asarray(self.initial_scan)
options = copy.copy(self)
options.initial_scan = cp.asarray(self.initial_scan)
if self.confidence is not None:
self.confidence = cp.asarray(self.confidence)
options.confidence = cp.asarray(self.confidence)
if self.use_adaptive_moment:
self._momentum = cp.asarray(self._momentum)
return self
options._momentum = cp.asarray(self._momentum)
return options

def copy_to_host(self):
"""Copy to the host CPU memory."""
self.initial_scan = cp.asnumpy(self.initial_scan)
options = copy.copy(self)
options.initial_scan = cp.asnumpy(self.initial_scan)
if self.confidence is not None:
self.confidence = cp.asnumpy(self.confidence)
options.confidence = cp.asnumpy(self.confidence)
if self.use_adaptive_moment:
self._momentum = cp.asnumpy(self._momentum)
return self
options._momentum = cp.asnumpy(self._momentum)
return options

def resample(self, factor: float) -> PositionOptions:
"""Return a new `PositionOptions` with the parameters scaled."""
Expand Down
36 changes: 26 additions & 10 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"""

from __future__ import annotations
import copy
import dataclasses
import logging
import typing
Expand Down Expand Up @@ -135,29 +136,41 @@ class ProbeOptions:
)
"""The power of the primary probe modes at each iteration."""

multigrid_update: typing.Union[npt.NDArray, None] = dataclasses.field(
init=False,
default_factory=lambda: None,
)
"""Used for multigrid updates."""

def copy_to_device(self, comm):
"""Copy to the current GPU memory."""
options = copy.copy(self)
if self.v is not None:
self.v = cp.asarray(self.v)
options.v = cp.asarray(self.v)
if self.m is not None:
self.m = cp.asarray(self.m)
options.m = cp.asarray(self.m)
if self.preconditioner is not None:
self.preconditioner = comm.pool.bcast([self.preconditioner])
return self
options.preconditioner = comm.pool.bcast([self.preconditioner])
if self.multigrid_update is not None:
options.multigrid_update = cp.asarray(self.multigrid_update)
return options

def copy_to_host(self):
"""Copy to the host CPU memory."""
options = copy.copy(self)
if self.v is not None:
self.v = cp.asnumpy(self.v)
options.v = cp.asnumpy(self.v)
if self.m is not None:
self.m = cp.asnumpy(self.m)
options.m = cp.asnumpy(self.m)
if self.preconditioner is not None:
self.preconditioner = cp.asnumpy(self.preconditioner[0])
return self
options.preconditioner = cp.asnumpy(self.preconditioner[0])
if self.multigrid_update is not None:
options.multigrid_update = cp.asnumpy(self.multigrid_update)
return options

def resample(self, factor: float) -> ProbeOptions:
def resample(self, factor: float, interp) -> ProbeOptions:
"""Return a new `ProbeOptions` with the parameters rescaled."""
return ProbeOptions(
options = ProbeOptions(
force_orthogonality=self.force_orthogonality,
force_centered_intensity=self.force_centered_intensity,
force_sparsity=self.force_sparsity,
Expand All @@ -168,6 +181,9 @@ def resample(self, factor: float) -> ProbeOptions:
probe_support_degree=self.probe_support_degree,
probe_support_radius=self.probe_support_radius,
)
if self.multigrid_update is not None:
options.multigrid_update = interp(self.multigrid_update, factor)
return options
# Momentum reset to zero when grid scale changes


Expand Down
Loading