Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic Gradient Descent (SGD) example and functionality #252

Merged
merged 25 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3a87e7d
Merge pull request #243 from MPoL-dev/WIP-v0.3
iancze Dec 27, 2023
e2e3acb
anticipated changes.
iancze Jan 5, 2024
aaacddc
removed stale comment.
iancze Jan 5, 2024
860be57
changed basecube default value.
iancze Jan 6, 2024
5b48924
adding conftest changes
iancze Jan 6, 2024
296f0bf
removed basecube from test entirely.
iancze Jan 6, 2024
65173f7
added untested convolve method.
iancze Jan 7, 2024
8437f53
implemented UV Gaussian taper, image convolve, and tests.
iancze Jan 8, 2024
c27f230
added tests for non-circular PSF with rotation.
iancze Jan 8, 2024
8ccad6c
fixing type error.
iancze Jan 8, 2024
f396e15
Merge branch 'main' of https://github.com/MPoL-dev/MPoL into sgd
iancze Jan 10, 2024
d1ab601
removing autodoc2, unused.
iancze Jan 10, 2024
5554125
added ruff as test and dev dependency
iancze Jan 18, 2024
5e03444
initial ruff fixes.
iancze Jan 18, 2024
dc54cc8
ruff passing with basics.
iancze Jan 18, 2024
0e0a576
fix docstring @ kristin.hopley
iancze Jan 26, 2024
f9cc5cb
switched to float32 default throughout codebase.
iancze Feb 29, 2024
13a3a06
fixed ruff checks.
iancze Feb 29, 2024
3c5da5b
Merge pull request #255 from MPoL-dev/float32
iancze Feb 29, 2024
c1b70ef
updated docstring and changed default Omega to 0.
iancze Mar 1, 2024
0b9e416
working rotating Gaussian convolve.
iancze Mar 4, 2024
add01fe
added Gauss Fourier option for Base Cube.
iancze Mar 5, 2024
0b5db31
removed nchan from gauss
iancze Mar 5, 2024
1b3271a
added fixed resolution base gauss.
iancze Mar 5, 2024
4c7b453
consolidated Gauss routines.
iancze Mar 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,9 @@ jobs:
- name: Install test dependencies
run: |
pip install .[test]
- name: Lint with flake8
- name: Lint with ruff
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
ruff check .
- name: Check types with MyPy
run: |
mypy src/mpol --pretty
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,6 @@ plotsdir
runs

# hatch-generated version file
src/mpol/mpol_version.py
src/mpol/mpol_version.py

.ruff_cache
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Contributors
* Hannah Grzybowski, `@hgrzy`
* Mary Ogborn
* Tyler Quinn, `@trq5014`
* Kristin Hopley
20 changes: 10 additions & 10 deletions docs/_static/baselines/src/print_conversions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import csv

import numpy as np
from mpol.constants import c_ms

import argparse

parser = argparse.ArgumentParser(
Expand All @@ -6,11 +11,6 @@
parser.add_argument("outfile", help="Destination to save CSV table.")
args = parser.parse_args()

import csv

import numpy as np

from mpol.constants import c_ms

header = ["baseline", "100 GHz (Band 3)", "230 GHz (Band 6)", "340 GHz (Band 7)"]

Expand All @@ -20,18 +20,18 @@

def format_baseline(baseline_m):
if baseline_m < 1e3:
return "{:.0f} m".format(baseline_m)
return f"{baseline_m:.0f} m"
elif baseline_m < 1e6:
return "{:.0f} km".format(baseline_m * 1e-3)
return f"{baseline_m * 1e-3:.0f} km"


def format_lambda(lam):
if lam < 1e3:
return "{:.0f}".format(lam) + " :math:`\lambda`"
return f"{lam:.0f}" + r" :math:`\lambda`"
elif lam < 1e6:
return "{:.0f}".format(lam * 1e-3) + " :math:`\mathrm{k}\lambda`"
return f"{lam * 1e-3:.0f}" + r" :math:`\mathrm{k}\lambda`"
else:
return "{:.0f}".format(lam * 1e-6) + " :math:`\mathrm{M}\lambda`"
return f"{lam * 1e-6:.0f}" + r" :math:`\mathrm{M}\lambda`"


data = []
Expand Down
14 changes: 7 additions & 7 deletions docs/_static/fftshift/src/plot.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import argparse

parser = argparse.ArgumentParser(description="Create the fftshift plot")
parser.add_argument("outfile", help="Destination to save plot.")
args = parser.parse_args()

import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from astropy.utils.data import download_file
from matplotlib import patches
from matplotlib.colors import LogNorm
from matplotlib.gridspec import GridSpec

from mpol import coordinates

import argparse

parser = argparse.ArgumentParser(description="Create the fftshift plot")
parser.add_argument("outfile", help="Destination to save plot.")
args = parser.parse_args()


fname = download_file(
"https://zenodo.org/record/4711811/files/logo_cont.fits",
cache=True,
Expand Down
6 changes: 5 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
# Changelog

## v0.3.0

- removed explicit type declarations in base MPoL modules. Previously, core representations were set to be in `float64` or `complex128`. Now, core MPoL representations (e.g., {class}`mpol.images.BaseCube`) will follow the [default tensor type](https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type.html), which is commonly `torch.float32`. If you want your model to run fully in `float32` or `complex64`, then be sure that your data is also in these formats, since otherwise PyTorch will promote downstream tensors as needed. Fully `float32` or `complex64` models should be able to run on Apple MPS [#254](https://github.com/MPoL-dev/MPoL/issues/254)
- added {meth}`mpol.utils.convolve_packed_cube` method to convolve a 3D packed image cube with a 2D Gaussian. You can specify major axis, minor axis, and rotation angle.
- added {meth}`mpol.utils.uv_gaussian_taper` to calculate a Gaussian tapering window in the visibility plane.
- added the `vis_ext_Mlam` instance attribute to {class}`mpol.coordinates.GridCoords` for convenience plotting of visibility grids with axes labels in units of M$\lambda$.
- Updated [MPoL-dev/examples](https://github.com/MPoL-dev/examples) with Stochastic Gradient Descent Example.
- Standardized nomenclature of {class}`mpol.coordinates.GridCoords` and {class}`mpol.fourier.FourierCube` to use `sky_cube` for a normal image and `ground_cube` for a normal visibility cube (rather than `sky_` for visibility quantities). Routines use `packed_cube` instead of `cube` internally to be clear when packed format is preferred.
- Modified {class}`mpol.coordinates.GridCoords` object to use cached properties [#187](https://github.com/MPoL-dev/MPoL/pull/187).
- Changed the base spatial frequency unit from k$\lambda$ to $\lambda$, addressing [#223](https://github.com/MPoL-dev/MPoL/issues/223). This will affect most users data-reading routines!
Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os

# -- Project information -----------------------------------------------------
from pkg_resources import DistributionNotFound, get_distribution
Expand Down Expand Up @@ -46,7 +45,7 @@
autodoc_mock_imports = ["torch", "torchvision"]
autodoc_member_order = "bysource"
# https://github.com/sphinx-doc/sphinx/issues/9709
# bug that if we set this here, we can't list individual members in the
# bug that if we set this here, we can't list individual members in the
# actual API doc
# autodoc_default_options = {"members": None}

Expand Down
21 changes: 18 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ dev = [
"mypy",
"frank>=1.2.1",
"sphinx>=7.2.0",
"sphinx-autodoc2",
"jupytext",
"ipython!=8.7.0", # broken version for syntax higlight https://github.com/spatialaudio/nbsphinx/issues/687
"nbsphinx",
Expand All @@ -51,7 +50,8 @@ dev = [
"asdf",
"pyro-ppl",
"arviz[all]",
"visread>=0.0.4"
"visread>=0.0.4",
"ruff"
]
test = [
"pytest",
Expand All @@ -62,6 +62,7 @@ test = [
"mypy",
"visread>=0.0.4",
"frank>=1.2.1",
"ruff"
]

[project.urls]
Expand Down Expand Up @@ -105,4 +106,18 @@ module = [
"MPoL.precomposed",
"MPoL.utils"
]
disallow_untyped_defs = true
disallow_untyped_defs = true

[tool.ruff]
target-version = "py310"
line-length = 88
# will enable after sorting module locations
# select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"]
lint.ignore = [
"E741", # Allow ambiguous variable names
"PLR0911", # Allow many return statements
"PLR0913", # Allow many arguments to functions
"PLR0915", # Allow many statements
"PLR2004", # Allow magic numbers in comparisons
]
exclude = []
1 change: 0 additions & 1 deletion src/mpol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from mpol.mpol_version import __version__
zenodo_record = 10064221
25 changes: 14 additions & 11 deletions src/mpol/coordinates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from functools import cached_property

from functools import cached_property
from typing import Any

import numpy as np
Expand All @@ -10,7 +10,7 @@

import mpol.constants as const
from mpol.exceptions import CellSizeError
from mpol.utils import get_max_spatial_freq, get_maximum_cell_size
from mpol.utils import get_maximum_cell_size


class GridCoords:
Expand Down Expand Up @@ -79,6 +79,7 @@ class GridCoords:
:ivar vis_ext: length-4 list of (left, right, bottom, top) expected by routines
like ``matplotlib.pyplot.imshow`` in the ``extent`` parameter assuming
``origin='lower'``. Units of [:math:`\lambda`]
:ivar vis_ext_Mlam: like vis_ext, but in units of [:math:`\mathrm{M}\lambda`].
"""

def __init__(self, cell_size: float, npix: int):
Expand Down Expand Up @@ -205,16 +206,18 @@ def vis_ext(self) -> list[float]:
self.u_bin_max,
self.v_bin_min,
self.v_bin_max,
] # [kλ]
] # [λ]

@property
def vis_ext_Mlam(self) -> list[float]:
return [1e-6 * edge for edge in self.vis_ext]

# --------------------------------------------------------------------------
# Non-identical u & v properties
# --------------------------------------------------------------------------
@cached_property
def ground_u_centers_2D(self) -> npt.NDArray[np.floating[Any]]:
# only useful for plotting a sky_vis
# uu increasing, no fftshift
# tile replicates the 1D u_centers array to a 2D array the size of the full UV grid
# tile replicates the 1D u_centers array to a 2D array the size of the full
# UV grid
return np.tile(self.u_centers, (self.npix_u, 1))

@cached_property
Expand Down Expand Up @@ -304,10 +307,10 @@ def check_data_fit(

Parameters
----------
uu : :class:`torch.Tensor` of `torch.double`
uu : :class:`torch.Tensor`
u spatial frequency coordinates.
Units of [:math:`\lambda`]
vv : :class:`torch.Tensor` of `torch.double`
vv : :class:`torch.Tensor`
v spatial frequency coordinates.
Units of [:math:`\lambda`]

Expand Down Expand Up @@ -354,6 +357,6 @@ def __eq__(self, other: Any) -> bool:
# don't attempt to compare against different types
return NotImplemented

# GridCoords objects are considered equal if they have the same cell_size and npix, since
# all other attributes are derived from these two core properties.
# GridCoords objects are considered equal if they have the same cell_size and
# npix, since all other attributes are derived from these two core properties.
return bool(self.cell_size == other.cell_size and self.npix == other.npix)
8 changes: 3 additions & 5 deletions src/mpol/crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import logging
from collections import defaultdict
from typing import Any

import numpy as np
Expand All @@ -11,11 +10,9 @@
from numpy.typing import NDArray

from mpol.datasets import Dartboard, GriddedDataset
from mpol.precomposed import GriddedNet

# from mpol.training import TrainTest, train_to_dirty_image
# from mpol.training import TrainTest, train_to_dirty_image
from mpol.plot import split_diagnostics_fig
from mpol.utils import loglinspace


# class CrossValidate:
Expand Down Expand Up @@ -59,7 +56,8 @@
# Number of k-folds to use in cross-validation
# split_method : str, default='dartboard'
# Method to split full dataset into train/test subsets
# dartboard_q_edges, dartboard_phi_edges : list of float, default=None, unit=[klambda]
# dartboard_q_edges, dartboard_phi_edges : list of float, default=None,
# unit=[klambda]
# Radial and azimuthal bin edges of the cells used to split the dataset
# if `split_method`==`dartboard` (see `datasets.Dartboard`)
# split_diag_fig : bool, default=False
Expand Down
5 changes: 2 additions & 3 deletions src/mpol/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from numpy import floating, integer
from numpy.typing import ArrayLike, NDArray

from mpol.coordinates import GridCoords

from mpol import utils
from mpol.coordinates import GridCoords


class GriddedDataset(torch.nn.Module):
Expand All @@ -20,7 +19,7 @@ class GriddedDataset(torch.nn.Module):
If providing this, cannot provide ``cell_size`` or ``npix``.
vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128`
the gridded visibility data stored in a "packed" format (pre-shifted for fft)
weight_gridded : :class:`torch.Tensor` of :class:`torch.double`
weight_gridded : :class:`torch.Tensor`
the weights corresponding to the gridded visibility data,
also in a packed format
mask : :class:`torch.Tensor` of :class:`torch.bool`
Expand Down
Loading
Loading