diff --git a/docs/conf.py b/docs/conf.py index 265dcae4..538deb47 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ sphinx.util.osutil.ENOENT = errno.ENOENT project = "MPoL" -copyright = "2019-22, Ian Czekala" +copyright = "2019-24, Ian Czekala" author = "Ian Czekala" # The full version, including alpha/beta/rc tags diff --git a/docs/index.md b/docs/index.md index 2fde0caa..89ff8d83 100644 --- a/docs/index.md +++ b/docs/index.md @@ -23,7 +23,7 @@ rml_intro.md installation.md units-and-conventions.md developer-documentation.md -api.rst +apidocs/index ``` ```{toctree} diff --git a/pyproject.toml b/pyproject.toml index 9a1b6097..c9ba4508 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dev = [ "mypy", "frank>=1.2.1", "sphinx>=5.3.0", + "sphinx-autodoc2", "jupytext", "ipython!=8.7.0", # broken version for syntax higlight https://github.com/spatialaudio/nbsphinx/issues/687 "nbsphinx", diff --git a/src/mpol/losses.py b/src/mpol/losses.py index e1193636..3e5bce85 100644 --- a/src/mpol/losses.py +++ b/src/mpol/losses.py @@ -356,13 +356,9 @@ def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor: total variation loss """ - # diff the cube in ll and remove the last row - diff_ll = sky_cube[:, 0:-1, 1:] - sky_cube[:, 0:-1, 0:-1] - - # diff the cube in mm and remove the last column - diff_mm = sky_cube[:, 1:, 0:-1] - sky_cube[:, 0:-1, 0:-1] - - loss = torch.sum(torch.sqrt(diff_ll**2 + diff_mm**2 + epsilon)) + diff_ll = torch.diff(sky_cube[:, 0:-1, :], dim=2) + diff_mm = torch.diff(sky_cube[:, :, 0:-1], dim=1) + loss = torch.sqrt(diff_ll**2 + diff_mm**2 + epsilon).sum() return loss @@ -427,11 +423,8 @@ def TSV(sky_cube: torch.Tensor) -> torch.Tensor: """ - # diff the cube in ll and remove the last row - diff_ll = sky_cube[:, 0:-1, 1:] - sky_cube[:, 0:-1, 0:-1] - - # diff the cube in mm and remove the last column - diff_mm = sky_cube[:, 1:, 0:-1] - sky_cube[:, 0:-1, 0:-1] + diff_ll = torch.diff(sky_cube[:, 0:-1, :], dim=2) + diff_mm = torch.diff(sky_cube[:, :, 0:-1], dim=1) loss = torch.sum(diff_ll**2 + diff_mm**2)