diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index c9dc61c8a7..9e4b8ff699 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1154,7 +1154,8 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: # axis / tick labels can be shown on interior shared axes if desired axis_obj = getattr(ax, f"{axis}axis") - visible_side = {"x": "bottom", "y": "left"}.get(axis) + # This allows correct handling for twin{x/y} axises (GH 3614) + visible_side = axis_obj.get_label_position() show_axis_label = ( sub[visible_side] or not p._pair_spec.get("cross", True) @@ -1172,8 +1173,9 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: ) ) for group in ("major", "minor"): - side = {"x": "bottom", "y": "left"}[axis] - axis_obj.set_tick_params(**{f"label{side}": show_tick_labels}) + axis_obj.set_tick_params( + **{f"label{visible_side}": show_tick_labels} + ) for t in getattr(axis_obj, f"get_{group}ticklabels")(): t.set_visible(show_tick_labels) diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index 5554ea650f..40f281cd71 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -1124,6 +1124,24 @@ def test_on_axes(self): p = Plot([1], [2]).on(ax).add(m).plot() assert m.passed_axes == [ax] assert p._figure is ax.figure + assert ax.yaxis.get_label_position() == "left" + assert ax.yaxis.get_ticks_position() == "left" + assert ax.xaxis.get_label_position() == "bottom" + assert ax.xaxis.get_ticks_position() == "bottom" + + @pytest.mark.parametrize("axis,exp_position", [("x", "top"), ("y", "right")]) + def test_on_secondary_axes(self, axis, exp_position): + + ax = mpl.figure.Figure().subplots() + twinned_ax = {"x": "y", "y": "x"}[axis] + ax2 = getattr(ax, f"twin{twinned_ax}")() + m = MockMark() + p = Plot([1], [2]).on(ax2).add(m).plot() + assert m.passed_axes == [ax2] + assert p._figure is ax2.figure + targetaxis = getattr(ax2, f"{axis}axis") + assert targetaxis.get_label_position() == exp_position + assert targetaxis.get_ticks_position() == exp_position @pytest.mark.parametrize("facet", [True, False]) def test_on_figure(self, facet):