Skip to content

DEV: Prototype some zernike polynomials #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/tike/operators/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from .convolution import *
from .flow import *
from .lamino import *
from .operator import *
from .objective import *
from .operator import *
from .pad import *
from .patch import *
from .propagation import *
from .ptycho import *
from .rotate import *
from .shift import *
from .zernike import *
61 changes: 61 additions & 0 deletions src/tike/operators/cupy/zernike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Defines an inverse-Zernike transform operator."""

__author__ = "Daniel Ching"
__copyright__ = "Copyright (c) 2024, UChicago Argonne, LLC."

import numpy.typing as npt
import numpy as np
import tike.zernike
import tike.linalg

from .operator import Operator


class Zernike(Operator):
"""Reconstruct an image from coefficients and zernike basis using CuPy.

Take an (..., W) array of zernike coefficients and reconstruct an image
from them.

Parameters
----------
size : int
The pixel width and height of the reconstruction.
weights: (..., W) complex64
The zernike coefficients

.. versionadded:: 0.25.5

"""

def fwd(
self,
weights: npt.NDArray[np.csingle],
size: int,
degree_max: int,
**kwargs,
) -> npt.NDArray[np.csingle]:
basis = tike.zernike.basis(
size=size,
degree_min=0,
degree_max=degree_max,
xp=self.xp,
)
# (..., W) @ (W, size, size)
return np.einsum("...c,cwh->...wh", weights, basis)

def adj(
self,
images: npt.NDArray[np.csingle],
size: int,
degree_max: int,
**kwargs,
) -> npt.NDArray[np.csingle]:
basis = tike.zernike.basis(
size=size,
degree_min=0,
degree_max=degree_max,
xp=self.xp,
)
# (..., size, size) @ (W, size, size)
return np.einsum("...wh,cwh->...c", images, basis)
163 changes: 163 additions & 0 deletions src/tike/zernike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Provide functions to evaluate Zernike polynomials on a discrete grid.


References
----------
@article{Niu_2022,
doi = {10.1088/2040-8986/ac9e08},
url = {https://dx.doi.org/10.1088/2040-8986/ac9e08},
year = {2022},
month = {nov},
publisher = {IOP Publishing},
volume = {24},
number = {12},
pages = {123001},
author = {Kuo Niu and Chao Tian},
title = {Zernike polynomials and their applications},
journal = {Journal of Optics},
abstract = {The Zernike polynomials are a complete set of continuous functions orthogonal over a unit circle. Since first developed by Zernike in 1934, they have been in widespread use in many fields ranging from optics, vision sciences, to image processing. However, due to the lack of a unified definition, many confusing indices have been used in the past decades and mathematical properties are scattered in the literature. This review provides a comprehensive account of Zernike circle polynomials and their noncircular derivatives, including history, definitions, mathematical properties, roles in wavefront fitting, relationships with optical aberrations, and connections with other polynomials. We also survey state-of-the-art applications of Zernike polynomials in a range of fields, including the diffraction theory of aberrations, optical design, optical testing, ophthalmic optics, adaptive optics, and image analysis. Owing to their elegant and rigorous mathematical properties, the range of scientific and industrial applications of Zernike polynomials is likely to expand. This review is expected to clear up the confusion of different indices, provide a self-contained reference guide for beginners as well as specialists, and facilitate further developments and applications of the Zernike polynomials.}
}
"""
import typing

import numpy as np


def Z(m: int, n: int, radius: np.array, angle: np.array) -> np.array:
"""Return values of Zernike[m,n] polynomial at given radii, angles.

Values outside valid radius will be zero.

Parameters
----------
m : int
Angular frequency of the polynomial.
n : int
Radial degree of the polynomial.
radius: float [0, 1]
The radial coordinates of the evaluated polynomial.
angle: float radians
The angular coordinates of the evaluated polynomial.

"""
if n < 0:
raise ValueError("Radial degree must be non-negative.")
_m_ = np.abs(m)
if _m_ > n:
raise ValueError("Angular frequency must be less than radial degree.")
if m < 0:
return np.sqrt(2 * (n + 1)) * R(_m_, n, radius) * np.sin(m * angle)
if m == 0:
return np.sqrt(n + 1) * R(_m_, n, radius)
if m > 0:
return np.sqrt(2 * (n + 1)) * R(_m_, n, radius) * np.cos(m * angle)


def N(m: int, n: int) -> float:
"""Zernike normalization factor."""
if m == 0:
return np.sqrt(n + 1)
return np.sqrt(2 * (n + 1))


def R(m: int, n: int, radius: np.array) -> np.array:
"""Return the values of the Zernike radial polynomial at the given radii.

This polynomial matches Figure 3 in Lakshminarayanan & Fleck (2011).

Parameters
----------
m : int
Angular frequency of the polynomial.
n : int
Radial degree of the polynomial.
radius: float [0, 1]
The radial coordinates of the evaluated polynomial.

References
----------
Vasudevan Lakshminarayanan & Andre Fleck (2011): Zernike polynomials: a
guide, Journal of ModernOptics, 58:7, 545-561
http://dx.doi.org/10.1080/09500340.2011.554896

"""
# Initialize with k=0 case because this term will always be included
sign = -1
result = 0 * radius
for k in range(0, (n - m) // 2 + 1):
sign = -sign
b0 = _bino(n - k, k)
b1 = _bino(n - 2 * k, (n - m) // 2 - k)
result += sign * b0 * b1 * radius ** (n - 2 * k)
# Smooth the sharp edges of the polynomial with a supergaussian window
# Higher smoothing degree makes the window edge sharper
smoothing_degree = 32
result *= np.exp(-(radius ** (2 * smoothing_degree)))
return result


def _bino(a: int, b: int) -> int:
"""Return the approximate binomial coeffient (a b)."""
result = 1
for i in range(1, b + 1):
result *= (a - i + 1) / i
return result


def _bino1(a: int, b: int, xp=np) -> int:
"""Return the approximate binomial coeffient (a b)."""
result = np.arange(a, a - b, -1) / np.arange(1, b + 1)
return np.prod(result)


def basis(size: int, degree_min: int, degree_max: int, xp=np) -> np.array:
"""Return all circular Zernike basis up to given radial degree.

Parameters
----------
size : int
The width of the discrete basis in pixel.
degree : int
The maximum radial degree of the polynomial (not inclusive). The number
of degrees included in the set of bases.

Returns
-------
basis : (degree, size, size)
The Zernike bases.

"""
endpoint = 1.0 - 1.0 / (2 * size)
x = xp.linspace(-endpoint, endpoint, size, endpoint=True)
coords = xp.stack(xp.meshgrid(x, x, indexing="ij"), axis=0)
radius = xp.linalg.norm(coords, axis=0)
theta = xp.arctan2(coords[0], coords[1])

basis = []
for m, n in valid_indices(degree_min, degree_max):
basis.append(Z(m, n, radius, theta))

basis = xp.stack(basis, axis=0)
return basis


def valid_indices(
degree_min: int,
degree_max: int,
) -> typing.Generator[typing.Tuple[int, int], None, None]:
"""Enumerate all valid zernike indices (m,n) up to the given degree."""
for n in range(degree_min, degree_max):
for m in range(-n, n + 1):
if (n - abs(m)) % 2 == 0:
yield m, n


def num_basis_less_than_degree(degree_max: int) -> int:
"""Return number of zernike basis in degrees < degree_max (strictly)."""
# And odd times even number is even and always cleanly divisible by 2.
return (degree_max) * (degree_max + 1) // 2


def degree_max_from_num_basis(num_basis: int) -> int:
"""Return the max degree (non-inclusive) required to have at least some number of basis."""
return int(np.ceil(0.5 * (-1 + np.sqrt(8 * num_basis))))
Binary file added tests/data/probe.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
54 changes: 54 additions & 0 deletions tests/operators/test_zernike.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

import numpy as np
from tike.operators import Zernike
import tike.precision
import tike.linalg

from .util import random_complex, OperatorTests

__author__ = "Daniel Ching"
__copyright__ = "Copyright (c) 2024, UChicago Argonne, LLC."
__docformat__ = "restructuredtext en"


class TestZernike(unittest.TestCase, OperatorTests):
"""Test the Zernike operator."""

def setUp(self):
self.nscan = 21
self.nprobe = 11
self.nbasis = 128
self.size = 16
self.degree_max = tike.zernike.degree_max_from_num_basis(self.nbasis)
self.nbasis = tike.zernike.num_basis_less_than_degree(self.degree_max)

basis = tike.zernike.basis(size=3, degree_min=0, degree_max=self.degree_max)
assert basis.shape == (self.nbasis, 3, 3), basis.size

self.operator = Zernike()
self.operator.__enter__()
self.xp = self.operator.xp

np.random.seed(0)
images = random_complex(self.nscan, self.nprobe, self.size, self.size)
weights = random_complex(self.nscan, self.nprobe, self.nbasis)

self.m = self.xp.asarray(weights)
self.m_name = "weights"
self.kwargs = {
"size": self.size,
"degree_max": self.degree_max,
}

self.d = self.xp.asarray(images)
self.d_name = "images"

print(self.operator)

@unittest.skip("FIXME: This operator is not scaled.")
def test_scaled(self):
pass
Loading