Skip to content

Commit

Permalink
fix and slightly refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
cvanelteren committed Feb 21, 2025
1 parent f21a814 commit 542603e
Showing 1 changed file with 71 additions and 13 deletions.
84 changes: 71 additions & 13 deletions ultraplot/axes/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import re
import sys
from numbers import Integral, Number
from typing import Any
from collections.abc import Iterable
from typing import Any, Iterable

import matplotlib.artist as martist
import matplotlib.axes as maxes
Expand Down Expand Up @@ -1388,12 +1387,70 @@ def _add_auto_labels(
formatter = constructor.Formatter(
formatter, precision=precision, **formatter_kw
) # noqa: E501
if isinstance(obj, mcontour.ContourSet):
self._add_contour_labels(obj, cobj, formatter, **labels_kw)
elif isinstance(obj, mcollections.Collection):
self._add_collection_labels(obj, formatter, **labels_kw)
else:
raise RuntimeError(f"Not possible to add labels to object {obj!r}.")
match obj:
case mcontour.ContourSet():
self._add_contour_labels(obj, cobj, formatter, **labels_kw)
case mcollections.QuadMesh():
self._add_quadmesh_labels(obj, formatter, **labels_kw)
case mcollections.Collection():
self._add_collection_labels(obj, formatter, **labels_kw)
case _:
raise RuntimeError(f"Not possible to add labels to object {obj!r}.")


def _add_quadmesh_labels(
self,
obj,
fmt,
*,
c=None,
color=None,
colors=None,
size=None,
fontsize=None,
**kwargs,
):
"""
Add labels to QuadMesh cells with support for shade-dependent text colors.
Values are inferred from the unnormalized mesh cell color.
"""
# Parse input args
obj.update_scalarmappable()
color = _not_none(c=c, color=color, colors=colors)
fontsize = _not_none(size=size, fontsize=fontsize, default=rc["font.smallsize"])
kwargs.setdefault("ha", "center")
kwargs.setdefault("va", "center")

# Get the mesh data
array = obj.get_array()
coords = obj.get_coordinates() # This gives vertices (11x11x2)

# Calculate cell centers by averaging the four corners of each cell
x_centers = (coords[:-1, :-1, 0] + coords[1:, 1:, 0]) / 2
y_centers = (coords[:-1, :-1, 1] + coords[1:, 1:, 1]) / 2

# Apply colors and create labels
labs = []
for i, ((x, y), value) in enumerate(zip(zip(x_centers.flat, y_centers.flat), array.flat)):
# Skip masked or invalid values
if value is ma.masked or not np.isfinite(value):
continue

# Handle discrete normalization if present
if isinstance(obj.norm, pcolors.DiscreteNorm):
value = obj.norm._norm.inverse(obj.norm(value))

# Determine text color based on background
icolor = color
if color is None:
_, _, lum = utils.to_xyz(obj.cmap(obj.norm(value)), "hcl")
icolor = "w" if lum < 50 else "k"

# Create text label
lab = self.text(x, y, fmt(value), color=icolor, size=fontsize, **kwargs)
labs.append(lab)

return labs

def _add_collection_labels(
self,
Expand Down Expand Up @@ -1444,7 +1501,6 @@ def _add_collection_labels(
y = (bbox.ymin + bbox.ymax) / 2
lab = self.text(x, y, fmt(value), color=icolor, size=fontsize, **kwargs)
labs.append(lab)

obj.set_edgecolors(edgecolors)
return labs

Expand Down Expand Up @@ -2487,16 +2543,18 @@ def _parse_cycle(
resolved_cycle = None
case True:
resolved_cycle = constructor.Cycle(rc["axes.prop_cycle"])
case constructor.Cycle():
resolved_cycle = constructor.Cycle(cycle)
case str() if cycle.lower() == "none":
resolved_cycle = None
case str() | int() | Iterable():
case str() | int():
resolved_cycle = constructor.Cycle(cycle, **cycle_kw)
case constructor.Cycle():
resolved_cycle = constructor.Cycle(cycle)
case _:
resolved_cycle = None

# Ignore cycle for single-column plotting unless cycle is different
# Ignore cycle for single-column plotting
resolved_cycle = None if ncycle == 1 else resolved_cycle

if resolved_cycle and resolved_cycle != self._active_cycle:
self.set_prop_cycle(resolved_cycle)

Expand Down

0 comments on commit 542603e

Please sign in to comment.