Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colormap not preserved in scatter plot when generating baseline plot. #84

Open
rcjackson opened this issue May 15, 2019 · 5 comments
Open

Comments

@rcjackson
Copy link

rcjackson commented May 15, 2019

Hello, when we are trying to create a baseline for a unit test that uses a scatter plot, we have noticed that pytest-mpl has not been preserving the colormap used in our scatter plot.

For example, the plot we wish to compare against is this:
myfig

When we run the unittest through simply doing an import and showing the resulting figure we get the figure above. However, whenever we use pytest to generate the baseline figure, we get the below figure:
test_time_height_scatter

Main test:

@pytest.mark.mpl_image_compare(tolerance=30)
def test_time_height_scatter():
    sonde_ds = arm.read_netcdf(
        sample_files.EXAMPLE_SONDE1)

    display = TimeSeriesDisplay({'sgpsondewnpnC1.b1': sonde_ds},
                                figsize=(7, 3))
    display.time_height_scatter('tdry', day_night_background=True)
    sonde_ds.close()

    return display.fig

time_height_scatter routine:

    def time_height_scatter(
            self, data_field=None, dsname=None, cmap='rainbow',
            alt_label=None, alt_field='alt', cb_label=None, **kwargs):
        """
        Create a time series plot of altitued and data varible with
        color also indicating value with a color bar. The Color bar is
        positioned to serve both as the indicator of the color intensity
        and the second y-axis.

        Parameters
        ----------
        data_field: str
            Name of data field in the object to plot on second y-axis
        height_field: str
            Name of height field in the object to plot on first y-axis.
        dsname: str or None
            The name of the datastream to plot
        cmap: str
            Colorbar corlor map to use.
        alt_label: str
            Altitued first y-axis label to use. If not set will try to use
            long_name and units.
        alt_field: str
            Label for field in the object to plot on first y-axis.
        cb_label: str
            Colorbar label to use. If not set will try to use
            long_name and units.
        **kwargs: keyword arguments
            Any other keyword arguments that will be passed
            into TimeSeriesDisplay.plot module when the figure
            is made.
        """
        if dsname is None and len(self._arm.keys()) > 1:
            raise ValueError(("You must choose a datastream when there are 2 "
                              "or more datasets in the TimeSeriesDisplay "
                              "object."))
        elif dsname is None:
            dsname = list(self._arm.keys())[0]

        # Get data and dimensions
        data = self._arm[dsname][data_field]
        altitude = self._arm[dsname][alt_field]
        dim = list(self._arm[dsname][data_field].dims)
        xdata = self._arm[dsname][dim[0]]

        if alt_label is None:
            try:
                alt_label = (altitude.attrs['long_name'] +
                             ''.join([' (', altitude.attrs['units'], ')']))
            except KeyError:
                alt_label = alt_field

        if cb_label is None:
            try:
                cb_label = (data.attrs['long_name'] +
                            ''.join([' (', data.attrs['units'], ')']))
            except KeyError:
                cb_label = data_field

        colorbar_map = plt.cm.get_cmap(cmap)
        self.fig.subplots_adjust(left=0.1, right=0.86,
                                 bottom=0.16, top=0.91)
        ax1 = self.plot(alt_field, color='black', **kwargs)
        ax1.set_ylabel(alt_label)
        ax2 = ax1.twinx()
        sc = ax2.scatter(xdata.values, data.values, c=data.values,
                         marker='.', cmap=colorbar_map)
        cbaxes = self.fig.add_axes(
            [self.fig.subplotpars.right + 0.02, self.fig.subplotpars.bottom,
             0.02, self.fig.subplotpars.top - self.fig.subplotpars.bottom])
        cbar = plt.colorbar(sc, cax=cbaxes)
        ax2.set_ylim(cbar.get_clim())
        cbar.ax.set_ylabel(cb_label)
        ax2.set_yticklabels([])

        return self.axes[0]

Any help you could provide on this would be appreciated.

@dopplershift
Copy link
Contributor

Any chance you can boil that down to a basic matplotlib plot that reproduces the failure? There's a lot of stuff that uses self in that example, not to mention time_height_scatter doesn't set .fig, which is what interacts with pytest-mpl--so we'd need the whole class, and I think you're probably in the best position to boil that down to what's relevant.

@dvalters
Copy link

Sorry to resurrect this - having a similar issue as well. I'm going to try create a MWE that reproduces it if that helps.

b w

@thatlittleboy
Copy link

I faced a similar issue today as well when plotting with fill_between. I don't have a MWE, but for future users who chance upon this issue, try setting edgecolor="face" (i.e., same color as the facecolor).

Those black lines are essentially all the edges being coloured in. Seems like pytest-mpl forces edgecolor to be black for some reason (doesn't happen when I run the exact same piece of plotting function on Jupyterlab, for example).

@ConorMacBride
Copy link
Member

I believe the differences are due to pytest-mpl using the classic style by default. (Matplotlib changed their default style in v2.0) To use Matplotlib's current default style in pytest-mpl, test functions should be decorated with @pytest.mark.mpl_image_compare(style="default") instead. Please let me know if that makes the plots identical.

I hope to make this the default in pytest-mpl v1.0.0 (#198).

@thatlittleboy
Copy link

@ConorMacBride Hi, thanks for the hint. Yes, setting style="default" fixes it.

For what it's worth, I made a (somewhat) minimal example, I'm assuming it's tangentially related to what the OP was experiencing.

@pytest.mark.mpl_image_compare(
    style="default",  # v.s. "classic"
)
def test_myfoo():
    import matplotlib.pyplot as plt
    import numpy as np

    ys=np.array([[1.23710908e-03, 4.20836477e-02, 2.09052698e-01, 2.67737557e-01,
        1.31218302e-01],
       [1.43456874e-01, 2.25156978e-01, 3.83926106e-01, 4.14150010e-01,
        2.48079254e-01],
       [1.43456874e-01, 2.27790562e-01, 1.25732949e+00, 4.39741108e-01,
        2.48079254e-01],
       [1.50913680e-01, 2.93258751e-01, 1.43550621e+00, 6.55657398e-01,
        3.99194379e-01],
       [2.41509546e-01, 4.45408842e-01, 1.60328764e+00, 7.91195155e-01,
        4.65439861e-01],
       [2.41509546e-01, 4.45408842e-01, 3.94855154e+00, 7.91195155e-01,
        4.65439861e-01],
       [2.79338293e-01, 6.54052129e-01, 4.29487580e+00, 1.12027195e+00,
        5.35642929e-01],
       [2.79344299e-01, 6.59965244e-01, 4.47781657e+00, 1.47472077e+00,
        6.50638813e-01],
       [2.79344299e-01, 6.59965244e-01, 4.47781657e+00, 2.15331765e+00,
        6.51761649e-01],
       [4.08001626e-01, 8.93606884e-01, 4.70345632e+00, 2.32287025e+00,
        7.76499905e-01],
       [5.06513650e-01, 1.05100248e+00, 4.91174350e+00, 2.57349145e+00,
        9.51840954e-01],
       [5.06513650e-01, 1.05100248e+00, 4.91178730e+00, 2.58122180e+00,
        9.51840954e-01],
       [6.26462036e-01, 1.28269152e+00, 5.12224572e+00, 2.65823967e+00,
        9.58692172e-01],
       [6.26462164e-01, 1.32131699e+00, 5.95439519e+00, 2.74620416e+00,
        9.58693576e-01],
       [7.13215309e-01, 1.52007822e+00, 6.19287378e+00, 2.93934942e+00,
        1.10517416e+00],
       [7.13215318e-01, 1.66888175e+00, 6.53449663e+00, 2.93934967e+00,
        1.10517416e+00],
       [8.25241450e-01, 1.79552434e+00, 6.65547289e+00, 3.03345712e+00,
        1.15874248e+00],
       [8.33397700e-01, 1.89528376e+00, 6.93623434e+00, 3.37110135e+00,
        1.38036646e+00],
       [9.36683719e-01, 2.14890549e+00, 7.22137506e+00, 3.59091813e+00,
        1.48106671e+00],
       [9.36683719e-01, 2.14908847e+00, 7.55258998e+00, 3.93106153e+00,
        1.48127267e+00]])
    x_points=np.array([-2.55298982, -1.34730371, -0.1416176 ,  1.06406851,  2.26975462])

    ys = np.cumsum(ys, axis=0)
    scale = ys.max() * 2 / 0.8

    fig, ax = plt.subplots()
    for i in range(4, -1, -1):
        y = ys[i, :] / scale
        c = plt.get_cmap("coolwarm")(i/4)
        plt.fill_between(x_points, -y, y, facecolor=c)

    plt.tight_layout()
    return fig

image

My actual image includes many more "layers" than this (200, v.s. 5 drawn here), so it ends up becoming almost completely black due to the edges being drawn in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants