Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-40359: Implement Butler.get for matplotlib generated png files #877

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True

[mypy-matplotlib.*]
ignore_missing_imports = True

[mypy-pyarrow.*]
ignore_missing_imports = True

Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/configs/storageClasses.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,14 @@ storageClasses:
dict: lsst.utils.packages.Packages
NumpyArray:
pytype: numpy.ndarray
converters:
matplotlib.figure.Figure: lsst.daf.butler.formatters.matplotlib.MatplotlibFormatter.dummyConverter
Thumbnail:
pytype: numpy.ndarray
Plot:
pytype: matplotlib.figure.Figure
converters:
numpy.ndarray: lsst.daf.butler.formatters.matplotlib.MatplotlibFormatter.fromArray
MetricValue:
pytype: lsst.verify.Measurement
StampsBase:
Expand Down
19 changes: 18 additions & 1 deletion python/lsst/daf/butler/formatters/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

from typing import Any

import matplotlib.pyplot as plt
import numpy as np

from .file import FileFormatter


Expand All @@ -38,8 +41,22 @@ class MatplotlibFormatter(FileFormatter):

def _readFile(self, path: str, pytype: type[Any] | None = None) -> Any:
# docstring inherited from FileFormatter._readFile
raise NotImplementedError(f"matplotlib figures cannot be read by the butler; path is {path}")
return plt.imread(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is returning something unexpected (and effectively backwards from expectations) please add a comment saying that it's not returning a figure.


def _writeFile(self, inMemoryDataset: Any) -> None:
# docstring inherited from FileFormatter._writeFile
inMemoryDataset.savefig(self.fileDescriptor.location.path)

@staticmethod
def fromArray(cls: np.ndarray) -> plt.Figure:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not need to be a static method in the formatter class. It's okay for this to be a private function in the file that is not exported and only referenced from the yaml file. The first parameter should not be named cls, it should be array or something.

"""Convert an array into a Figure."""
fig = plt.figure()
plt.imshow(cls)
return fig

@staticmethod
def dummyCovnerter(cls: np.ndarray) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name has a typo in it. Also see comment about naming I make above.

I am confused by the signature. Surely this should take a plt.Figure and return an np.ndarray? Is there no way to do that conversion? It is going to be problematic in some contexts (eg in-memory datastore) if this method is not implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is just what the name says, a dummy converter. This is what I was trying to get across in slack: In some cases we want to just load the png file directly, although a numpy array is fine. Yes, it is possible to convert a Figure to a numpy array, but it's messy. You basically have to create a dummy file, save the figure to the dummy file, then read the dummy file as a numpy array. 👎

I did the cursory stack overflow search, and that was the recommended solution. I also looked at the matplotlib codebase a bit and it does indeed look like it would be non-trivial to try and convert the figure to a numpy array directly. This is probably why the method wasn't implemented in the first place.

The other solution that I thought of is that we could define a PNG store class, which would be an IoBytes instance, and convert plots and arrays to and from the PNG class. In other words

@staticmethod
def fig_to_png(fig):
    buf = IoBytes()
    fig.savefig(buf)
    return buf

@staticmethod
def png_to_array(png):
    return plt.imread(png)

@staticmethod
def png_to_fig(png):
    arr = png_to_array(png)
    fig = plt.figure()
    plt.imshow(arr)
    return fig

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely untested but can something like this work?

def fig_to_np(fig: pyplot.figure.Figure) -> np.ndarray:
    buf = io.BytesIO()
    fig.savefig(buf, format="png")
    return plt.imread(buf, format="png")

?

"""This converter exists to trick the Butler into allowing
a numpy array on read with ``storageClass='NumpyArray'``.
"""
return cls
11 changes: 9 additions & 2 deletions tests/test_matplotlibFormatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import unittest
from random import Random

import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be a protected import since numpy is not butler requirement (although effectively it is since astropy won't work without it).


try:
import matplotlib

Expand Down Expand Up @@ -78,8 +80,13 @@ def testMatplotlibFormatter(self):
pyplot.gcf().savefig(file.name)
self.assertTrue(filecmp.cmp(local.ospath, file.name, shallow=True))
self.assertTrue(butler.exists(ref))
with self.assertRaises(ValueError):
butler.get(ref)

fig = butler.get(ref)
# Ensure that the result is a figure
self.assertTrue(isinstance(fig, pyplot.Figure))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertTrue(isinstance(fig, pyplot.Figure))
self.assertIsinstance(fig, pyplot.Figure)

image = butler.get(ref, storageClass="NumpyArray")
self.assertTrue(isinstance(image, np.ndarray))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.assertTrue(isinstance(image, np.ndarray))
self.assertIsinstance(image, np.ndarray)


butler.pruneDatasets([ref], unstore=True, purge=True)
self.assertFalse(butler.exists(ref))

Expand Down