Skip to content

Commit

Permalink
Merge pull request #2249 from rcomer/gridliner-updates
Browse files Browse the repository at this point in the history
Make GridLiner into an Artist
  • Loading branch information
greglucas authored Oct 2, 2023
2 parents f2bb81d + d3a3249 commit e424784
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 47 deletions.
33 changes: 15 additions & 18 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def __init__(self, *args, **kwargs):
self.projection = projection

super().__init__(*args, **kwargs)
self._gridliners = []
self.img_factories = []
self._done_img_factory = False

Expand Down Expand Up @@ -482,21 +481,19 @@ def _draw_preprocess(self, renderer):
if self.get_autoscale_on() and self.ignore_existing_data_limits:
self.autoscale_view()

# Adjust location of background patch so that new gridlines below are
# clipped correctly.
# Adjust location of background patch so that new gridlines generated
# by `draw` or `get_tightbbox` are clipped correctly.
self.patch._adjust_location()

self.apply_aspect()
for gl in self._gridliners:
gl._draw_gridliner(renderer=renderer)

def get_tightbbox(self, renderer, *args, **kwargs):
"""
Extend the standard behaviour of
:func:`matplotlib.axes.Axes.get_tightbbox`.
Adjust the axes aspect ratio, background patch location, and add
gridliners before calculating the tight bounding box.
Adjust the axes aspect ratio and background patch location before
calculating the tight bounding box.
"""
# Shared processing steps
self._draw_preprocess(renderer)
Expand All @@ -508,9 +505,8 @@ def draw(self, renderer=None, **kwargs):
"""
Extend the standard behaviour of :func:`matplotlib.axes.Axes.draw`.
Draw grid lines and image factory results before invoking standard
Matplotlib drawing. A global range is used if no limits have yet
been set.
Draw image factory results before invoking standard Matplotlib drawing.
A global range is used if no limits have yet been set.
"""
# Shared processing steps
self._draw_preprocess(renderer)
Expand All @@ -532,21 +528,22 @@ def draw(self, renderer=None, **kwargs):

def _update_title_position(self, renderer):
super()._update_title_position(renderer)
if not self._gridliners:
return

if self._autotitlepos is not None and not self._autotitlepos:
return

from cartopy.mpl.gridliner import Gridliner
gridliners = [a for a in self.artists if isinstance(a, Gridliner)]
if not gridliners:
return

# Get the max ymax of all top labels
top = -1
for gl in self._gridliners:
for gl in gridliners:
if gl.has_labels():
# Both top and geo labels can appear at the top of the axes
for label in (gl.top_label_artists +
gl.left_label_artists +
gl.right_label_artists):
# we skip bottom labels because they are usually
# not at the top
gl.geo_label_artists):
bb = label.get_tightbbox(renderer)
top = max(top, bb.ymax)
if top < 0:
Expand Down Expand Up @@ -1512,7 +1509,7 @@ def gridlines(self, crs=None, draw_labels=False,
labels_bbox_style=labels_bbox_style,
xpadding=xpadding, ypadding=ypadding, offset_angle=offset_angle,
auto_update=auto_update, formatter_kwargs=formatter_kwargs)
self._gridliners.append(gl)
self.add_artist(gl)
return gl

def _gen_axes_patch(self):
Expand Down
64 changes: 39 additions & 25 deletions lib/cartopy/mpl/gridliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import warnings

import matplotlib
import matplotlib.artist
import matplotlib.collections as mcollections
import matplotlib.text
import matplotlib.ticker as mticker
import matplotlib.transforms as mtrans
import numpy as np
Expand Down Expand Up @@ -101,11 +103,7 @@ def _north_south_formatted(latitude, num_format='g'):
_north_south_formatted(v))


class Gridliner:
# NOTE: In future, one of these objects will be add-able to a GeoAxes (and
# maybe even a plain old mpl axes) and it will call the "_draw_gridliner"
# method on draw. This will enable automatic gridline resolution
# determination on zoom/pan.
class Gridliner(matplotlib.artist.Artist):
def __init__(self, axes, crs, draw_labels=False, xlocator=None,
ylocator=None, collection_kwargs=None,
xformatter=None, yformatter=None, dms=False,
Expand All @@ -115,7 +113,7 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
xpadding=5, ypadding=5, offset_angle=25,
auto_update=False, formatter_kwargs=None):
"""
Object used by :meth:`cartopy.mpl.geoaxes.GeoAxes.gridlines`
Artist used by :meth:`cartopy.mpl.geoaxes.GeoAxes.gridlines`
to add gridlines and tick labels to a map.
Parameters
Expand Down Expand Up @@ -234,7 +232,13 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
used for the map, meridians and parallels can cross both the X axis and
the Y axis.
"""
self.axes = axes
super().__init__()

# We do not want the labels clipped to axes.
self.set_clip_on(False)
# Backcompat: the LineCollection was previously added directly to the
# axes, having a default zorder of 2.
self.set_zorder(2)

#: The :class:`~matplotlib.ticker.Locator` to use for the x
#: gridlines and labels.
Expand Down Expand Up @@ -332,10 +336,10 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
raise ValueError(f"Invalid draw_labels argument: {value}")

if auto_inline:
if isinstance(self.axes.projection, _X_INLINE_PROJS):
if isinstance(axes.projection, _X_INLINE_PROJS):
self.x_inline = True
self.y_inline = False
elif isinstance(self.axes.projection, _POLAR_PROJS):
elif isinstance(axes.projection, _POLAR_PROJS):
self.x_inline = False
self.y_inline = True
else:
Expand Down Expand Up @@ -399,7 +403,7 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
#: Control the rotation of labels.
if rotate_labels is None:
rotate_labels = (
self.axes.projection.__class__ in _ROTATE_LABEL_PROJS)
axes.projection.__class__ in _ROTATE_LABEL_PROJS)
if not isinstance(rotate_labels, (bool, float, int)):
raise ValueError("Invalid rotate_labels argument")
self.rotate_labels = rotate_labels
Expand Down Expand Up @@ -436,10 +440,6 @@ def __init__(self, axes, crs, draw_labels=False, xlocator=None,
self._drawn = False
self._auto_update = auto_update

# Check visibility of labels at each draw event
# (or once drawn, only at resize event ?)
self.axes.figure.canvas.mpl_connect('draw_event', self._draw_event)

@property
def xlabels_top(self):
warnings.warn('The .xlabels_top attribute is deprecated. Please '
Expand Down Expand Up @@ -488,9 +488,6 @@ def ylabels_right(self, value):
'use .right_labels to toggle visibility instead.')
self.right_labels = value

def _draw_event(self, event):
self._draw_gridliner(renderer=event.renderer)

def has_labels(self):
return len(self._labels) != 0

Expand Down Expand Up @@ -629,13 +626,9 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
return
self._drawn = True

# Clear lists of artists
for lines in [*self.xline_artists, *self.yline_artists]:
lines.remove()
# Clear lists of child artists
self.xline_artists.clear()
self.yline_artists.clear()
for label in self._labels:
label.artist.remove()
self._labels.clear()

# Inits
Expand Down Expand Up @@ -673,6 +666,7 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
if not any(x in collection_kwargs for x in ['lw', 'linewidth']):
collection_kwargs.setdefault('linewidth',
matplotlib.rcParams['grid.linewidth'])
collection_kwargs.setdefault('clip_path', self.axes.patch)

# Meridians
lat_min, lat_max = lat_lim
Expand All @@ -696,7 +690,6 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
lon_lc = mcollections.LineCollection(lon_lines,
**collection_kwargs)
self.xline_artists.append(lon_lc)
self.axes.add_collection(lon_lc, autolim=False)

# Parallels
lon_min, lon_max = lon_lim
Expand All @@ -711,7 +704,6 @@ def _draw_gridliner(self, nx=None, ny=None, renderer=None):
lat_lc = mcollections.LineCollection(lat_lines,
**collection_kwargs)
self.yline_artists.append(lat_lc)
self.axes.add_collection(lat_lc, autolim=False)

#################
# Label drawing #
Expand Down Expand Up @@ -925,7 +917,9 @@ def update_artist(artist, renderer):

# Add text to the plot
text = formatter(tick_value)
artist = self.axes.text(x, y, text, **kw)
artist = matplotlib.text.Text(x, y, text, **kw)
artist.set_figure(self.axes.figure)
artist.axes = self.axes

# Update loc from spine overlapping now that we have a bbox
# of the label.
Expand Down Expand Up @@ -1239,6 +1233,26 @@ def _axes_domain(self, nx=None, ny=None):

return lon_range, lat_range

def get_visible_children(self):
r"""Return a list of the visible child `.Artist`\s."""
all_children = (self.xline_artists + self.yline_artists
+ self.label_artists)
return [c for c in all_children if c.get_visible()]

def get_tightbbox(self, renderer=None):
self._draw_gridliner(renderer=renderer)
bboxes = [c.get_tightbbox(renderer=renderer)
for c in self.get_visible_children()]
if bboxes:
return mtrans.Bbox.union(bboxes)
else:
return mtrans.Bbox.null()

def draw(self, renderer=None):
self._draw_gridliner(renderer=renderer)
for c in self.get_visible_children():
c.draw(renderer=renderer)


class Label:
"""Helper class to manage the attributes for a single label"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion lib/cartopy/tests/mpl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ def pytest_itemcollected(item):
return
elif path.basename == 'tests':
subdir = item.fspath.relto(path)[:-len(item.fspath.ext)]
mpl_marker.kwargs['baseline_dir'] = f'baseline_images/{subdir}'
mpl_marker.kwargs.setdefault('baseline_dir',
f'baseline_images/{subdir}')
break
94 changes: 91 additions & 3 deletions lib/cartopy/tests/mpl/test_gridliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.

import io
from unittest import mock

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
Expand All @@ -13,7 +16,8 @@
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
from cartopy.mpl.gridliner import (LATITUDE_FORMATTER, LONGITUDE_FORMATTER,
classic_formatter, classic_locator)
Gridliner, classic_formatter,
classic_locator)
from cartopy.mpl.ticker import LongitudeFormatter, LongitudeLocator


Expand Down Expand Up @@ -241,9 +245,14 @@ def test_grid_labels_tight():
fig.tight_layout()

# Ensure gridliners were drawn
num_gridliners_drawn = 0
for ax in fig.axes:
for gl in ax._gridliners:
assert hasattr(gl, '_drawn') and gl._drawn
for artist in ax.artists:
if isinstance(artist, Gridliner) and getattr(artist, '_drawn',
False):
num_gridliners_drawn += 1

assert num_gridliners_drawn == 4

return fig

Expand Down Expand Up @@ -432,3 +441,82 @@ def test_gridliner_formatter_kwargs():
fig.canvas.draw()
labels = [a.get_text() for a in gl.bottom_label_artists if a.get_visible()]
assert labels == ['75°O', '70°O', '65°O', '60°O', '55°O', '50°O', '45°O']


def test_gridliner_count_draws():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
gl = ax.gridlines()

with mock.patch.object(gl, '_draw_gridliner', return_value=None) as mocked:
ax.get_tightbbox(renderer=None)
mocked.assert_called_once()

with mock.patch.object(gl, '_draw_gridliner', return_value=None) as mocked:
fig.draw_without_rendering()
mocked.assert_called_once()


@pytest.mark.mpl_image_compare(
baseline_dir='baseline_images/mpl/test_mpl_integration',
filename='simple_global.png')
def test_gridliner_remove():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
ax.coastlines()
gl = ax.gridlines(draw_labels=True)
fig.draw_without_rendering() # Generate child artists
gl.remove()

assert gl not in ax.artists
assert not ax.collections

return fig


def test_gridliner_save_tight_bbox():
# Smoke test for save with auto_update=True and bbox_inches=Tight (gh2246).
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
ax.gridlines(draw_labels=True, auto_update=True)
fig.savefig(io.BytesIO(), bbox_inches='tight')


@pytest.mark.mpl_image_compare(filename='gridliner_labels_title_adjust.png',
tolerance=grid_label_tol)
def test_gridliner_title_adjust():
# Test that title do not overlap labels
projs = [ccrs.Mercator(), ccrs.AlbersEqualArea(), ccrs.LambertConformal(),
ccrs.Orthographic()]

# Turn on automatic title placement (this is default in mpl rcParams but
# not in these tests).
plt.rcParams['axes.titley'] = None

fig = plt.figure(layout='constrained')
fig.get_layout_engine().set(h_pad=1/8)
for n, proj in enumerate(projs, 1):
ax = fig.add_subplot(2, 2, n, projection=proj)
ax.coastlines()
ax.gridlines(draw_labels=True)
ax.set_title(proj.__class__.__name__)

return fig


def test_gridliner_title_noadjust():
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_global()
ax.set_title('foo')
ax.gridlines(draw_labels=['left', 'right'], ylocs=[-60, 0, 60])
fig.draw_without_rendering()
pos = ax.title.get_position()

# Title position shouldn't change when a label is on the top boundary.
ax.set_extent([-180, 180, -60, 60])
fig.draw_without_rendering()
assert ax.title.get_position() == pos

0 comments on commit e424784

Please sign in to comment.