Skip to content

Commit

Permalink
Merge pull request #757 from pfebrer/plots_update
Browse files Browse the repository at this point in the history
Plot update on getting attributes
  • Loading branch information
zerothi authored Apr 23, 2024
2 parents 0b964fe + 499feb6 commit 8975066
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 13 deletions.
11 changes: 11 additions & 0 deletions src/sisl/viz/figure/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ def _build(self, plot_actions, *args, **kwargs):

return fig

@classmethod
def fig_has_attr(cls, key: str) -> bool:
"""Whether the figure that this class generates has a given attribute.
Parameters
-----------
key
the attribute to check for.
"""
return False

@staticmethod
def _sanitize_plot_actions(plot_actions):
def _flatten(plot_actions, out, level=0, root_i=0):
Expand Down
9 changes: 8 additions & 1 deletion src/sisl/viz/figure/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,16 @@ def _iter_multiaxis(self, plot_actions):

yield sanitized_section_actions

@classmethod
def fig_has_attr(cls, key: str) -> bool:
return hasattr(plt.Axes, key) or hasattr(plt.Figure, key)

def __getattr__(self, key):
if key != "axes":
return getattr(self.axes, key)
if hasattr(self.axes, key):
return getattr(self.axes, key)
elif key != "figure" and hasattr(self.figure, key):
return getattr(self.figure, key)
raise AttributeError(key)

def clear(self, layout=False):
Expand Down
5 changes: 5 additions & 0 deletions src/sisl/viz/figure/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ def _iter_animation(self, plot_actions):

self.update_layout(sliders=[slider], updatemenus=updatemenus)

@classmethod
def fig_has_attr(cls, key: str) -> bool:
print(key, hasattr(go.Figure, key))
return hasattr(go.Figure, key)

def __getattr__(self, key):
if key != "figure":
return getattr(self.figure, key)
Expand Down
30 changes: 18 additions & 12 deletions src/sisl/viz/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,31 @@
from sisl.messages import deprecate
from sisl.nodes import Workflow

from .figure import BACKENDS


class Plot(Workflow):
"""Base class for all plots"""

def __getattr__(self, key):
if key != "nodes":
# If an ipython key is requested, get the plot and look
# for the key in the plot. This is simply to enhance
# interactivity in a python notebook environment.
# However, this results in a (maybe undesired) behavior:
# The plot is updated when ipython requests it, without any
# explicit request to update it. This is how it has worked
# from the beggining, so it's probably best to keep it like
# this for now.
if "ipython" in key:
output = self.nodes.output.get()
# From the backend input, we find out which class is the figure going to be
# (even if no figure has been created yet or the latest figure was from a different backend)
# Then we check if the attribute will be available there. If it will, we update the plot and
# get the attribute on the updated plot.
# This is so that things like `plot.show()` work as expected.
# It has the downside that `.get()` is called even when for example a method of the figure is
# retreived to get its docs (e.g. in the helper messages of jupyter notebooks)
selected_backend = self.inputs.get("backend")
figure_cls = BACKENDS.get(selected_backend)
if figure_cls is not None and (
hasattr(figure_cls, key) or figure_cls.fig_has_attr(key)
):
return getattr(self.nodes.output.get(), key)
else:
output = self.nodes.output._output
return getattr(output, key)
raise AttributeError(
f"'{key}' not found in {self.__class__.__name__} with backend '{selected_backend}'"
)
else:
return super().__getattr__(key)

Expand Down

0 comments on commit 8975066

Please sign in to comment.