Skip to content

Commit

Permalink
addressed #153 internally, keeping result the same using kadri's pyto…
Browse files Browse the repository at this point in the history
…rch adaptation, but still distinct from eht

imaging.
  • Loading branch information
iancze committed Dec 28, 2023
1 parent c2e5370 commit fe9c786
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
sphinx.util.osutil.ENOENT = errno.ENOENT

project = "MPoL"
copyright = "2019-22, Ian Czekala"
copyright = "2019-24, Ian Czekala"
author = "Ian Czekala"

# The full version, including alpha/beta/rc tags
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ rml_intro.md
installation.md
units-and-conventions.md
developer-documentation.md
api.rst
apidocs/index
```

```{toctree}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dev = [
"mypy",
"frank>=1.2.1",
"sphinx>=5.3.0",
"sphinx-autodoc2",
"jupytext",
"ipython!=8.7.0", # broken version for syntax higlight https://github.com/spatialaudio/nbsphinx/issues/687
"nbsphinx",
Expand Down
17 changes: 5 additions & 12 deletions src/mpol/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,9 @@ def TV_image(sky_cube: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor:
total variation loss
"""

# diff the cube in ll and remove the last row
diff_ll = sky_cube[:, 0:-1, 1:] - sky_cube[:, 0:-1, 0:-1]

# diff the cube in mm and remove the last column
diff_mm = sky_cube[:, 1:, 0:-1] - sky_cube[:, 0:-1, 0:-1]

loss = torch.sum(torch.sqrt(diff_ll**2 + diff_mm**2 + epsilon))
diff_ll = torch.diff(sky_cube[:, 0:-1, :], dim=2)
diff_mm = torch.diff(sky_cube[:, :, 0:-1], dim=1)
loss = torch.sqrt(diff_ll**2 + diff_mm**2 + epsilon).sum()

return loss

Expand Down Expand Up @@ -427,11 +423,8 @@ def TSV(sky_cube: torch.Tensor) -> torch.Tensor:
"""

# diff the cube in ll and remove the last row
diff_ll = sky_cube[:, 0:-1, 1:] - sky_cube[:, 0:-1, 0:-1]

# diff the cube in mm and remove the last column
diff_mm = sky_cube[:, 1:, 0:-1] - sky_cube[:, 0:-1, 0:-1]
diff_ll = torch.diff(sky_cube[:, 0:-1, :], dim=2)
diff_mm = torch.diff(sky_cube[:, :, 0:-1], dim=1)

loss = torch.sum(diff_ll**2 + diff_mm**2)

Expand Down

0 comments on commit fe9c786

Please sign in to comment.