diff --git a/lib/cartopy/mpl/geoaxes.py b/lib/cartopy/mpl/geoaxes.py index d1a94cb7b..535aa2924 100644 --- a/lib/cartopy/mpl/geoaxes.py +++ b/lib/cartopy/mpl/geoaxes.py @@ -410,7 +410,6 @@ def __init__(self, *args, **kwargs): self.projection = projection super().__init__(*args, **kwargs) - self._gridliners = [] self.img_factories = [] self._done_img_factory = False @@ -482,21 +481,19 @@ def _draw_preprocess(self, renderer): if self.get_autoscale_on() and self.ignore_existing_data_limits: self.autoscale_view() - # Adjust location of background patch so that new gridlines below are - # clipped correctly. + # Adjust location of background patch so that new gridlines generated + # by `draw` or `get_tightbbox` are clipped correctly. self.patch._adjust_location() self.apply_aspect() - for gl in self._gridliners: - gl._draw_gridliner(renderer=renderer) def get_tightbbox(self, renderer, *args, **kwargs): """ Extend the standard behaviour of :func:`matplotlib.axes.Axes.get_tightbbox`. - Adjust the axes aspect ratio, background patch location, and add - gridliners before calculating the tight bounding box. + Adjust the axes aspect ratio and background patch location before + calculating the tight bounding box. """ # Shared processing steps self._draw_preprocess(renderer) @@ -508,9 +505,8 @@ def draw(self, renderer=None, **kwargs): """ Extend the standard behaviour of :func:`matplotlib.axes.Axes.draw`. - Draw grid lines and image factory results before invoking standard - Matplotlib drawing. A global range is used if no limits have yet - been set. + Draw image factory results before invoking standard Matplotlib drawing. + A global range is used if no limits have yet been set. """ # Shared processing steps self._draw_preprocess(renderer) @@ -532,21 +528,22 @@ def draw(self, renderer=None, **kwargs): def _update_title_position(self, renderer): super()._update_title_position(renderer) - if not self._gridliners: - return if self._autotitlepos is not None and not self._autotitlepos: return + from cartopy.mpl.gridliner import Gridliner + gridliners = [a for a in self.artists if isinstance(a, Gridliner)] + if not gridliners: + return + # Get the max ymax of all top labels top = -1 - for gl in self._gridliners: + for gl in gridliners: if gl.has_labels(): + # Both top and geo labels can appear at the top of the axes for label in (gl.top_label_artists + - gl.left_label_artists + - gl.right_label_artists): - # we skip bottom labels because they are usually - # not at the top + gl.geo_label_artists): bb = label.get_tightbbox(renderer) top = max(top, bb.ymax) if top < 0: @@ -1512,7 +1509,7 @@ def gridlines(self, crs=None, draw_labels=False, labels_bbox_style=labels_bbox_style, xpadding=xpadding, ypadding=ypadding, offset_angle=offset_angle, auto_update=auto_update, formatter_kwargs=formatter_kwargs) - self._gridliners.append(gl) + self.add_artist(gl) return gl def _gen_axes_patch(self): diff --git a/lib/cartopy/mpl/gridliner.py b/lib/cartopy/mpl/gridliner.py index dcc22f523..2469e251b 100644 --- a/lib/cartopy/mpl/gridliner.py +++ b/lib/cartopy/mpl/gridliner.py @@ -9,7 +9,9 @@ import warnings import matplotlib +import matplotlib.artist import matplotlib.collections as mcollections +import matplotlib.text import matplotlib.ticker as mticker import matplotlib.transforms as mtrans import numpy as np @@ -101,11 +103,7 @@ def _north_south_formatted(latitude, num_format='g'): _north_south_formatted(v)) -class Gridliner: - # NOTE: In future, one of these objects will be add-able to a GeoAxes (and - # maybe even a plain old mpl axes) and it will call the "_draw_gridliner" - # method on draw. This will enable automatic gridline resolution - # determination on zoom/pan. +class Gridliner(matplotlib.artist.Artist): def __init__(self, axes, crs, draw_labels=False, xlocator=None, ylocator=None, collection_kwargs=None, xformatter=None, yformatter=None, dms=False, @@ -115,7 +113,7 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None, xpadding=5, ypadding=5, offset_angle=25, auto_update=False, formatter_kwargs=None): """ - Object used by :meth:`cartopy.mpl.geoaxes.GeoAxes.gridlines` + Artist used by :meth:`cartopy.mpl.geoaxes.GeoAxes.gridlines` to add gridlines and tick labels to a map. Parameters @@ -234,7 +232,13 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None, used for the map, meridians and parallels can cross both the X axis and the Y axis. """ - self.axes = axes + super().__init__() + + # We do not want the labels clipped to axes. + self.set_clip_on(False) + # Backcompat: the LineCollection was previously added directly to the + # axes, having a default zorder of 2. + self.set_zorder(2) #: The :class:`~matplotlib.ticker.Locator` to use for the x #: gridlines and labels. @@ -332,10 +336,10 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None, raise ValueError(f"Invalid draw_labels argument: {value}") if auto_inline: - if isinstance(self.axes.projection, _X_INLINE_PROJS): + if isinstance(axes.projection, _X_INLINE_PROJS): self.x_inline = True self.y_inline = False - elif isinstance(self.axes.projection, _POLAR_PROJS): + elif isinstance(axes.projection, _POLAR_PROJS): self.x_inline = False self.y_inline = True else: @@ -399,7 +403,7 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None, #: Control the rotation of labels. if rotate_labels is None: rotate_labels = ( - self.axes.projection.__class__ in _ROTATE_LABEL_PROJS) + axes.projection.__class__ in _ROTATE_LABEL_PROJS) if not isinstance(rotate_labels, (bool, float, int)): raise ValueError("Invalid rotate_labels argument") self.rotate_labels = rotate_labels @@ -436,10 +440,6 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None, self._drawn = False self._auto_update = auto_update - # Check visibility of labels at each draw event - # (or once drawn, only at resize event ?) - self.axes.figure.canvas.mpl_connect('draw_event', self._draw_event) - @property def xlabels_top(self): warnings.warn('The .xlabels_top attribute is deprecated. Please ' @@ -488,9 +488,6 @@ def ylabels_right(self, value): 'use .right_labels to toggle visibility instead.') self.right_labels = value - def _draw_event(self, event): - self._draw_gridliner(renderer=event.renderer) - def has_labels(self): return len(self._labels) != 0 @@ -629,13 +626,9 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None): return self._drawn = True - # Clear lists of artists - for lines in [*self.xline_artists, *self.yline_artists]: - lines.remove() + # Clear lists of child artists self.xline_artists.clear() self.yline_artists.clear() - for label in self._labels: - label.artist.remove() self._labels.clear() # Inits @@ -673,6 +666,7 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None): if not any(x in collection_kwargs for x in ['lw', 'linewidth']): collection_kwargs.setdefault('linewidth', matplotlib.rcParams['grid.linewidth']) + collection_kwargs.setdefault('clip_path', self.axes.patch) # Meridians lat_min, lat_max = lat_lim @@ -696,7 +690,6 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None): lon_lc = mcollections.LineCollection(lon_lines, **collection_kwargs) self.xline_artists.append(lon_lc) - self.axes.add_collection(lon_lc, autolim=False) # Parallels lon_min, lon_max = lon_lim @@ -711,7 +704,6 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None): lat_lc = mcollections.LineCollection(lat_lines, **collection_kwargs) self.yline_artists.append(lat_lc) - self.axes.add_collection(lat_lc, autolim=False) ################# # Label drawing # @@ -925,7 +917,9 @@ def update_artist(artist, renderer): # Add text to the plot text = formatter(tick_value) - artist = self.axes.text(x, y, text, **kw) + artist = matplotlib.text.Text(x, y, text, **kw) + artist.set_figure(self.axes.figure) + artist.axes = self.axes # Update loc from spine overlapping now that we have a bbox # of the label. @@ -1239,6 +1233,26 @@ def _axes_domain(self, nx=None, ny=None): return lon_range, lat_range + def get_visible_children(self): + r"""Return a list of the visible child `.Artist`\s.""" + all_children = (self.xline_artists + self.yline_artists + + self.label_artists) + return [c for c in all_children if c.get_visible()] + + def get_tightbbox(self, renderer=None): + self._draw_gridliner(renderer=renderer) + bboxes = [c.get_tightbbox(renderer=renderer) + for c in self.get_visible_children()] + if bboxes: + return mtrans.Bbox.union(bboxes) + else: + return mtrans.Bbox.null() + + def draw(self, renderer=None): + self._draw_gridliner(renderer=renderer) + for c in self.get_visible_children(): + c.draw(renderer=renderer) + class Label: """Helper class to manage the attributes for a single label""" diff --git a/lib/cartopy/tests/mpl/baseline_images/mpl/test_gridliner/gridliner_labels_title_adjust.png b/lib/cartopy/tests/mpl/baseline_images/mpl/test_gridliner/gridliner_labels_title_adjust.png new file mode 100644 index 000000000..9ed9c4ffe Binary files /dev/null and b/lib/cartopy/tests/mpl/baseline_images/mpl/test_gridliner/gridliner_labels_title_adjust.png differ diff --git a/lib/cartopy/tests/mpl/conftest.py b/lib/cartopy/tests/mpl/conftest.py index 60f172364..03f307562 100644 --- a/lib/cartopy/tests/mpl/conftest.py +++ b/lib/cartopy/tests/mpl/conftest.py @@ -35,5 +35,6 @@ def pytest_itemcollected(item): return elif path.basename == 'tests': subdir = item.fspath.relto(path)[:-len(item.fspath.ext)] - mpl_marker.kwargs['baseline_dir'] = f'baseline_images/{subdir}' + mpl_marker.kwargs.setdefault('baseline_dir', + f'baseline_images/{subdir}') break diff --git a/lib/cartopy/tests/mpl/test_gridliner.py b/lib/cartopy/tests/mpl/test_gridliner.py index 120295c4a..92884cf36 100644 --- a/lib/cartopy/tests/mpl/test_gridliner.py +++ b/lib/cartopy/tests/mpl/test_gridliner.py @@ -4,6 +4,9 @@ # See COPYING and COPYING.LESSER in the root of the repository for full # licensing details. +import io +from unittest import mock + import matplotlib.pyplot as plt import matplotlib.ticker as mticker import numpy as np @@ -13,7 +16,8 @@ import cartopy.crs as ccrs from cartopy.mpl.geoaxes import GeoAxes from cartopy.mpl.gridliner import (LATITUDE_FORMATTER, LONGITUDE_FORMATTER, - classic_formatter, classic_locator) + Gridliner, classic_formatter, + classic_locator) from cartopy.mpl.ticker import LongitudeFormatter, LongitudeLocator @@ -241,9 +245,14 @@ def test_grid_labels_tight(): fig.tight_layout() # Ensure gridliners were drawn + num_gridliners_drawn = 0 for ax in fig.axes: - for gl in ax._gridliners: - assert hasattr(gl, '_drawn') and gl._drawn + for artist in ax.artists: + if isinstance(artist, Gridliner) and getattr(artist, '_drawn', + False): + num_gridliners_drawn += 1 + + assert num_gridliners_drawn == 4 return fig @@ -432,3 +441,82 @@ def test_gridliner_formatter_kwargs(): fig.canvas.draw() labels = [a.get_text() for a in gl.bottom_label_artists if a.get_visible()] assert labels == ['75°O', '70°O', '65°O', '60°O', '55°O', '50°O', '45°O'] + + +def test_gridliner_count_draws(): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + ax.set_global() + gl = ax.gridlines() + + with mock.patch.object(gl, '_draw_gridliner', return_value=None) as mocked: + ax.get_tightbbox(renderer=None) + mocked.assert_called_once() + + with mock.patch.object(gl, '_draw_gridliner', return_value=None) as mocked: + fig.draw_without_rendering() + mocked.assert_called_once() + + +@pytest.mark.mpl_image_compare( + baseline_dir='baseline_images/mpl/test_mpl_integration', + filename='simple_global.png') +def test_gridliner_remove(): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + ax.set_global() + ax.coastlines() + gl = ax.gridlines(draw_labels=True) + fig.draw_without_rendering() # Generate child artists + gl.remove() + + assert gl not in ax.artists + assert not ax.collections + + return fig + + +def test_gridliner_save_tight_bbox(): + # Smoke test for save with auto_update=True and bbox_inches=Tight (gh2246). + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + ax.set_global() + ax.gridlines(draw_labels=True, auto_update=True) + fig.savefig(io.BytesIO(), bbox_inches='tight') + + +@pytest.mark.mpl_image_compare(filename='gridliner_labels_title_adjust.png', + tolerance=grid_label_tol) +def test_gridliner_title_adjust(): + # Test that title do not overlap labels + projs = [ccrs.Mercator(), ccrs.AlbersEqualArea(), ccrs.LambertConformal(), + ccrs.Orthographic()] + + # Turn on automatic title placement (this is default in mpl rcParams but + # not in these tests). + plt.rcParams['axes.titley'] = None + + fig = plt.figure(layout='constrained') + fig.get_layout_engine().set(h_pad=1/8) + for n, proj in enumerate(projs, 1): + ax = fig.add_subplot(2, 2, n, projection=proj) + ax.coastlines() + ax.gridlines(draw_labels=True) + ax.set_title(proj.__class__.__name__) + + return fig + + +def test_gridliner_title_noadjust(): + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree()) + ax.set_global() + ax.set_title('foo') + ax.gridlines(draw_labels=['left', 'right'], ylocs=[-60, 0, 60]) + fig.draw_without_rendering() + pos = ax.title.get_position() + + # Title position shouldn't change when a label is on the top boundary. + ax.set_extent([-180, 180, -60, 60]) + fig.draw_without_rendering() + assert ax.title.get_position() == pos