Skip to content

Commit bd6dcf7

Browse files
Refactor (#9121)
1 parent 53757d0 commit bd6dcf7

File tree

3 files changed

+21
-27
lines changed

3 files changed

+21
-27
lines changed

mne/viz/_brain/_brain.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,7 @@ def __init__(self, subject_id, hemi, surf, title=None,
445445
self.silhouette = True
446446
else:
447447
self.silhouette = silhouette
448-
# for now only one color bar can be added
449-
# since it is the same for all figures
450-
self._colorbar_added = False
448+
self._scalar_bar = None
451449
# for now only one time label can be added
452450
# since it is the same for all figures
453451
self._time_label_added = False
@@ -692,14 +690,12 @@ def _clean(self):
692690
# Qt LeaveEvent requires _Iren so we use _FakeIren instead of None
693691
# to resolve the ref to vtkGenericRenderWindowInteractor
694692
self.plotter._Iren = _FakeIren()
695-
if getattr(self.plotter, 'scalar_bar', None) is not None:
696-
self.plotter.scalar_bar = None
697693
if getattr(self.plotter, 'picker', None) is not None:
698694
self.plotter.picker = None
699695
# XXX end PyVista
700696
for key in ('plotter', 'window', 'dock', 'tool_bar', 'menu_bar',
701697
'status_bar', 'interactor', 'mpl_canvas', 'time_actor',
702-
'picked_renderer', 'act_data_smooth',
698+
'picked_renderer', 'act_data_smooth', '_scalar_bar',
703699
'actions', 'widgets', 'geo', '_data'):
704700
setattr(self, key, None)
705701

@@ -824,12 +820,11 @@ def _configure_time_label(self):
824820
self.time_actor.GetTextProperty().BoldOn()
825821

826822
def _configure_scalar_bar(self):
827-
if self._colorbar_added:
828-
scalar_bar = self.plotter.scalar_bar
829-
scalar_bar.SetOrientationToVertical()
830-
scalar_bar.SetHeight(0.6)
831-
scalar_bar.SetWidth(0.05)
832-
scalar_bar.SetPosition(0.02, 0.2)
823+
if self._scalar_bar is not None:
824+
self._scalar_bar.SetOrientationToVertical()
825+
self._scalar_bar.SetHeight(0.6)
826+
self._scalar_bar.SetWidth(0.05)
827+
self._scalar_bar.SetPosition(0.02, 0.2)
833828

834829
def _configure_dock_time_widget(self, layout=None):
835830
len_time = len(self._data['time']) - 1
@@ -1947,12 +1942,11 @@ def add_data(self, array, fmin=None, fmid=None, fmax=None,
19471942
)
19481943
self._data['time_actor'] = time_actor
19491944
self._time_label_added = True
1950-
if colorbar and not self._colorbar_added and do:
1945+
if colorbar and self._scalar_bar is None and do:
19511946
kwargs = dict(source=actor, n_labels=8, color=self._fg_color,
19521947
bgcolor=self._brain_color[:3])
19531948
kwargs.update(colorbar_kwargs or {})
1954-
self._renderer.scalarbar(**kwargs)
1955-
self._colorbar_added = True
1949+
self._scalar_bar = self._renderer.scalarbar(**kwargs)
19561950
self._renderer.set_camera(**views_dicts[hemi][v])
19571951

19581952
# 4) update the scalar bar and opacity
@@ -2650,9 +2644,6 @@ def update_lut(self, fmin=None, fmid=None, fmax=None, alpha=None):
26502644
# update our values
26512645
rng = self._cmap_range
26522646
ctable = self._data['ctable']
2653-
# in testing, no plotter; if colorbar=False, no scalar_bar
2654-
scalar_bar = getattr(
2655-
getattr(self._renderer, 'plotter', None), 'scalar_bar', None)
26562647
for hemi in ['lh', 'rh', 'vol']:
26572648
hemi_data = self._data.get(hemi)
26582649
if hemi_data is not None:
@@ -2663,25 +2654,22 @@ def update_lut(self, fmin=None, fmid=None, fmax=None, alpha=None):
26632654
opacity=alpha,
26642655
rng=rng)
26652656
self._renderer._set_colormap_range(
2666-
mesh._actor, ctable, scalar_bar, rng,
2657+
mesh._actor, ctable, self._scalar_bar, rng,
26672658
self._brain_color)
2668-
scalar_bar = None
26692659

26702660
grid_volume_pos = hemi_data.get('grid_volume_pos')
26712661
grid_volume_neg = hemi_data.get('grid_volume_neg')
26722662
for grid_volume in (grid_volume_pos, grid_volume_neg):
26732663
if grid_volume is not None:
26742664
self._renderer._set_volume_range(
26752665
grid_volume, ctable, hemi_data['alpha'],
2676-
scalar_bar, rng)
2677-
scalar_bar = None
2666+
self._scalar_bar, rng)
26782667

26792668
glyph_actor = hemi_data.get('glyph_actor')
26802669
if glyph_actor is not None:
26812670
for glyph_actor_ in glyph_actor:
26822671
self._renderer._set_colormap_range(
2683-
glyph_actor_, ctable, scalar_bar, rng)
2684-
scalar_bar = None
2672+
glyph_actor_, ctable, self._scalar_bar, rng)
26852673
if self.time_viewer:
26862674
with self._no_lut_update(f'update_lut {args}'):
26872675
for key in ('fmin', 'fmid', 'fmax'):

mne/viz/_brain/tests/test_brain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def test_brain_traces(renderer_interactive_pyvista, hemi, src, tmpdir,
618618
assert brain.traces_mode == 'vertex'
619619
assert hasattr(brain, "picked_points")
620620
assert hasattr(brain, "_spheres")
621-
assert brain.plotter.scalar_bar.GetNumberOfLabels() == 3
621+
assert brain._scalar_bar.GetNumberOfLabels() == 3
622622

623623
# add foci should work for volumes
624624
brain.add_foci([[0, 0, 0]], hemi='lh' if src == 'surface' else 'vol')

mne/viz/backends/_pyvista.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,15 +607,21 @@ def text3d(self, x, y, z, text, scale, color='white'):
607607

608608
def scalarbar(self, source, color="white", title=None, n_labels=4,
609609
bgcolor=None, **extra_kwargs):
610+
if isinstance(source, vtk.vtkMapper):
611+
mapper = source
612+
elif isinstance(source, vtk.vtkActor):
613+
mapper = source.GetMapper()
614+
else:
615+
mapper = None
610616
with warnings.catch_warnings():
611617
warnings.filterwarnings("ignore", category=FutureWarning)
612618
kwargs = dict(color=color, title=title, n_labels=n_labels,
613619
use_opacity=False, n_colors=256, position_x=0.15,
614620
position_y=0.05, width=0.7, shadow=False, bold=True,
615621
label_font_size=22, font_family=self.font_family,
616-
background_color=bgcolor)
622+
background_color=bgcolor, mapper=mapper)
617623
kwargs.update(extra_kwargs)
618-
self.plotter.add_scalar_bar(**kwargs)
624+
return self.plotter.add_scalar_bar(**kwargs)
619625

620626
def show(self):
621627
self.plotter.show()

0 commit comments

Comments
 (0)