Skip to content

Commit

Permalink
removed persistent buffer and fixed some typos.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Feb 11, 2023
1 parent b6d9b24 commit a32f97e
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ Cross-validation


Plotting
------
--------

.. automodule:: mpol.plot
.. automodule:: mpol.plot
2 changes: 1 addition & 1 deletion docs/ci-tutorials/initializedirtyimage.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ rml = precomposed.SimpleNet(coords=coords)
rml.state_dict() # the now uninitialized parameters of the model (the ones we started with)
```

Here you can clearly see the ``state_dict`` is in its original state, before the training loop changed the paramters through the optimization function. Loading our saved dirty image state into the model is as simple as
Here you can clearly see the ``state_dict`` is in its original state, before the training loop changed the parameters through the optimization function. Loading our saved dirty image state into the model is as simple as

```{code-cell}
rml.load_state_dict(torch.load("dirty_image_model.pt"))
Expand Down
8 changes: 4 additions & 4 deletions src/mpol/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ class FourierCube(nn.Module):
cell_size (float): the width of an image-plane pixel [arcseconds]
npix (int): the number of pixels per image side
coords (GridCoords): an object already instantiated from the GridCoords class. If providing this, cannot provide ``cell_size`` or ``npix``.
persistent_vis (Boolean): should the visibility cube be stored as part of the modules `state_dict`? If `True`, the state of the UV grid will be stored. It is recommended to use `False` for most applications, since the visibility cube will rarely be a direct parameter of the model.
"""

def __init__(self, cell_size=None, npix=None, coords=None):
def __init__(self, cell_size=None, npix=None, coords=None, persistent_vis=False):
super().__init__()

# we don't want to bother with the nchan argument here, so
Expand All @@ -41,8 +42,7 @@ def __init__(self, cell_size=None, npix=None, coords=None):

self.coords = GridCoords(cell_size=cell_size, npix=npix)

self.register_buffer("vis", None)

self.register_buffer("vis", None, persistent=persistent_vis)

def forward(self, cube):
"""
Expand All @@ -62,7 +62,7 @@ def forward(self, cube):
# since it needs to correct for the spacing of the input grid.
# See MPoL documentation and/or TMS Eqn A8.18 for more information.
self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2))

return self.vis

@property
Expand Down
6 changes: 3 additions & 3 deletions test/images_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_odd_npix():
images.BaseCube.from_image_properties(npix=853, nchan=30, cell_size=0.015)

with pytest.raises(ValueError, match=expected_error_message):
images.ImageCube.from_image_properteis(npix=853, nchan=30, cell_size=0.015)
images.ImageCube.from_image_properties(npix=853, nchan=30, cell_size=0.015)


def test_negative_cell_size():
Expand All @@ -24,11 +24,11 @@ def test_negative_cell_size():
images.BaseCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015)

with pytest.raises(ValueError, match=expected_error_message):
images.ImageCube.from_image_properteis(npix=800, nchan=30, cell_size=-0.015)
images.ImageCube.from_image_properties(npix=800, nchan=30, cell_size=-0.015)


def test_single_chan():
im = images.ImageCube.from_image_properteis(cell_size=0.015, npix=800)
im = images.ImageCube.from_image_properties(cell_size=0.015, npix=800)
assert im.nchan == 1


Expand Down

0 comments on commit a32f97e

Please sign in to comment.