diff --git a/src/tike/operators/cupy/__init__.py b/src/tike/operators/cupy/__init__.py index cb2a9d6f..a67a72f5 100644 --- a/src/tike/operators/cupy/__init__.py +++ b/src/tike/operators/cupy/__init__.py @@ -11,11 +11,12 @@ from .convolution import * from .flow import * from .lamino import * -from .operator import * from .objective import * +from .operator import * from .pad import * from .patch import * from .propagation import * from .ptycho import * from .rotate import * from .shift import * +from .zernike import * diff --git a/src/tike/operators/cupy/zernike.py b/src/tike/operators/cupy/zernike.py new file mode 100644 index 00000000..0e6ffc3d --- /dev/null +++ b/src/tike/operators/cupy/zernike.py @@ -0,0 +1,61 @@ +"""Defines an inverse-Zernike transform operator.""" + +__author__ = "Daniel Ching" +__copyright__ = "Copyright (c) 2024, UChicago Argonne, LLC." + +import numpy.typing as npt +import numpy as np +import tike.zernike +import tike.linalg + +from .operator import Operator + + +class Zernike(Operator): + """Reconstruct an image from coefficients and zernike basis using CuPy. + + Take an (..., W) array of zernike coefficients and reconstruct an image + from them. + + Parameters + ---------- + size : int + The pixel width and height of the reconstruction. + weights: (..., W) complex64 + The zernike coefficients + + .. versionadded:: 0.25.5 + + """ + + def fwd( + self, + weights: npt.NDArray[np.csingle], + size: int, + degree_max: int, + **kwargs, + ) -> npt.NDArray[np.csingle]: + basis = tike.zernike.basis( + size=size, + degree_min=0, + degree_max=degree_max, + xp=self.xp, + ) + # (..., W) @ (W, size, size) + return np.einsum("...c,cwh->...wh", weights, basis) + + def adj( + self, + images: npt.NDArray[np.csingle], + size: int, + degree_max: int, + **kwargs, + ) -> npt.NDArray[np.csingle]: + basis = tike.zernike.basis( + size=size, + degree_min=0, + degree_max=degree_max, + xp=self.xp, + ) + # (..., size, size) @ (W, size, size) + return np.einsum("...wh,cwh->...c", images, basis) diff --git a/src/tike/zernike.py b/src/tike/zernike.py new file mode 100644 index 00000000..90a8bc51 --- /dev/null +++ b/src/tike/zernike.py @@ -0,0 +1,163 @@ +"""Provide functions to evaluate Zernike polynomials on a discrete grid. + + +References +---------- +@article{Niu_2022, +doi = {10.1088/2040-8986/ac9e08}, +url = {https://dx.doi.org/10.1088/2040-8986/ac9e08}, +year = {2022}, +month = {nov}, +publisher = {IOP Publishing}, +volume = {24}, +number = {12}, +pages = {123001}, +author = {Kuo Niu and Chao Tian}, +title = {Zernike polynomials and their applications}, +journal = {Journal of Optics}, +abstract = {The Zernike polynomials are a complete set of continuous functions orthogonal over a unit circle. Since first developed by Zernike in 1934, they have been in widespread use in many fields ranging from optics, vision sciences, to image processing. However, due to the lack of a unified definition, many confusing indices have been used in the past decades and mathematical properties are scattered in the literature. This review provides a comprehensive account of Zernike circle polynomials and their noncircular derivatives, including history, definitions, mathematical properties, roles in wavefront fitting, relationships with optical aberrations, and connections with other polynomials. We also survey state-of-the-art applications of Zernike polynomials in a range of fields, including the diffraction theory of aberrations, optical design, optical testing, ophthalmic optics, adaptive optics, and image analysis. Owing to their elegant and rigorous mathematical properties, the range of scientific and industrial applications of Zernike polynomials is likely to expand. This review is expected to clear up the confusion of different indices, provide a self-contained reference guide for beginners as well as specialists, and facilitate further developments and applications of the Zernike polynomials.} +} +""" +import typing + +import numpy as np + + +def Z(m: int, n: int, radius: np.array, angle: np.array) -> np.array: + """Return values of Zernike[m,n] polynomial at given radii, angles. + + Values outside valid radius will be zero. + + Parameters + ---------- + m : int + Angular frequency of the polynomial. + n : int + Radial degree of the polynomial. + radius: float [0, 1] + The radial coordinates of the evaluated polynomial. + angle: float radians + The angular coordinates of the evaluated polynomial. + + """ + if n < 0: + raise ValueError("Radial degree must be non-negative.") + _m_ = np.abs(m) + if _m_ > n: + raise ValueError("Angular frequency must be less than radial degree.") + if m < 0: + return np.sqrt(2 * (n + 1)) * R(_m_, n, radius) * np.sin(m * angle) + if m == 0: + return np.sqrt(n + 1) * R(_m_, n, radius) + if m > 0: + return np.sqrt(2 * (n + 1)) * R(_m_, n, radius) * np.cos(m * angle) + + +def N(m: int, n: int) -> float: + """Zernike normalization factor.""" + if m == 0: + return np.sqrt(n + 1) + return np.sqrt(2 * (n + 1)) + + +def R(m: int, n: int, radius: np.array) -> np.array: + """Return the values of the Zernike radial polynomial at the given radii. + + This polynomial matches Figure 3 in Lakshminarayanan & Fleck (2011). + + Parameters + ---------- + m : int + Angular frequency of the polynomial. + n : int + Radial degree of the polynomial. + radius: float [0, 1] + The radial coordinates of the evaluated polynomial. + + References + ---------- + Vasudevan Lakshminarayanan & Andre Fleck (2011): Zernike polynomials: a + guide, Journal of ModernOptics, 58:7, 545-561 + http://dx.doi.org/10.1080/09500340.2011.554896 + + """ + # Initialize with k=0 case because this term will always be included + sign = -1 + result = 0 * radius + for k in range(0, (n - m) // 2 + 1): + sign = -sign + b0 = _bino(n - k, k) + b1 = _bino(n - 2 * k, (n - m) // 2 - k) + result += sign * b0 * b1 * radius ** (n - 2 * k) + # Smooth the sharp edges of the polynomial with a supergaussian window + # Higher smoothing degree makes the window edge sharper + smoothing_degree = 32 + result *= np.exp(-(radius ** (2 * smoothing_degree))) + return result + + +def _bino(a: int, b: int) -> int: + """Return the approximate binomial coeffient (a b).""" + result = 1 + for i in range(1, b + 1): + result *= (a - i + 1) / i + return result + + +def _bino1(a: int, b: int, xp=np) -> int: + """Return the approximate binomial coeffient (a b).""" + result = np.arange(a, a - b, -1) / np.arange(1, b + 1) + return np.prod(result) + + +def basis(size: int, degree_min: int, degree_max: int, xp=np) -> np.array: + """Return all circular Zernike basis up to given radial degree. + + Parameters + ---------- + size : int + The width of the discrete basis in pixel. + degree : int + The maximum radial degree of the polynomial (not inclusive). The number + of degrees included in the set of bases. + + Returns + ------- + basis : (degree, size, size) + The Zernike bases. + + """ + endpoint = 1.0 - 1.0 / (2 * size) + x = xp.linspace(-endpoint, endpoint, size, endpoint=True) + coords = xp.stack(xp.meshgrid(x, x, indexing="ij"), axis=0) + radius = xp.linalg.norm(coords, axis=0) + theta = xp.arctan2(coords[0], coords[1]) + + basis = [] + for m, n in valid_indices(degree_min, degree_max): + basis.append(Z(m, n, radius, theta)) + + basis = xp.stack(basis, axis=0) + return basis + + +def valid_indices( + degree_min: int, + degree_max: int, +) -> typing.Generator[typing.Tuple[int, int], None, None]: + """Enumerate all valid zernike indices (m,n) up to the given degree.""" + for n in range(degree_min, degree_max): + for m in range(-n, n + 1): + if (n - abs(m)) % 2 == 0: + yield m, n + + +def num_basis_less_than_degree(degree_max: int) -> int: + """Return number of zernike basis in degrees < degree_max (strictly).""" + # And odd times even number is even and always cleanly divisible by 2. + return (degree_max) * (degree_max + 1) // 2 + + +def degree_max_from_num_basis(num_basis: int) -> int: + """Return the max degree (non-inclusive) required to have at least some number of basis.""" + return int(np.ceil(0.5 * (-1 + np.sqrt(8 * num_basis)))) diff --git a/tests/data/probe.png b/tests/data/probe.png new file mode 100644 index 00000000..cce3eb97 Binary files /dev/null and b/tests/data/probe.png differ diff --git a/tests/operators/test_zernike.py b/tests/operators/test_zernike.py new file mode 100644 index 00000000..05a4639b --- /dev/null +++ b/tests/operators/test_zernike.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import unittest + +import numpy as np +from tike.operators import Zernike +import tike.precision +import tike.linalg + +from .util import random_complex, OperatorTests + +__author__ = "Daniel Ching" +__copyright__ = "Copyright (c) 2024, UChicago Argonne, LLC." +__docformat__ = "restructuredtext en" + + +class TestZernike(unittest.TestCase, OperatorTests): + """Test the Zernike operator.""" + + def setUp(self): + self.nscan = 21 + self.nprobe = 11 + self.nbasis = 128 + self.size = 16 + self.degree_max = tike.zernike.degree_max_from_num_basis(self.nbasis) + self.nbasis = tike.zernike.num_basis_less_than_degree(self.degree_max) + + basis = tike.zernike.basis(size=3, degree_min=0, degree_max=self.degree_max) + assert basis.shape == (self.nbasis, 3, 3), basis.size + + self.operator = Zernike() + self.operator.__enter__() + self.xp = self.operator.xp + + np.random.seed(0) + images = random_complex(self.nscan, self.nprobe, self.size, self.size) + weights = random_complex(self.nscan, self.nprobe, self.nbasis) + + self.m = self.xp.asarray(weights) + self.m_name = "weights" + self.kwargs = { + "size": self.size, + "degree_max": self.degree_max, + } + + self.d = self.xp.asarray(images) + self.d_name = "images" + + print(self.operator) + + @unittest.skip("FIXME: This operator is not scaled.") + def test_scaled(self): + pass diff --git a/tests/test_zernike.py b/tests/test_zernike.py new file mode 100644 index 00000000..a5d8b025 --- /dev/null +++ b/tests/test_zernike.py @@ -0,0 +1,156 @@ +import unittest +import os + +import tike.zernike +import tike.linalg +import tike.view +import matplotlib.pyplot as plt +import numpy as np + +testdir = os.path.dirname(__file__) + + +class TestZernike(unittest.TestCase): + def test_zernike_preview(self): + fname = os.path.join(testdir, "result", "zernike") + os.makedirs(fname, exist_ok=True) + for i, Z in enumerate( + tike.zernike.basis( + 256, + degree_min=0, + degree_max=6, + ) + ): + plt.figure() + tike.view.plot_complex(Z, rmin=-1, rmax=1) + plt.savefig(os.path.join(fname, f"zernike-{i:02d}.png")) + plt.close() + + def _radial_template(self, m=0): + fname = os.path.join(testdir, "result", "zernike") + os.makedirs(fname, exist_ok=True) + + radius = np.linspace(0, 1, 200) + + plt.figure() + labels = [] + for n in range(0, 9): + if (n + m) % 2 == 0: + v = tike.zernike.R(m, n, radius) + plt.plot( + radius, + v, + ) + labels.append(n) + plt.legend(labels) + plt.ylim([-1, 1]) + plt.savefig(os.path.join(fname, f"radial-function-{m}.png")) + plt.close() + + def test_radial(self): + self._radial_template(0) + + def test_radial_1(self): + self._radial_template(1) + + def test_radial_2(self): + self._radial_template(2) + + def test_transform(self): + fname = os.path.join(testdir, "result", "zernike") + os.makedirs(fname, exist_ok=True) + + # import libimage + # f0 = libimage.load("cryptomeria", 256) + f0 = plt.imread(os.path.join(testdir, "data", "probe.png")) + size = f0.shape[-1] + plt.imsave(os.path.join(fname, "basis-0.png"), f0, vmin=0, vmax=1.0) + + f0 = f0.reshape(size * size, 1) + + _basis = [] + + for d in range(0, 64): + _basis.append( + tike.zernike.basis( + size, + degree_min=d, + degree_max=d + 1, + ) + ) + + basis = np.concatenate(_basis, axis=0) + print(f"degree {d} - {len(basis)}") + basis = np.moveaxis(basis, 0, -1) + basis = basis.reshape(size * size, -1) + # weight only pixels inside basis + # w = (basis[..., 0] > 0).astype("float32") + + # x = tike.linalg.lstsq(basis, f0, )#weights=w) + x, _, _, _ = np.linalg.lstsq(basis, f0, rcond=1e-9) + + f1 = basis @ x + f1 = f1.reshape(size, size) + plt.imsave(os.path.join(fname, f"basis-{d:02d}.png"), f1, vmin=0, vmax=1.0) + + plt.figure() + plt.title(f"basis weights for {d} degree polynomials") + plt.bar(list(range(x.size)), x.flatten()) + plt.savefig(os.path.join(fname, f"basis-w-{d:02d}.png")) + plt.close() + + def test_transform1(self): + fname = os.path.join(testdir, "result", "zernike") + os.makedirs(fname, exist_ok=True) + + # import libimage + # f0 = libimage.load("cryptomeria", 256) + # print(f0.max()) + f0 = plt.imread(os.path.join(testdir, "data", "probe.png")) + size = f0.shape[-1] + plt.imsave(os.path.join(fname, "basis1-0.png"), f0, vmin=0, vmax=1.0) + + f0 = f0.reshape(1, size * size) + + _basis = [] + + print(f"This image has {size * size} pixels.") + + for d in range(0, 64): + more_basis = tike.zernike.basis( + size, + degree_min=d, + degree_max=d + 1, + ) + _basis.append( + # Normalize the basis (size-dependent) + more_basis + / tike.linalg.norm( + more_basis, + axis=(-2, -1), + keepdims=True, + ) + ) + + basis = np.concatenate(_basis, axis=0) + + print(f"Adding degree {d} - {len(basis)} total basis functions") + basis = basis.reshape(-1, size * size) + + # print(basis.shape, f0.shape) + # y = tike.linalg.inner(basis[0], basis[-1], axis=-1, keepdims=True) + # print(f"orthogonality {y}") + x = tike.linalg.inner(f0, basis, axis=-1, keepdims=True) + # print(x.shape) + f1 = x.T @ basis + f1 = f1.reshape(size, size) + print(f1.max()) + plt.imsave(os.path.join(fname, f"basis1-{d:02d}.png"), f1, vmin=0, vmax=1.0) + + plt.figure() + plt.title(f"basis weights for {d} degree polynomials") + plt.bar(list(range(x.size)), x.flatten()) + plt.savefig(os.path.join(fname, f"basis1-w-{d:02d}.png")) + plt.close() + + # print(f"{x.flatten()[:16]}")