Skip to content

Commit

Permalink
Update docstrings for sphinx rtd
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Sep 1, 2024
1 parent 8803426 commit aaa7cfa
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 25 deletions.
2 changes: 2 additions & 0 deletions dflat/GDSII/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
assemble_ellipse_gds,
assemble_fin_gds,
)

__all__ = ["assemble_cylinder_gds", "assemble_ellipse_gds", "assemble_fin_gds"]
2 changes: 2 additions & 0 deletions dflat/propagation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .propagators import ASMPropagation, FresnelPropagation, PointSpreadFunction

__all__ = ["ASMPropagation", "FresnelPropagation", "PointSpreadFunction"]
134 changes: 122 additions & 12 deletions dflat/propagation/propagators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,34 @@


class BaseFrequencySpace(nn.Module):
"""Base class for frequency space propagation methods.
This class provides common initialization and validation for frequency space
propagation methods such as Fresnel and Angular Spectrum Method (ASM).
Attributes:
in_size (np.ndarray): Input field size [height, width].
in_dx_m (np.ndarray): Input pixel size in meters [dy, dx].
out_distance_m (float): Propagation distance in meters.
out_size (np.ndarray): Output field size [height, width].
out_dx_m (np.ndarray): Output pixel size in meters [dy, dx].
wavelength_set_m (np.ndarray): Set of wavelengths in meters.
out_resample_dx_m (np.ndarray): Output resampling pixel size in meters [dy, dx].
manual_upsample_factor (float): Manual upsampling factor for input field.
radial_symmetry (bool): If True, assume radial symmetry in the input field.
Args:
in_size (List[int]): Input field size [height, width].
in_dx_m (List[float]): Input pixel size in meters [dy, dx].
out_distance_m (float): Propagation distance in meters.
out_size (List[int]): Output field size [height, width].
out_dx_m (List[float]): Output pixel size in meters [dy, dx].
wavelength_set_m (List[float]): Set of wavelengths in meters.
out_resample_dx_m (List[float], optional): Output resampling pixel size in meters [dy, dx].
manual_upsample_factor (float, optional): Manual upsampling factor for input field. Defaults to 1.
radial_symmetry (bool, optional): If True, assume radial symmetry in the input field. Defaults to False.
"""

def __init__(
self,
in_size,
Expand Down Expand Up @@ -100,6 +128,18 @@ def _unit_conversion(self):


class FresnelPropagation(BaseFrequencySpace):
"""Fresnel propagation method for optical field propagation.
This class implements the Fresnel propagation method for propagating optical fields
from an input plane to an output plane.
Note:
The Fresnel propagation method is suitable for propagation distances where the
paraxial approximation holds. To generate fields on a specified, uniform output grid,
the input grid is sometimes upsampled and/or zero-padded. The memory/output has a non-linear relationship
with requested output grid size.
"""

def __init__(
self,
in_size,
Expand All @@ -115,6 +155,20 @@ def __init__(
*args,
**kwargs,
):
"""Initializes the propagation class.
Args:
in_size (list): input grid shape as [H, W].
in_dx_m (list): input grid discretization (in meters) as [dy, dx]
out_distance_m (float): output plane distance
out_size (list): output grid shape as [H, W]
out_dx_m (list): output grid discretization (in meters) as [dy, dx]
wavelength_set_m (list): List of wavelengths (in meters) corresponding to the wavelength dimension stack in forward.
out_resample_dx_m (list, optional): List of output grid discretizations to resample by area sum (area averaging for phase). This can be used to compute at a sub-pixel scale then return the integrated field on each pixel. Defaults to None.
manual_upsample_factor (int, optional): Force factor to manually upsample (nearest neighbor) the input lens. This can improve fourier space calculation accuracy. Defaults to 1.
radial_symmetry (bool, optional): Flag to use radial symmetry during calculations. Note that we expect radial field profiles to be passed in if True. Defaults to False.
verbose (bool, optional): If True, prints information about the actual grid sizes etc that will be used in the back-end calculation. This may often be larger than user defined sizes due to fourier space rules. Defaults to False.
"""
super().__init__(
in_size,
in_dx_m,
Expand Down Expand Up @@ -254,8 +308,11 @@ def forward(self, amplitude, phase, **kwargs):
"""Propagates a complex field from an input plane to a planar output plane a distance out_distance_m.
Args:
amplitude (float): Field amplitude of shape (Batch, Lambda, *in_size) or (Batch, Lambda, 1, in_size_r).
phase (float): Field phase of shape (Batch, Lambda, *in_size) or (Batch, Lambda, 1, in_size_r).
amplitude (float): Field amplitude of shape (Batch, Lambda, H W) or (Batch, Lambda, 1, R).
phase (float): Field phase of shape (Batch, Lambda, H W) or (Batch, Lambda, 1, R).
Returns:
list: amplitude and phase with the same shape.
"""
if "wavelength_set_m" in kwargs:
raise ValueError(
Expand Down Expand Up @@ -402,6 +459,18 @@ def _resample_field(self, amplitude, phase):


class ASMPropagation(BaseFrequencySpace):
"""Angular Spectrum Method (ASM) for optical field propagation.
This class implements the Angular Spectrum Method for propagating optical fields
from an input plane to an output plane.
Note:
The ASM is suitable for a wide range of propagation distances and can handle
non-paraxial cases more accurately than the Fresnel method. The output grid for ASM methods will always be forced to match the input grid.
Consequently, in the back-end, we upsample and pad the input profile to match your target output grid. This affects memory and computation costs
in a sometimes non-intuitive way for users.
"""

def __init__(
self,
in_size,
Expand All @@ -416,6 +485,22 @@ def __init__(
FFTPadFactor=1.0,
verbose=False,
):
"""Initializes the propagation class.
Args:
in_size (list): input grid shape as [H, W].
in_dx_m (list): input grid discretization (in meters) as [dy, dx]
out_distance_m (float): output plane distance
out_size (list): output grid shape as [H, W]
out_dx_m (list): output grid discretization (in meters) as [dy, dx]
wavelength_set_m (list): List of wavelengths (in meters) corresponding to the wavelength dimension stack in forward.
out_resample_dx_m (list, optional): List of output grid discretizations to resample by area sum (area averaging for phase). This can be used to compute at a sub-pixel scale then return the integrated field on each pixel. Defaults to None.
manual_upsample_factor (int, optional): Force factor to manually upsample (nearest neighbor) the input lens. This can improve fourier space calculation accuracy. Defaults to 1.
FFTPadFactor (float, optional): Force a larger zero-pad factor during FFT used for frequency-space convolution. This is for developer debug/testing.
radial_symmetry (bool, optional): Flag to use radial symmetry during calculations. Note that we expect radial field profiles to be passed in if True. Defaults to False.
verbose (bool, optional): If True, prints information about the actual grid sizes etc that will be used in the back-end calculation. This may often be larger than user defined sizes due to fourier space rules. Defaults to False.
"""

super().__init__(
in_size,
in_dx_m,
Expand Down Expand Up @@ -486,8 +571,11 @@ def forward(self, amplitude, phase, **kwargs):
"""Propagates a complex field from an input plane to a planar output plane a distance out_distance_m.
Args:
amplitude (float): Field amplitude of shape (Batch, Lambda, *in_size) or (Batch, Lambda, 1, in_size_r).
phase (float): Field phase of shape (Batch, Lambda, *in_size) or (Batch, Lambda, 1, in_size_r).
amplitude (float): Field amplitude of shape (Batch, Lambda, H W) or (Batch, Lambda, 1, R).
phase (float): Field phase of shape (Batch, Lambda, H W) or (Batch, Lambda, 1, R).
Returns:
list: amplitude and phase on the output grid. Shape of tensors same as passed in.
"""
if "wavelength_set_m" in kwargs:
raise ValueError(
Expand Down Expand Up @@ -526,7 +614,7 @@ def _forward(self, amplitude, phase, **kwargs):
amplitude, phase = self._regularize_field(amplitude, phase)

# propagate by the asm method
amplitude, phase = self.ASM_transform(amplitude, phase)
amplitude, phase = self._ASM_transform(amplitude, phase)

# Transform field back to the specified output grid and convert to 2D
amplitude, phase = self._resample_field(amplitude, phase)
Expand All @@ -540,7 +628,7 @@ def _forward(self, amplitude, phase, **kwargs):

return amplitude, phase

def ASM_transform(self, amplitude, phase):
def _ASM_transform(self, amplitude, phase):
init_shape = amplitude.shape
dtype = amplitude.dtype
device = amplitude.device
Expand Down Expand Up @@ -641,6 +729,14 @@ def _regularize_field(self, amplitude, phase):


class PointSpreadFunction(nn.Module):
"""Calculates the Point Spread Function (PSF) for an optical system.
This class uses either the Angular Spectrum Method (ASM) or Fresnel propagation
to calculate the PSF of an optical system for given input fields and point source locations.
Note: Normalize_to_aperture in the forward argument enables re-normalization of the output PSF relative to the total energy incident at the input plane.
"""

def __init__(
self,
in_size,
Expand All @@ -655,6 +751,20 @@ def __init__(
diffraction_engine="ASM",
verbose=False,
):
"""Initializes the point-spread function class.
Args:
in_size (list): input grid shape as [H, W].
in_dx_m (list): input grid discretization (in meters) as [dy, dx]
out_distance_m (float): output plane distance
out_size (list): output grid shape as [H, W]
out_dx_m (list): output grid discretization (in meters) as [dy, dx]
wavelength_set_m (list): List of wavelengths (in meters) corresponding to the wavelength dimension stack in forward.
out_resample_dx_m (list, optional): List of output grid discretizations to resample by area sum (area averaging for phase). This can be used to compute at a sub-pixel scale then return the integrated field on each pixel. Defaults to None.
manual_upsample_factor (int, optional): Force factor to manually upsample (nearest neighbor) the input lens. This can improve fourier space calculation accuracy. Defaults to 1.
radial_symmetry (bool, optional): Flag to use radial symmetry during calculations. Note that we expect radial field profiles to be passed in if True. Defaults to False.
verbose (bool, optional): If True, prints information about the actual grid sizes etc that will be used in the back-end calculation. This may often be larger than user defined sizes due to fourier space rules. Defaults to False.
"""
super().__init__()

assert isinstance(
Expand Down Expand Up @@ -704,19 +814,19 @@ def forward(
normalize_to_aperture=True,
**kwargs,
):
"""_summary_
"""Computes the pont-spread function for teh amplitude and phase profile given a list of point-source locations.
Args:
amplitude (tensor): Lens amplitude of shape [... L H W], where L may equal 1 for broadcasting.
phase (tensor): Lens phase of shape [... L H W], where L may equal 1 for broadcasting.
amplitude (tensor): Lens amplitude of shape [Batch L H W], where L may equal 1 for broadcasting.
phase (tensor): Lens phase of shape [Batch L H W], where L may equal 1 for broadcasting.
ps_locs_m (tensor): Array point-source locations of shape [N x 3] where each column corresponds to Y, X, Depth
aperture (Tensor, optional): Field aperture applied on the lens the same rank as amplitude
and with the same H W dimensions. Defaults to None.
normalize_to_aperture (bool, optional): If true the energy in the PSF will be normalized to the total energy
incident on the optic/aperture. Defaults to True.
Returns:
List: Returns point-spread function intensity and phase of shape [B P Z L H W].
List: Returns point-spread function intensity and phase of shape [Batch Num_point_sources Lambda H W].
"""
if "wavelength_set_m" in kwargs:
raise ValueError(
Expand Down Expand Up @@ -787,7 +897,7 @@ def forward(

amplitude = amplitude**2
normalization = (
np.prod(self.out_resample_dx) / self.aperture_energy(aperture)
np.prod(self.out_resample_dx) / self._aperture_energy(aperture)
).to(dtype=amplitude.dtype, device=amplitude.device)
if normalize_to_aperture:
return amplitude * normalization, phase
Expand Down Expand Up @@ -837,7 +947,7 @@ def _incident_wavefront(self, amplitude, phase, ps_locs_m):
)

@torch.no_grad()
def aperture_energy(self, aperture):
def _aperture_energy(self, aperture):
in_size = self.in_size
sz = [
1,
Expand Down
8 changes: 8 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
API Reference
=============

.. toctree::
:maxdepth: 2

api/propagation
api/GDSII
15 changes: 15 additions & 0 deletions docs/api/GDSII.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
GDSII Helpers
=============

Module Overview
---------------
This module is used to generate GDS files for metasurface shape designs so they can be fabricated or sent to foundries. The core functionality is built on top of GDSPY.


Public Functions
----------------

.. automodule:: dflat.GDSII
:members:
:undoc-members:
:show-inheritance:
14 changes: 14 additions & 0 deletions docs/api/propagation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Field Propagation
=================

Module Overview
---------------
This module provides auto-differentiable field propagation suitable for propagating complex fields from one plane to another or computing the optical point-spread function.

Public Functions
----------------

.. automodule:: dflat.propagation
:members:
:undoc-members:
:show-inheritance:
42 changes: 36 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Configuration file for the Sphinx documentation builder.
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
import os
import sys

sys.path.insert(0, os.path.abspath("..")) # Adjust this path as necessary


project = "dflat"
copyright = "2024, Dean Hazineh"
Expand All @@ -14,17 +19,42 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = []

extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx_design",
]
templates_path = ["_templates"]
exclude_patterns = []

# Napoleon settings
napoleon_google_docstring = True
napoleon_numpy_docstring = False # Disable numpy-style docstrings
napoleon_include_init_with_doc = False
napoleon_include_private_with_doc = False
napoleon_include_special_with_doc = True
napoleon_use_admonition_for_examples = False
napoleon_use_admonition_for_notes = False
napoleon_use_admonition_for_references = False
napoleon_use_ivar = False
napoleon_use_param = True
napoleon_use_rtype = True
napoleon_preprocess_types = False
napoleon_type_aliases = None
napoleon_attr_annotations = True

autodoc_default_options = {
"members": True,
"member-order": "bysource",
"special-members": "__init__",
"undoc-members": True,
"exclude-members": "__weakref__",
"imported-members": True,
}


# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

extensions = [
"sphinx_design",
]
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]
19 changes: 13 additions & 6 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
dflat documentation
===================

Add your content using ``reStructuredText`` syntax. See the
`reStructuredText <https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html>`_
documentation for details.
DFlat End-to-End Optimization
=============================

Official ReadTheDocs Page

.. toctree::
:maxdepth: 2
:caption: Contents:

introduction
installation
api

Indices and tables
==================

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
2 changes: 2 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
installation
============
2 changes: 2 additions & 0 deletions docs/introduction.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
introduction
============
2 changes: 1 addition & 1 deletion docs/notes.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sphinx-build -b html source/ _build
sphinx-build -b html docs/ _build

0 comments on commit aaa7cfa

Please sign in to comment.