diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 91d1137d..dea23c03 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -9,8 +9,7 @@ import re import sys from numbers import Integral, Number -from typing import Any -from collections.abc import Iterable +from typing import Any, Iterable import matplotlib.artist as martist import matplotlib.axes as maxes @@ -1388,12 +1387,70 @@ def _add_auto_labels( formatter = constructor.Formatter( formatter, precision=precision, **formatter_kw ) # noqa: E501 - if isinstance(obj, mcontour.ContourSet): - self._add_contour_labels(obj, cobj, formatter, **labels_kw) - elif isinstance(obj, mcollections.Collection): - self._add_collection_labels(obj, formatter, **labels_kw) - else: - raise RuntimeError(f"Not possible to add labels to object {obj!r}.") + match obj: + case mcontour.ContourSet(): + self._add_contour_labels(obj, cobj, formatter, **labels_kw) + case mcollections.QuadMesh(): + self._add_quadmesh_labels(obj, formatter, **labels_kw) + case mcollections.Collection(): + self._add_collection_labels(obj, formatter, **labels_kw) + case _: + raise RuntimeError(f"Not possible to add labels to object {obj!r}.") + + + def _add_quadmesh_labels( + self, + obj, + fmt, + *, + c=None, + color=None, + colors=None, + size=None, + fontsize=None, + **kwargs, + ): + """ + Add labels to QuadMesh cells with support for shade-dependent text colors. + Values are inferred from the unnormalized mesh cell color. + """ + # Parse input args + obj.update_scalarmappable() + color = _not_none(c=c, color=color, colors=colors) + fontsize = _not_none(size=size, fontsize=fontsize, default=rc["font.smallsize"]) + kwargs.setdefault("ha", "center") + kwargs.setdefault("va", "center") + + # Get the mesh data + array = obj.get_array() + coords = obj.get_coordinates() # This gives vertices (11x11x2) + + # Calculate cell centers by averaging the four corners of each cell + x_centers = (coords[:-1, :-1, 0] + coords[1:, 1:, 0]) / 2 + y_centers = (coords[:-1, :-1, 1] + coords[1:, 1:, 1]) / 2 + + # Apply colors and create labels + labs = [] + for i, ((x, y), value) in enumerate(zip(zip(x_centers.flat, y_centers.flat), array.flat)): + # Skip masked or invalid values + if value is ma.masked or not np.isfinite(value): + continue + + # Handle discrete normalization if present + if isinstance(obj.norm, pcolors.DiscreteNorm): + value = obj.norm._norm.inverse(obj.norm(value)) + + # Determine text color based on background + icolor = color + if color is None: + _, _, lum = utils.to_xyz(obj.cmap(obj.norm(value)), "hcl") + icolor = "w" if lum < 50 else "k" + + # Create text label + lab = self.text(x, y, fmt(value), color=icolor, size=fontsize, **kwargs) + labs.append(lab) + + return labs def _add_collection_labels( self, @@ -1444,7 +1501,6 @@ def _add_collection_labels( y = (bbox.ymin + bbox.ymax) / 2 lab = self.text(x, y, fmt(value), color=icolor, size=fontsize, **kwargs) labs.append(lab) - obj.set_edgecolors(edgecolors) return labs @@ -2487,16 +2543,18 @@ def _parse_cycle( resolved_cycle = None case True: resolved_cycle = constructor.Cycle(rc["axes.prop_cycle"]) - case constructor.Cycle(): - resolved_cycle = constructor.Cycle(cycle) case str() if cycle.lower() == "none": resolved_cycle = None - case str() | int() | Iterable(): + case str() | int(): resolved_cycle = constructor.Cycle(cycle, **cycle_kw) + case constructor.Cycle(): + resolved_cycle = constructor.Cycle(cycle) case _: resolved_cycle = None - # Ignore cycle for single-column plotting unless cycle is different + # Ignore cycle for single-column plotting + resolved_cycle = None if ncycle == 1 else resolved_cycle + if resolved_cycle and resolved_cycle != self._active_cycle: self.set_prop_cycle(resolved_cycle)