From f41169fbd43e13a3314a350b2ad675052e4d43ad Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Mar 2016 12:24:26 +0100 Subject: [PATCH 0001/1059] Revert to branch point of master..kill-kwik --- .gitignore | 2 +- Makefile | 16 +++++++++++++++- README.md | 7 ++++--- phy/plot/features.py | 2 +- phy/plot/traces.py | 4 ++-- phy/plot/waveforms.py | 4 ++-- phy/traces/detect.py | 23 ++--------------------- phy/traces/tests/test_detect.py | 28 ---------------------------- phy/utils/_color.py | 7 ------- setup.cfg | 1 + setup.py | 25 +++++++++++++++++++++++++ 11 files changed, 53 insertions(+), 66 deletions(-) diff --git a/.gitignore b/.gitignore index b7aba4db4..4da4f02fe 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ wiki *.orig .eggs __pycache__ - +_old *.py[cod] .coverage* *credentials diff --git a/Makefile b/Makefile index fb4d079dd..2a1349eba 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ lint: flake8 phy test: lint - py.test + python setup.py test coverage: coverage --html @@ -38,3 +38,17 @@ unit-tests: lint integration-tests: lint python setup.py test -a tests +apidoc: + python tools/api.py + +build: + python setup.py sdist --formats=zip + +upload: + python setup.py sdist --formats=zip upload + +release-test: + python tools/release.py release_test + +release: + python tools/release.py release diff --git a/README.md b/README.md index 85482706f..8ec57faa0 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ # phy project -[![Build Status](https://travis-ci.org/kwikteam/phy.svg?branch=master)](https://travis-ci.org/kwikteam/phy) -[![Build Status](https://ci.appveyor.com/api/projects/status/fuoyuo113domjplr/branch/master?svg=true)](https://ci.appveyor.com/project/kwikteam/phy/) -[![codecov.io](https://img.shields.io/codecov/c/github/kwikteam/phy.svg?)](http://codecov.io/github/kwikteam/phy?branch=master) +[![Build Status](https://img.shields.io/travis/kwikteam/phy.svg)](https://travis-ci.org/kwikteam/phy) +[![Build Status](https://img.shields.io/appveyor/ci/kwikteam/phy.svg)](https://ci.appveyor.com/project/kwikteam/phy/) +[![codecov.io](https://img.shields.io/codecov/c/github/kwikteam/phy.svg)](http://codecov.io/github/kwikteam/phy?branch=master) [![Documentation Status](https://readthedocs.org/projects/phy/badge/?version=latest)](https://readthedocs.org/projects/phy/?badge=latest) [![PyPI release](https://img.shields.io/pypi/v/phy.svg)](https://pypi.python.org/pypi/phy) +[![GitHub release](https://img.shields.io/github/release/kwikteam/phy.svg)](https://github.com/kwikteam/phy/releases/latest) [![Join the chat at https://gitter.im/kwikteam/phy](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/kwikteam/phy?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [**phy**](https://github.com/kwikteam/phy) is an open source electrophysiological data analysis package in Python for neuronal recordings made with high-density multielectrode arrays containing up to thousands of channels. diff --git a/phy/plot/features.py b/phy/plot/features.py index 230d615fc..34c76b1a7 100644 --- a/phy/plot/features.py +++ b/phy/plot/features.py @@ -618,7 +618,7 @@ def on_key_press(self, event): """Handle key press events.""" coeff = .25 if 'Alt' in event.modifiers: - if event.key == '+' or event.key == '=': + if event.key == '+': self.marker_size += coeff if event.key == '-': self.marker_size -= coeff diff --git a/phy/plot/traces.py b/phy/plot/traces.py index 32c2d3d47..889d141a4 100644 --- a/phy/plot/traces.py +++ b/phy/plot/traces.py @@ -288,12 +288,12 @@ def on_key_press(self, event): ctrl = 'Control' in event.modifiers # Box scale. - if ctrl and key in ('+', '-', '='): + if ctrl and key in ('+', '-'): coeff = 1.1 u = self.channel_scale if key == '-': self.channel_scale = u / coeff - elif key == '+' or key == '=': + elif key == '+': self.channel_scale = u * coeff diff --git a/phy/plot/waveforms.py b/phy/plot/waveforms.py index 1b937b402..de9b5665a 100644 --- a/phy/plot/waveforms.py +++ b/phy/plot/waveforms.py @@ -241,7 +241,7 @@ class WaveformView(BaseSpikeCanvas): """A VisPy canvas displaying waveforms.""" _visual_class = WaveformVisual _arrows = ('Left', 'Right', 'Up', 'Down') - _pm = ('+', '-', '=') + _pm = ('+', '-') _events = ('channel_click',) _key_pressed = None _show_mean = False @@ -417,7 +417,7 @@ def on_key_press(self, event): self.box_scale = (u * coeff, v) elif key in ('Down', '-'): self.box_scale = (u, v / coeff) - elif key in ('Up', '+', '='): + elif key in ('Up', '+'): self.box_scale = (u, v * coeff) # Probe scale. diff --git a/phy/traces/detect.py b/phy/traces/detect.py index 3c90361f4..44af233b0 100644 --- a/phy/traces/detect.py +++ b/phy/traces/detect.py @@ -155,8 +155,7 @@ def _to_list(x): def connected_components(weak_crossings=None, strong_crossings=None, probe_adjacency_list=None, - join_size=None, - channels=None): + join_size=None): """Find all connected components in binary arrays of threshold crossings. Parameters @@ -168,13 +167,10 @@ def connected_components(weak_crossings=None, `(n_samples, n_channels)` array with strong threshold crossings probe_adjacency_list : dict A dict `{channel: [neighbors]}` - channels : array - An (n_channels,) array with a list of all non-dead channels join_size : int The number of samples defining the tolerance in time for finding connected components - Returns ------- @@ -194,15 +190,6 @@ def connected_components(weak_crossings=None, if probe_adjacency_list is None: probe_adjacency_list = {} - if channels is None: - channels = [] - - # If the channels aren't referenced at all but exist in 'channels', add a - # trivial self-connection so temporal floodfill will work. If this channel - # is dead, it should be removed from 'channels'. - probe_adjacency_list.update({i: {i} for i in channels - if not probe_adjacency_list.get(i)}) - # Make sure the values are sets. probe_adjacency_list = {c: set(cs) for c, cs in probe_adjacency_list.items()} @@ -350,23 +337,17 @@ class FloodFillDetector(object): for every sample in the component. """ - def __init__(self, probe_adjacency_list=None, join_size=None, - channels_per_group=None): + def __init__(self, probe_adjacency_list=None, join_size=None): self._adjacency_list = probe_adjacency_list self._join_size = join_size - self._channels_per_group = channels_per_group def __call__(self, weak_crossings=None, strong_crossings=None): weak_crossings = _as_array(weak_crossings, np.bool) strong_crossings = _as_array(strong_crossings, np.bool) - all_channels = sorted([item for sublist - in self._channels_per_group.values() - for item in sublist]) cc = connected_components(weak_crossings=weak_crossings, strong_crossings=strong_crossings, probe_adjacency_list=self._adjacency_list, - channels=all_channels, join_size=self._join_size, ) # cc is a list of list of pairs (sample, channel) diff --git a/phy/traces/tests/test_detect.py b/phy/traces/tests/test_detect.py index fc511d59f..10c4780eb 100644 --- a/phy/traces/tests/test_detect.py +++ b/phy/traces/tests/test_detect.py @@ -251,11 +251,8 @@ def test_flood_fill(): graph = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: []} - channels = {0: [1, 2, 3]} - ff = FloodFillDetector(probe_adjacency_list=graph, join_size=1, - channels_per_group=channels ) weak = [[0, 0, 0, 0], @@ -294,28 +291,3 @@ def test_flood_fill(): ] cc = ff(weak, strong) _assert_components_equal(cc, comps) - - channels = {0: [1, 2, 3, 4]} - - ff = FloodFillDetector(probe_adjacency_list=graph, - join_size=2, - channels_per_group=channels - ) - - weak = [[0, 0, 0, 0, 0], - [0, 1, 0, 1, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 0], - ] - - comps = [[(1, 1)], - [(1, 3)], - [(4, 4), (6, 4)], - ] - - cc = ff(weak, weak) - _assert_components_equal(cc, comps) diff --git a/phy/utils/_color.py b/phy/utils/_color.py index ae0fcb24b..6edc3cc0a 100644 --- a/phy/utils/_color.py +++ b/phy/utils/_color.py @@ -49,13 +49,6 @@ def _random_bright_color(): [228, 31, 228], [2, 217, 2], [255, 147, 2], - [212, 150, 70], - [205, 131, 201], - [201, 172, 36], - [150, 179, 62], - [95, 188, 122], - [129, 173, 190], - [231, 107, 119], ]) diff --git a/setup.cfg b/setup.cfg index e53fb1af0..42eef79ea 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,5 +3,6 @@ universal = 1 [pytest] addopts = --cov-report term-missing --cov phy -s + [flake8] ignore=E265 diff --git a/setup.py b/setup.py index 7b03d24c0..35b0c26c3 100644 --- a/setup.py +++ b/setup.py @@ -10,15 +10,38 @@ import os import os.path as op +import sys import re from setuptools import setup +from setuptools.command.test import test as TestCommand #------------------------------------------------------------------------------ # Setup #------------------------------------------------------------------------------ +class PyTest(TestCommand): + user_options = [('pytest-args=', 'a', "String of arguments to pass to py.test")] + + def initialize_options(self): + TestCommand.initialize_options(self) + self.pytest_args = '--cov-report term-missing --cov=phy phy tests' + + def finalize_options(self): + TestCommand.finalize_options(self) + self.test_args = [] + self.test_suite = True + + def run_tests(self): + #import here, cause outside the eggs aren't loaded + import pytest + pytest_string = '-s ' + self.pytest_args + print("Running: py.test " + pytest_string) + errno = pytest.main(pytest_string) + sys.exit(errno) + + def _package_tree(pkgroot): path = op.dirname(__file__) subdirs = [op.relpath(i[0], path).replace(op.sep, '.') @@ -58,6 +81,7 @@ def _package_tree(pkgroot): ], }, include_package_data=True, + # zip_safe=False, keywords='phy,data analysis,electrophysiology,neuroscience', classifiers=[ 'Development Status :: 4 - Beta', @@ -70,4 +94,5 @@ def _package_tree(pkgroot): 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', ], + cmdclass={'test': PyTest}, ) From 96eceb235165b33ebeee6eb4c95ebab1800ef988 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 09:55:47 +0200 Subject: [PATCH 0002/1059] WIP: remove old files --- phy/cluster/manual/gui.py | 619 --------- phy/cluster/manual/tests/test_gui.py | 227 ---- phy/cluster/manual/tests/test_view_models.py | 238 ---- phy/cluster/manual/view_models.py | 1098 ---------------- phy/detect/__init__.py | 0 phy/detect/default_settings.py | 44 - phy/detect/spikedetekt.py | 610 --------- phy/detect/store.py | 208 --- phy/detect/tests/__init__.py | 0 phy/detect/tests/test_spikedetekt.py | 157 --- phy/detect/tests/test_store.py | 133 -- phy/io/base.py | 519 -------- phy/io/kwik/__init__.py | 12 - phy/io/kwik/creator.py | 463 ------- phy/io/kwik/mock.py | 133 -- phy/io/kwik/model.py | 1243 ------------------ phy/io/kwik/sparse_kk2.py | 75 -- phy/io/kwik/store_items.py | 711 ---------- phy/io/kwik/tests/__init__.py | 0 phy/io/kwik/tests/test_creator.py | 217 --- phy/io/kwik/tests/test_mock.py | 34 - phy/io/kwik/tests/test_model.py | 384 ------ phy/io/kwik/tests/test_sparse_kk2.py | 89 -- phy/io/kwik/tests/test_store_items.py | 167 --- phy/io/sparse.py | 155 --- phy/io/store.py | 752 ----------- phy/io/tests/test_base.py | 85 -- phy/io/tests/test_sparse.py | 122 -- phy/io/tests/test_store.py | 423 ------ phy/scripts/__init__.py | 11 - phy/scripts/phy_script.py | 479 ------- phy/scripts/tests/__init__.py | 0 phy/scripts/tests/test_phy_script.py | 42 - phy/session/__init__.py | 11 - phy/session/default_settings.py | 35 - phy/session/session.py | 468 ------- phy/session/tests/__init__.py | 0 phy/session/tests/test_session.py | 401 ------ 38 files changed, 10365 deletions(-) delete mode 100644 phy/cluster/manual/gui.py delete mode 100644 phy/cluster/manual/tests/test_gui.py delete mode 100644 phy/cluster/manual/tests/test_view_models.py delete mode 100644 phy/cluster/manual/view_models.py delete mode 100644 phy/detect/__init__.py delete mode 100644 phy/detect/default_settings.py delete mode 100644 phy/detect/spikedetekt.py delete mode 100644 phy/detect/store.py delete mode 100644 phy/detect/tests/__init__.py delete mode 100644 phy/detect/tests/test_spikedetekt.py delete mode 100644 phy/detect/tests/test_store.py delete mode 100644 phy/io/base.py delete mode 100644 phy/io/kwik/__init__.py delete mode 100644 phy/io/kwik/creator.py delete mode 100644 phy/io/kwik/mock.py delete mode 100644 phy/io/kwik/model.py delete mode 100644 phy/io/kwik/sparse_kk2.py delete mode 100644 phy/io/kwik/store_items.py delete mode 100644 phy/io/kwik/tests/__init__.py delete mode 100644 phy/io/kwik/tests/test_creator.py delete mode 100644 phy/io/kwik/tests/test_mock.py delete mode 100644 phy/io/kwik/tests/test_model.py delete mode 100644 phy/io/kwik/tests/test_sparse_kk2.py delete mode 100644 phy/io/kwik/tests/test_store_items.py delete mode 100644 phy/io/sparse.py delete mode 100644 phy/io/store.py delete mode 100644 phy/io/tests/test_base.py delete mode 100644 phy/io/tests/test_sparse.py delete mode 100644 phy/io/tests/test_store.py delete mode 100644 phy/scripts/__init__.py delete mode 100644 phy/scripts/phy_script.py delete mode 100644 phy/scripts/tests/__init__.py delete mode 100644 phy/scripts/tests/test_phy_script.py delete mode 100644 phy/session/__init__.py delete mode 100644 phy/session/default_settings.py delete mode 100644 phy/session/session.py delete mode 100644 phy/session/tests/__init__.py delete mode 100644 phy/session/tests/test_session.py diff --git a/phy/cluster/manual/gui.py b/phy/cluster/manual/gui.py deleted file mode 100644 index 385ca4978..000000000 --- a/phy/cluster/manual/gui.py +++ /dev/null @@ -1,619 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""GUI creator.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -import phy -from ...gui.base import BaseGUI -from ...gui.qt import _prompt -from .view_models import (WaveformViewModel, - FeatureGridViewModel, - FeatureViewModel, - CorrelogramViewModel, - TraceViewModel, - StatsViewModel, - ) -from ...utils.logging import debug, info, warn -from ...io.kwik.model import cluster_group_id -from ._history import GlobalHistory -from ._utils import ClusterMetadataUpdater -from .clustering import Clustering -from .wizard import Wizard, WizardViewModel - - -#------------------------------------------------------------------------------ -# Manual clustering window -#------------------------------------------------------------------------------ - -def _check_list_argument(arg, name='clusters'): - if not isinstance(arg, (list, tuple, np.ndarray)): - raise ValueError("The argument should be a list or an array.") - if len(name) == 0: - raise ValueError("No {0} were selected.".format(name)) - - -def _to_wizard_group(group_id): - """Return the group name required by the wizard, as a function - of the Kwik cluster group.""" - if hasattr(group_id, '__len__'): - group_id = group_id[0] - return { - 0: 'ignored', - 1: 'ignored', - 2: 'good', - 3: None, - None: None, - }.get(group_id, 'good') - - -def _process_ups(ups): - """This function processes the UpdateInfo instances of the two - undo stacks (clustering and cluster metadata) and concatenates them - into a single UpdateInfo instance.""" - if len(ups) == 0: - return - elif len(ups) == 1: - return ups[0] - elif len(ups) == 2: - up = ups[0] - up.update(ups[1]) - return up - else: - raise NotImplementedError() - - -class ClusterManualGUI(BaseGUI): - """Manual clustering GUI. - - This object represents a main window with: - - * multiple views - * high-level clustering methods - * global keyboard shortcuts - - Events - ------ - - cluster - select - request_save - - """ - - _vm_classes = { - 'waveforms': WaveformViewModel, - 'features': FeatureViewModel, - 'features_grid': FeatureGridViewModel, - 'correlograms': CorrelogramViewModel, - 'traces': TraceViewModel, - 'wizard': WizardViewModel, - 'stats': StatsViewModel, - } - - def __init__(self, model=None, store=None, cluster_ids=None, **kwargs): - self.store = store - self.wizard = Wizard() - self._is_dirty = False - self._cluster_ids = cluster_ids - super(ClusterManualGUI, self).__init__(model=model, - vm_classes=self._vm_classes, - **kwargs) - - def _initialize_views(self): - # The wizard needs to be started *before* the views are created, - # so that the first cluster selection is already set for the views - # when they're created. - self.connect(self._connect_view, event='add_view') - - @self.main_window.connect_ - def on_close_gui(): - # Return False if the close event should be discarded, so that - # the close Qt event is discarded. - return self._prompt_save() - - self.on_open() - self.wizard.start() - if self._cluster_ids is None: - self._cluster_ids = self.wizard.selection - super(ClusterManualGUI, self)._initialize_views() - - # View methods - # --------------------------------------------------------------------- - - @property - def title(self): - """Title of the main window.""" - name = self.__class__.__name__ - filename = getattr(self.model, 'kwik_path', 'mock') - clustering = self.model.clustering - channel_group = self.model.channel_group - template = ("{filename} (shank {channel_group}, " - "{clustering} clustering) " - "- {name} - phy {version}") - return template.format(name=name, - version=phy.__version_git__, - filename=filename, - channel_group=channel_group, - clustering=clustering, - ) - - def _connect_view(self, view): - """Connect a view to the GUI's events (select and cluster).""" - @self.connect - def on_select(cluster_ids, auto_update=True): - view.select(cluster_ids, auto_update=auto_update) - - @self.connect - def on_cluster(up): - view.on_cluster(up) - - def _connect_store(self): - @self.connect - def on_cluster(up=None): - self.store.update_spikes_per_cluster(self.model.spikes_per_cluster) - # No need to delete the old clusters from the store, we can keep - # them for possible undo, and regularly clean up the store. - for item in self.store.items.values(): - item.on_cluster(up) - - def _set_default_view_connections(self): - """Set view connections.""" - - # Select feature dimension from waveform view. - @self.connect_views('waveforms', 'features') - def channel_click(waveforms, features): - - @waveforms.view.connect - def on_channel_click(e): - # The box id is set when the features grid view is to be - # updated. - if e.box_idx is not None: - return - dimension = (e.channel_idx, 0) - features.set_dimension(e.ax, dimension) - features.update() - - # Select feature grid dimension from waveform view. - @self.connect_views('waveforms', 'features_grid') - def channel_click_grid(waveforms, features_grid): - - @waveforms.view.connect - def on_channel_click(e): - # The box id is set when the features grid view is to be - # updated. - if e.box_idx is None: - return - if not (1 <= e.box_idx <= features_grid.n_rows - 1): - return - dimension = (e.channel_idx, 0) - box = (e.box_idx, e.box_idx) - features_grid.set_dimension(e.ax, box, dimension) - features_grid.update() - - # Enlarge feature subplot. - @self.connect_views('features_grid', 'features') - def enlarge(grid, features): - - @grid.view.connect - def on_enlarge(e): - features.set_dimension('x', e.x_dim, smart=False) - features.set_dimension('y', e.y_dim, smart=False) - features.update() - - def _view_model_kwargs(self, name): - kwargs = {'model': self.model, - 'store': self.store, - 'wizard': self.wizard, - 'cluster_ids': self._cluster_ids, - } - return kwargs - - # Creation methods - # --------------------------------------------------------------------- - - def _get_clusters(self, which): - # Move best/match/both to noise/mua/good. - return { - 'best': [self.wizard.best], - 'match': [self.wizard.match], - 'both': [self.wizard.best, self.wizard.match], - }[which] - - def _create_actions(self): - for action in ['reset_gui', - 'save', - 'exit', - 'show_shortcuts', - 'select', - # Wizard. - 'reset_wizard', - 'first', - 'last', - 'next', - 'previous', - 'pin', - 'unpin', - # Actions. - 'merge', - 'split', - 'undo', - 'redo', - # Views. - 'toggle_correlogram_normalization', - 'toggle_waveforms_mean', - 'toggle_waveforms_overlap', - 'show_features_time', - ]: - self._add_gui_shortcut(action) - - def _make_func(which, group): - """Return a function that moves best/match/both clusters to - a group.""" - - def func(): - clusters = self._get_clusters(which) - if None in clusters: - return - self.move(clusters, group) - - name = 'move_{}_to_{}'.format(which, group) - func.__name__ = name - setattr(self, name, func) - return name - - for which in ('best', 'match', 'both'): - for group in ('noise', 'mua', 'good'): - self._add_gui_shortcut(_make_func(which, group)) - - def _create_cluster_metadata(self): - self._cluster_metadata_updater = ClusterMetadataUpdater( - self.model.cluster_metadata) - - @self.connect - def on_cluster(up): - for cluster in up.metadata_changed: - group_0 = self._cluster_metadata_updater.group(cluster) - group_1 = self.model.cluster_metadata.group(cluster) - assert group_0 == group_1 - - def _create_clustering(self): - self.clustering = Clustering(self.model.spike_clusters) - - @self.connect - def on_cluster(up): - spc = self.clustering.spikes_per_cluster - self.model.update_spikes_per_cluster(spc) - - def _create_global_history(self): - self._global_history = GlobalHistory(process_ups=_process_ups) - - def _create_wizard(self): - - # Initialize the groups for the wizard. - def _group(cluster): - group_id = self._cluster_metadata_updater.group(cluster) - return _to_wizard_group(group_id) - - groups = {cluster: _group(cluster) - for cluster in self.clustering.cluster_ids} - self.wizard.cluster_groups = groups - - self.wizard.reset() - - # Set the similarity and quality functions for the wizard. - @self.wizard.set_similarity_function - def similarity(target, candidate): - """Compute the distance between the mean masked features.""" - - mu_0 = self.store.mean_features(target).ravel() - mu_1 = self.store.mean_features(candidate).ravel() - - omeg_0 = self.store.mean_masks(target) - omeg_1 = self.store.mean_masks(candidate) - - omeg_0 = np.repeat(omeg_0, self.model.n_features_per_channel) - omeg_1 = np.repeat(omeg_1, self.model.n_features_per_channel) - - d_0 = mu_0 * omeg_0 - d_1 = mu_1 * omeg_1 - - # WARNING: "-" because this is a distance, not a score. - return -np.linalg.norm(d_0 - d_1) - - @self.wizard.set_quality_function - def quality(cluster): - """Return the maximum mean_masks across all channels - for a given cluster.""" - return self.store.mean_masks(cluster).max() - - @self.connect - def on_cluster(up): - # HACK: get the current group as it is not available in `up` - # currently. - if up.description.startswith('metadata'): - up = up.copy() - cluster = up.metadata_changed[0] - group = self.model.cluster_metadata.group(cluster) - up.metadata_value = _to_wizard_group(group) - - # This called for both regular and history actions. - # Save the wizard selection and update the wizard. - self.wizard.on_cluster(up) - - # Update the wizard selection after a clustering action. - self._wizard_select_after_clustering(up) - - def _wizard_select_after_clustering(self, up): - # Make as few updates as possible in the views after clustering - # actions. This allows for better before/after comparisons. - if up.added: - self.select(up.added, auto_update=False) - elif up.description == 'metadata_group': - # Select the last selected clusters after undo/redo. - if up.history and up.selection: - self.select(up.selection, auto_update=False) - return - cluster = up.metadata_changed[0] - if cluster == self.wizard.best: - self.wizard.next_best() - elif cluster == self.wizard.match: - self.wizard.next_match() - self._wizard_select() - elif up.selection: - self.select(up.selection, auto_update=False) - - # Open data - # ------------------------------------------------------------------------- - - def on_open(self): - """Reinitialize the GUI after new data has been loaded.""" - self._create_global_history() - # This connects the callback that updates the model spikes_per_cluster. - self._create_clustering() - self._create_cluster_metadata() - # This connects the callback that updates the store. - self._connect_store() - self._create_wizard() - self._is_dirty = False - - @self.connect - def on_cluster(up): - self._is_dirty = True - - def save(self): - """Save the changes.""" - # The session saves the model when this event is emitted. - self.emit('request_save') - self._is_dirty = False - - def _prompt_save(self): - """Display a prompt for saving and return whether the GUI should be - closed.""" - if (self.settings.get('prompt_save_on_exit', False) and - self._is_dirty): - res = _prompt(self.main_window, - "Do you want to save your changes?", - ('save', 'cancel', 'close')) - if res == 'save': - self.save() - return True - elif res == 'cancel': - return False - elif res == 'close': - return True - return True - - # General actions - # --------------------------------------------------------------------- - - def start(self): - """Start the wizard.""" - self.wizard.start() - self._cluster_ids = self.wizard.selection - - @property - def cluster_ids(self): - """Array of all cluster ids used in the current clustering.""" - return self.clustering.cluster_ids - - @property - def n_clusters(self): - """Number of clusters in the current clustering.""" - return self.clustering.n_clusters - - # View-related actions - # --------------------------------------------------------------------- - - def toggle_correlogram_normalization(self): - """Toggle CCG normalization in the correlograms views.""" - for vm in self.get_views('correlograms'): - vm.toggle_normalization() - - def toggle_waveforms_mean(self): - """Toggle mean mode in the waveform views.""" - for vm in self.get_views('waveforms'): - vm.show_mean = not(vm.show_mean) - - def toggle_waveforms_overlap(self): - """Toggle cluster overlap in the waveform views.""" - for vm in self.get_views('waveforms'): - vm.overlap = not(vm.overlap) - - def show_features_time(self): - """Set the x dimension to time in all feature views.""" - for vm in self.get_views('features'): - vm.set_dimension('x', 'time') - vm.update() - - # Selection - # --------------------------------------------------------------------- - - def select(self, cluster_ids, **kwargs): - """Select clusters.""" - cluster_ids = list(cluster_ids) - assert len(cluster_ids) == len(set(cluster_ids)) - # Do not re-select an already-selected list of clusters. - if cluster_ids == self._cluster_ids: - return - if not set(cluster_ids) <= set(self.clustering.cluster_ids): - n_selected = len(cluster_ids) - cluster_ids = [cl for cl in cluster_ids - if cl in self.clustering.cluster_ids] - n_kept = len(cluster_ids) - warn("{} of the {} selected clusters do not exist.".format( - n_selected - n_kept, n_selected)) - if len(cluster_ids) >= 14: - warn("You cannot select more than 13 clusters in the GUI.") - return - debug("Select clusters {0:s}.".format(str(cluster_ids))) - self._cluster_ids = cluster_ids - self.emit('select', cluster_ids, **kwargs) - - @property - def selected_clusters(self): - """The list of selected clusters.""" - return self._cluster_ids - - # Wizard list - # --------------------------------------------------------------------- - - def _wizard_select(self, **kwargs): - self.select(self.wizard.selection, **kwargs) - - def reset_wizard(self): - """Restart the wizard.""" - self.wizard.start() - self._wizard_select() - - def first(self): - """Go to the first cluster proposed by the wizard.""" - self.wizard.first() - self._wizard_select() - - def last(self): - """Go to the last cluster proposed by the wizard.""" - self.wizard.last() - self._wizard_select() - - def next(self): - """Go to the next cluster proposed by the wizard.""" - self.wizard.next() - self._wizard_select() - - def previous(self): - """Go to the previous cluster proposed by the wizard.""" - self.wizard.previous() - self._wizard_select() - - def pin(self): - """Pin the current best cluster.""" - cluster = (self.selected_clusters[0] - if len(self.selected_clusters) else None) - self.wizard.pin(cluster) - self._wizard_select() - - def unpin(self): - """Unpin the current best cluster.""" - self.wizard.unpin() - self._wizard_select() - - # Cluster actions - # --------------------------------------------------------------------- - - def merge(self, clusters=None): - """Merge some clusters.""" - if clusters is None: - clusters = self.selected_clusters - clusters = list(clusters) - if len(clusters) <= 1: - return - up = self.clustering.merge(clusters) - up.selection = self.selected_clusters - info("Merge clusters {} to {}.".format(str(clusters), - str(up.added[0]))) - self._global_history.action(self.clustering) - self.emit('cluster', up=up) - return up - - def _spikes_to_split(self): - """Find the spikes lasso selected in a feature view for split.""" - for features in self.get_views('features', 'features_grid'): - spikes = features.spikes_in_lasso() - if spikes is not None: - features.lasso.clear() - return spikes - - def split(self, spikes=None): - """Make a new cluster out of some spikes. - - Notes - ----- - - Spikes belonging to affected clusters, but not part of the `spikes` - array, will move to brand new cluster ids. This is because a new - cluster id must be used as soon as a cluster changes. - - """ - if spikes is None: - spikes = self._spikes_to_split() - if spikes is None: - return - if not len(spikes): - info("No spikes to split.") - return - _check_list_argument(spikes, 'spikes') - info("Split {0:d} spikes.".format(len(spikes))) - up = self.clustering.split(spikes) - up.selection = self.selected_clusters - self._global_history.action(self.clustering) - self.emit('cluster', up=up) - return up - - def move(self, clusters, group): - """Move some clusters to a cluster group. - - Here is the list of cluster groups: - - * 0=Noise - * 1=MUA - * 2=Good - * 3=Unsorted - - """ - _check_list_argument(clusters) - info("Move clusters {0} to {1}.".format(str(clusters), group)) - group_id = cluster_group_id(group) - up = self._cluster_metadata_updater.set_group(clusters, group_id) - up.selection = self.selected_clusters - self._global_history.action(self._cluster_metadata_updater) - self.emit('cluster', up=up) - return up - - def _undo_redo(self, up): - if up: - info("{} {}.".format(up.history.title(), - up.description, - )) - self.emit('cluster', up=up) - - def undo(self): - """Undo the last clustering action.""" - up = self._global_history.undo() - self._undo_redo(up) - return up - - def redo(self): - """Redo the last undone action.""" - # debug("The saved selection before the undo is {}.".format(clusters)) - up = self._global_history.redo() - if up: - up.selection = self.selected_clusters - self._undo_redo(up) - return up diff --git a/phy/cluster/manual/tests/test_gui.py b/phy/cluster/manual/tests/test_gui.py deleted file mode 100644 index 31e5d6abe..000000000 --- a/phy/cluster/manual/tests/test_gui.py +++ /dev/null @@ -1,227 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of manual clustering GUI.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark -import numpy as np -from numpy.testing import assert_allclose as ac -from numpy.testing import assert_array_equal as ae - -from ..gui import ClusterManualGUI -from ....utils.settings import _load_default_settings -from ....utils.logging import set_level -from ....utils.array import _spikes_in_clusters -from ....io.mock import MockModel -from ....io.kwik.store_items import create_store - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Kwik tests -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -def _start_manual_clustering(config='none', shortcuts=None): - if config is 'none': - config = [] - model = MockModel() - spc = model.spikes_per_cluster - store = create_store(model, spikes_per_cluster=spc) - gui = ClusterManualGUI(model=model, store=store, - config=config, shortcuts=shortcuts) - return gui - - -def test_gui_clustering(qtbot): - - gui = _start_manual_clustering() - gui.show() - qtbot.addWidget(gui.main_window) - - cs = gui.store - spike_clusters = gui.model.spike_clusters.copy() - - f = gui.model.features - m = gui.model.masks - - def _check_arrays(cluster, clusters_for_sc=None, spikes=None): - """Check the features and masks in the cluster store - of a given custer.""" - if spikes is None: - if clusters_for_sc is None: - clusters_for_sc = [cluster] - spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc) - shape = (len(spikes), - len(gui.model.channel_order), - gui.model.n_features_per_channel) - ac(cs.features(cluster), f[spikes, :].reshape(shape)) - ac(cs.masks(cluster), m[spikes]) - - _check_arrays(0) - _check_arrays(2) - - # Merge two clusters. - clusters = [0, 2] - up = gui.merge(clusters) - new = up.added[0] - _check_arrays(new, clusters) - - # Split some spikes. - spikes = [2, 3, 5, 7, 11, 13] - # clusters = np.unique(spike_clusters[spikes]) - up = gui.split(spikes) - _check_arrays(new + 1, spikes=spikes) - - # Undo. - gui.undo() - _check_arrays(new, clusters) - - # Undo. - gui.undo() - _check_arrays(0) - _check_arrays(2) - - # Redo. - gui.redo() - _check_arrays(new, clusters) - - # Split some spikes. - spikes = [5, 7, 11, 13, 17, 19] - # clusters = np.unique(spike_clusters[spikes]) - gui.split(spikes) - _check_arrays(new + 1, spikes=spikes) - - # Test merge-undo-different-merge combo. - spc = gui.clustering.spikes_per_cluster.copy() - clusters = gui.cluster_ids[:3] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - # Undo. - gui.undo() - for cluster in clusters: - _check_arrays(cluster, spikes=spc[cluster]) - # Another merge. - clusters = gui.cluster_ids[1:5] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - - # Move a cluster to a group. - cluster = gui.cluster_ids[0] - gui.move([cluster], 2) - assert len(gui.store.mean_probe_position(cluster)) == 2 - - spike_clusters_new = gui.model.spike_clusters.copy() - # Check that the spike clusters have changed. - assert not np.all(spike_clusters_new == spike_clusters) - ac(gui.model.spike_clusters, gui.clustering.spike_clusters) - - gui.close() - - -def test_gui_move_wizard(qtbot): - gui = _start_manual_clustering() - qtbot.addWidget(gui.main_window) - gui.show() - - gui.next() - gui.pin() - gui.next() - best = gui.wizard.best - assert gui.selected_clusters[0] == best - match = gui.selected_clusters[1] - gui.move([gui.wizard.match], 'mua') - assert gui.selected_clusters[0] == best - assert gui.selected_clusters[1] != match - - gui.close() - - -def test_gui_wizard(qtbot): - gui = _start_manual_clustering() - n = gui.n_clusters - qtbot.addWidget(gui.main_window) - gui.show() - - clusters = np.arange(gui.n_clusters) - best_clusters = gui.wizard.best_clusters() - - # assert gui.wizard.best_clusters(1)[0] == best_clusters[0] - ae(np.unique(best_clusters), clusters) - assert len(gui.wizard.most_similar_clusters()) == n - 1 - - assert len(gui.wizard.most_similar_clusters(0, n_max=3)) == 3 - - clusters = gui.cluster_ids[:2] - up = gui.merge(clusters) - new = up.added[0] - assert np.all(np.in1d(gui.wizard.best_clusters(), - np.arange(clusters[-1] + 1, new + 1))) - assert np.all(np.in1d(gui.wizard.most_similar_clusters(new), - np.arange(clusters[-1] + 1, new))) - - gui.close() - - -@mark.long -def test_gui_history(qtbot): - - gui = _start_manual_clustering() - qtbot.addWidget(gui.main_window) - gui.show() - - gui.wizard.start() - - spikes = _spikes_in_clusters(gui.model.spike_clusters, - gui.wizard.selection) - gui.split(spikes[::3]) - gui.undo() - gui.wizard.next() - gui.redo() - gui.undo() - - for _ in range(10): - gui.merge(gui.wizard.selection) - gui.wizard.next() - gui.undo() - gui.wizard.next() - gui.redo() - gui.wizard.next() - - spikes = _spikes_in_clusters(gui.model.spike_clusters, - gui.wizard.selection) - if len(spikes): - gui.split(spikes[::10]) - gui.wizard.next() - gui.undo() - gui.merge(gui.wizard.selection) - gui.wizard.next() - gui.wizard.next() - - gui.wizard.next_best() - ae(gui.model.spike_clusters, gui.clustering.spike_clusters) - - gui.close() - - -@mark.long -def test_gui_gui(qtbot): - settings = _load_default_settings() - config = settings['cluster_manual_config'] - shortcuts = settings['cluster_manual_shortcuts'] - - gui = _start_manual_clustering(config=config, - shortcuts=shortcuts, - ) - qtbot.addWidget(gui.main_window) - gui.show() - gui.close() diff --git a/phy/cluster/manual/tests/test_view_models.py b/phy/cluster/manual/tests/test_view_models.py deleted file mode 100644 index f973fcc0b..000000000 --- a/phy/cluster/manual/tests/test_view_models.py +++ /dev/null @@ -1,238 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of view model.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os - -from pytest import mark - -from ....utils.array import _spikes_per_cluster -from ....utils.logging import set_level -from ....utils.testing import (show_test_start, - show_test_stop, - show_test_run, - ) -from ....io.kwik.mock import create_mock_kwik -from ....io.kwik import KwikModel, create_store -from ..view_models import (WaveformViewModel, - FeatureGridViewModel, - CorrelogramViewModel, - TraceViewModel, - ) - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Utilities -#------------------------------------------------------------------------------ - -_N_CLUSTERS = 5 -_N_SPIKES = 200 -_N_CHANNELS = 28 -_N_FETS = 3 -_N_SAMPLES_TRACES = 10000 -_N_FRAMES = int((float(os.environ.get('PHY_EVENT_LOOP_DELAY', 0)) * 60) or 2) - - -def setup(): - set_level('info') - - -def _test_empty(tempdir, view_model_class, stop=True, **kwargs): - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=1, - n_spikes=1, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - model = KwikModel(filename) - spikes_per_cluster = _spikes_per_cluster(model.spike_ids, - model.spike_clusters) - store = create_store(model, - path=tempdir, - spikes_per_cluster=spikes_per_cluster, - features_masks_chunk_size=10, - waveforms_n_spikes_max=10, - waveforms_excerpt_size=5, - ) - store.generate() - - vm = view_model_class(model=model, store=store, **kwargs) - vm.on_open() - - # Show the view. - show_test_start(vm.view) - show_test_run(vm.view, _N_FRAMES) - vm.select([0]) - show_test_run(vm.view, _N_FRAMES) - vm.select([]) - show_test_run(vm.view, _N_FRAMES) - - if stop: - show_test_stop(vm.view) - - return vm - - -def _test_view_model(tempdir, view_model_class, stop=True, **kwargs): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - model = KwikModel(filename) - spikes_per_cluster = _spikes_per_cluster(model.spike_ids, - model.spike_clusters) - store = create_store(model, - path=tempdir, - spikes_per_cluster=spikes_per_cluster, - features_masks_chunk_size=15, - waveforms_n_spikes_max=20, - waveforms_excerpt_size=5, - ) - store.generate() - - vm = view_model_class(model=model, store=store, **kwargs) - vm.on_open() - show_test_start(vm.view) - - vm.select([2]) - show_test_run(vm.view, _N_FRAMES) - - vm.select([2, 3]) - show_test_run(vm.view, _N_FRAMES) - - vm.select([3, 2]) - show_test_run(vm.view, _N_FRAMES) - - if stop: - show_test_stop(vm.view) - - return vm - - -#------------------------------------------------------------------------------ -# Waveforms -#------------------------------------------------------------------------------ - -def test_waveforms_full(tempdir): - vm = _test_view_model(tempdir, WaveformViewModel, stop=False) - vm.overlap = True - show_test_run(vm.view, _N_FRAMES) - vm.show_mean = True - show_test_run(vm.view, _N_FRAMES) - show_test_stop(vm.view) - - -def test_waveforms_empty(tempdir): - _test_empty(tempdir, WaveformViewModel) - - -#------------------------------------------------------------------------------ -# Features -#------------------------------------------------------------------------------ - -def test_features_empty(tempdir): - _test_empty(tempdir, FeatureGridViewModel) - - -def test_features_full(tempdir): - _test_view_model(tempdir, FeatureGridViewModel, - marker_size=8, n_spikes_max=20) - - -def test_features_lasso(tempdir): - vm = _test_view_model(tempdir, - FeatureGridViewModel, - marker_size=8, - stop=False, - ) - show_test_run(vm.view, _N_FRAMES) - box = (1, 2) - vm.view.lasso.box = box - x, y = 0., 1. - vm.view.lasso.add((x, x)) - vm.view.lasso.add((y, x)) - vm.view.lasso.add((y, y)) - vm.view.lasso.add((x, y)) - show_test_run(vm.view, _N_FRAMES) - # Find spikes in lasso. - spikes = vm.spikes_in_lasso() - # Change their clusters. - vm.model.spike_clusters[spikes] = 3 - sc = vm.model.spike_clusters - vm.view.visual.spike_clusters = sc[vm.view.visual.spike_ids] - show_test_run(vm.view, _N_FRAMES) - show_test_stop(vm.view) - - -#------------------------------------------------------------------------------ -# Correlograms -#------------------------------------------------------------------------------ - -def test_ccg_empty(tempdir): - _test_empty(tempdir, - CorrelogramViewModel, - binsize=20, - winsize_bins=51, - n_excerpts=100, - excerpt_size=100, - ) - - -def test_ccg_simple(tempdir): - _test_view_model(tempdir, - CorrelogramViewModel, - binsize=10, - winsize_bins=61, - n_excerpts=80, - excerpt_size=120, - ) - - -def test_ccg_full(tempdir): - vm = _test_view_model(tempdir, - CorrelogramViewModel, - binsize=20, - winsize_bins=51, - n_excerpts=100, - excerpt_size=100, - stop=False, - ) - show_test_run(vm.view, _N_FRAMES) - vm.change_bins(half_width=100., bin=1.) - show_test_run(vm.view, _N_FRAMES) - show_test_stop(vm.view) - - -#------------------------------------------------------------------------------ -# Traces -#------------------------------------------------------------------------------ - -def test_traces_empty(tempdir): - _test_empty(tempdir, TraceViewModel) - - -def test_traces_simple(tempdir): - _test_view_model(tempdir, TraceViewModel) - - -def test_traces_full(tempdir): - vm = _test_view_model(tempdir, TraceViewModel, stop=False) - vm.move_right() - show_test_run(vm.view, _N_FRAMES) - vm.move_left() - show_test_run(vm.view, _N_FRAMES) - - show_test_stop(vm.view) diff --git a/phy/cluster/manual/view_models.py b/phy/cluster/manual/view_models.py deleted file mode 100644 index 0f9554108..000000000 --- a/phy/cluster/manual/view_models.py +++ /dev/null @@ -1,1098 +0,0 @@ -# -*- coding: utf-8 -*- - -"""View model for clustered data.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from six import string_types - -from ...io.kwik.model import _DEFAULT_GROUPS -from ...utils.array import _unique, _spikes_in_clusters, _as_array -from ...utils.selector import Selector -from ...utils._misc import _show_shortcuts -from ...utils._types import _is_integer, _is_float -from ...utils._color import _selected_clusters_colors -from ...utils import _as_list -from ...stats.ccg import correlograms, _symmetrize_correlograms -from ...plot.ccg import CorrelogramView -from ...plot.features import FeatureView -from ...plot.waveforms import WaveformView -from ...plot.traces import TraceView -from ...gui.base import BaseViewModel, HTMLViewModel -from ...gui._utils import _read - - -#------------------------------------------------------------------------------ -# Misc -#------------------------------------------------------------------------------ - -def _create_view(cls, backend=None, **kwargs): - if backend in ('pyqt4', None): - kwargs.update({'always_on_top': True}) - return cls(**kwargs) - - -def _oddify(x): - return x if x % 2 == 1 else x + 1 - - -#------------------------------------------------------------------------------ -# Base view models -#------------------------------------------------------------------------------ - -class BaseClusterViewModel(BaseViewModel): - """Interface between a view and a model.""" - _view_class = None - - def __init__(self, model=None, - store=None, wizard=None, - cluster_ids=None, **kwargs): - assert store is not None - self._store = store - self._wizard = wizard - - super(BaseClusterViewModel, self).__init__(model=model, - **kwargs) - - self._cluster_ids = None - if cluster_ids is not None: - self.select(_as_list(cluster_ids)) - - @property - def store(self): - """The cluster store.""" - return self._store - - @property - def wizard(self): - """The wizard.""" - return self._wizard - - @property - def cluster_ids(self): - """Selected clusters.""" - return self._cluster_ids - - @property - def n_clusters(self): - """Number of selected clusters.""" - return len(self._cluster_ids) - - # Public methods - #-------------------------------------------------------------------------- - - def select(self, cluster_ids, **kwargs): - """Select a list of clusters.""" - cluster_ids = _as_list(cluster_ids) - self._cluster_ids = cluster_ids - self.on_select(cluster_ids, **kwargs) - - # Callback methods - #-------------------------------------------------------------------------- - - def on_select(self, cluster_ids, **kwargs): - """Update the view after a new selection has been made. - - Must be overriden.""" - - def on_cluster(self, up): - """Called when a clustering action occurs. - - May be overriden.""" - - -def _css_cluster_colors(): - colors = _selected_clusters_colors() - # HACK: this is the maximum number of clusters that can be displayed - # in an HTML view. If this number is exceeded, cluster colors will be - # wrong for the extra clusters. - n = 32 - - def _color(i): - i = i % len(colors) - c = colors[i] - c = (255 * c).astype(np.int32) - return 'rgb({}, {}, {})'.format(*c) - - return ''.join(""".cluster_{i} {{ - color: {color}; - }}\n""".format(i=i, color=_color(i)) - for i in range(n)) - - -class HTMLClusterViewModel(BaseClusterViewModel, HTMLViewModel): - """HTML view model that displays per-cluster information.""" - - def get_css(self, **kwargs): - # TODO: improve this - # Currently, child classes *must* append some CSS to this parent's - # method. - return _css_cluster_colors() - - def on_select(self, cluster_ids, **kwargs): - """Update the view after a new selection has been made.""" - self.update(cluster_ids=cluster_ids) - - def on_cluster(self, up): - """Update the view after a clustering action.""" - self.update(cluster_ids=self._cluster_ids, up=up) - - -class VispyViewModel(BaseClusterViewModel): - """Create a VisPy view from a model. - - This object uses an internal `Selector` instance to manage spike and - cluster selection. - - """ - _imported_params = ('n_spikes_max', 'excerpt_size') - keyboard_shortcuts = {} - scale_factor = 1. - - def __init__(self, **kwargs): - super(VispyViewModel, self).__init__(**kwargs) - - # Call on_close() when the view is closed. - @self._view.connect - def on_close(e): - self.on_close() - - def _create_view(self, **kwargs): - n_spikes_max = kwargs.get('n_spikes_max', None) - excerpt_size = kwargs.get('excerpt_size', None) - backend = kwargs.get('backend', None) - position = kwargs.get('position', None) - size = kwargs.get('size', None) - - # Create the spike/cluster selector. - self._selector = Selector(self._model.spike_clusters, - n_spikes_max=n_spikes_max, - excerpt_size=excerpt_size, - ) - - # Create the VisPy canvas. - view = _create_view(self._view_class, - backend=backend, - position=position or (200, 200), - size=size or (600, 600), - ) - view.connect(self.on_key_press) - return view - - @property - def selector(self): - """A Selector instance managing the selected spikes and clusters.""" - return self._selector - - @property - def cluster_ids(self): - """Selected clusters.""" - return self._selector.selected_clusters - - @property - def spike_ids(self): - """Selected spikes.""" - return self._selector.selected_spikes - - @property - def n_spikes(self): - """Number of selected spikes.""" - return self._selector.n_spikes - - def update_spike_clusters(self, spikes=None, spike_clusters=None): - """Update the spike clusters and cluster colors.""" - if spikes is None: - spikes = self.spike_ids - if spike_clusters is None: - spike_clusters = self.model.spike_clusters[spikes] - n_clusters = len(_unique(spike_clusters)) - visual = self._view.visual - # This updates the list of unique clusters in the view. - visual.spike_clusters = spike_clusters - visual.cluster_colors = _selected_clusters_colors(n_clusters) - - def select(self, cluster_ids, **kwargs): - """Select a set of clusters.""" - self._selector.selected_clusters = cluster_ids - self.on_select(cluster_ids, **kwargs) - - def on_select(self, cluster_ids, **kwargs): - """Update the view after a new selection has been made. - - Must be overriden. - - """ - self.update_spike_clusters() - self._view.update() - - def on_close(self): - """Clear the view when the model is closed.""" - self._view.visual.spike_clusters = [] - self._view.update() - - def on_key_press(self, event): - """Called when a key is pressed.""" - if event.key == 'h' and 'control' not in event.modifiers: - shortcuts = self._view.keyboard_shortcuts - shortcuts.update(self.keyboard_shortcuts) - _show_shortcuts(shortcuts, name=self.name) - - def update(self): - """Update the view.""" - self.view.update() - - -#------------------------------------------------------------------------------ -# Stats panel -#------------------------------------------------------------------------------ - -class StatsViewModel(HTMLClusterViewModel): - """Display cluster statistics.""" - - def get_html(self, cluster_ids=None, up=None): - """Return the HTML table with the cluster statistics.""" - stats = self.store.items['statistics'] - names = stats.fields - if cluster_ids is None: - return '' - # Only keep scalar stats. - _arrays = {name: isinstance(getattr(self.store, name)(cluster_ids[0]), - np.ndarray) for name in names} - names = sorted([name for name in _arrays if not _arrays[name]]) - # Add the cluster group as a first statistic. - names = ['cluster_group'] + names - # Generate the table. - html = '' - for i, cluster in enumerate(cluster_ids): - html += '{cluster}'.format( - cluster=cluster, style='cluster_{}'.format(i)) - html += '' - cluster_groups = self.model.cluster_groups - group_names = dict(_DEFAULT_GROUPS) - for name in names: - html += '' - html += '{name}'.format(name=name) - for i, cluster in enumerate(cluster_ids): - if name == 'cluster_group': - # Get the cluster group. - group_id = cluster_groups.get(cluster, -1) - value = group_names.get(group_id, 'unknown') - else: - value = getattr(self.store, name)(cluster) - if _is_float(value): - value = '{:.3f}'.format(value) - elif _is_integer(value): - value = '{:d}'.format(value) - else: - value = str(value) - html += '{value}'.format( - value=value, style='cluster_{}'.format(i)) - html += '' - return '
\n' + html + '
' - - def get_css(self, cluster_ids=None, up=None): - css = super(StatsViewModel, self).get_css(cluster_ids=cluster_ids, - up=up) - static_path = op.join(op.dirname(op.realpath(__file__)), - 'static/') - css += _read('styles.css', static_path=static_path) - return css - - -#------------------------------------------------------------------------------ -# Kwik view models -#------------------------------------------------------------------------------ - -class WaveformViewModel(VispyViewModel): - """Waveforms.""" - _view_class = WaveformView - _view_name = 'waveforms' - _imported_params = ('scale_factor', 'box_scale', 'probe_scale', - 'overlap', 'show_mean') - - def on_open(self): - """Initialize the view when the model is opened.""" - super(WaveformViewModel, self).on_open() - # Waveforms. - self.view.visual.channel_positions = self.model.probe.positions - self.view.visual.channel_order = self.model.channel_order - # Mean waveforms. - self.view.mean.channel_positions = self.model.probe.positions - self.view.mean.channel_order = self.model.channel_order - if self.scale_factor is None: - self.scale_factor = 1. - - def _load_waveforms(self): - # NOTE: we load all spikes from the store. - # The waveforms store item is responsible for making a subselection - # of spikes both on disk and in the view. - waveforms = self.store.load('waveforms', - clusters=self.cluster_ids, - ) - return waveforms - - def _load_mean_waveforms(self): - mean_waveforms = self.store.load('mean_waveforms', - clusters=self.cluster_ids, - ) - mean_masks = self.store.load('mean_masks', - clusters=self.cluster_ids, - ) - return mean_waveforms, mean_masks - - def update_spike_clusters(self, spikes=None): - """Update the view's spike clusters.""" - super(WaveformViewModel, self).update_spike_clusters(spikes=spikes) - self._view.mean.spike_clusters = np.sort(self.cluster_ids) - self._view.mean.cluster_colors = self._view.visual.cluster_colors - - def on_select(self, clusters, **kwargs): - """Update the view when the selection changes.""" - # Get the spikes of the stored waveforms. - n_clusters = len(clusters) - waveforms = self._load_waveforms() - spikes = self.store.items['waveforms'].spikes_in_clusters(clusters) - n_spikes = len(spikes) - _, self._n_samples, self._n_channels = waveforms.shape - mean_waveforms, mean_masks = self._load_mean_waveforms() - - self.update_spike_clusters(spikes) - - # Cluster display order. - self.view.visual.cluster_order = clusters - self.view.mean.cluster_order = clusters - - # Waveforms. - assert waveforms.shape[0] == n_spikes - self.view.visual.waveforms = waveforms * self.scale_factor - - assert mean_waveforms.shape == (n_clusters, - self._n_samples, - self._n_channels) - self.view.mean.waveforms = mean_waveforms * self.scale_factor - - # Masks. - masks = self.store.load('masks', clusters=clusters, spikes=spikes) - assert masks.shape == (n_spikes, self._n_channels) - self.view.visual.masks = masks - - assert mean_masks.shape == (n_clusters, self._n_channels) - self.view.mean.masks = mean_masks - - # Spikes. - self.view.visual.spike_ids = spikes - self.view.mean.spike_ids = np.arange(len(clusters)) - - self.view.update() - - def on_close(self): - """Clear the view when the model is closed.""" - self.view.visual.channel_positions = [] - self.view.mean.channel_positions = [] - super(WaveformViewModel, self).on_close() - - @property - def box_scale(self): - """Scale of the waveforms. - - This is a pair of scalars. - - """ - return self.view.box_scale - - @box_scale.setter - def box_scale(self, value): - self.view.box_scale = value - - @property - def probe_scale(self): - """Scale of the probe. - - This is a pair of scalars. - - """ - return self.view.probe_scale - - @probe_scale.setter - def probe_scale(self, value): - self.view.probe_scale = value - - @property - def overlap(self): - """Whether to overlap waveforms.""" - return self.view.overlap - - @overlap.setter - def overlap(self, value): - self.view.overlap = value - - @property - def show_mean(self): - """Whether to show mean waveforms.""" - return self.view.show_mean - - @show_mean.setter - def show_mean(self, value): - self.view.show_mean = value - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(WaveformViewModel, self).exported_params(save_size_pos) - params.update({ - 'scale_factor': self.scale_factor, - 'box_scale': self.view.box_scale, - 'probe_scale': self.view.probe_scale, - 'overlap': self.view.overlap, - 'show_mean': self.view.show_mean, - }) - return params - - -class CorrelogramViewModel(VispyViewModel): - """Correlograms.""" - _view_class = CorrelogramView - _view_name = 'correlograms' - binsize = 20 # in number of samples - winsize_bins = 41 # in number of bins - _imported_params = ('binsize', 'winsize_bins', 'lines', 'normalization') - _normalization = 'equal' # or 'independent' - _ccgs = None - - def change_bins(self, bin=None, half_width=None): - """Change the parameters of the correlograms. - - Parameters - ---------- - bin : float (ms) - Bin size. - half_width : float (ms) - Half window size. - - """ - sr = float(self.model.sample_rate) - - # Default values. - if bin is None: - bin = 1000. * self.binsize / sr # in ms - if half_width is None: - half_width = 1000. * (float((self.winsize_bins // 2) * - self.binsize / sr)) - - bin = np.clip(bin, .1, 1e3) # in ms - self.binsize = int(sr * bin * .001) # in s - - half_width = np.clip(half_width, .1, 1e3) # in ms - self.winsize_bins = 2 * int(half_width / bin) + 1 - - self.select(self.cluster_ids) - - def on_select(self, clusters, **kwargs): - """Update the view when the selection changes.""" - super(CorrelogramViewModel, self).on_select(clusters) - spikes = self.spike_ids - self.view.cluster_ids = clusters - - # Compute the correlograms. - spike_samples = self.model.spike_samples[spikes] - spike_clusters = self.view.visual.spike_clusters - - ccgs = correlograms(spike_samples, - spike_clusters, - cluster_order=clusters, - binsize=self.binsize, - # NOTE: this must be an odd number, for symmetry - winsize_bins=_oddify(self.winsize_bins), - ) - self._ccgs = _symmetrize_correlograms(ccgs) - # Normalize the CCGs. - self.view.correlograms = self._normalize(self._ccgs) - - # Take the cluster order into account. - self.view.visual.cluster_order = clusters - self.view.update() - - def _normalize(self, ccgs): - if not len(ccgs): - return ccgs - if self._normalization == 'equal': - return ccgs * (1. / max(1., ccgs.max())) - elif self._normalization == 'independent': - return ccgs * (1. / np.maximum(1., ccgs.max(axis=2)[:, :, None])) - - @property - def normalization(self): - """Correlogram normalization: `equal` or `independent`.""" - return self._normalization - - @normalization.setter - def normalization(self, value): - self._normalization = value - if self._ccgs is not None: - self.view.visual.correlograms = self._normalize(self._ccgs) - self.view.update() - - @property - def lines(self): - return self.view.lines - - @lines.setter - def lines(self, value): - self.view.lines = value - - def toggle_normalization(self): - """Change the correlogram normalization.""" - self.normalization = ('equal' if self._normalization == 'independent' - else 'independent') - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(CorrelogramViewModel, self).exported_params( - save_size_pos) - params.update({ - 'normalization': self.normalization, - }) - return params - - -class TraceViewModel(VispyViewModel): - """Traces.""" - _view_class = TraceView - _view_name = 'traces' - _imported_params = ('scale_factor', 'channel_scale', 'interval_size') - interval_size = .25 - - def __init__(self, **kwargs): - self._interval = None - super(TraceViewModel, self).__init__(**kwargs) - - def _load_traces(self, interval): - start, end = interval - spikes = self.spike_ids - - # Load the traces. - # debug("Loading traces...") - # Using channel_order ensures that we get rid of the dead channels. - # We also keep the channel order as specified by the PRM file. - # WARNING: HDF5 does not support out-of-order indexing (...!!) - traces = self.model.traces[start:end, :][:, self.model.channel_order] - - # Normalize and set the traces. - traces_f = np.empty_like(traces, dtype=np.float32) - traces_f[...] = traces * self.scale_factor - # Detrend the traces. - m = np.mean(traces_f[::10, :], axis=0) - traces_f -= m - self.view.visual.traces = traces_f - - # Keep the spikes in the interval. - spike_samples = self.model.spike_samples[spikes] - a, b = spike_samples.searchsorted(interval) - spikes = spikes[a:b] - self.view.visual.n_spikes = len(spikes) - self.view.visual.spike_ids = spikes - - if len(spikes) == 0: - return - - # We update the spike clusters according to the subselection of spikes. - # We don't update the list of unique clusters, which only change - # when selecting or clustering, not when changing the interval. - # self.update_spike_clusters(spikes) - self.view.visual.spike_clusters = self.model.spike_clusters[spikes] - - # Set the spike samples. - spike_samples = self.model.spike_samples[spikes] - # This is in unit of samples relative to the start of the interval. - spike_samples = spike_samples - start - self.view.visual.spike_samples = spike_samples - self.view.visual.offset = start - - # Load the masks. - # TODO: ensure model.masks is always 2D, even with 1 spike - masks = np.atleast_2d(self._model.masks[spikes]) - self.view.visual.masks = masks - - @property - def interval(self): - """The interval of the view, in unit of sample.""" - return self._interval - - @interval.setter - def interval(self, value): - if self.model.traces is None: - return - if not isinstance(value, tuple) or len(value) != 2: - raise ValueError("The interval should be a (start, end) tuple.") - # Restrict the interval to the boundaries of the traces. - start, end = value - start, end = int(start), int(end) - n = self.model.traces.shape[0] - if start < 0: - end += (-start) - start = 0 - elif end >= n: - start -= (end - n) - end = n - start = np.clip(start, 0, end) - end = np.clip(end, start, n) - assert 0 <= start < end <= n - self._interval = (start, end) - self._load_traces((start, end)) - self.view.update() - - @property - def channel_scale(self): - """Vertical scale of the traces.""" - return self.view.channel_scale - - @channel_scale.setter - def channel_scale(self, value): - self.view.channel_scale = value - - def move(self, amount): - """Move the current interval by a given amount (in samples).""" - amount = int(amount) - start, end = self.interval - self.interval = start + amount, end + amount - - def move_right(self, fraction=.05): - """Move the current interval to the right.""" - start, end = self.interval - self.move(int(+(end - start) * fraction)) - - def move_left(self, fraction=.05): - """Move the current interval to the left.""" - start, end = self.interval - self.move(int(-(end - start) * fraction)) - - keyboard_shortcuts = { - 'scroll_left': 'ctrl+left', - 'scroll_right': 'ctrl+right', - 'fast_scroll_left': 'shift+left', - 'fast_scroll_right': 'shift+right', - } - - def on_key_press(self, event): - """Called when a key is pressed.""" - super(TraceViewModel, self).on_key_press(event) - key = event.key - if 'Control' in event.modifiers: - if key == 'Left': - self.move_left() - elif key == 'Right': - self.move_right() - if 'Shift' in event.modifiers: - if key == 'Left': - self.move_left(1) - elif key == 'Right': - self.move_right(1) - - def on_open(self): - """Initialize the view when the model is opened.""" - super(TraceViewModel, self).on_open() - self.view.visual.n_samples_per_spike = self.model.n_samples_waveforms - self.view.visual.sample_rate = self.model.sample_rate - if self.scale_factor is None: - self.scale_factor = 1. - if self.interval_size is None: - self.interval_size = .25 - self.select([]) - - def on_select(self, clusters, **kwargs): - """Update the view when the selection changes.""" - # Get the spikes in the selected clusters. - spikes = self.spike_ids - n_clusters = len(clusters) - spike_clusters = self.model.spike_clusters[spikes] - - # Update the clusters of the trace view. - visual = self._view.visual - visual.spike_clusters = spike_clusters - visual.cluster_ids = clusters - visual.cluster_order = clusters - visual.cluster_colors = _selected_clusters_colors(n_clusters) - - # Select the default interval. - half_size = int(self.interval_size * self.model.sample_rate / 2.) - if len(spikes) > 0: - # Center the default interval around the first spike. - sample = self._model.spike_samples[spikes[0]] - else: - sample = half_size - # Load traces by setting the interval. - visual._update_clusters_automatically = False - self.interval = sample - half_size, sample + half_size - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(TraceViewModel, self).exported_params(save_size_pos) - params.update({ - 'scale_factor': self.scale_factor, - 'channel_scale': self.channel_scale, - }) - return params - - -#------------------------------------------------------------------------------ -# Feature view models -#------------------------------------------------------------------------------ - -def _best_channels(cluster, model=None, store=None): - """Return the channels with the largest mean features.""" - n_fet = model.n_features_per_channel - score = store.mean_features(cluster) - score = score.reshape((-1, n_fet)).mean(axis=1) - assert len(score) == len(model.channel_order) - channels = np.argsort(score)[::-1] - return channels - - -def _dimensions(x_channels, y_channels): - """Default dimensions matrix.""" - # time, depth time, (x, 0) time, (y, 0) time, (z, 0) - # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) - # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) - # time, (z', 0) (z', 0), (x, 2) (z', 1), (y, 2) (z', 2), (z, 2) - - n = len(x_channels) - assert len(y_channels) == n - y_dim = {} - x_dim = {} - # TODO: depth - x_dim[0, 0] = 'time' - y_dim[0, 0] = 'time' - - # Time in first column and first row. - for i in range(1, n + 1): - x_dim[0, i] = 'time' - y_dim[0, i] = (x_channels[i - 1], 0) - x_dim[i, 0] = 'time' - y_dim[i, 0] = (y_channels[i - 1], 0) - - for i in range(1, n + 1): - for j in range(1, n + 1): - x_dim[i, j] = (x_channels[i - 1], j - 1) - y_dim[i, j] = (y_channels[j - 1], i - 1) - - return x_dim, y_dim - - -class BaseFeatureViewModel(VispyViewModel): - """Features.""" - _view_class = FeatureView - _view_name = 'base_features' - _imported_params = ('scale_factor', 'n_spikes_max_bg', 'marker_size') - n_spikes_max_bg = 10000 - - def __init__(self, *args, **kwargs): - self._extra_features = {} - super(BaseFeatureViewModel, self).__init__(*args, **kwargs) - - def _rescale_features(self, features): - # WARNING: convert features to a 3D array - # (n_spikes, n_channels, n_features) - # because that's what the FeatureView expects currently. - n_fet = self.model.n_features_per_channel - n_channels = len(self.model.channel_order) - shape = (-1, n_channels, n_fet) - features = features[:, :n_fet * n_channels].reshape(shape) - # Scale factor. - return features * self.scale_factor - - @property - def lasso(self): - """The spike lasso visual.""" - return self.view.lasso - - def spikes_in_lasso(self): - """Return the spike ids from the selected clusters within the lasso.""" - if not len(self.cluster_ids) or self.view.lasso.n_points <= 2: - return - clusters = self.cluster_ids - # Load *all* features from the selected clusters. - features = self.store.load('features', clusters=clusters) - # Find the corresponding spike_ids. - spike_ids = _spikes_in_clusters(self.model.spike_clusters, clusters) - assert features.shape[0] == len(spike_ids) - # Extract the extra features for all the spikes in the clusters. - extra_features = self._subset_extra_features(spike_ids) - # Rescale the features (and not the extra features!). - features = features * self.scale_factor - # Extract the two relevant dimensions. - points = self.view.visual.project(self.view.lasso.box, - features=features, - extra_features=extra_features, - ) - # Find the points within the lasso. - in_lasso = self.view.lasso.in_lasso(points) - return spike_ids[in_lasso] - - @property - def marker_size(self): - """Marker size, in pixels.""" - return self.view.marker_size - - @marker_size.setter - def marker_size(self, value): - self.view.marker_size = value - - @property - def n_features(self): - """Number of features.""" - return self.view.background.n_features - - @property - def n_rows(self): - """Number of rows in the view. - - To be overriden. - - """ - return 1 - - def dimensions_for_clusters(self, cluster_ids): - """Return the x and y dimensions most appropriate for the set of - selected clusters. - - To be overriden. - - TODO: make this customizable. - - """ - return {}, {} - - def set_dimension(self, axis, box, dim, smart=True): - """Set a dimension. - - Parameters - ---------- - - axis : str - `'x'` or `'y'` - box : tuple - A `(i, j)` pair. - dim : str or tuple - A feature name, or a tuple `(channel_id, feature_id)`. - smart : bool - Whether to ensure the two axes in the subplot are different, - to avoid a useless `x=y` diagonal situation. - - """ - if smart: - dim = self.view.smart_dimension(axis, box, dim) - self.view.set_dimensions(axis, {box: dim}) - - def add_extra_feature(self, name, array): - """Add an extra feature. - - Parameters - ---------- - - name : str - The feature's name. - array : ndarray - A `(n_spikes,)` array with the feature's value for every spike. - - """ - assert isinstance(name, string_types) - array = _as_array(array) - n_spikes = self.model.n_spikes - if array.shape != (n_spikes,): - raise ValueError("The extra feature needs to be a 1D vector with " - "`n_spikes={}` values.".format(n_spikes)) - self._extra_features[name] = (array, array.min(), array.max()) - - def _subset_extra_features(self, spikes): - return {name: (array[spikes], m, M) - for name, (array, m, M) in self._extra_features.items()} - - def _add_extra_features_in_view(self, spikes): - """Add the extra features in the view, by selecting only the - displayed spikes.""" - subset_extra = self._subset_extra_features(spikes) - for name, (sub_array, m, M) in subset_extra.items(): - # Make the extraction for the background spikes. - array, _, _ = self._extra_features[name] - sub_array_bg = array[self.view.background.spike_ids] - self.view.add_extra_feature(name, sub_array, m, M, - array_bg=sub_array_bg) - - @property - def x_dim(self): - """List of x dimensions in every row/column.""" - return self.view.x_dim - - @property - def y_dim(self): - """List of y dimensions in every row/column.""" - return self.view.y_dim - - def on_open(self): - """Initialize the view when the model is opened.""" - # Get background features. - # TODO OPTIM: precompute this once for all and store in the cluster - # store. But might be unnecessary. - if self.n_spikes_max_bg is not None: - k = max(1, self.model.n_spikes // self.n_spikes_max_bg) - else: - k = 1 - if self.model.features is not None: - # Background features. - features_bg = self.store.load('features', - spikes=slice(None, None, k)) - self.view.background.features = self._rescale_features(features_bg) - self.view.background.spike_ids = self.model.spike_ids[::k] - # Register the time dimension. - self.add_extra_feature('time', self.model.spike_samples) - # Add the subset extra features to the visuals. - self._add_extra_features_in_view(slice(None, None, k)) - # Number of rows: number of features + 1 for - self.view.init_grid(self.n_rows) - - def on_select(self, clusters, auto_update=True): - """Update the view when the selection changes.""" - super(BaseFeatureViewModel, self).on_select(clusters) - spikes = self.spike_ids - - features = self.store.load('features', - clusters=clusters, - spikes=spikes) - masks = self.store.load('masks', - clusters=clusters, - spikes=spikes) - - nc = len(self.model.channel_order) - nf = self.model.n_features_per_channel - features = features.reshape((len(spikes), nc, nf)) - self.view.visual.features = self._rescale_features(features) - self.view.visual.masks = masks - - # Spikes. - self.view.visual.spike_ids = spikes - - # Extra features (including time). - self._add_extra_features_in_view(spikes) - - # Cluster display order. - self.view.visual.cluster_order = clusters - - # Set default dimensions. - if auto_update: - x_dim, y_dim = self.dimensions_for_clusters(clusters) - self.view.set_dimensions('x', x_dim) - self.view.set_dimensions('y', y_dim) - - keyboard_shortcuts = { - 'increase_scale': 'ctrl+', - 'decrease_scale': 'ctrl-', - } - - def on_key_press(self, event): - """Handle key press events.""" - super(BaseFeatureViewModel, self).on_key_press(event) - key = event.key - ctrl = 'Control' in event.modifiers - if ctrl and key in ('+', '-'): - k = 1.1 if key == '+' else .9 - self.scale_factor *= k - self.view.visual.features *= k - self.view.background.features *= k - self.view.update() - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(BaseFeatureViewModel, - self).exported_params(save_size_pos) - params.update({ - 'scale_factor': self.scale_factor, - 'marker_size': self.marker_size, - }) - return params - - -class FeatureGridViewModel(BaseFeatureViewModel): - """Features grid""" - _view_name = 'features_grid' - - keyboard_shortcuts = { - 'enlarge_subplot': 'ctrl+click', - 'increase_scale': 'ctrl+', - 'decrease_scale': 'ctrl-', - } - - @property - def n_rows(self): - """Number of rows in the grid view.""" - return self.n_features + 1 - - def dimensions_for_clusters(self, cluster_ids): - """Return the x and y dimensions most appropriate for the set of - selected clusters. - - TODO: make this customizable. - - """ - n = len(cluster_ids) - if not n: - return {}, {} - x_channels = _best_channels(cluster_ids[min(1, n - 1)], - model=self.model, - store=self.store, - ) - y_channels = _best_channels(cluster_ids[0], - model=self.model, - store=self.store, - ) - y_channels = y_channels[:self.n_rows - 1] - # For the x axis, remove the channels that already are in - # the y axis. - x_channels = [c for c in x_channels if c not in y_channels] - # Now, select the right number of channels in the x axis. - x_channels = x_channels[:self.n_rows - 1] - return _dimensions(x_channels, y_channels) - - -class FeatureViewModel(BaseFeatureViewModel): - """Feature view with a single subplot.""" - _view_name = 'features' - _x_dim = 'time' - _y_dim = (0, 0) - - keyboard_shortcuts = { - 'increase_scale': 'ctrl+', - 'decrease_scale': 'ctrl-', - } - - @property - def n_rows(self): - """Number of rows.""" - return 1 - - @property - def x_dim(self): - """x dimension.""" - return self._x_dim - - @property - def y_dim(self): - """y dimension.""" - return self._y_dim - - def set_dimension(self, axis, dim, smart=True): - """Set a (smart) dimension. - - "smart" means that the dimension may be changed if it is the same - than the other dimension, to avoid x=y. - - """ - super(FeatureViewModel, self).set_dimension(axis, (0, 0), dim, - smart=smart) - - def dimensions_for_clusters(self, cluster_ids): - """Current dimensions.""" - return self._x_dim, self._y_dim diff --git a/phy/detect/__init__.py b/phy/detect/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/detect/default_settings.py b/phy/detect/default_settings.py deleted file mode 100644 index 6b6794eda..000000000 --- a/phy/detect/default_settings.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for spike detection.""" - - -# ----------------------------------------------------------------------------- -# Spike detection -# ----------------------------------------------------------------------------- - -spikedetekt = { - 'filter_low': 500., - 'filter_high_factor': 0.95 * .5, # will be multiplied by the sample rate - 'filter_butter_order': 3, - - # Data chunks. - 'chunk_size_seconds': 1., - 'chunk_overlap_seconds': .015, - - # Threshold. - 'n_excerpts': 50, - 'excerpt_size_seconds': 1., - 'use_single_threshold': True, - 'threshold_strong_std_factor': 4.5, - 'threshold_weak_std_factor': 2., - 'detect_spikes': 'negative', - - # Connected components. - 'connected_component_join_size': 1, - - # Spike extractions. - 'extract_s_before': 10, - 'extract_s_after': 10, - 'weight_power': 2, - - # Features. - 'n_features_per_channel': 3, - 'pca_n_waveforms_max': 10000, - - # Waveform filtering in GUI. - 'waveform_filter': True, - 'waveform_dc_offset': None, - 'waveform_scale_factor': None, - -} diff --git a/phy/detect/spikedetekt.py b/phy/detect/spikedetekt.py deleted file mode 100644 index 3615a4823..000000000 --- a/phy/detect/spikedetekt.py +++ /dev/null @@ -1,610 +0,0 @@ -# -*- coding: utf-8 -*- - -"""SpikeDetekt algorithm.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from collections import defaultdict - -import numpy as np - -from ..utils.array import (get_excerpts, - chunk_bounds, - data_chunk, - _as_array, - _concatenate, - ) -from ..utils._types import Bunch -from ..utils.event import EventEmitter, ProgressReporter -from ..utils.logging import debug, info -from ..electrode.mea import (_channels_per_group, - _probe_adjacency_list, - ) -from ..traces import (Filter, Thresholder, compute_threshold, - FloodFillDetector, WaveformExtractor, PCA, - ) -from .store import SpikeDetektStore - - -#------------------------------------------------------------------------------ -# Spike detection class -#------------------------------------------------------------------------------ - -def _find_dead_channels(channels_per_group, n_channels): - all_channels = sorted([item for sublist in channels_per_group.values() - for item in sublist]) - dead = np.setdiff1d(np.arange(n_channels), all_channels) - debug("Using dead channels: {}.".format(dead)) - return dead - - -def _keep_spikes(samples, bounds): - """Only keep spikes within the bounds `bounds=(start, end)`.""" - start, end = bounds - return (start <= samples) & (samples <= end) - - -def _split_spikes(groups, idx=None, **arrs): - """Split spike data according to the channel group.""" - # split: {group: {'spike_samples': ..., 'waveforms':, 'masks':}} - dtypes = {'spike_samples': np.float64, - 'waveforms': np.float32, - 'masks': np.float32, - } - groups = _as_array(groups) - if idx is not None: - n_spikes_chunk = np.sum(idx) - # First, remove the overlapping bands. - groups = groups[idx] - arrs_bis = arrs.copy() - for key, arr in arrs.items(): - arrs_bis[key] = arr[idx] - assert len(arrs_bis[key]) == n_spikes_chunk - # Then, split along the group. - groups_u = np.unique(groups) - out = {} - for group in groups_u: - i = (groups == group) - out[group] = {} - for key, arr in arrs_bis.items(): - out[group][key] = _concat(arr[i], dtypes.get(key, None)) - return out - - -def _array_list(arrs): - out = np.empty((len(arrs),), dtype=np.object) - out[:] = arrs - return out - - -def _concat(arr, dtype=None): - out = np.array([_[...] for _ in arr], dtype=dtype) - return out - - -def _cut_traces(traces, interval_samples): - n_samples, n_channels = traces.shape - # Take a subset if necessary. - if interval_samples is not None: - start, end = interval_samples - assert start <= end - traces = traces[start:end, ...] - n_samples = traces.shape[0] - else: - start, end = 0, n_samples - assert 0 <= start < end - if start > 0: - # TODO: add offset to the spike samples... - raise NotImplementedError("Need to add `start` to the " - "spike samples") - return traces, start - - -#------------------------------------------------------------------------------ -# Spike detection class -#------------------------------------------------------------------------------ - -_spikes_message = "{n_spikes:d} spikes in chunk {value:d}/{value_max:d}." - - -class SpikeDetektProgress(ProgressReporter): - _progress_messages = { - 'detect': ("Detecting spikes: {progress:.2f}%. " + _spikes_message, - "Spike detection complete: {n_spikes_total:d} " + - "spikes detected."), - - 'excerpt': ("Extracting waveforms subset for PCs: " + - "{progress:.2f}%. " + _spikes_message, - "Waveform subset extraction complete: " + - "{n_spikes_total} spikes."), - - 'pca': ("Performing PCA: {progress:.2f}%.", - "Principal waveform components computed."), - - 'extract': ("Extracting spikes: {progress:.2f}%. " + _spikes_message, - "Spike extraction complete: {n_spikes_total:d} " + - "spikes extracted."), - - } - - def __init__(self, n_chunks=None): - super(SpikeDetektProgress, self).__init__() - self.n_chunks = n_chunks - - def start_step(self, name, value_max): - self._iter = 0 - self.reset(value_max) - self.set_progress_message(self._progress_messages[name][0], - line_break=True) - self.set_complete_message(self._progress_messages[name][1]) - - -class SpikeDetekt(EventEmitter): - """Spike detection class. - - Parameters - ---------- - - tempdir : str - Path to the temporary directory used by the algorithm. It should be - on a SSD for best performance. - probe : dict - The probe dictionary. - **kwargs : dict - Spike detection parameters. - - """ - def __init__(self, tempdir=None, probe=None, **kwargs): - super(SpikeDetekt, self).__init__() - self._tempdir = tempdir - self._dead_channels = None - # Load a probe. - if probe is not None: - kwargs['probe_channels'] = _channels_per_group(probe) - kwargs['probe_adjacency_list'] = _probe_adjacency_list(probe) - self._kwargs = kwargs - self._n_channels_per_group = { - group: len(channels) - for group, channels in self._kwargs['probe_channels'].items() - } - self._groups = sorted(self._n_channels_per_group) - self._n_features = self._kwargs['n_features_per_channel'] - before = self._kwargs['extract_s_before'] - after = self._kwargs['extract_s_after'] - self._n_samples_waveforms = before + after - - # Processing objects creation - # ------------------------------------------------------------------------- - - def _create_filter(self): - rate = self._kwargs['sample_rate'] - low = self._kwargs['filter_low'] - high = self._kwargs['filter_high_factor'] * rate - order = self._kwargs['filter_butter_order'] - return Filter(rate=rate, - low=low, - high=high, - order=order, - ) - - def _create_thresholder(self, thresholds=None): - mode = self._kwargs['detect_spikes'] - return Thresholder(mode=mode, thresholds=thresholds) - - def _create_detector(self): - graph = self._kwargs['probe_adjacency_list'] - probe_channels = self._kwargs['probe_channels'] - join_size = self._kwargs['connected_component_join_size'] - return FloodFillDetector(probe_adjacency_list=graph, - join_size=join_size, - channels_per_group=probe_channels, - ) - - def _create_extractor(self, thresholds): - before = self._kwargs['extract_s_before'] - after = self._kwargs['extract_s_after'] - weight_power = self._kwargs['weight_power'] - probe_channels = self._kwargs['probe_channels'] - return WaveformExtractor(extract_before=before, - extract_after=after, - weight_power=weight_power, - channels_per_group=probe_channels, - thresholds=thresholds, - ) - - def _create_pca(self): - n_pcs = self._kwargs['n_features_per_channel'] - return PCA(n_pcs=n_pcs) - - # Misc functions - # ------------------------------------------------------------------------- - - def update_params(self, **kwargs): - self._kwargs.update(kwargs) - - # Processing functions - # ------------------------------------------------------------------------- - - def apply_filter(self, data): - """Filter the traces.""" - filter = self._create_filter() - return filter(data).astype(np.float32) - - def find_thresholds(self, traces): - """Find weak and strong thresholds in filtered traces.""" - rate = self._kwargs['sample_rate'] - n_excerpts = self._kwargs['n_excerpts'] - excerpt_size = int(self._kwargs['excerpt_size_seconds'] * rate) - single = bool(self._kwargs['use_single_threshold']) - strong_f = self._kwargs['threshold_strong_std_factor'] - weak_f = self._kwargs['threshold_weak_std_factor'] - - info("Finding the thresholds...") - excerpt = get_excerpts(traces, - n_excerpts=n_excerpts, - excerpt_size=excerpt_size) - excerpt_f = self.apply_filter(excerpt) - thresholds = compute_threshold(excerpt_f, - single_threshold=single, - std_factor=(weak_f, strong_f)) - debug("Thresholds: {}.".format(thresholds)) - return {'weak': thresholds[0], - 'strong': thresholds[1]} - - def detect(self, traces_f, thresholds=None, dead_channels=None): - """Detect connected waveform components in filtered traces. - - Parameters - ---------- - - traces_f : array - An `(n_samples, n_channels)` array with the filtered data. - thresholds : dict - The weak and strong thresholds. - dead_channels : array-like - Array of dead channels. - - Returns - ------- - - components : list - A list of `(n, 2)` arrays with `sample, channel` pairs. - - """ - # Threshold the data following the weak and strong thresholds. - thresholder = self._create_thresholder(thresholds) - # Transform the filtered data according to the detection mode. - traces_t = thresholder.transform(traces_f) - # Compute the threshold crossings. - weak = thresholder.detect(traces_t, 'weak') - strong = thresholder.detect(traces_t, 'strong') - # Force crossings to be False on dead channels. - if dead_channels is not None and len(dead_channels): - assert dead_channels.max() < traces_f.shape[1] - weak[:, dead_channels] = 0 - strong[:, dead_channels] = 0 - else: - debug("No dead channels specified.") - # Run the detection. - detector = self._create_detector() - return detector(weak_crossings=weak, - strong_crossings=strong) - - def extract_spikes(self, components, traces_f, - thresholds=None, keep_bounds=None): - """Extract spikes from connected components. - - Returns a split object. - - Parameters - ---------- - components : list - List of connected components. - traces_f : array - Filtered data. - thresholds : dict - The weak and strong thresholds. - keep_bounds : tuple - (keep_start, keep_end). - - """ - n_spikes = len(components) - if n_spikes == 0: - return {} - - # Transform the filtered data according to the detection mode. - thresholder = self._create_thresholder() - traces_t = thresholder.transform(traces_f) - # Extract all waveforms. - extractor = self._create_extractor(thresholds) - groups, samples, waveforms, masks = zip(*[extractor(component, - data=traces_f, - data_t=traces_t, - ) - for component in components]) - - # Create the return arrays. - groups = np.array(groups, dtype=np.int32) - assert groups.shape == (n_spikes,) - assert groups.dtype == np.int32 - - samples = np.array(samples, dtype=np.float64) - assert samples.shape == (n_spikes,) - assert samples.dtype == np.float64 - - # These are lists of arrays of various shapes (because of various - # groups). - waveforms = _array_list(waveforms) - assert waveforms.shape == (n_spikes,) - assert waveforms.dtype == np.object - - masks = _array_list(masks) - assert masks.dtype == np.object - assert masks.shape == (n_spikes,) - - # Reorder the spikes. - idx = np.argsort(samples) - groups = groups[idx] - samples = samples[idx] - waveforms = waveforms[idx] - masks = masks[idx] - - # Remove spikes in the overlapping bands. - # WARNING: add keep_start to spike_samples, because spike_samples - # is relative to the start of the chunk. - (keep_start, keep_end) = keep_bounds - idx = _keep_spikes(samples + keep_start, (keep_start, keep_end)) - - # Split the data according to the channel groups. - split = _split_spikes(groups, idx=idx, spike_samples=samples, - waveforms=waveforms, masks=masks) - # split: {group: {'spike_samples': ..., 'waveforms':, 'masks':}} - return split - - def waveform_pcs(self, waveforms, masks): - """Compute waveform principal components. - - Returns - ------- - - pcs : array - An `(n_features, n_samples, n_channels)` array. - - """ - pca = self._create_pca() - if waveforms is None or not len(waveforms): - return - assert (waveforms.shape[0], waveforms.shape[2]) == masks.shape - return pca.fit(waveforms, masks) - - def features(self, waveforms, pcs): - """Extract features from waveforms. - - Returns - ------- - - features : array - An `(n_spikes, n_channels, n_features)` array. - - """ - pca = self._create_pca() - out = pca.transform(waveforms, pcs=pcs) - assert out.dtype == np.float32 - return out - - # Chunking - # ------------------------------------------------------------------------- - - def iter_chunks(self, n_samples): - """Iterate over chunks.""" - rate = self._kwargs['sample_rate'] - chunk_size = int(self._kwargs['chunk_size_seconds'] * rate) - overlap = int(self._kwargs['chunk_overlap_seconds'] * rate) - for chunk_idx, bounds in enumerate(chunk_bounds(n_samples, chunk_size, - overlap=overlap)): - yield Bunch(bounds=bounds, - s_start=bounds[0], - s_end=bounds[1], - keep_start=bounds[2], - keep_end=bounds[3], - keep_bounds=(bounds[2:4]), - key=bounds[2], - chunk_idx=chunk_idx, - ) - - def n_chunks(self, n_samples): - """Number of chunks.""" - return len(list(self.iter_chunks(n_samples))) - - def chunk_keys(self, n_samples): - return [chunk.key for chunk in self.iter_chunks(n_samples)] - - # Output data - # ------------------------------------------------------------------------- - - def output_data(self): - """Bunch of values to be returned by the algorithm.""" - sc = self._store.spike_counts - chunk_keys = self._store.chunk_keys - output = Bunch(groups=self._groups, - n_chunks=len(chunk_keys), - chunk_keys=chunk_keys, - spike_samples=self._store.spike_samples(), - masks=self._store.masks(), - features=self._store.features(), - spike_counts=sc, - n_spikes_total=sc(), - n_spikes_per_group={group: sc(group=group) - for group in self._groups}, - n_spikes_per_chunk={chunk_key: sc(chunk_key=chunk_key) - for chunk_key in chunk_keys}, - ) - return output - - # Main loop - # ------------------------------------------------------------------------- - - def _iter_spikes(self, n_samples, step_spikes=1, thresholds=None): - """Iterate over extracted spikes (possibly subset). - - Yield a split dictionary `{group: {'waveforms': ..., ...}}`. - - """ - for chunk in self.iter_chunks(n_samples): - - # Extract a few components. - components = self._store.load(name='components', - chunk_key=chunk.key) - if components is None or not len(components): - yield chunk, {} - continue - - k = np.clip(step_spikes, 1, len(components)) - components = components[::k] - - # Get the filtered chunk. - chunk_f = self._store.load(name='filtered', - chunk_key=chunk.key) - - # Extract the spikes from the chunk. - split = self.extract_spikes(components, chunk_f, - keep_bounds=chunk.keep_bounds, - thresholds=thresholds) - - yield chunk, split - - def step_detect(self, traces=None, thresholds=None): - n_samples, n_channels = traces.shape - n_chunks = self.n_chunks(n_samples) - - # Pass 1: find the connected components and count the spikes. - self._pr.start_step('detect', n_chunks) - - # Dictionary {chunk_key: components}. - # Every chunk has a unique key: the `keep_start` integer. - n_spikes_total = 0 - for chunk in self.iter_chunks(n_samples): - chunk_data = data_chunk(traces, chunk.bounds, with_overlap=True) - - # Apply the filter. - data_f = self.apply_filter(chunk_data) - assert data_f.dtype == np.float32 - assert data_f.shape == chunk_data.shape - - # Save the filtered chunk. - self._store.store(name='filtered', chunk_key=chunk.key, - data=data_f) - - # Detect spikes in the filtered chunk. - components = self.detect(data_f, thresholds=thresholds, - dead_channels=self._dead_channels) - self._store.store(name='components', chunk_key=chunk.key, - data=components) - - # Report progress. - n_spikes_chunk = len(components) - n_spikes_total += n_spikes_chunk - self._pr.increment(n_spikes=n_spikes_chunk, - n_spikes_total=n_spikes_total) - - return n_spikes_total - - def step_excerpt(self, n_samples=None, - n_spikes_total=None, thresholds=None): - self._pr.start_step('excerpt', self.n_chunks(n_samples)) - - k = int(n_spikes_total / float(self._kwargs['pca_n_waveforms_max'])) - w_subset = defaultdict(list) - m_subset = defaultdict(list) - n_spikes_total = 0 - for chunk, split in self._iter_spikes(n_samples, step_spikes=k, - thresholds=thresholds): - n_spikes_chunk = 0 - for group, out in split.items(): - w_subset[group].append(out['waveforms']) - m_subset[group].append(out['masks']) - assert len(out['masks']) == len(out['waveforms']) - n_spikes_chunk += len(out['masks']) - - n_spikes_total += n_spikes_chunk - self._pr.increment(n_spikes=n_spikes_chunk, - n_spikes_total=n_spikes_total) - for group in self._groups: - w_subset[group] = _concatenate(w_subset[group]) - m_subset[group] = _concatenate(m_subset[group]) - - return w_subset, m_subset - - def step_pcs(self, w_subset=None, m_subset=None): - self._pr.start_step('pca', len(self._groups)) - pcs = {} - for group in self._groups: - # Perform PCA and return the components. - pcs[group] = self.waveform_pcs(w_subset[group], - m_subset[group]) - self._pr.increment() - return pcs - - def step_extract(self, n_samples=None, - pcs=None, thresholds=None): - self._pr.start_step('extract', self.n_chunks(n_samples)) - # chunk_counts = defaultdict(dict) # {group: {key: n_spikes}}. - n_spikes_total = 0 - for chunk, split in self._iter_spikes(n_samples, - thresholds=thresholds): - # Delete filtered and components cache files. - self._store.delete(name='filtered', chunk_key=chunk.key) - self._store.delete(name='components', chunk_key=chunk.key) - # split: {group: {'spike_samples': ..., 'waveforms':, 'masks':}} - for group, out in split.items(): - out['features'] = self.features(out['waveforms'], pcs[group]) - self._store.append(group=group, - chunk_key=chunk.key, - spike_samples=out['spike_samples'], - features=out['features'], - masks=out['masks'], - spike_offset=chunk.s_start, - ) - n_spikes_total = self._store.spike_counts() - n_spikes_chunk = self._store.spike_counts(chunk_key=chunk.key) - self._pr.increment(n_spikes_total=n_spikes_total, - n_spikes=n_spikes_chunk) - - def run_serial(self, traces, interval_samples=None): - """Run SpikeDetekt using one CPU.""" - traces, offset = _cut_traces(traces, interval_samples) - n_samples, n_channels = traces.shape - - # Initialize the main loop. - chunk_keys = self.chunk_keys(n_samples) - n_chunks = len(chunk_keys) - self._pr = SpikeDetektProgress(n_chunks=n_chunks) - self._store = SpikeDetektStore(self._tempdir, - groups=self._groups, - chunk_keys=chunk_keys) - - # Find the weak and strong thresholds. - thresholds = self.find_thresholds(traces) - - # Find dead channels. - probe_channels = self._kwargs['probe_channels'] - self._dead_channels = _find_dead_channels(probe_channels, n_channels) - - # Spike detection. - n_spikes_total = self.step_detect(traces=traces, - thresholds=thresholds) - - # Excerpt waveforms. - w_subset, m_subset = self.step_excerpt(n_samples=n_samples, - n_spikes_total=n_spikes_total, - thresholds=thresholds) - - # Compute the PCs. - pcs = self.step_pcs(w_subset=w_subset, m_subset=m_subset) - - # Compute all features. - self.step_extract(n_samples=n_samples, pcs=pcs, thresholds=thresholds) - - return self.output_data() diff --git a/phy/detect/store.py b/phy/detect/store.py deleted file mode 100644 index 179f511af..000000000 --- a/phy/detect/store.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Spike detection store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op -from collections import defaultdict - -import numpy as np -from six import string_types - -from ..utils.array import (_save_arrays, - _load_arrays, - _concatenate, - ) -from ..utils.logging import debug -from ..utils.settings import _ensure_dir_exists - - -#------------------------------------------------------------------------------ -# Spike counts -#------------------------------------------------------------------------------ - -class SpikeCounts(object): - """Count spikes in chunks and channel groups.""" - def __init__(self, counts=None, groups=None, chunk_keys=None): - self._groups = groups - self._chunk_keys = chunk_keys - self._counts = counts or defaultdict(lambda: defaultdict(int)) - - def append(self, group=None, chunk_key=None, count=None): - self._counts[group][chunk_key] += count - - @property - def counts(self): - return self._counts - - def per_group(self, group): - return sum(self._counts.get(group, {}).values()) - - def per_chunk(self, chunk_key): - return sum(self._counts[group].get(chunk_key, 0) - for group in self._groups) - - def __call__(self, group=None, chunk_key=None): - if group is not None and chunk_key is not None: - return self._counts.get(group, {}).get(chunk_key, 0) - elif group is not None: - return self.per_group(group) - elif chunk_key is not None: - return self.per_chunk(chunk_key) - elif group is None and chunk_key is None: - return sum(self.per_group(group) for group in self._groups) - - -#------------------------------------------------------------------------------ -# Spike detection store -#------------------------------------------------------------------------------ - -class ArrayStore(object): - def __init__(self, root_dir): - self._root_dir = op.realpath(root_dir) - _ensure_dir_exists(self._root_dir) - - def _rel_path(self, **kwargs): - """Relative to the root.""" - raise NotImplementedError() - - def _path(self, **kwargs): - """Absolute path of a data file.""" - path = op.realpath(op.join(self._root_dir, self._rel_path(**kwargs))) - _ensure_dir_exists(op.dirname(path)) - assert path.endswith('.npy') - return path - - def _offsets_path(self, path): - assert path.endswith('.npy') - return op.splitext(path)[0] + '.offsets.npy' - - def _contains_multiple_arrays(self, path): - return op.exists(path) and op.exists(self._offsets_path(path)) - - def store(self, data=None, **kwargs): - """Store an array or list of arrays.""" - path = self._path(**kwargs) - if isinstance(data, list): - if not data: - return - _save_arrays(path, data) - elif isinstance(data, np.ndarray): - dtype = data.dtype - if not data.size: - return - assert dtype != np.object - np.save(path, data) - # debug("Store {}.".format(path)) - - def load(self, **kwargs): - path = self._path(**kwargs) - if not op.exists(path): - debug("File `{}` doesn't exist.".format(path)) - return - # Multiple arrays: - # debug("Load {}.".format(path)) - if self._contains_multiple_arrays(path): - return _load_arrays(path) - else: - return np.load(path) - - def delete(self, **kwargs): - path = self._path(**kwargs) - if op.exists(path): - os.remove(path) - # debug("Deleted `{}`.".format(path)) - offsets_path = self._offsets_path(path) - if op.exists(offsets_path): - os.remove(offsets_path) - # debug("Deleted `{}`.".format(offsets_path)) - - -class SpikeDetektStore(ArrayStore): - """Store the following items: - - * filtered - * components - * spike_samples - * features - * masks - - """ - def __init__(self, root_dir, groups=None, chunk_keys=None): - super(SpikeDetektStore, self).__init__(root_dir) - self._groups = groups - self._chunk_keys = chunk_keys - self._spike_counts = SpikeCounts(groups=groups, chunk_keys=chunk_keys) - - def _rel_path(self, name=None, chunk_key=None, group=None): - assert chunk_key >= 0 - assert group is None or group >= 0 - assert isinstance(name, string_types) - group = group if group is not None else 'all' - return 'group_{group}/{name}/chunk_{chunk:d}.npy'.format( - chunk=chunk_key, name=name, group=group) - - @property - def groups(self): - return self._groups - - @property - def chunk_keys(self): - return self._chunk_keys - - def _iter(self, group=None, name=None): - for chunk_key in self.chunk_keys: - yield self.load(group=group, chunk_key=chunk_key, name=name) - - def spike_samples(self, group=None): - if group is None: - return {group: self.spike_samples(group) for group in self._groups} - return self.concatenate(self._iter(group=group, name='spike_samples')) - - def features(self, group=None): - """Yield chunk features.""" - if group is None: - return {group: self.features(group) for group in self._groups} - return self._iter(group=group, name='features') - - def masks(self, group=None): - """Yield chunk masks.""" - if group is None: - return {group: self.masks(group) for group in self._groups} - return self._iter(group=group, name='masks') - - @property - def spike_counts(self): - return self._spike_counts - - def append(self, group=None, chunk_key=None, - spike_samples=None, features=None, masks=None, - spike_offset=0): - if spike_samples is None or len(spike_samples) == 0: - return - n = len(spike_samples) - assert features.shape[0] == n - assert masks.shape[0] == n - spike_samples = spike_samples + spike_offset - - self.store(group=group, chunk_key=chunk_key, - name='features', data=features) - self.store(group=group, chunk_key=chunk_key, - name='masks', data=masks) - self.store(group=group, chunk_key=chunk_key, - name='spike_samples', data=spike_samples) - self._spike_counts.append(group=group, chunk_key=chunk_key, count=n) - - def concatenate(self, arrays): - return _concatenate(arrays) - - def delete_all(self, name): - """Delete all files for a given data name.""" - for group in self._groups: - for chunk_key in self._chunk_keys: - super(SpikeDetektStore, self).delete(name=name, group=group, - chunk_key=chunk_key) diff --git a/phy/detect/tests/__init__.py b/phy/detect/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/detect/tests/test_spikedetekt.py b/phy/detect/tests/test_spikedetekt.py deleted file mode 100644 index 1f49e267a..000000000 --- a/phy/detect/tests/test_spikedetekt.py +++ /dev/null @@ -1,157 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of clustering algorithms.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_equal as ae -from pytest import mark - -from ...utils.logging import set_level -from ...utils.testing import show_test -from ..spikedetekt import (SpikeDetekt, _split_spikes, _concat, _concatenate) - - -#------------------------------------------------------------------------------ -# Tests spike detection -#------------------------------------------------------------------------------ - -def setup(): - set_level('info') - - -def test_split_spikes(): - groups = np.zeros(10, dtype=np.int) - groups[1::2] = 1 - - idx = np.ones(10, dtype=np.bool) - idx[0] = False - idx[-1] = False - - a = np.random.rand(10, 2) - b = np.random.rand(10, 3, 2) - - out = _split_spikes(groups, idx, a=a, b=b) - - assert sorted(out) == [0, 1] - assert sorted(out[0]) == ['a', 'b'] - assert sorted(out[1]) == ['a', 'b'] - - ae(out[0]['a'], a[1:-1][1::2]) - ae(out[0]['b'], b[1:-1][1::2]) - - ae(out[1]['a'], a[1:-1][::2]) - ae(out[1]['b'], b[1:-1][::2]) - - -def test_spike_detect_methods(tempdir, raw_dataset): - params = raw_dataset.params - probe = raw_dataset.probe - sample_rate = raw_dataset.sample_rate - sd = SpikeDetekt(tempdir=tempdir, - probe=raw_dataset.probe, - sample_rate=sample_rate, - **params) - traces = raw_dataset.traces - n_samples = raw_dataset.n_samples - n_channels = raw_dataset.n_channels - - # Filter the data. - traces_f = sd.apply_filter(traces) - assert traces_f.shape == traces.shape - assert not np.any(np.isnan(traces_f)) - - # Thresholds. - thresholds = sd.find_thresholds(traces) - assert np.all(0 <= thresholds['weak']) - assert np.all(thresholds['weak'] <= thresholds['strong']) - - # Spike detection. - traces_f[1000:1010, :3] *= 5 - traces_f[2000:2010, [0, 2]] *= 5 - traces_f[3000:3020, :] *= 5 - components = sd.detect(traces_f, thresholds) - assert isinstance(components, list) - # n_spikes = len(components) - n_samples_waveforms = (params['extract_s_before'] + - params['extract_s_after']) - - # Spike extraction. - split = sd.extract_spikes(components, traces_f, thresholds, - keep_bounds=(0, n_samples)) - - if not split: - return - samples = _concat(split[0]['spike_samples'], np.float64) - waveforms = _concat(split[0]['waveforms'], np.float32) - masks = _concat(split[0]['masks'], np.float32) - - n_spikes = len(samples) - n_channels = len(probe['channel_groups'][0]['channels']) - - assert samples.dtype == np.float64 - assert samples.shape == (n_spikes,) - assert waveforms.shape == (n_spikes, n_samples_waveforms, n_channels) - assert masks.shape == (n_spikes, n_channels) - assert 0. <= masks.min() < masks.max() <= 1. - assert not np.any(np.isnan(samples)) - assert not np.any(np.isnan(waveforms)) - assert not np.any(np.isnan(masks)) - - # PCA. - pcs = sd.waveform_pcs(waveforms, masks) - n_pcs = params['n_features_per_channel'] - assert pcs.shape == (n_pcs, n_samples_waveforms, n_channels) - assert not np.any(np.isnan(pcs)) - - # Features. - features = sd.features(waveforms, pcs) - assert features.shape == (n_spikes, n_channels, n_pcs) - assert not np.any(np.isnan(features)) - - -@mark.long -def test_spike_detect_real_data(tempdir, raw_dataset): - - params = raw_dataset.params - probe = raw_dataset.probe - sample_rate = raw_dataset.sample_rate - sd = SpikeDetekt(tempdir=tempdir, - probe=probe, - sample_rate=sample_rate, - **params) - traces = raw_dataset.traces - n_samples = raw_dataset.n_samples - npc = params['n_features_per_channel'] - n_samples_w = params['extract_s_before'] + params['extract_s_after'] - - # Run the detection. - out = sd.run_serial(traces, interval_samples=(0, n_samples)) - - channels = probe['channel_groups'][0]['channels'] - n_channels = len(channels) - - spike_samples = _concatenate(out.spike_samples[0]) - masks = _concatenate(out.masks[0]) - features = _concatenate(out.features[0]) - n_spikes = out.n_spikes_per_group[0] - - if n_spikes: - assert spike_samples.shape == (n_spikes,) - assert masks.shape == (n_spikes, n_channels) - assert features.shape == (n_spikes, n_channels, npc) - - # There should not be any spike with only masked channels. - assert np.all(masks.max(axis=1) > 0) - - # Plot... - from phy.plot.traces import plot_traces - c = plot_traces(traces[:30000, channels], - spike_samples=spike_samples, - masks=masks, - n_samples_per_spike=n_samples_w, - show=False) - show_test(c) diff --git a/phy/detect/tests/test_store.py b/phy/detect/tests/test_store.py deleted file mode 100644 index 5ada61b66..000000000 --- a/phy/detect/tests/test_store.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of spike detection store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import yield_fixture -import numpy as np -from numpy.testing import assert_equal as ae - -from ...utils.logging import set_level -from ..store import SpikeCounts, ArrayStore, SpikeDetektStore - - -#------------------------------------------------------------------------------ -# Tests spike detection store -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -@yield_fixture(params=['from_dict', 'append']) -def spike_counts(request): - groups = [0, 2] - chunk_keys = [10, 20, 30] - if request.param == 'from_dict': - c = {0: {10: 100, 20: 200}, - 2: {10: 1, 30: 300}, - } - sc = SpikeCounts(c, groups=groups, chunk_keys=chunk_keys) - elif request.param == 'append': - sc = SpikeCounts(groups=groups, chunk_keys=chunk_keys) - sc.append(group=0, chunk_key=10, count=100) - sc.append(group=0, chunk_key=20, count=200) - sc.append(group=2, chunk_key=10, count=1) - sc.append(group=2, chunk_key=30, count=300) - yield sc - - -def test_spike_counts(spike_counts): - assert spike_counts() == 601 - - assert spike_counts(group=0) == 300 - assert spike_counts(group=1) == 0 - assert spike_counts(group=2) == 301 - - assert spike_counts(chunk_key=10) == 101 - assert spike_counts(chunk_key=20) == 200 - assert spike_counts(chunk_key=30) == 300 - - -class TestArrayStore(ArrayStore): - def _rel_path(self, a=None, b=None): - return '{}/{}.npy'.format(a or 'none_a', b or 'none_b') - - -def test_array_store(tempdir): - store = TestArrayStore(tempdir) - - store.store(a=1, b=1, data=np.arange(11)) - store.store(a=1, b=2, data=np.arange(12)) - store.store(b=2, data=np.arange(2)) - store.store(b=3, data=None) - store.store(a=1, data=[np.arange(3), np.arange(3, 8)]) - - ae(store.load(a=0, b=1), None) - ae(store.load(a=1, b=1), np.arange(11)) - ae(store.load(a=1, b=2), np.arange(12)) - ae(store.load(b=2), np.arange(2)) - ae(store.load(b=3), None) - ae(store.load(a=1)[0], np.arange(3)) - ae(store.load(a=1)[1], np.arange(3, 8)) - - store.delete(b=2) - store.delete(a=1, b=2) - - -def test_spikedetekt_store(tempdir): - groups = [0, 2] - chunk_keys = [10, 20, 30] - - n_channels = 4 - npc = 2 - - _keys = [(0, 10), (0, 20), (2, 10), (2, 30)] - _counts = [100, 200, 1, 300] - - # Generate random spike samples, features, and masks. - s = {k: np.arange(c) for k, c in zip(_keys, _counts)} - f = {k: np.random.rand(c, n_channels, npc) - for k, c in zip(_keys, _counts)} - m = {k: np.random.rand(c, n_channels) - for k, c in zip(_keys, _counts)} - - store = SpikeDetektStore(tempdir, groups=groups, chunk_keys=chunk_keys) - - # Save data. - for group, chunk_key in _keys: - spike_samples = s.get((group, chunk_key), None) - features = f.get((group, chunk_key), None) - masks = m.get((group, chunk_key), None) - - store.append(group=group, chunk_key=chunk_key, - spike_samples=spike_samples, - features=features, - masks=masks, - spike_offset=chunk_key, - ) - - # Load data. - for group in groups: - # Check spike samples. - ae(store.spike_samples(group), - np.hstack([s[key] + key[1] for key in _keys if key[0] == group])) - - # Check features and masks. - for name in ('features', 'masks'): - # Actual data. - data_dict = f if name == 'features' else m - # Stored data (generator). - data_gen = getattr(store, name)(group) - # Go through all chunks. - for chunk_key, data in zip(chunk_keys, data_gen): - if (group, chunk_key) in _keys: - ae(data_dict[group, chunk_key], data) - else: - assert data is None - - # Test spike counts. - test_spike_counts(store.spike_counts) diff --git a/phy/io/base.py b/phy/io/base.py deleted file mode 100644 index 02c7050cf..000000000 --- a/phy/io/base.py +++ /dev/null @@ -1,519 +0,0 @@ -# -*- coding: utf-8 -*- - -"""The BaseModel class holds the data from an experiment.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from collections import defaultdict - -import numpy as np - -import six -from ..utils import debug, EventEmitter -from ..utils._types import _as_list, _is_list -from ..utils.settings import (Settings, - _ensure_dir_exists, - _phy_user_dir, - ) - - -#------------------------------------------------------------------------------ -# ClusterMetadata class -#------------------------------------------------------------------------------ - -class ClusterMetadata(object): - """Hold cluster metadata. - - Features - -------- - - * New metadata fields can be easily registered - * Arbitrary functions can be used for default values - - Notes - ---- - - If a metadata field `group` is registered, then two methods are - dynamically created: - - * `group(cluster)` returns the group of a cluster, or the default value - if the cluster doesn't exist. - * `set_group(cluster, value)` sets a value for the `group` metadata field. - - """ - def __init__(self, data=None): - self._fields = {} - self._data = defaultdict(dict) - # Fill the existing values. - if data is not None: - self._data.update(data) - - @property - def data(self): - return self._data - - def _get_one(self, cluster, field): - """Return the field value for a cluster, or the default value if it - doesn't exist.""" - if cluster in self._data: - if field in self._data[cluster]: - return self._data[cluster][field] - elif field in self._fields: - # Call the default field function. - return self._fields[field](cluster) - else: - return None - else: - if field in self._fields: - return self._fields[field](cluster) - else: - return None - - def _get(self, clusters, field): - if _is_list(clusters): - return [self._get_one(cluster, field) - for cluster in _as_list(clusters)] - else: - return self._get_one(clusters, field) - - def _set_one(self, cluster, field, value): - """Set a field value for a cluster.""" - self._data[cluster][field] = value - - def _set(self, clusters, field, value): - clusters = _as_list(clusters) - for cluster in clusters: - self._set_one(cluster, field, value) - - def default(self, func): - """Register a new metadata field with a function - returning the default value of a cluster.""" - field = func.__name__ - # Register the decorated function as the default field function. - self._fields[field] = func - # Create self.(clusters). - setattr(self, field, lambda clusters: self._get(clusters, field)) - # Create self.set_(clusters, value). - setattr(self, 'set_{0:s}'.format(field), - lambda clusters, value: self._set(clusters, field, value)) - return func - - -#------------------------------------------------------------------------------ -# BaseModel class -#------------------------------------------------------------------------------ - -class BaseModel(object): - """This class holds data from an experiment. - - This base class must be derived. - - """ - def __init__(self): - self.name = 'model' - self._channel_group = None - self._clustering = None - - @property - def path(self): - return None - - # Channel groups - # ------------------------------------------------------------------------- - - @property - def channel_group(self): - return self._channel_group - - @channel_group.setter - def channel_group(self, value): - assert isinstance(value, six.integer_types) - self._channel_group = value - self._channel_group_changed(value) - - def _channel_group_changed(self, value): - """Called when the channel group changes. - - May be implemented by child classes. - - """ - pass - - @property - def channel_groups(self): - """List of channel groups. - - May be implemented by child classes. - - """ - return [] - - # Clusterings - # ------------------------------------------------------------------------- - - @property - def clustering(self): - return self._clustering - - @clustering.setter - def clustering(self, value): - # The clustering is specified by a string. - assert isinstance(value, six.string_types) - self._clustering = value - self._clustering_changed(value) - - def _clustering_changed(self, value): - """Called when the clustering changes. - - May be implemented by child classes. - - """ - pass - - @property - def clusterings(self): - """List of clusterings. - - May be implemented by child classes. - - """ - return [] - - # Data - # ------------------------------------------------------------------------- - - @property - def metadata(self): - """A dictionary holding metadata about the experiment. - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def traces(self): - """Traces (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def spike_samples(self): - """Spike times from the current channel_group. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def sample_rate(self): - pass - - @property - def spike_times(self): - """Spike times from the current channel_group. - - This is a NumPy array containing `float64` values (in seconds). - - The spike times of all recordings are concatenated. There is no gap - between consecutive recordings, currently. - - """ - return self.spike_samples.astype(np.float64) / self.sample_rate - - @property - def spike_clusters(self): - """Spike clusters from the current channel_group. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - def spike_train(self, cluster_id): - """Return the spike times of a given cluster.""" - return self.spike_times[self.spikes_per_cluster[cluster_id]] - - @property - def spikes_per_cluster(self): - """Spikes per cluster dictionary. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - def update_spikes_per_cluster(self, spc): - raise NotImplementedError() - - @property - def cluster_metadata(self): - """ClusterMetadata instance holding information about the clusters. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def cluster_groups(self): - """Groups of all clusters in the current channel group and clustering. - - This is a regular Python dictionary. - - """ - return {cluster: self.cluster_metadata.group(cluster) - for cluster in self.cluster_ids} - - @property - def features(self): - """Features from the current channel_group (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def masks(self): - """Masks from the current channel_group (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def waveforms(self): - """Waveforms from the current channel_group (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def probe(self): - """A Probe instance. - - May be implemented by child classes. - - """ - raise NotImplementedError() - - def save(self): - """Save the data. - - May be implemented by child classes. - - """ - raise NotImplementedError() - - def close(self): - """Close the model and the underlying files. - - May be implemented by child classes. - - """ - pass - - -#------------------------------------------------------------------------------ -# Session -#------------------------------------------------------------------------------ - -class BaseSession(EventEmitter): - """Give access to the data, views, and GUIs in an interactive session. - - The model must implement: - - * `model(path)` - * `model.path` - * `model.close()` - - Events - ------ - - open - close - - """ - def __init__(self, - model=None, - path=None, - phy_user_dir=None, - default_settings_paths=None, - vm_classes=None, - gui_classes=None, - ): - super(BaseSession, self).__init__() - - self.model = None - if phy_user_dir is None: - phy_user_dir = _phy_user_dir() - _ensure_dir_exists(phy_user_dir) - self.phy_user_dir = phy_user_dir - - self._create_settings(default_settings_paths) - - if gui_classes is None: - gui_classes = self.settings['gui_classes'] - - # HACK: avoid Qt import here - try: - from ..gui.base import WidgetCreator - self._gui_creator = WidgetCreator(widget_classes=gui_classes) - except ImportError: - self._gui_creator = None - - self.connect(self.on_open) - self.connect(self.on_close) - - # Custom `on_open()` callback function. - if 'on_open' in self.settings: - @self.connect - def on_open(): - self.settings['on_open'](self) - - self._pre_open() - if model or path: - self.open(path, model=model) - - def _create_settings(self, default_settings_paths): - self.settings = Settings(phy_user_dir=self.phy_user_dir, - default_paths=default_settings_paths, - ) - - @self.connect - def on_open(): - # Initialize the settings with the model's path. - self.settings.on_open(self.experiment_path) - - # Methods to override - # ------------------------------------------------------------------------- - - def _pre_open(self): - pass - - def _create_model(self, path): - """Create a model from a path. - - Must be overriden. - - """ - pass - - def _save_model(self): - """Save a model. - - Must be overriden. - - """ - pass - - def on_open(self): - pass - - def on_close(self): - pass - - # File-related actions - # ------------------------------------------------------------------------- - - def open(self, path=None, model=None): - """Open a dataset.""" - # Close the session if it is already open. - if self.model: - self.close() - if model is None: - model = self._create_model(path) - self.model = model - self.experiment_path = (op.realpath(path) - if path else self.phy_user_dir) - self.emit('open') - - def reopen(self): - self.open(model=self.model) - - def save(self): - self._save_model() - - def close(self): - """Close the currently-open dataset.""" - self.model.close() - self.emit('close') - self.model = None - - # Views and GUIs - # ------------------------------------------------------------------------- - - def show_gui(self, name=None, show=True, **kwargs): - """Show a new GUI.""" - if name is None: - gui_classes = list(self._gui_creator.widget_classes.keys()) - if gui_classes: - name = gui_classes[0] - - # Get the default GUI config. - params = {p: self.settings.get('{}_{}'.format(name, p), None) - for p in ('config', 'shortcuts', 'snippets', 'state')} - params.update(kwargs) - - # Create the GUI. - gui = self._gui_creator.add(name, - model=self.model, - settings=self.settings, - **params) - gui._save_state = True - - # Connect the 'open' event. - self.connect(gui.on_open) - - @gui.main_window.connect_ - def on_close_gui(): - self.unconnect(gui.on_open) - # Save the params of every view in the GUI. - for vm in gui.views: - self.save_view_params(vm, save_size_pos=False) - gs = gui.main_window.save_geometry_state() - gs['view_count'] = gui.view_count() - if not gui._save_state: - gs['state'] = None - gs['geometry'] = None - self.settings['{}_state'.format(name)] = gs - self.settings.save() - - # HACK: do not save GUI state when views have been closed or reset - # in the session, otherwise Qt messes things up in the GUI. - @gui.connect - def on_close_view(view): - gui._save_state = False - - @gui.connect - def on_reset_gui(): - gui._save_state = False - - # Custom `on_gui_open()` callback. - if 'on_gui_open' in self.settings: - self.settings['on_gui_open'](self, gui) - - if show: - gui.show() - - return gui - - def save_view_params(self, vm, save_size_pos=True): - """Save the parameters exported by a view model instance.""" - to_save = vm.exported_params(save_size_pos=save_size_pos) - for key, value in to_save.items(): - assert vm.name - name = '{}_{}'.format(vm.name, key) - self.settings[name] = value - debug("Save {0}={1} for {2}.".format(name, value, vm.name)) diff --git a/phy/io/kwik/__init__.py b/phy/io/kwik/__init__.py deleted file mode 100644 index f1aabd050..000000000 --- a/phy/io/kwik/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Kwik format.""" - -from .store_items import (FeatureMasks, - Waveforms, - ClusterStatistics, - create_store, - ) -from .creator import KwikCreator, create_kwik -from .model import KwikModel diff --git a/phy/io/kwik/creator.py b/phy/io/kwik/creator.py deleted file mode 100644 index 2a7400863..000000000 --- a/phy/io/kwik/creator.py +++ /dev/null @@ -1,463 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Kwik creator.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np -from h5py import Dataset -from six import string_types -from six.moves import zip - -import phy -from ...electrode.mea import load_probe -from ..h5 import open_h5 -from ..traces import _dat_n_samples -from ...utils._misc import _read_python -from ...utils._types import _as_array -from ...utils.array import _unique -from ...utils.logging import warn -from ...utils.settings import _load_default_settings - - -#------------------------------------------------------------------------------ -# Kwik creator -#------------------------------------------------------------------------------ - -def _write_by_chunk(dset, arrs): - # Note: arrs should be a generator for performance reasons. - assert isinstance(dset, Dataset) - # Start the data. - offset = 0 - for arr in arrs: - n = arr.shape[0] - arr = arr[...] - # Match the shape of the chunk array with the dset shape. - assert arr.shape == (n,) + dset.shape[1:] - dset[offset:offset + n, ...] = arr - offset += arr.shape[0] - # Check that the copy is complete. - assert offset == dset.shape[0] - - -def _concat(arrs): - return np.hstack([_[...] for _ in arrs]) - - -_DEFAULT_GROUPS = [(0, 'Noise'), - (1, 'MUA'), - (2, 'Good'), - (3, 'Unsorted'), - ] - - -class KwikCreator(object): - """Create and modify a `.kwik` file.""" - def __init__(self, basename=None, kwik_path=None, kwx_path=None): - # Find the .kwik filename. - if kwik_path is None: - assert basename is not None - if basename.endswith('.kwik'): - basename, _ = op.splitext(basename) - kwik_path = basename + '.kwik' - self.kwik_path = kwik_path - if basename is None: - basename, _ = op.splitext(kwik_path) - self.basename = basename - - # Find the .kwx filename. - if kwx_path is None: - basename, _ = op.splitext(kwik_path) - kwx_path = basename + '.kwx' - self.kwx_path = kwx_path - - def create_empty(self): - """Create empty `.kwik` and `.kwx` files.""" - assert not op.exists(self.kwik_path) - with open_h5(self.kwik_path, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - f.write_attr('/', 'name', self.basename) - f.write_attr('/', 'creator_version', 'phy ' - + phy.__version_git__) - - assert not op.exists(self.kwx_path) - with open_h5(self.kwx_path, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - f.write_attr('/', 'creator_version', 'phy ' - + phy.__version_git__) - - def set_metadata(self, path, **kwargs): - """Set metadata fields in a HDF5 path.""" - assert isinstance(path, string_types) - assert path - with open_h5(self.kwik_path, 'a') as f: - for key, value in kwargs.items(): - f.write_attr(path, key, value) - - def set_probe(self, probe): - """Save a probe dictionary in the file.""" - with open_h5(self.kwik_path, 'a') as f: - probe = probe['channel_groups'] - for group, d in probe.items(): - group = int(group) - channels = np.array(list(d['channels']), dtype=np.int32) - - # Write the channel order. - f.write_attr('/channel_groups/{:d}'.format(group), - 'channel_order', channels) - - # Write the probe adjacency graph. - graph = d.get('graph', []) - graph = np.array(graph, dtype=np.int32) - f.write_attr('/channel_groups/{:d}'.format(group), - 'adjacency_graph', graph) - - # Write the channel positions. - positions = d.get('geometry', {}) - for channel in channels: - channel = int(channel) - # Get the channel position. - if channel in positions: - position = positions[channel] - else: - # Default position. - position = (0, channel) - path = '/channel_groups/{:d}/channels/{:d}'.format( - group, channel) - - f.write_attr(path, 'name', str(channel)) - - position = np.array(position, dtype=np.float32) - f.write_attr(path, 'position', position) - - def add_spikes(self, - group=None, - spike_samples=None, - spike_recordings=None, - masks=None, - features=None, - n_channels=None, - n_features=None, - ): - """Add spikes in the file. - - Parameters - ---------- - - group : int - Channel group. - spike_samples : ndarray - The spike times in number of samples. - spike_recordings : ndarray (optional) - The recording indices of every spike. - masks : ndarray or list of ndarrays - The masks. - features : ndarray or list of ndarrays - The features. - - Note - ---- - - The features and masks can be passed as lists or generators of - (memmapped) data chunks in order to avoid loading the entire arrays - in RAM. - - """ - assert group >= 0 - assert n_channels >= 0 - assert n_features >= 0 - - if spike_samples is None: - return - if isinstance(spike_samples, list): - spike_samples = _concat(spike_samples) - spike_samples = _as_array(spike_samples, dtype=np.float64).ravel() - n_spikes = len(spike_samples) - if spike_recordings is None: - spike_recordings = np.zeros(n_spikes, dtype=np.int32) - spike_recordings = spike_recordings.ravel() - - # Add spikes in the .kwik file. - assert op.exists(self.kwik_path) - with open_h5(self.kwik_path, 'a') as f: - # This method can only be called once. - if '/channel_groups/{:d}/spikes/time_samples'.format(group) in f: - raise RuntimeError("Spikes have already been added to this " - "dataset.") - time_samples = spike_samples.astype(np.uint64) - frac = ((spike_samples - time_samples) * 255).astype(np.uint8) - f.write('/channel_groups/{:d}/spikes/time_samples'.format(group), - time_samples) - f.write('/channel_groups/{}/spikes/time_fractional'.format(group), - frac) - f.write('/channel_groups/{:d}/spikes/recording'.format(group), - spike_recordings) - - if masks is None and features is None: - return - # Add features and masks in the .kwx file. - assert masks is not None - assert features is not None - - # Determine the shape of the features_masks array. - shape = (n_spikes, n_channels * n_features, 2) - - def transform_f(f): - return f[...].reshape((-1, n_channels * n_features)) - - def transform_m(m): - return np.repeat(m[...], 3, axis=1) - - assert op.exists(self.kwx_path) - with open_h5(self.kwx_path, 'a') as f: - fm = f.write('/channel_groups/{:d}/features_masks'.format(group), - shape=shape, dtype=np.float32) - - # Write the features and masks in one block. - if (isinstance(features, np.ndarray) and - isinstance(masks, np.ndarray)): - fm[:, :, 0] = transform_f(features) - fm[:, :, 1] = transform_m(masks) - # Write the features and masks chunk by chunk. - else: - # Concatenate the features/masks chunks. - fm_arrs = (np.dstack((transform_f(fet), transform_m(m))) - for (fet, m) in zip(features, masks) - if fet is not None and m is not None) - _write_by_chunk(fm, fm_arrs) - - def add_recording(self, id=None, raw_path=None, - start_sample=None, sample_rate=None): - """Add a recording. - - Parameters - ---------- - - id : int - The recording id (0, 1, 2, etc.). - raw_path : str - Path to the file containing the raw data. - start_sample : int - The offset of the recording, in number of samples. - sample_rate : float - The sample rate of the recording - - """ - path = '/recordings/{:d}'.format(id) - start_sample = int(start_sample) - sample_rate = float(sample_rate) - - with open_h5(self.kwik_path, 'a') as f: - f.write_attr(path, 'name', 'recording_{:d}'.format(id)) - f.write_attr(path, 'start_sample', start_sample) - f.write_attr(path, 'sample_rate', sample_rate) - f.write_attr(path, 'start_time', start_sample / sample_rate) - if raw_path: - if op.splitext(raw_path)[1] == '.kwd': - f.write_attr(path + '/raw', 'hdf5_path', raw_path) - elif op.splitext(raw_path)[1] == '.dat': - f.write_attr(path + '/raw', 'dat_path', raw_path) - - def _add_recordings_from_dat(self, files, sample_rate=None, - n_channels=None, dtype=None): - start_sample = 0 - for i, filename in enumerate(files): - # WARNING: different sample rates in recordings is not - # supported yet. - self.add_recording(id=i, - start_sample=start_sample, - sample_rate=sample_rate, - raw_path=filename, - ) - assert op.splitext(filename)[1] == '.dat' - # Compute the offset for different recordings. - start_sample += _dat_n_samples(filename, - n_channels=n_channels, - dtype=dtype) - - def _add_recordings_from_kwd(self, file, sample_rate=None): - assert file.endswith('.kwd') - start_sample = 0 - with open_h5(file, 'r') as f: - recordings = f.children('/recordings') - for recording in recordings: - path = '/recordings/{}'.format(recording) - if f.has_attr(path, 'sample_rate'): - sample_rate = f.read_attr(path, 'sample_rate') - assert sample_rate > 0 - self.add_recording(id=int(recording), - start_sample=start_sample, - sample_rate=sample_rate, - raw_path=file, - ) - start_sample += f.read(path + '/data').shape[0] - - def add_cluster_group(self, - group=None, - id=None, - name=None, - clustering=None, - ): - """Add a cluster group. - - Parameters - ---------- - - group : int - The channel group. - id : int - The cluster group id. - name : str - The cluster group name. - clustering : str - The name of the clustering. - - """ - assert group >= 0 - cg_path = ('/channel_groups/{0:d}/' - 'cluster_groups/{1:s}/{2:d}').format(group, - clustering, - id, - ) - with open_h5(self.kwik_path, 'a') as f: - f.write_attr(cg_path, 'name', name) - - def add_clustering(self, - group=None, - name=None, - spike_clusters=None, - cluster_groups=None, - ): - """Add a clustering. - - Parameters - ---------- - - group : int - The channel group. - name : str - The clustering name. - spike_clusters : ndarray - The spike clusters assignements. This is `(n_spikes,)` array. - cluster_groups : dict - The cluster group of every cluster. - - """ - if cluster_groups is None: - cluster_groups = {} - path = '/channel_groups/{0:d}/spikes/clusters/{1:s}'.format( - group, name) - - with open_h5(self.kwik_path, 'a') as f: - assert not f.exists(path) - - # Save spike_clusters. - spike_clusters = spike_clusters.astype(np.int32).ravel() - f.write(path, spike_clusters) - cluster_ids = _unique(spike_clusters) - - # Create cluster metadata. - for cluster in cluster_ids: - cluster_path = '/channel_groups/{0:d}/clusters/{1:s}/{2:d}'. \ - format(group, name, cluster) - - # Default group: unsorted. - cluster_group = cluster_groups.get(cluster, 3) - f.write_attr(cluster_path, 'cluster_group', cluster_group) - - # Create cluster group metadata. - for group_id, cg_name in _DEFAULT_GROUPS: - self.add_cluster_group(id=group_id, - name=cg_name, - clustering=name, - group=group, - ) - - -def create_kwik(prm_file=None, kwik_path=None, overwrite=False, - probe=None, **kwargs): - """Create a new Kwik dataset from a PRM file.""" - prm = _read_python(prm_file) if prm_file else {} - - if prm: - assert 'spikedetekt' in prm - assert 'traces' in prm - sample_rate = prm['traces']['sample_rate'] - - if 'sample_rate' in kwargs: - sample_rate = kwargs['sample_rate'] - - assert sample_rate > 0 - - # Default SpikeDetekt parameters. - settings = _load_default_settings() - params = settings['spikedetekt'] - params.update(settings['traces']) - # Update with PRM and user parameters. - if prm: - params['experiment_name'] = prm['experiment_name'] - params['prb_file'] = prm['prb_file'] - params.update(prm['spikedetekt']) - params.update(prm['klustakwik2']) - params.update(prm['traces']) - params.update(kwargs) - - kwik_path = kwik_path or params['experiment_name'] + '.kwik' - kwx_path = op.splitext(kwik_path)[0] + '.kwx' - if op.exists(kwik_path): - if overwrite: - os.remove(kwik_path) - os.remove(kwx_path) - else: - raise IOError("The `.kwik` file already exists. Please use " - "the `--overwrite` option.") - - # Ensure the probe file exists if it is required. - if probe is None: - probe = load_probe(params['prb_file']) - assert probe - - # KwikCreator. - creator = KwikCreator(kwik_path) - creator.create_empty() - creator.set_probe(probe) - - # Add the recordings. - raw_data_files = params.get('raw_data_files', None) - if isinstance(raw_data_files, string_types): - if raw_data_files.endswith('.raw.kwd'): - creator._add_recordings_from_kwd(raw_data_files, - sample_rate=sample_rate, - ) - else: - raw_data_files = [raw_data_files] - if isinstance(raw_data_files, list) and len(raw_data_files): - if len(raw_data_files) > 1: - raise NotImplementedError("There is no support for " - "multiple .dat files yet.") - # The dtype must be a string so that it can be serialized in HDF5. - if not params.get('dtype', None): - warn("The `dtype` parameter is mandatory. Using a default value " - "of `int16` for now. Please update your `.prm` file.") - params['dtype'] = 'int16' - assert 'dtype' in params and isinstance(params['dtype'], string_types) - dtype = np.dtype(params['dtype']) - assert dtype is not None - - # The number of channels in the .dat file *must* be specified. - n_channels = params['n_channels'] - assert n_channels > 0 - creator._add_recordings_from_dat(raw_data_files, - sample_rate=sample_rate, - n_channels=n_channels, - dtype=dtype, - ) - - creator.set_metadata('/application_data/spikedetekt', **params) - - return kwik_path diff --git a/phy/io/kwik/mock.py b/phy/io/kwik/mock.py deleted file mode 100644 index 357286c8a..000000000 --- a/phy/io/kwik/mock.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Mock Kwik files.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np - -from ...electrode.mea import staggered_positions -from ..mock import (artificial_spike_samples, - artificial_spike_clusters, - artificial_features, - artificial_masks, - artificial_traces) -from ..h5 import open_h5 -from .model import _create_clustering - - -#------------------------------------------------------------------------------ -# Mock Kwik file -#------------------------------------------------------------------------------ - -def create_mock_kwik(dir_path, n_clusters=None, n_spikes=None, - n_channels=None, n_features_per_channel=None, - n_samples_traces=None, - with_kwx=True, - with_kwd=True, - add_original=True, - ): - """Create a test kwik file.""" - filename = op.join(dir_path, '_test.kwik') - kwx_filename = op.join(dir_path, '_test.kwx') - kwd_filename = op.join(dir_path, '_test.raw.kwd') - - # Create the kwik file. - with open_h5(filename, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - - def _write_metadata(key, value): - f.write_attr('/application_data/spikedetekt', key, value) - - _write_metadata('sample_rate', 20000.) - - # Filter parameters. - _write_metadata('filter_low', 500.) - _write_metadata('filter_high_factor', 0.95 * .5) - _write_metadata('filter_butter_order', 3) - - _write_metadata('extract_s_before', 15) - _write_metadata('extract_s_after', 25) - - _write_metadata('n_features_per_channel', n_features_per_channel) - - # Create spike times. - spike_samples = artificial_spike_samples(n_spikes).astype(np.int64) - spike_recordings = np.zeros(n_spikes, dtype=np.uint16) - # Size of the first recording. - recording_size = 2 * n_spikes // 3 - if recording_size > 0: - # Find the recording offset. - recording_offset = spike_samples[recording_size] - recording_offset += spike_samples[recording_size + 1] - recording_offset //= 2 - spike_recordings[recording_size:] = 1 - # Make sure the spike samples of the second recording start over. - spike_samples[recording_size:] -= spike_samples[recording_size] - spike_samples[recording_size:] += 10 - else: - recording_offset = 1 - - if spike_samples.max() >= n_samples_traces: - raise ValueError("There are too many spikes: decrease 'n_spikes'.") - - f.write('/channel_groups/1/spikes/time_samples', spike_samples) - f.write('/channel_groups/1/spikes/recording', spike_recordings) - f.write_attr('/channel_groups/1', 'channel_order', - np.arange(1, n_channels - 1)[::-1]) - graph = np.array([[1, 2], [2, 3]]) - f.write_attr('/channel_groups/1', 'adjacency_graph', graph) - - # Create channels. - positions = staggered_positions(n_channels) - for channel in range(n_channels): - group = '/channel_groups/1/channels/{0:d}'.format(channel) - f.write_attr(group, 'name', str(channel)) - f.write_attr(group, 'position', positions[channel]) - - # Create spike clusters. - clusterings = [('main', n_clusters)] - if add_original: - clusterings += [('original', n_clusters * 2)] - for clustering, n_clusters_rec in clusterings: - spike_clusters = artificial_spike_clusters(n_spikes, - n_clusters_rec) - groups = {0: 0, 1: 1, 2: 2} - _create_clustering(f, clustering, 1, spike_clusters, groups) - - # Create recordings. - f.write_attr('/recordings/0', 'name', 'recording_0') - f.write_attr('/recordings/1', 'name', 'recording_1') - - f.write_attr('/recordings/0/raw', 'hdf5_path', kwd_filename) - f.write_attr('/recordings/1/raw', 'hdf5_path', kwd_filename) - - # Create the kwx file. - if with_kwx: - with open_h5(kwx_filename, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - features = artificial_features(n_spikes, - (n_channels - 2) * - n_features_per_channel) - masks = artificial_masks(n_spikes, - (n_channels - 2) * - n_features_per_channel) - fm = np.dstack((features, masks)).astype(np.float32) - f.write('/channel_groups/1/features_masks', fm) - - # Create the raw kwd file. - if with_kwd: - with open_h5(kwd_filename, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - traces = artificial_traces(n_samples_traces, n_channels) - # TODO: int16 traces - f.write('/recordings/0/data', - traces[:recording_offset, ...].astype(np.float32)) - f.write('/recordings/1/data', - traces[recording_offset:, ...].astype(np.float32)) - - return filename diff --git a/phy/io/kwik/model.py b/phy/io/kwik/model.py deleted file mode 100644 index c8724cd8a..000000000 --- a/phy/io/kwik/model.py +++ /dev/null @@ -1,1243 +0,0 @@ -# -*- coding: utf-8 -*- - -"""The KwikModel class manages in-memory structures and Kwik file open/save.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from random import randint -import os - -import numpy as np - -from .creator import KwikCreator -import six -from ..base import BaseModel, ClusterMetadata -from ..h5 import open_h5, File -from ..traces import _dat_to_traces -from ...traces.waveform import WaveformLoader, SpikeLoader -from ...traces.filter import bandpass_filter, apply_filter -from ...electrode.mea import MEA -from ...utils.logging import debug, warn -from ...utils.array import (PartialArray, - _concatenate_virtual_arrays, - _spikes_per_cluster, - _unique, - ) -from ...utils.settings import _load_default_settings -from ...utils._types import _is_integer, _as_array - - -#------------------------------------------------------------------------------ -# Kwik utility functions -#------------------------------------------------------------------------------ - -def _to_int_list(l): - """Convert int strings to ints.""" - return [int(_) for _ in l] - - -def _list_int_children(group): - """Return the list of int children of a HDF5 group.""" - return sorted(_to_int_list(group.keys())) - - -# TODO: refactor the functions below with h5.File.children(). - -def _list_channel_groups(kwik): - """Return the list of channel groups in a kwik file.""" - if 'channel_groups' in kwik: - return _list_int_children(kwik['/channel_groups']) - else: - return [] - - -def _list_recordings(kwik): - """Return the list of recordings in a kwik file.""" - if '/recordings' in kwik: - recordings = _list_int_children(kwik['/recordings']) - else: - recordings = [] - # TODO: return a dictionary of recordings instead of a list of recording - # ids. - # return {rec: Bunch({ - # 'start': kwik['/recordings/{0}'.format(rec)].attrs['start_sample'] - # }) for rec in recordings} - return recordings - - -def _list_channels(kwik, channel_group=None): - """Return the list of channels in a kwik file.""" - assert isinstance(channel_group, six.integer_types) - path = '/channel_groups/{0:d}/channels'.format(channel_group) - if path in kwik: - channels = _list_int_children(kwik[path]) - return channels - else: - return [] - - -def _list_clusterings(kwik, channel_group=None): - """Return the list of clusterings in a kwik file.""" - if channel_group is None: - raise RuntimeError("channel_group must be specified when listing " - "the clusterings.") - assert isinstance(channel_group, six.integer_types) - path = '/channel_groups/{0:d}/clusters'.format(channel_group) - if path not in kwik: - return [] - clusterings = sorted(kwik[path].keys()) - # Ensure 'main' is the first if it exists. - if 'main' in clusterings: - clusterings.remove('main') - clusterings = ['main'] + clusterings - return clusterings - - -def _concatenate_spikes(spikes, recs, offsets): - """Concatenate spike samples belonging to consecutive recordings.""" - assert offsets is not None - spikes = _as_array(spikes) - offsets = _as_array(offsets) - recs = _as_array(recs) - return (spikes + offsets[recs]).astype(np.uint64) - - -def _create_cluster_group(f, group_id, name, - clustering=None, - channel_group=None, - write_color=True, - ): - cg_path = ('/channel_groups/{0:d}/' - 'cluster_groups/{1:s}/{2:d}').format(channel_group, - clustering, - group_id, - ) - kv_path = cg_path + '/application_data/klustaviewa' - f.write_attr(cg_path, 'name', name) - if write_color: - f.write_attr(kv_path, 'color', randint(2, 10)) - - -def _create_clustering(f, name, - channel_group=None, - spike_clusters=None, - cluster_groups=None, - ): - if cluster_groups is None: - cluster_groups = {} - assert isinstance(f, File) - path = '/channel_groups/{0:d}/spikes/clusters/{1:s}'.format(channel_group, - name) - assert not f.exists(path) - - # Save spike_clusters. - f.write(path, spike_clusters.astype(np.int32)) - - cluster_ids = _unique(spike_clusters) - - # Create cluster metadata. - for cluster in cluster_ids: - cluster_path = '/channel_groups/{0:d}/clusters/{1:s}/{2:d}'.format( - channel_group, name, cluster) - kv_path = cluster_path + '/application_data/klustaviewa' - - # Default group: unsorted. - cluster_group = cluster_groups.get(cluster, 3) - f.write_attr(cluster_path, 'cluster_group', cluster_group) - f.write_attr(kv_path, 'color', randint(2, 10)) - - # Create cluster group metadata. - for group_id, cg_name in _DEFAULT_GROUPS: - _create_cluster_group(f, group_id, cg_name, - clustering=name, - channel_group=channel_group, - ) - - -def list_kwik(folders): - """Return the list of Kwik files found in a list of folders.""" - ret = [] - for d in folders: - for root, dirs, files in os.walk(os.path.expanduser(d)): - for f in files: - if f.endswith(".kwik"): - ret.append(os.path.join(root, f)) - return ret - - -def _open_h5_if_exists(kwik_path, file_type, mode=None): - basename, ext = op.splitext(kwik_path) - path = '{basename}.{ext}'.format(basename=basename, ext=file_type) - return open_h5(path, mode=mode) if op.exists(path) else None - - -def _read_traces(kwik, dtype=None, n_channels=None): - kwd_path = None - dat_path = None - if '/recordings' not in kwik: - return (None, None) - recordings = kwik.children('/recordings') - traces = [] - opened_files = [] - for recording in recordings: - # Is there a path specified to a .raw.kwd file which exists in - # [KWIK]/recordings/[X]/raw? If so, open it. - path = '/recordings/{}/raw'.format(recording) - if kwik.has_attr(path, 'hdf5_path'): - kwd_path = kwik.read_attr(path, 'hdf5_path')[:-8] - kwd = _open_h5_if_exists(kwd_path, 'raw.kwd') - if kwd is None: - debug("{} not found, trying same basename in KWIK dir" - .format(kwd_path)) - else: - debug("Loading traces: {}" - .format(kwd_path)) - traces.append(kwd.read('/recordings/{}/data' - .format(recording))) - opened_files.append(kwd) - continue - # Is there a path specified to a .dat file which exists? - elif kwik.has_attr(path, 'dat_path'): - assert dtype is not None - assert n_channels - dat_path = kwik.read_attr(path, 'dat_path') - if not op.exists(dat_path): - debug("{} not found, trying same basename in KWIK dir" - .format(dat_path)) - else: - debug("Loading traces: {}" - .format(dat_path)) - dat = _dat_to_traces(dat_path, dtype=dtype, - n_channels=n_channels) - traces.append(dat) - opened_files.append(dat) - continue - - # Does a file exist with the same name in the current directory? - # If so, open it. - if kwd_path is not None: - rel_path = str(op.basename(kwd_path)) - rel_path = op.join(op.dirname(op.realpath(kwik.filename)), - rel_path) - kwd = _open_h5_if_exists(rel_path, 'raw.kwd') - if kwd is None: - debug("{} not found, trying experiment basename in KWIK dir" - .format(rel_path)) - else: - debug("Loading traces: {}" - .format(rel_path)) - traces.append(kwd.read('/recordings/{}/data' - .format(recording))) - opened_files.append(kwd) - continue - elif dat_path is not None: - rel_path = op.basename(dat_path) - rel_path = op.join(op.dirname(op.realpath(kwik.filename)), - rel_path) - if not op.exists(rel_path): - debug("{} not found, trying experiment basename in KWIK dir" - .format(rel_path)) - else: - debug("Loading traces: {}" - .format(rel_path)) - dat = _dat_to_traces(rel_path, dtype=dtype, - n_channels=n_channels) - traces.append(dat) - opened_files.append(dat) - continue - - # Finally, is there a `raw.kwd` with the experiment basename in the - # current directory? If so, open it. - kwd = _open_h5_if_exists(kwik.filename, 'raw.kwd') - if kwd is None: - warn("Could not find any data source for traces (raw.kwd or " - ".dat or .bin.) Waveforms and traces will not be available.") - else: - debug("Successfully loaded basename.raw.kwd in same directory") - traces.append(kwd.read('/recordings/{}/data'.format(recording))) - opened_files.append(kwd) - - return traces, opened_files - - -_DEFAULT_GROUPS = [(0, 'Noise'), - (1, 'MUA'), - (2, 'Good'), - (3, 'Unsorted'), - ] - - -"""Metadata fields that must be provided when creating the Kwik file.""" -_mandatory_metadata_fields = ('dtype', - 'n_channels', - 'prb_file', - 'raw_data_files', - ) - - -def cluster_group_id(name_or_id): - """Return the id of a cluster group from its name.""" - if isinstance(name_or_id, six.string_types): - d = {group.lower(): id for id, group in _DEFAULT_GROUPS} - return d[name_or_id.lower()] - else: - assert _is_integer(name_or_id) - return name_or_id - - -#------------------------------------------------------------------------------ -# KwikModel class -#------------------------------------------------------------------------------ - -class KwikModel(BaseModel): - """Holds data contained in a kwik file.""" - - """Names of the default cluster groups.""" - default_cluster_groups = dict(_DEFAULT_GROUPS) - - def __init__(self, kwik_path=None, - channel_group=None, - clustering=None, - waveform_filter=True, - ): - super(KwikModel, self).__init__() - - # Initialize fields. - self._spike_samples = None - self._spike_clusters = None - self._spikes_per_cluster = None - self._metadata = None - self._clustering = clustering or 'main' - self._probe = None - self._channels = [] - self._channel_order = None - self._features = None - self._features_masks = None - self._masks = None - self._waveforms = None - self._cluster_metadata = None - self._clustering_metadata = {} - self._traces = None - self._recording_offsets = None - self._waveform_loader = None - self._waveform_filter = waveform_filter - self._opened_files = [] - - # Open the experiment. - self.kwik_path = kwik_path - self.open(kwik_path, - channel_group=channel_group, - clustering=clustering) - - @property - def path(self): - return self.kwik_path - - # Internal properties and methods - # ------------------------------------------------------------------------- - - def _check_kwik_version(self): - # This class only works with kwik version 2 for now. - kwik_version = self._kwik.read_attr('/', 'kwik_version') - if kwik_version != 2: - raise IOError("The kwik version is {v} != 2.".format(kwik_version)) - - @property - def _channel_groups_path(self): - return '/channel_groups/{0:d}'.format(self._channel_group) - - @property - def _spikes_path(self): - return '{0:s}/spikes'.format(self._channel_groups_path) - - @property - def _channels_path(self): - return '{0:s}/channels'.format(self._channel_groups_path) - - @property - def _clusters_path(self): - return '{0:s}/clusters'.format(self._channel_groups_path) - - def _cluster_path(self, cluster): - return '{0:s}/{1:d}'.format(self._clustering_path, cluster) - - @property - def _spike_clusters_path(self): - return '{0:s}/clusters/{1:s}'.format(self._spikes_path, - self._clustering) - - @property - def _clustering_path(self): - return '{0:s}/{1:s}'.format(self._clusters_path, self._clustering) - - # Loading and saving - # ------------------------------------------------------------------------- - - def _open_kwik_if_needed(self, mode=None): - if not self._kwik.is_open(): - self._kwik.open(mode=mode) - return True - else: - if mode is not None: - self._kwik.mode = mode - return False - - @property - def n_samples_waveforms(self): - return (self._metadata['extract_s_before'] + - self._metadata['extract_s_after']) - - def _create_waveform_loader(self): - """Create a waveform loader.""" - n_samples = (self._metadata['extract_s_before'], - self._metadata['extract_s_after']) - order = self._metadata['filter_butter_order'] - rate = self._metadata['sample_rate'] - low = self._metadata['filter_low'] - high = self._metadata['filter_high_factor'] * rate - b_filter = bandpass_filter(rate=rate, - low=low, - high=high, - order=order) - - if self._metadata.get('waveform_filter', True): - debug("Enable waveform filter.") - - def filter(x): - return apply_filter(x, b_filter) - - filter_margin = order * 3 - else: - debug("Disable waveform filter.") - filter = None - filter_margin = 0 - - dc_offset = self._metadata.get('waveform_dc_offset', None) - scale_factor = self._metadata.get('waveform_scale_factor', None) - self._waveform_loader = WaveformLoader(n_samples=n_samples, - filter=filter, - filter_margin=filter_margin, - dc_offset=dc_offset, - scale_factor=scale_factor, - ) - - def _update_waveform_loader(self): - if self._traces is not None: - self._waveform_loader.traces = self._traces - else: - self._waveform_loader.traces = np.zeros((0, self.n_channels), - dtype=np.float32) - - # Update the list of channels for the waveform loader. - self._waveform_loader.channels = self._channel_order - - def _create_cluster_metadata(self): - self._cluster_metadata = ClusterMetadata() - - @self._cluster_metadata.default - def group(cluster): - # Default group is unsorted. - return 3 - - def _load_meta(self): - """Load metadata from kwik file.""" - # Automatically load all metadata from spikedetekt group. - path = '/application_data/spikedetekt/' - params = {} - for attr in self._kwik.attrs(path): - params[attr] = self._kwik.read_attr(path, attr) - # Make sure all params are there. - default_params = {} - settings = _load_default_settings() - default_params.update(settings['traces']) - default_params.update(settings['spikedetekt']) - default_params.update(settings['klustakwik2']) - for name, default_value in default_params.items(): - if name not in params: - params[name] = default_value - self._metadata = params - - def _load_probe(self): - # Re-create the probe from the Kwik file. - channel_groups = {} - for group in self._channel_groups: - cg_p = '/channel_groups/{:d}'.format(group) - c_p = cg_p + '/channels' - channels = self._kwik.read_attr(cg_p, 'channel_order') - graph = self._kwik.read_attr(cg_p, 'adjacency_graph') - positions = { - channel: self._kwik.read_attr(c_p + '/' + str(channel), - 'position') - for channel in channels - } - channel_groups[group] = { - 'channels': channels, - 'graph': graph, - 'geometry': positions, - } - probe = {'channel_groups': channel_groups} - self._probe = MEA(probe=probe) - - def _load_recordings(self): - # Load recordings. - self._recordings = _list_recordings(self._kwik.h5py_file) - # This will be updated later if a KWD file is present. - self._recording_offsets = [0] * (len(self._recordings) + 1) - - def _load_channels(self): - self._channels = np.array(_list_channels(self._kwik.h5py_file, - self._channel_group)) - self._channel_order = self._probe.channels - assert set(self._channel_order) <= set(self._channels) - - def _load_channel_groups(self, channel_group=None): - self._channel_groups = _list_channel_groups(self._kwik.h5py_file) - if channel_group is None and self._channel_groups: - # Choose the default channel group if not specified. - channel_group = self._channel_groups[0] - # Load the channel group. - self._channel_group = channel_group - - def _load_features_masks(self): - - # Load features masks. - path = '{0:s}/features_masks'.format(self._channel_groups_path) - - nfpc = self._metadata['n_features_per_channel'] - nc = len(self.channel_order) - - if self._kwx is not None: - self._kwx = _open_h5_if_exists(self.kwik_path, 'kwx') - if path not in self._kwx: - debug("There are no features and masks in the `.kwx` file.") - # No need to keep the file open if it is empty. - self._kwx.close() - return - fm = self._kwx.read(path) - self._features_masks = fm - self._features = PartialArray(fm, 0) - - # This partial array simulates a (n_spikes, n_channels) array. - self._masks = PartialArray(fm, (slice(0, nfpc * nc, nfpc), 1)) - assert self._masks.shape == (self.n_spikes, nc) - - def _load_spikes(self): - # Load spike samples. - path = '{0:s}/time_samples'.format(self._spikes_path) - - # Concatenate the spike samples from consecutive recordings. - if path not in self._kwik: - debug("There are no spikes in the dataset.") - return - _spikes = self._kwik.read(path)[:] - self._spike_recordings = self._kwik.read( - '{0:s}/recording'.format(self._spikes_path))[:] - self._spike_samples = _concatenate_spikes(_spikes, - self._spike_recordings, - self._recording_offsets) - - def _load_spike_clusters(self): - self._spike_clusters = self._kwik.read(self._spike_clusters_path)[:] - - def _save_spike_clusters(self, spike_clusters): - assert spike_clusters.shape == self._spike_clusters.shape - assert spike_clusters.dtype == self._spike_clusters.dtype - self._spike_clusters = spike_clusters - sc = self._kwik.read(self._spike_clusters_path) - sc[:] = spike_clusters - - def _load_clusterings(self, clustering=None): - # Once the channel group is loaded, list the clusterings. - self._clusterings = _list_clusterings(self._kwik.h5py_file, - self.channel_group) - # Choose the first clustering (should always be 'main'). - if clustering is None and self.clusterings: - clustering = self.clusterings[0] - # Load the specified clustering. - self._clustering = clustering - - def _load_cluster_groups(self): - clusters = self._kwik.groups(self._clustering_path) - clusters = [int(cluster) for cluster in clusters] - for cluster in clusters: - path = self._cluster_path(cluster) - group = self._kwik.read_attr(path, 'cluster_group') - self._cluster_metadata.set_group([cluster], group) - - def _save_cluster_groups(self, cluster_groups): - assert isinstance(cluster_groups, dict) - for cluster, group in cluster_groups.items(): - path = self._cluster_path(cluster) - self._kwik.write_attr(path, 'cluster_group', group) - self._cluster_metadata.set_group([cluster], group) - - def _load_clustering_metadata(self): - attrs = self._kwik.attrs(self._clustering_path) - metadata = {} - for attr in attrs: - try: - metadata[attr] = self._kwik.read_attr(self._clustering_path, - attr) - except OSError: - debug("Error when reading `{}:{}`.".format( - self._clustering_path, attr)) - self._clustering_metadata = metadata - - def _save_clustering_metadata(self, metadata): - if not metadata: - return - assert isinstance(metadata, dict) - for name, value in metadata.items(): - path = self._clustering_path - self._kwik.write_attr(path, name, value) - self._clustering_metadata.update(metadata) - - def _load_traces(self): - n_channels = self._metadata.get('n_channels', None) - dtype = self._metadata.get('dtype', None) - dtype = np.dtype(dtype) if dtype else None - traces, opened_files = _read_traces(self._kwik, - dtype=dtype, - n_channels=n_channels) - - if not traces: - return - - # Update the list of opened files for cleanup. - self._opened_files.extend(opened_files) - - # Set the recordings offsets (no delay between consecutive recordings). - i = 0 - self._recording_offsets = [] - for trace in traces: - self._recording_offsets.append(i) - i += trace.shape[0] - self._traces = _concatenate_virtual_arrays(traces) - - def has_kwx(self): - """Returns whether the `.kwx` file is present. - - If not, the features and masks won't be available. - - """ - return self._kwx is not None - - def open(self, kwik_path, channel_group=None, clustering=None): - """Open a Kwik dataset. - - The `.kwik` and `.kwx` must be in the same folder with the - same basename. - - The files containing the traces (`.raw.kwd` or `.dat` / `.bin`) are - determined according to the following logic: - - - Is there a path specified to a file which exists in - [KWIK]/recordings/[X]/raw? If so, open it. - - If this file does not exist, does a file exist with the same name - in the current directory? If so, open it. - - If such a file does not exist, or no filename is specified in - the [KWIK], then is there a `raw.kwd` with the experiment basename - in the current directory? If so, open it. - - If not, return with a warning. - - Notes - ----- - - The `.kwik` file is opened in read-only mode, and is automatically - closed when this function returns. It is temporarily reopened when - the channel group or clustering changes. - - The `.kwik` file is temporarily opened in append mode when saving. - - The `.kwx` and `.raw.kwd` or `.dat` / `.bin` files stay open in - read-only mode as long as `model.close()` is not called. This is - because there might be read accesses to `features_masks` (`.kwx`) - and waveforms (`.raw.kwd` or `.dat` / `.bin`) while the dataset is - opened. - - Parameters - ---------- - - kwik_path : str - Path to a `.kwik` file. - channel_group : int or None (default is None) - The channel group (shank) index to use. This can be changed - later after the file has been opened. By default, the first - channel group is used. - clustering : str or None (default is None) - The clustering to use. This can be changed later after the file - has been opened. By default, the `main` clustering is used. An - error is raised if the `main` clustering doesn't exist. - - """ - - if kwik_path is None: - raise ValueError("No kwik_path specified.") - - if not kwik_path.endswith('.kwik'): - raise ValueError("File does not end in .kwik") - - # Open the file. - kwik_path = op.realpath(kwik_path) - self.kwik_path = kwik_path - self.name = op.splitext(op.basename(kwik_path))[0] - - # Open the KWIK file. - self._kwik = _open_h5_if_exists(kwik_path, 'kwik') - if self._kwik is None: - raise IOError("File `{0}` doesn't exist.".format(kwik_path)) - if not self._kwik.is_open(): - raise IOError("File `{0}` failed to open.".format(kwik_path)) - self._check_kwik_version() - - # Open the KWX and KWD files. - self._kwx = _open_h5_if_exists(kwik_path, 'kwx') - if self._kwx is None: - warn("The `.kwx` file hasn't been found. " - "Features and masks won't be available.") - - # KwikCreator instance. - self.creator = KwikCreator(kwik_path=kwik_path) - - # Load the data. - self._load_meta() - - # This needs metadata. - self._create_waveform_loader() - - self._load_recordings() - - # This generates the recording offset. - self._load_traces() - - self._load_channel_groups(channel_group) - - # Load the probe. - self._load_probe() - - # This needs channel groups. - self._load_clusterings(clustering) - - # This needs the recording offsets. - # This loads channels, channel_order, spikes, probe. - self._channel_group_changed(self._channel_group) - - # This loads spike clusters and cluster groups. - self._clustering_changed(clustering) - - # This needs channels, channel_order, and waveform loader. - self._update_waveform_loader() - - # No need to keep the kwik file open. - self._kwik.close() - - def save(self, spike_clusters, cluster_groups, clustering_metadata=None): - """Save the spike clusters and cluster groups in the Kwik file.""" - - # REFACTOR: with() to open/close the file if needed - to_close = self._open_kwik_if_needed(mode='a') - - self._save_spike_clusters(spike_clusters) - self._save_cluster_groups(cluster_groups) - self._save_clustering_metadata(clustering_metadata) - - if to_close: - self._kwik.close() - - def describe(self): - """Display information about the dataset.""" - def _print(name, value): - print("{0: <24}{1}".format(name, value)) - _print("Kwik file", self.kwik_path) - _print("Recordings", self.n_recordings) - - # List of channel groups. - cg = ['{:d}'.format(g) + ('*' if g == self.channel_group else '') - for g in self.channel_groups] - _print("List of shanks", ', '.join(cg)) - - # List of clusterings. - cl = ['{:s}'.format(c) + ('*' if c == self.clustering else '') - for c in self.clusterings] - _print("Clusterings", ', '.join(cl)) - - _print("Channels", self.n_channels) - _print("Spikes", self.n_spikes) - _print("Clusters", self.n_clusters) - _print("Duration", "{:.0f}s".format(self.duration)) - - # Changing channel group and clustering - # ------------------------------------------------------------------------- - - def _channel_group_changed(self, value): - """Called when the channel group changes.""" - if value not in self.channel_groups: - raise ValueError("The channel group {0} is invalid.".format(value)) - self._channel_group = value - - # Load data. - _to_close = self._open_kwik_if_needed() - - if self._kwik.h5py_file: - clusterings = _list_clusterings(self._kwik.h5py_file, - self._channel_group) - else: - warn(".kwik filepath doesn't exist.") - clusterings = None - - if clusterings: - if 'main' in clusterings: - self._load_clusterings('main') - self.clustering = 'main' - else: - self._load_clusterings(clusterings[0]) - self.clustering = clusterings[0] - - self._probe.change_channel_group(value) - self._load_channels() - self._load_spikes() - self._load_features_masks() - - # Recalculate spikes_per_cluster manually - self._spikes_per_cluster = \ - _spikes_per_cluster(self.spike_ids, self._spike_clusters) - - if _to_close: - self._kwik.close() - - # Update the list of channels for the waveform loader. - self._waveform_loader.channels = self._channel_order - - def _clustering_changed(self, value): - """Called when the clustering changes.""" - if value is None: - return - if value not in self.clusterings: - raise ValueError("The clustering {0} is invalid.".format(value)) - self._clustering = value - - # Load data. - _to_close = self._open_kwik_if_needed() - self._create_cluster_metadata() - self._load_spike_clusters() - self._load_cluster_groups() - self._load_clustering_metadata() - if _to_close: - self._kwik.close() - - # Managing cluster groups - # ------------------------------------------------------------------------- - - def _write_cluster_group(self, group_id, name, write_color=True): - if group_id <= 3: - raise ValueError("Default groups cannot be changed.") - - _to_close = self._open_kwik_if_needed(mode='a') - - _create_cluster_group(self._kwik, group_id, name, - clustering=self._clustering, - channel_group=self._channel_group, - write_color=write_color, - ) - - if _to_close: - self._kwik.close() - - def add_cluster_group(self, group_id, name): - """Add a new cluster group.""" - self._write_cluster_group(group_id, name, write_color=True) - - def rename_cluster_group(self, group_id, name): - """Rename an existing cluster group.""" - self._write_cluster_group(group_id, name, write_color=False) - - def delete_cluster_group(self, group_id): - if group_id <= 3: - raise ValueError("Default groups cannot be deleted.") - - path = ('/channel_groups/{0:d}/' - 'cluster_groups/{1:s}/{2:d}').format(self._channel_group, - self._clustering, - group_id, - ) - - _to_close = self._open_kwik_if_needed(mode='a') - - self._kwik.delete(path) - - if _to_close: - self._kwik.close() - - # Managing clusterings - # ------------------------------------------------------------------------- - - def add_clustering(self, name, spike_clusters): - """Save a new clustering to the file.""" - if name in self._clusterings: - raise ValueError("The clustering '{0}' ".format(name) + - "already exists.") - assert len(spike_clusters) == self.n_spikes - - _to_close = self._open_kwik_if_needed(mode='a') - - _create_clustering(self._kwik, - name, - channel_group=self._channel_group, - spike_clusters=spike_clusters, - ) - - # Update the list of clusterings. - self._load_clusterings(self._clustering) - - if _to_close: - self._kwik.close() - - def _move_clustering(self, old_name, new_name, copy=None): - if not copy and old_name == self._clustering: - raise ValueError("You cannot move the current clustering.") - if new_name in self._clusterings: - raise ValueError("The clustering '{0}' ".format(new_name) + - "already exists.") - - _to_close = self._open_kwik_if_needed(mode='a') - - if copy: - func = self._kwik.copy - else: - func = self._kwik.move - - # /channel_groups/x/spikes/clusters/ - p = self._spikes_path + '/clusters/' - func(p + old_name, p + new_name) - - # /channel_groups/x/clusters/ - p = self._clusters_path + '/' - func(p + old_name, p + new_name) - - # /channel_groups/x/cluster_groups/ - p = self._channel_groups_path + '/cluster_groups/' - func(p + old_name, p + new_name) - - # Update the list of clusterings. - self._load_clusterings(self._clustering) - - if _to_close: - self._kwik.close() - - def rename_clustering(self, old_name, new_name): - """Rename a clustering in the `.kwik` file.""" - self._move_clustering(old_name, new_name, copy=False) - - def copy_clustering(self, name, new_name): - """Copy a clustering in the `.kwik` file.""" - self._move_clustering(name, new_name, copy=True) - - def delete_clustering(self, name): - """Delete a clustering.""" - if name == self._clustering: - raise ValueError("You cannot delete the current clustering.") - if name not in self._clusterings: - raise ValueError(("The clustering {0} " - "doesn't exist.").format(name)) - - _to_close = self._open_kwik_if_needed(mode='a') - - # /channel_groups/x/spikes/clusters/ - parent = self._kwik.read(self._spikes_path + '/clusters/') - del parent[name] - - # /channel_groups/x/clusters/ - parent = self._kwik.read(self._clusters_path) - del parent[name] - - # /channel_groups/x/cluster_groups/ - parent = self._kwik.read(self._channel_groups_path + - '/cluster_groups/') - del parent[name] - - # Update the list of clusterings. - self._load_clusterings(self._clustering) - - if _to_close: - self._kwik.close() - - # Data - # ------------------------------------------------------------------------- - - @property - def duration(self): - """Duration of the experiment (in seconds).""" - if self._traces is None: - return 0. - return float(self.traces.shape[0]) / self.sample_rate - - @property - def channel_groups(self): - """List of channel groups found in the Kwik file.""" - return self._channel_groups - - @property - def n_features_per_channel(self): - """Number of features per channel (generally 3).""" - return self._metadata['n_features_per_channel'] - - @property - def channels(self): - """List of all channels in the current channel group. - - This list comes from the /channel_groups HDF5 group in the Kwik file. - - """ - # TODO: rename to channel_ids? - return self._channels - - @property - def channel_order(self): - """List of kept channels in the current channel group. - - If you want the channels used in the features, masks, and waveforms, - this is the property you want to use, and not `model.channels`. - - The channel order is the same than the one from the PRB file. - This order was used when generating the features and masks - in SpikeDetekt2. The same order is used in phy when loading the - waveforms from the traces file(s). - - """ - return self._channel_order - - @property - def n_channels(self): - """Number of all channels in the current channel group.""" - return len(self._channels) - - @property - def recordings(self): - """List of recording indices found in the Kwik file.""" - return self._recordings - - @property - def n_recordings(self): - """Number of recordings found in the Kwik file.""" - return len(self._recordings) - - @property - def clusterings(self): - """List of clusterings found in the Kwik file. - - The first one is always `main`. - - """ - return self._clusterings - - @property - def clustering(self): - """The currently-active clustering. - - Default is `main`. - - """ - return self._clustering - - @clustering.setter - def clustering(self, value): - """Change the currently-active clustering.""" - self._clustering_changed(value) - - @property - def clustering_metadata(self): - """A dictionary of key-value metadata specific to the current - clustering.""" - return self._clustering_metadata - - @property - def metadata(self): - """A dictionary holding metadata about the experiment. - - This information comes from the PRM file. It was used by - SpikeDetekt2 and KlustaKwik during automatic clustering. - - """ - return self._metadata - - @property - def probe(self): - """A `Probe` instance representing the probe used for the recording. - - This object contains information about the adjacency graph and - the channel positions. - - """ - return self._probe - - @property - def traces(self): - """Raw traces as found in the traces file(s). - - This object is memory-mapped to the HDF5 file, or `.dat` / `.bin` file, - or both. - - """ - return self._traces - - @property - def spike_samples(self): - """Spike samples from the current channel group. - - This is a NumPy array containing `uint64` values (number of samples - in unit of the sample rate). - - The spike times of all recordings are concatenated. There is no gap - between consecutive recordings, currently. - - """ - return self._spike_samples - - @property - def sample_rate(self): - """Sample rate of the recording. - - This value is found in the metadata coming from the PRM file. - - """ - return float(self._metadata['sample_rate']) - - @property - def spike_recordings(self): - """The recording index for each spike. - - This is a NumPy array of integers with `n_spikes` elements. - - """ - return self._spike_recordings - - @property - def n_spikes(self): - """Number of spikes in the current channel group.""" - return (len(self._spike_samples) - if self._spike_samples is not None else 0) - - @property - def features(self): - """Features from the current channel group. - - This is memory-mapped to the `.kwx` file. - - Note: in general, it is better to use the cluster store to access - the features and masks of some clusters. - - """ - return self._features - - @property - def masks(self): - """Masks from the current channel group. - - This is memory-mapped to the `.kwx` file. - - Note: in general, it is better to use the cluster store to access - the features and masks of some clusters. - - """ - return self._masks - - @property - def features_masks(self): - """Features-masks from the current channel group. - - This is memory-mapped to the `.kwx` file. - - Note: in general, it is better to use the cluster store to access - the features and masks of some clusters. - - """ - return self._features_masks - - @property - def waveforms(self): - """High-passed filtered waveforms from the current channel group. - - This is a virtual array mapped to the traces file(s). Filtering is - done on the fly. - - The shape is `(n_spikes, n_samples, n_channels)`. - - """ - return SpikeLoader(self._waveform_loader, self.spike_samples) - - @property - def spike_clusters(self): - """Spike clusters from the current channel group and clustering. - - Every element is the cluster identifier of a spike. - - The shape is `(n_spikes,)`. - - """ - return self._spike_clusters - - @property - def spikes_per_cluster(self): - """Spikes per cluster from the current channel group and clustering.""" - if self._spikes_per_cluster is None: - if self._spike_clusters is None: - self._spikes_per_cluster = {0: self.spike_ids} - else: - self._spikes_per_cluster = \ - _spikes_per_cluster(self.spike_ids, self._spike_clusters) - return self._spikes_per_cluster - - def update_spikes_per_cluster(self, spc): - self._spikes_per_cluster = spc - - @property - def cluster_metadata(self): - """Metadata about the clusters in the current channel group and - clustering. - - `cluster_metadata.group(cluster_id)` returns the group of a given - cluster. The default group is 3 (unsorted). - - """ - return self._cluster_metadata - - @property - def cluster_ids(self): - """List of cluster ids from the current channel group and clustering. - - This is a sorted list of unique cluster ids as found in the current - `spike_clusters` array. - - """ - return _unique(self._spike_clusters) - - @property - def spike_ids(self): - """List of spike ids.""" - return np.arange(self.n_spikes, dtype=np.int32) - - @property - def n_clusters(self): - """Number of clusters in the current channel group and clustering.""" - return len(self.cluster_ids) - - # Close - # ------------------------------------------------------------------------- - - def close(self): - """Close the `.kwik` and `.kwx` files if they are open, and cleanup - handles to all raw data files""" - - debug("Closing files") - if self._kwx is not None: - self._kwx.close() - for f in self._opened_files: - # upside-down if statement to avoid false positive lint error - if not (isinstance(f, np.ndarray)): - f.close() - else: - del f - self._kwik.close() diff --git a/phy/io/kwik/sparse_kk2.py b/phy/io/kwik/sparse_kk2.py deleted file mode 100644 index a28257b6f..000000000 --- a/phy/io/kwik/sparse_kk2.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Convert data to sparse structures for KK2.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ...utils.array import chunk_bounds -import six - - -#------------------------------------------------------------------------------ -# KlustaKwik2 functions -#------------------------------------------------------------------------------ - -def sparsify_features_masks(features, masks, chunk_size=10000): - from klustakwik2 import RawSparseData - - assert features.ndim == 2 - assert masks.ndim == 2 - assert features.shape == masks.shape - - n_spikes, num_features = features.shape - - # Stage 1: read min/max of fet values for normalization - # and count total number of unmasked features. - vmin = np.ones(num_features) * np.inf - vmax = np.ones(num_features) * (-np.inf) - total_unmasked_features = 0 - for _, _, i, j in chunk_bounds(n_spikes, chunk_size): - f, m = features[i:j], masks[i:j] - inds = m > 0 - # Replace the masked values by NaN. - vmin = np.minimum(np.min(f, axis=0), vmin) - vmax = np.maximum(np.max(f, axis=0), vmax) - total_unmasked_features += inds.sum() - # Stage 2: read data line by line, normalising - vdiff = vmax - vmin - vdiff[vdiff == 0] = 1 - fetsum = np.zeros(num_features) - fet2sum = np.zeros(num_features) - nsum = np.zeros(num_features) - all_features = np.zeros(total_unmasked_features) - all_fmasks = np.zeros(total_unmasked_features) - all_unmasked = np.zeros(total_unmasked_features, dtype=int) - offsets = np.zeros(n_spikes + 1, dtype=int) - curoff = 0 - for i in six.moves.range(n_spikes): - fetvals, fmaskvals = (features[i] - vmin) / vdiff, masks[i] - inds = (fmaskvals > 0).nonzero()[0] - masked_inds = (fmaskvals == 0).nonzero()[0] - all_features[curoff:curoff + len(inds)] = fetvals[inds] - all_fmasks[curoff:curoff + len(inds)] = fmaskvals[inds] - all_unmasked[curoff:curoff + len(inds)] = inds - offsets[i] = curoff - curoff += len(inds) - fetsum[masked_inds] += fetvals[masked_inds] - fet2sum[masked_inds] += fetvals[masked_inds] ** 2 - nsum[masked_inds] += 1 - offsets[-1] = curoff - - nsum[nsum == 0] = 1 - noise_mean = fetsum / nsum - noise_variance = fet2sum / nsum - noise_mean ** 2 - - return RawSparseData(noise_mean, - noise_variance, - all_features, - all_fmasks, - all_unmasked, - offsets, - ) diff --git a/phy/io/kwik/store_items.py b/phy/io/kwik/store_items.py deleted file mode 100644 index f8d8f8f5f..000000000 --- a/phy/io/kwik/store_items.py +++ /dev/null @@ -1,711 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""Store items for Kwik.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np - -from ...utils.selector import Selector -from ...utils.array import (_index_of, - _spikes_per_cluster, - _concatenate_per_cluster_arrays, - ) -from ..store import ClusterStore, FixedSizeItem, VariableSizeItem - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _default_array(shape, value=0, n_spikes=0, dtype=np.float32): - shape = (n_spikes,) + shape[1:] - out = np.empty(shape, dtype=dtype) - out.fill(value) - return out - - -def _atleast_nd(arr, ndim): - if arr.ndim == ndim: - return arr - assert arr.ndim < ndim - if ndim - arr.ndim == 1: - return arr[None, ...] - elif ndim - arr.ndim == 2: - return arr[None, None, ...] - - -def _mean(arr, shape): - if arr is not None: - # Make sure (possibly lazy) memmapped arrays are fully loaded. - arr = arr[...] - assert isinstance(arr, np.ndarray) - if arr.shape[0]: - return arr.mean(axis=0) - return np.zeros(shape, dtype=np.float32) - - -#------------------------------------------------------------------------------ -# Store items -#------------------------------------------------------------------------------ - -class FeatureMasks(VariableSizeItem): - """Store all features and masks of all clusters.""" - name = 'features and masks' - fields = ['features', 'masks'] - - def __init__(self, *args, **kwargs): - # Size of the chunk used when reading features and masks from the HDF5 - # .kwx file. - self.chunk_size = kwargs.pop('chunk_size') - - super(FeatureMasks, self).__init__(*args, **kwargs) - - self.n_features = self.model.n_features_per_channel - self.n_channels = len(self.model.channel_order) - self.n_spikes = self.model.n_spikes - self.n_chunks = self.n_spikes // self.chunk_size + 1 - self._shapes['masks'] = (-1, self.n_channels) - self._shapes['features'] = (-1, self.n_channels, self.n_features) - - def _store(self, - cluster, - chunk_spikes, - chunk_spikes_per_cluster, - chunk_features_masks, - ): - - nc = self.n_channels - nf = self.n_features - - # Number of spikes in the cluster and in the current - # chunk. - ns = len(chunk_spikes_per_cluster[cluster]) - - # Find the indices of the spikes in that cluster - # relative to the chunk. - idx = _index_of(chunk_spikes_per_cluster[cluster], chunk_spikes) - - # Extract features and masks for that cluster, in the - # current chunk. - tmp = chunk_features_masks[idx, :] - - # NOTE: channel order has already been taken into account - # by SpikeDetekt2 when saving the features and masks. - # All we need to know here is the number of channels - # in channel_order, there is no need to reorder. - - # Features. - f = tmp[:, :nc * nf, 0] - assert f.shape == (ns, nc * nf) - f = f.ravel().astype(np.float32) - - # Masks. - m = tmp[:, :nc * nf, 1][:, ::nf] - assert m.shape == (ns, nc) - m = m.ravel().astype(np.float32) - - # Save the data to disk. - self.disk_store.store(cluster, - features=f, - masks=m, - append=True, - ) - - def mean_masks(self, cluster): - masks = self.cluster_store.masks(cluster) - mean_masks = _mean(masks, (self.n_channels,)) - return mean_masks - - def mean_features(self, cluster): - features = self.cluster_store.features(cluster) - mean_features = _mean(features, (self.n_channels,)) - return mean_features - - def _store_means(self, cluster): - if not self.disk_store: - return - self.disk_store.store(cluster, - mean_masks=self.mean_masks(cluster), - mean_features=self.mean_features(cluster), - ) - - def is_consistent(self, cluster, spikes): - """Return whether the filesizes of the two cluster store files - (`.features` and `.masks`) are correct.""" - cluster_size = len(spikes) - expected_file_sizes = [('masks', (cluster_size * - self.n_channels * - 4)), - ('features', (cluster_size * - self.n_channels * - self.n_features * - 4))] - for name, expected_file_size in expected_file_sizes: - path = self.disk_store._cluster_path(cluster, name) - if not op.exists(path): - return False - actual_file_size = os.stat(path).st_size - if expected_file_size != actual_file_size: - return False - return True - - def store_all(self, mode=None): - """Store the features, masks, and their means of the clusters that - need it. - - Parameters - ---------- - - mode : str or None - How to choose whether cluster files need to be re-generated. - Can be one of the following options: - - * `None` or `default`: only regenerate the missing or inconsistent - clusters - * `force`: fully regenerate all clusters - * `read-only`: just load the existing files, do not write anything - - """ - - # No need to regenerate the cluster store if it exists and is valid. - clusters_to_generate = self.to_generate(mode=mode) - need_generate = len(clusters_to_generate) > 0 - - if need_generate: - - self._pr.value_max = self.n_chunks + 1 - - fm = self.model.features_masks - assert fm.shape[0] == self.n_spikes - - for i in range(self.n_chunks): - a, b = i * self.chunk_size, (i + 1) * self.chunk_size - - # Load a chunk from HDF5. - chunk_features_masks = fm[a:b] - assert isinstance(chunk_features_masks, np.ndarray) - if chunk_features_masks.shape[0] == 0: - break - - chunk_spike_clusters = self.model.spike_clusters[a:b] - chunk_spikes = np.arange(a, b) - - # Split the spikes. - chunk_spc = _spikes_per_cluster(chunk_spikes, - chunk_spike_clusters) - - # Go through the clusters appearing in the chunk and that - # need to be re-generated. - clusters = (set(chunk_spc.keys()). - intersection(set(clusters_to_generate))) - for cluster in sorted(clusters): - self._store(cluster, - chunk_spikes, - chunk_spc, - chunk_features_masks, - ) - - # Update the progress reporter. - self._pr.value += 1 - - # Store mean features and masks on disk. - self._pr.value = 0 - self._pr.value_max = len(clusters_to_generate) - for cluster in clusters_to_generate: - self._store_means(cluster) - self._pr.value += 1 - - self._pr.set_complete() - - def load(self, cluster, name): - """Load features or masks for a cluster. - - This uses the cluster store if possible, otherwise it falls back - to the model (much slower). - - """ - assert name in ('features', 'masks') - dtype = np.float32 - shape = self._shapes[name] - if self.disk_store: - data = self.disk_store.load(cluster, name, dtype, shape) - if data is not None: - return data - # Fallback to load_spikes if the data could not be obtained from - # the store. - spikes = self.spikes_per_cluster[cluster] - return self.load_spikes(spikes, name) - - def load_spikes(self, spikes, name): - """Load features or masks for an array of spikes.""" - assert name in ('features', 'masks') - shape = self._shapes[name] - data = getattr(self.model, name) - if data is not None and (isinstance(spikes, slice) or len(spikes)): - out = data[spikes] - # Default masks and features. - else: - out = _default_array(shape, - value=0. if name == 'features' else 1., - n_spikes=len(spikes), - ) - return out.reshape(shape) - - def on_merge(self, up): - """Create the cluster store files of the merged cluster - from the files of the old clusters. - - This is basically a concatenation of arrays, but the spike order - needs to be taken into account. - - """ - if not self.disk_store: - return - clusters = up.deleted - spc = up.old_spikes_per_cluster - # We load all masks and features of the merged clusters. - for name, shape in [('features', - (-1, self.n_channels, self.n_features)), - ('masks', - (-1, self.n_channels)), - ]: - arrays = {cluster: self.disk_store.load(cluster, - name, - dtype=np.float32, - shape=shape) - for cluster in clusters} - # Then, we concatenate them using the right insertion order - # as defined by the spikes. - - # OPTIM: this could be made a bit faster by passing - # both arrays at once. - concat = _concatenate_per_cluster_arrays(spc, arrays) - - # Finally, we store the result into the new cluster. - self.disk_store.store(up.added[0], **{name: concat}) - - def on_assign(self, up): - """Create the cluster store files of the new clusters - from the files of the old clusters. - - The files of all old clusters are loaded, re-split and concatenated - to form the new cluster files. - - """ - if not self.disk_store: - return - for name, shape in [('features', - (-1, self.n_channels, self.n_features)), - ('masks', - (-1, self.n_channels)), - ]: - # Load all data from the old clusters. - old_arrays = {cluster: self.disk_store.load(cluster, - name, - dtype=np.float32, - shape=shape) - for cluster in up.deleted} - # Create the new arrays. - for new in up.added: - # Find the old clusters which are parents of the current - # new cluster. - old_clusters = [o - for (o, n) in up.descendants - if n == new] - # Spikes per old cluster, used to create - # the concatenated array. - spc = {} - old_arrays_sub = {} - # Find the relative spike indices of every old cluster - # for the current new cluster. - for old in old_clusters: - # Find the spike indices in the old and new cluster. - old_spikes = up.old_spikes_per_cluster[old] - new_spikes = up.new_spikes_per_cluster[new] - old_in_new = np.in1d(old_spikes, new_spikes) - old_spikes_subset = old_spikes[old_in_new] - spc[old] = old_spikes_subset - # Extract the data from the old cluster to - # be moved to the new cluster. - old_spikes_rel = _index_of(old_spikes_subset, - old_spikes) - old_arrays_sub[old] = old_arrays[old][old_spikes_rel] - # Construct the array of the new cluster. - concat = _concatenate_per_cluster_arrays(spc, - old_arrays_sub) - # Save it in the cluster store. - self.disk_store.store(new, **{name: concat}) - - def on_cluster(self, up=None): - super(FeatureMasks, self).on_cluster(up) - # No need to change anything in the store if this is an undo or - # a redo. - if up is None or up.history is not None: - return - # Store the means of the new clusters. - for cluster in up.added: - self._store_means(cluster) - - -class Waveforms(VariableSizeItem): - """A cluster store item that manages the waveforms of all clusters.""" - name = 'waveforms' - fields = ['waveforms'] - - def __init__(self, *args, **kwargs): - self.n_spikes_max = kwargs.pop('n_spikes_max') - self.excerpt_size = kwargs.pop('excerpt_size') - # Size of the chunk used when reading waveforms from the raw data. - self.chunk_size = kwargs.pop('chunk_size') - - super(Waveforms, self).__init__(*args, **kwargs) - - self.n_channels = len(self.model.channel_order) - self.n_samples = self.model.n_samples_waveforms - self.n_spikes = self.model.n_spikes - - self._shapes['waveforms'] = (-1, self.n_samples, self.n_channels) - self._selector = Selector(self.model.spike_clusters, - n_spikes_max=self.n_spikes_max, - excerpt_size=self.excerpt_size, - ) - - # Get or create the subset spikes per cluster dictionary. - spc = (self.disk_store.load_file('waveforms_spikes') - if self.disk_store else None) - if spc is None: - spc = self._selector.subset_spikes_clusters(self.model.cluster_ids) - if self.disk_store: - self.disk_store.save_file('waveforms_spikes', spc) - self._spikes_per_cluster = spc - - def _subset_spikes_cluster(self, cluster, force=False): - if force or cluster not in self._spikes_per_cluster: - spikes = self._selector.subset_spikes_clusters([cluster])[cluster] - # Persist the new _spikes_per_cluster array on disk. - self._spikes_per_cluster[cluster] = spikes - if self.disk_store: - self.disk_store.save_file('waveforms_spikes', - self._spikes_per_cluster) - return self._spikes_per_cluster[cluster] - - @property - def spikes_per_cluster(self): - """Spikes per cluster.""" - # WARNING: this is read-only, because this item is responsible - # for the spike subselection. - return self._spikes_per_cluster - - def _store_mean(self, cluster): - if not self.disk_store: - return - self.disk_store.store(cluster, - mean_waveforms=self.mean_waveforms(cluster), - ) - - def waveforms_and_mean(self, cluster): - spikes = self._subset_spikes_cluster(cluster, force=True) - waveforms = self.model.waveforms[spikes].astype(np.float32) - mean_waveforms = _mean(waveforms, (self.n_samples, self.n_channels)) - return waveforms, mean_waveforms - - def mean_waveforms(self, cluster): - return self.waveforms_and_mean(cluster)[1] - - def store(self, cluster): - """Store waveforms and mean waveforms.""" - if not self.disk_store: - return - # NOTE: make sure to erase old spikes for that cluster. - # Typical case merge, undo, different merge. - waveforms, mean_waveforms = self.waveforms_and_mean(cluster) - self.disk_store.store(cluster, - waveforms=waveforms, - mean_waveforms=mean_waveforms, - ) - - def store_all(self, mode=None): - if not self.disk_store: - return - - clusters_to_generate = self.to_generate(mode=mode) - need_generate = len(clusters_to_generate) > 0 - - if need_generate: - spc = {cluster: self._subset_spikes_cluster(cluster, force=True) - for cluster in clusters_to_generate} - - # All spikes to fetch and save in the store. - spike_ids = _concatenate_per_cluster_arrays(spc, spc) - - # Load waveforms chunk by chunk for I/O contiguity. - n_chunks = len(spike_ids) // self.chunk_size + 1 - self._pr.value_max = n_chunks + 1 - - for i in range(n_chunks): - a, b = i * self.chunk_size, (i + 1) * self.chunk_size - spk = spike_ids[a:b] - - # Load a chunk of waveforms. - chunk_waveforms = self.model.waveforms[spk] - assert isinstance(chunk_waveforms, np.ndarray) - if chunk_waveforms.shape[0] == 0: - break - - chunk_spike_clusters = self.model.spike_clusters[spk] - - # Split the spikes. - chunk_spc = _spikes_per_cluster(spk, chunk_spike_clusters) - - # Go through the clusters appearing in the chunk and that - # need to be re-generated. - clusters = (set(chunk_spc.keys()). - intersection(set(clusters_to_generate))) - for cluster in sorted(clusters): - i = _index_of(chunk_spc[cluster], spk) - w = chunk_waveforms[i].astype(np.float32) - self.disk_store.store(cluster, - waveforms=w, - append=True - ) - - self._pr.increment() - - # Store mean waveforms on disk. - self._pr.value = 0 - self._pr.value_max = len(clusters_to_generate) - for cluster in clusters_to_generate: - self._store_mean(cluster) - self._pr.value += 1 - - def is_consistent(self, cluster, spikes): - """Return whether the waveforms and spikes match.""" - path_w = self.disk_store._cluster_path(cluster, 'waveforms') - if not op.exists(path_w): - return False - file_size_w = os.stat(path_w).st_size - n_spikes_w = file_size_w // (self.n_channels * self.n_samples * 4) - if n_spikes_w != len(self.spikes_per_cluster[cluster]): - return False - return True - - def load(self, cluster, name='waveforms'): - """Load waveforms for a cluster. - - This uses the cluster store if possible, otherwise it falls back - to the model (much slower). - - """ - assert name == 'waveforms' - dtype = np.float32 - if self.disk_store: - data = self.disk_store.load(cluster, - name, - dtype, - self._shapes[name], - ) - if data is not None: - return data - # Fallback to load_spikes if the data could not be obtained from - # the store. - spikes = self._subset_spikes_cluster(cluster) - return self.load_spikes(spikes, name) - - def load_spikes(self, spikes, name): - """Load waveforms for an array of spikes.""" - assert name == 'waveforms' - data = getattr(self.model, name) - shape = self._shapes[name] - if data is not None and (isinstance(spikes, slice) or len(spikes)): - return data[spikes] - # Default waveforms. - return _default_array(shape, value=0., n_spikes=len(spikes)) - - -class ClusterStatistics(FixedSizeItem): - """Manage cluster statistics.""" - name = 'statistics' - - def __init__(self, *args, **kwargs): - super(ClusterStatistics, self).__init__(*args, **kwargs) - self._funcs = {} - self.n_channels = len(self.model.channel_order) - self.n_samples_waveforms = self.model.n_samples_waveforms - self.n_features = self.model.n_features_per_channel - self._shapes = { - 'mean_masks': (-1, self.n_channels), - 'mean_features': (-1, self.n_channels, self.n_features), - 'mean_waveforms': (-1, self.n_samples_waveforms, self.n_channels), - 'mean_probe_position': (-1, 2), - 'main_channels': (-1, self.n_channels), - 'n_unmasked_channels': (-1,), - 'n_spikes': (-1,), - } - self.fields = list(self._shapes.keys()) - - def add(self, name, func, shape): - """Add a new statistics.""" - self.fields.append(name) - self._funcs[name] = func - self._shapes[name] = shape - - def remove(self, name): - """Remove a statistics.""" - self.fields.remove(name) - del self.funcs[name] - - def _load_mean(self, cluster, name): - if self.disk_store: - # Load from the disk store if possible. - return self.disk_store.load(cluster, name, np.float32, - self._shapes[name][1:]) - else: - # Otherwise compute the mean directly from the model. - item_name = ('waveforms' - if name == 'mean_waveforms' - else 'features and masks') - item = self.cluster_store.items[item_name] - return getattr(item, name)(cluster) - - def mean_masks(self, cluster): - return self._load_mean(cluster, 'mean_masks') - - def mean_features(self, cluster): - return self._load_mean(cluster, 'mean_features') - - def mean_waveforms(self, cluster): - return self._load_mean(cluster, 'mean_waveforms') - - def unmasked_channels(self, cluster): - mean_masks = self.load(cluster, 'mean_masks') - return np.nonzero(mean_masks > .1)[0] - - def n_unmasked_channels(self, cluster): - unmasked_channels = self.load(cluster, 'unmasked_channels') - return len(unmasked_channels) - - def mean_probe_position(self, cluster): - mean_masks = self.load(cluster, 'mean_masks') - if mean_masks is not None and mean_masks.shape[0]: - mean_cluster_position = (np.sum(self.model.probe.positions * - mean_masks[:, np.newaxis], axis=0) / - max(1, np.sum(mean_masks))) - else: - mean_cluster_position = np.zeros((2,), dtype=np.float32) - - return mean_cluster_position - - def n_spikes(self, cluster): - return len(self._spikes_per_cluster[cluster]) - - def main_channels(self, cluster): - mean_masks = self.load(cluster, 'mean_masks') - unmasked_channels = self.load(cluster, 'unmasked_channels') - # Weighted mean of the channels, weighted by the mean masks. - main_channels = np.argsort(mean_masks)[::-1] - main_channels = np.array([c for c in main_channels - if c in unmasked_channels]) - return main_channels - - def store_default(self, cluster): - """Compute the built-in statistics for one cluster. - - The mean masks, features, and waveforms are loaded from disk. - - """ - self.memory_store.store(cluster, - mean_masks=self.mean_masks(cluster), - mean_features=self.mean_features(cluster), - mean_waveforms=self.mean_waveforms(cluster), - ) - - n_spikes = self.n_spikes(cluster) - - # Note: some of the default statistics below rely on other statistics - # computed previously and stored in the memory store. - unmasked_channels = self.unmasked_channels(cluster) - self.memory_store.store(cluster, - unmasked_channels=unmasked_channels) - - n_unmasked_channels = self.n_unmasked_channels(cluster) - self.memory_store.store(cluster, - n_unmasked_channels=n_unmasked_channels) - - mean_probe_position = self.mean_probe_position(cluster) - main_channels = self.main_channels(cluster) - n_unmasked_channels = self.n_unmasked_channels(cluster) - self.memory_store.store(cluster, - mean_probe_position=mean_probe_position, - main_channels=main_channels, - n_unmasked_channels=n_unmasked_channels, - n_spikes=n_spikes, - ) - - def store(self, cluster, name=None): - """Compute all statistics for one cluster.""" - if name is None: - self.store_default(cluster) - for func in self._funcs.values(): - func(cluster) - else: - assert name in self._funcs - self._funcs[name](cluster) - - def load(self, cluster, name): - """Return a cluster statistic.""" - if cluster in self.memory_store: - return self.memory_store.load(cluster, name) - else: - # If the item hadn't been stored, compute it here by calling - # the corresponding method. - if hasattr(self, name): - return getattr(self, name)(cluster) - # Custom statistic. - else: - return self._funcs[name](cluster) - - def is_consistent(self, cluster, spikes): - """Return whether a cluster is consistent.""" - return cluster in self.memory_store - - -#------------------------------------------------------------------------------ -# Store creation -#------------------------------------------------------------------------------ - -def create_store(model, - spikes_per_cluster=None, - path=None, - features_masks_chunk_size=100000, - waveforms_n_spikes_max=None, - waveforms_excerpt_size=None, - waveforms_chunk_size=1000, - ): - """Create a cluster store for a model.""" - assert spikes_per_cluster is not None - cluster_store = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=path, - ) - - # Create the FeatureMasks store item. - # chunk_size is the number of spikes to load at once from - # the features_masks array. - cluster_store.register_item(FeatureMasks, - chunk_size=features_masks_chunk_size, - ) - cluster_store.register_item(Waveforms, - n_spikes_max=waveforms_n_spikes_max, - excerpt_size=waveforms_excerpt_size, - chunk_size=waveforms_chunk_size, - ) - cluster_store.register_item(ClusterStatistics) - return cluster_store diff --git a/phy/io/kwik/tests/__init__.py b/phy/io/kwik/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/io/kwik/tests/test_creator.py b/phy/io/kwik/tests/test_creator.py deleted file mode 100644 index 4ca7b7683..000000000 --- a/phy/io/kwik/tests/test_creator.py +++ /dev/null @@ -1,217 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of Kwik file creator.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac -from pytest import raises - -from ...h5 import open_h5 -from ..creator import KwikCreator, _write_by_chunk, create_kwik -from ..mock import (artificial_spike_samples, - artificial_features, - artificial_masks, - ) - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_write_by_chunk(tempdir): - n = 5 - arrs = [np.random.rand(i + 1, 3).astype(np.float32) for i in range(n)] - n_tot = sum(_.shape[0] for _ in arrs) - - path = op.join(tempdir, 'test.h5') - with open_h5(path, 'w') as f: - ds = f.write('/test', shape=(n_tot, 3), dtype=np.float32) - _write_by_chunk(ds, arrs) - with open_h5(path, 'r') as f: - ds = f.read('/test')[...] - offset = 0 - for i, arr in enumerate(arrs): - size = arr.shape[0] - assert size == (i + 1) - ac(ds[offset:offset + size, ...], arr) - offset += size - - -def test_creator_simple(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - - # Test create empty files. - creator.create_empty() - assert op.exists(basename + '.kwik') - assert op.exists(basename + '.kwx') - - # Test metadata. - creator.set_metadata('/application_data/spikedetekt', - a=1, b=2., c=[0, 1]) - - with open_h5(creator.kwik_path, 'r') as f: - assert f.read_attr('/application_data/spikedetekt', 'a') == 1 - assert f.read_attr('/application_data/spikedetekt', 'b') == 2. - ae(f.read_attr('/application_data/spikedetekt', 'c'), [0, 1]) - - # Test add spikes in one block. - n_spikes = 100 - n_channels = 8 - n_features = 3 - - spike_samples = artificial_spike_samples(n_spikes) - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - - creator.add_spikes(group=0, - spike_samples=spike_samples, - features=features.astype(np.float32), - masks=masks.astype(np.float32), - n_channels=n_channels, - n_features=n_features, - ) - - # Test the spike samples. - with open_h5(creator.kwik_path, 'r') as f: - s = f.read('/channel_groups/0/spikes/time_samples')[...] - assert s.dtype == np.uint64 - ac(s, spike_samples) - - # Test the features and masks. - with open_h5(creator.kwx_path, 'r') as f: - fm = f.read('/channel_groups/0/features_masks')[...] - assert fm.dtype == np.float32 - ac(fm[:, :, 0], features.reshape((-1, n_channels * n_features))) - ac(fm[:, ::n_features, 1], masks) - - # Spikes can only been added once. - with raises(RuntimeError): - creator.add_spikes(group=0, - spike_samples=spike_samples, - n_channels=n_channels, - n_features=n_features) - - -def test_creator_chunks(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - creator.create_empty() - - # Test add spikes in one block. - n_spikes = 100 - n_channels = 8 - n_features = 3 - - spike_samples = artificial_spike_samples(n_spikes) - features = artificial_features(n_spikes, n_channels, - n_features).astype(np.float32) - masks = artificial_masks(n_spikes, n_channels).astype(np.float32) - - def _split(arr): - n = n_spikes // 10 - return [arr[k:k + n, ...] for k in range(0, n_spikes, n)] - - creator.add_spikes(group=0, - spike_samples=spike_samples, - features=_split(features), - masks=_split(masks), - n_channels=n_channels, - n_features=n_features, - ) - - # Test the spike samples. - with open_h5(creator.kwik_path, 'r') as f: - s = f.read('/channel_groups/0/spikes/time_samples')[...] - assert s.dtype == np.uint64 - ac(s, spike_samples) - - # Test the features and masks. - with open_h5(creator.kwx_path, 'r') as f: - fm = f.read('/channel_groups/0/features_masks')[...] - assert fm.dtype == np.float32 - ac(fm[:, :, 0], features.reshape((-1, n_channels * n_features))) - ac(fm[:, ::n_features, 1], masks) - - -def test_creator_add_no_spikes(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - creator.create_empty() - - creator.add_spikes(group=0, n_channels=4, n_features=2) - - -def test_creator_metadata(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - - # Add recording. - creator.add_recording(0, start_sample=20000, sample_rate=10000) - - with open_h5(creator.kwik_path, 'r') as f: - assert f.read_attr('/recordings/0', 'start_sample') == 20000 - assert f.read_attr('/recordings/0', 'start_time') == 2. - assert f.read_attr('/recordings/0', 'sample_rate') == 10000 - - # Add probe. - channels = [0, 3, 1] - graph = [[0, 3], [1, 0]] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - 'geometry': {0: (10, 10)}, - }}} - creator.set_probe(probe) - - with open_h5(creator.kwik_path, 'r') as f: - ae(f.read_attr('/channel_groups/0', 'channel_order'), channels) - ae(f.read_attr('/channel_groups/0', 'adjacency_graph'), graph) - for channel in channels: - path = '/channel_groups/0/channels/{:d}'.format(channel) - position = (10, 10) if channel == 0 else (0, channel) - ae(f.read_attr(path, 'position'), position) - - # Add clustering. - creator.add_clustering(group=0, - name='main', - spike_clusters=np.arange(10), - cluster_groups={3: 1}, - ) - - with open_h5(creator.kwik_path, 'r') as f: - sc = f.read('/channel_groups/0/spikes/clusters/main') - ae(sc, np.arange(10)) - - for cluster in range(10): - path = '/channel_groups/0/clusters/main/{:d}'.format(cluster) - cg = f.read_attr(path, 'cluster_group') - assert cg == 3 if cluster != 3 else 1 - - -def test_create_kwik(tempdir): - - channels = [0, 3, 1] - graph = [[0, 3], [1, 0]] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - 'geometry': {0: (10, 10)}, - }}} - - kwik_path = op.join(tempdir, 'test.kwik') - create_kwik(kwik_path=kwik_path, - probe=probe, - sample_rate=20000, - ) diff --git a/phy/io/kwik/tests/test_mock.py b/phy/io/kwik/tests/test_mock.py deleted file mode 100644 index 49563e559..000000000 --- a/phy/io/kwik/tests/test_mock.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of mock Kwik file creation.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ...h5 import open_h5 -from ..mock import create_mock_kwik - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_create_kwik(tempdir): - - n_clusters = 10 - n_spikes = 50 - n_channels = 28 - n_fets = 2 - n_samples_traces = 3000 - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=n_clusters, - n_spikes=n_spikes, - n_channels=n_channels, - n_features_per_channel=n_fets, - n_samples_traces=n_samples_traces) - - with open_h5(filename) as f: - assert f diff --git a/phy/io/kwik/tests/test_model.py b/phy/io/kwik/tests/test_model.py deleted file mode 100644 index b51ab3933..000000000 --- a/phy/io/kwik/tests/test_model.py +++ /dev/null @@ -1,384 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of Kwik file opening routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ....electrode.mea import MEA, staggered_positions -from ....utils.logging import StringLogger, register, unregister -from ..model import (KwikModel, - _list_channel_groups, - _list_channels, - _list_recordings, - _list_clusterings, - _concatenate_spikes, - ) -from ..mock import create_mock_kwik -from ..creator import create_kwik - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -_N_CLUSTERS = 10 -_N_SPIKES = 100 -_N_CHANNELS = 28 -_N_FETS = 2 -_N_SAMPLES_TRACES = 10000 - - -def test_kwik_utility(tempdir): - - channels = list(range(_N_CHANNELS)) - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - model = KwikModel(filename) - - model._kwik.open() - assert _list_channel_groups(model._kwik.h5py_file) == [1] - assert _list_recordings(model._kwik.h5py_file) == [0, 1] - assert _list_clusterings(model._kwik.h5py_file, 1) == ['main', - 'original', - ] - assert _list_channels(model._kwik.h5py_file, 1) == channels - - -def test_concatenate_spikes(): - spikes = [2, 3, 5, 0, 11, 1] - recs = [0, 0, 0, 1, 1, 2] - offsets = [0, 7, 100] - concat = _concatenate_spikes(spikes, recs, offsets) - ae(concat, [2, 3, 5, 7, 18, 101]) - - -def test_kwik_empty(tempdir): - - channels = [0, 3, 1] - graph = [[0, 3], [1, 0]] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - 'geometry': {0: (10, 10)}, - }}} - sample_rate = 20000 - - kwik_path = op.join(tempdir, 'test.kwik') - create_kwik(kwik_path=kwik_path, probe=probe, sample_rate=sample_rate) - - model = KwikModel(kwik_path) - ae(model.channels, sorted(channels)) - ae(model.channel_order, channels) - - assert model.sample_rate == sample_rate - assert model.n_channels == 3 - assert model.spike_samples is None - assert model.has_kwx() - assert model.n_spikes == 0 - assert model.n_clusters == 0 - model.describe() - - -def test_kwik_open_full(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - with raises(ValueError): - KwikModel() - - # NOTE: n_channels - 2 because we use a special channel order. - nc = _N_CHANNELS - 2 - - # Test implicit open() method. - kwik = KwikModel(filename) - kwik.describe() - - kwik.metadata - ae(kwik.channels, np.arange(_N_CHANNELS)) - assert kwik.n_channels == _N_CHANNELS - assert kwik.n_spikes == _N_SPIKES - ae(kwik.channel_order, np.arange(1, _N_CHANNELS - 1)[::-1]) - - assert kwik.spike_samples.shape == (_N_SPIKES,) - assert kwik.spike_samples.dtype == np.uint64 - - # Make sure the spike samples are increasing, even with multiple - # recordings. - # WARNING: need to cast to int64, otherwise negative values will - # overflow and be positive, making the test pass while the - # spike samples are *not* increasing! - assert np.all(np.diff(kwik.spike_samples.astype(np.int64)) >= 0) - - assert kwik.spike_times.shape == (_N_SPIKES,) - assert kwik.spike_times.dtype == np.float64 - - assert kwik.spike_recordings.shape == (_N_SPIKES,) - assert kwik.spike_recordings.dtype == np.uint16 - - assert kwik.spike_clusters.shape == (_N_SPIKES,) - assert kwik.spike_clusters.min() in (0, 1, 2) - assert kwik.spike_clusters.max() in(_N_CLUSTERS - 2, _N_CLUSTERS - 1) - - assert kwik.features.shape == (_N_SPIKES, nc * _N_FETS) - kwik.features[0, ...] - - assert kwik.masks.shape == (_N_SPIKES, nc) - - assert kwik.traces.shape == (_N_SAMPLES_TRACES, _N_CHANNELS) - - assert kwik.waveforms[0].shape == (1, 40, nc) - assert kwik.waveforms[-1].shape == (1, 40, nc) - assert kwik.waveforms[-10].shape == (1, 40, nc) - assert kwik.waveforms[10].shape == (1, 40, nc) - assert kwik.waveforms[[10, 20]].shape == (2, 40, nc) - with raises(IndexError): - kwik.waveforms[_N_SPIKES + 10] - - with raises(ValueError): - kwik.clustering = 'foo' - with raises(ValueError): - kwik.channel_group = 42 - assert kwik.n_recordings == 2 - - # Test cluster groups. - for cluster in range(_N_CLUSTERS): - print(cluster) - assert kwik.cluster_metadata.group(cluster) == min(cluster, 3) - for cluster, group in kwik.cluster_groups.items(): - assert group == min(cluster, 3) - - # Test probe. - assert isinstance(kwik.probe, MEA) - assert kwik.probe.positions.shape == (nc, 2) - ae(kwik.probe.positions, staggered_positions(_N_CHANNELS)[1:-1][::-1]) - - kwik.close() - - -def test_kwik_open_no_kwx(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES, - with_kwx=False) - - # Test implicit open() method. - kwik = KwikModel(filename) - kwik.close() - - -def test_kwik_open_no_kwd(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES, - with_kwd=False) - - # Test implicit open() method. - kwik = KwikModel(filename) - l = StringLogger(level='debug') - register(l) - kwik.waveforms[:] - # Enusure that there is no error message. - assert not str(l).strip() - kwik.close() - unregister(l) - - -def test_kwik_save(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - - cluster_groups = {cluster: kwik.cluster_metadata.group(cluster) - for cluster in range(_N_CLUSTERS)} - sc_0 = kwik.spike_clusters.copy() - sc_1 = sc_0.copy() - new_cluster = _N_CLUSTERS + 10 - sc_1[_N_SPIKES // 2:] = new_cluster - cluster_groups[new_cluster] = 7 - ae(kwik.spike_clusters, sc_0) - - assert kwik.cluster_metadata.group(new_cluster) == 3 - kwik.save(sc_1, cluster_groups, {'test': (1, 2.)}) - ae(kwik.spike_clusters, sc_1) - assert kwik.cluster_metadata.group(new_cluster) == 7 - - kwik.close() - - kwik = KwikModel(filename) - ae(kwik.spike_clusters, sc_1) - assert kwik.cluster_metadata.group(new_cluster) == 7 - ae(kwik.clustering_metadata['test'], [1, 2]) - - -def test_kwik_clusterings(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - assert kwik.clusterings == ['main', 'original'] - - # The default clustering is 'main'. - assert kwik.n_spikes == _N_SPIKES - assert kwik.n_clusters == _N_CLUSTERS - assert kwik.cluster_groups[_N_CLUSTERS - 1] == 3 - ae(kwik.cluster_ids, np.arange(_N_CLUSTERS)) - - # Change clustering. - kwik.clustering = 'original' - n_clu = kwik.n_clusters - assert kwik.n_spikes == _N_SPIKES - # Some clusters may be empty with a small number of spikes like here - assert _N_CLUSTERS * 2 - 4 <= n_clu <= _N_CLUSTERS * 2 - assert kwik.cluster_groups[n_clu - 1] == 3 - assert len(kwik.cluster_ids) == n_clu - - -def test_kwik_manage_clusterings(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - spike_clusters = kwik.spike_clusters - assert kwik.clusterings == ['main', 'original'] - - # Test renaming. - kwik.clustering = 'original' - with raises(ValueError): - kwik.rename_clustering('a', 'b') - with raises(ValueError): - kwik.rename_clustering('original', 'b') - with raises(ValueError): - kwik.rename_clustering('main', 'original') - - kwik.clustering = 'main' - kwik.rename_clustering('original', 'original_2') - assert kwik.clusterings == ['main', 'original_2'] - with raises(ValueError): - kwik.clustering = 'original' - kwik.clustering = 'original_2' - n_clu = kwik.n_clusters - if (n_clu - 1) in kwik.cluster_groups: - assert kwik.cluster_groups[n_clu - 1] == 3 - assert len(kwik.cluster_ids) == n_clu - - # Test copy. - with raises(ValueError): - kwik.copy_clustering('a', 'b') - with raises(ValueError): - kwik.copy_clustering('original', 'b') - with raises(ValueError): - kwik.copy_clustering('main', 'original_2') - - # You cannot move the current clustering, but you can copy it. - with raises(ValueError): - kwik.rename_clustering('original_2', 'original_2_copy') - kwik.copy_clustering('original_2', 'original_2_copy') - kwik.delete_clustering('original_2_copy') - - kwik.clustering = 'main' - kwik.copy_clustering('original_2', 'original') - assert kwik.clusterings == ['main', 'original', 'original_2'] - - kwik.clustering = 'original' - cg = kwik.cluster_groups - ci = kwik.cluster_ids - - kwik.clustering = 'original_2' - assert kwik.cluster_groups == cg - ae(kwik.cluster_ids, ci) - - # Test delete. - with raises(ValueError): - kwik.delete_clustering('a') - kwik.delete_clustering('original') - kwik.clustering = 'main' - kwik.delete_clustering('original_2') - assert kwik.clusterings == ['main', 'original'] - - # Test add. - sc = np.ones(_N_SPIKES, dtype=np.int32) - sc[1] = sc[-2] = 3 - kwik.add_clustering('new', sc) - ae(kwik.spike_clusters, spike_clusters) - kwik.clustering = 'new' - ae(kwik.spike_clusters, sc) - assert kwik.n_clusters == 2 - ae(kwik.cluster_ids, [1, 3]) - assert kwik.cluster_groups == {1: 3, - 3: 3} - - -def test_kwik_manage_cluster_groups(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - - with raises(ValueError): - kwik.delete_cluster_group(2) - with raises(ValueError): - kwik.add_cluster_group(1, 'new') - with raises(ValueError): - kwik.rename_cluster_group(1, 'renamed') - - kwik.add_cluster_group(4, 'new') - kwik.rename_cluster_group(4, 'renamed') - - kwik.delete_cluster_group(4) - with raises(ValueError): - kwik.delete_cluster_group(4) diff --git a/phy/io/kwik/tests/test_sparse_kk2.py b/phy/io/kwik/tests/test_sparse_kk2.py deleted file mode 100644 index 6766719a1..000000000 --- a/phy/io/kwik/tests/test_sparse_kk2.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of sparse KK2 routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_array_almost_equal as aee - -from ..sparse_kk2 import sparsify_features_masks - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_sparsify_features_masks(): - # Data as arrays - fet = np.array([[1, 3, 5, 7, 11], - [6, 7, 8, 9, 10], - [11, 12, 13, 14, 15], - [16, 17, 18, 19, 20]], dtype=float) - - fmask = np.array([[1, 0.5, 0, 0, 0], - [0, 1, 1, 0, 0], - [0, 1, 1, 0, 0], - [0, 0, 0, 1, 1]], dtype=float) - - # Normalisation to [0, 1] - fet_original = fet.copy() - fet = (fet - np.amin(fet, axis=0)) / \ - (np.amax(fet, axis=0) - np.amin(fet, axis=0)) - - nanmasked_fet = fet.copy() - nanmasked_fet[fmask > 0] = np.nan - - # Correct computation of the corrected data and correction term - x = fet - w = fmask - nu = np.nanmean(nanmasked_fet, axis=0)[np.newaxis, :] - sigma2 = np.nanvar(nanmasked_fet, axis=0)[np.newaxis, :] - y = w * x + (1 - w) * nu - z = w * x * x + (1 - w) * (nu * nu + sigma2) - correction_terms = z - y * y - features = y - - data = sparsify_features_masks(fet_original, fmask) - - aee(data.noise_mean, np.nanmean(nanmasked_fet, axis=0)) - aee(data.noise_variance, np.nanvar(nanmasked_fet, axis=0)) - assert np.amin(data.features) == 0 - assert np.amax(data.features) == 1 - assert len(data.offsets) == 5 - - for i in range(4): - data_u = data.unmasked[data.offsets[i]:data.offsets[i + 1]] - true_u, = fmask[i, :].nonzero() - ae(data_u, true_u) - data_f = data.features[data.offsets[i]:data.offsets[i + 1]] - true_f = fet[i, data_u] - ae(data_f, true_f) - data_m = data.masks[data.offsets[i]:data.offsets[i + 1]] - true_m = fmask[i, data_u] - ae(data_m, true_m) - - # PART 2: Check that converting to SparseData is correct - # compute unique masks and apply correction terms to data - data = data.to_sparse_data() - - assert data.num_spikes == 4 - assert data.num_features == 5 - assert data.num_masks == 3 - - for i in range(4): - data_u = data.unmasked[data.unmasked_start[i]:data.unmasked_end[i]] - true_u, = fmask[i, :].nonzero() - ae(data_u, true_u) - data_f = data.features[data.values_start[i]:data.values_end[i]] - true_f = features[i, data_u] - aee(data_f, true_f) - data_c = data.correction_terms[data.values_start[i]:data.values_end[i]] - true_c = correction_terms[i, data_u] - aee(data_c, true_c) - data_m = data.masks[data.values_start[i]:data.values_end[i]] - true_m = fmask[i, data_u] - aee(data_m, true_m) diff --git a/phy/io/kwik/tests/test_store_items.py b/phy/io/kwik/tests/test_store_items.py deleted file mode 100644 index 026afa293..000000000 --- a/phy/io/kwik/tests/test_store_items.py +++ /dev/null @@ -1,167 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of Kwik store items.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae - -from ....utils.array import _spikes_per_cluster, _spikes_in_clusters -from ..model import (KwikModel, - ) -from ..mock import create_mock_kwik -from ..store_items import create_store - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -_N_CLUSTERS = 5 -_N_SPIKES = 100 -_N_CHANNELS = 28 -_N_FETS = 2 -_N_SAMPLES_TRACES = 10000 - - -def test_kwik_store(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - nc = _N_CHANNELS - 2 - nf = _N_FETS - - model = KwikModel(filename) - ns = model.n_samples_waveforms - spc = _spikes_per_cluster(np.arange(_N_SPIKES), model.spike_clusters) - clusters = sorted(spc.keys()) - - # We initialize the ClusterStore. - cs = create_store(model, - path=tempdir, - spikes_per_cluster=spc, - features_masks_chunk_size=15, - waveforms_chunk_size=15, - waveforms_n_spikes_max=5, - waveforms_excerpt_size=2, - ) - - # We add a custom statistic function. - def mean_features_bis(cluster): - fet = cs.features(cluster) - cs.memory_store.store(cluster, mean_features_bis=fet.mean(axis=0)) - - cs.items['statistics'].add('mean_features_bis', - mean_features_bis, - (-1, nc), - ) - cs.register_field('mean_features_bis', 'statistics') - - waveforms_item = cs.items['waveforms'] - - # Now we generate the store. - cs.generate() - - # One cluster at a time. - for cluster in clusters: - # Check features. - fet_store = cs.features(cluster) - fet_expected = model.features[spc[cluster]].reshape((-1, nc, nf)) - ae(fet_store, fet_expected) - - # Check masks. - masks_store = cs.masks(cluster) - masks_expected = model.masks[spc[cluster]] - ae(masks_store, masks_expected) - - # Check waveforms. - waveforms_store = cs.waveforms(cluster) - # Find the spikes. - spikes = waveforms_item.spikes_per_cluster[cluster] - ae(waveforms_item.spikes_per_cluster[cluster], spikes) - waveforms_expected = model.waveforms[spikes] - ae(waveforms_store, waveforms_expected) - - # Check some statistics. - ae(cs.mean_features(cluster), fet_expected.mean(axis=0)) - ae(cs.mean_features_bis(cluster), fet_expected.mean(axis=0)) - ae(cs.mean_masks(cluster), masks_expected.mean(axis=0)) - ae(cs.mean_waveforms(cluster), waveforms_expected.mean(axis=0)) - - assert cs.n_unmasked_channels(cluster) >= 0 - assert cs.main_channels(cluster).shape == (nc,) - assert cs.mean_probe_position(cluster).shape == (2,) - - # Multiple clusters. - for clusters in (clusters[::2], [clusters[0]], []): - n_clusters = len(clusters) - spikes = _spikes_in_clusters(model.spike_clusters, clusters) - n_spikes = len(spikes) - - # Features. - if n_clusters: - fet_expected = model.features[spikes].reshape((n_spikes, - nc, nf)) - else: - fet_expected = np.zeros((0, nc, nf), dtype=np.float32) - ae(cs.load('features', clusters=clusters), fet_expected) - ae(cs.load('features', spikes=spikes), fet_expected) - - # Masks. - if n_clusters: - masks_expected = model.masks[spikes] - else: - masks_expected = np.ones((0, nc), dtype=np.float32) - ae(cs.load('masks', clusters=clusters), masks_expected) - ae(cs.load('masks', spikes=spikes), masks_expected) - - # Waveforms. - spc = waveforms_item.spikes_per_cluster - if n_clusters: - spikes = _spikes_in_clusters(spc, clusters) - waveforms_expected = model.waveforms[spikes] - else: - spikes = np.array([], dtype=np.int64) - waveforms_expected = np.zeros((0, ns, nc), dtype=np.float32) - ae(cs.load('waveforms', clusters=clusters), waveforms_expected) - ae(cs.load('waveforms', spikes=spikes), waveforms_expected) - - assert (cs.load('mean_features', clusters=clusters).shape == - (n_clusters, nc, nf)) - assert (cs.load('mean_features_bis', clusters=clusters).shape == - (n_clusters, nc, nf) if n_clusters else (0,)) - assert (cs.load('mean_masks', clusters=clusters).shape == - (n_clusters, nc)) - assert (cs.load('mean_waveforms', clusters=clusters).shape == - (n_clusters, ns, nc)) - - assert (cs.load('n_unmasked_channels', clusters=clusters).shape == - (n_clusters,)) - assert (cs.load('main_channels', clusters=clusters).shape == - (n_clusters, nc)) - assert (cs.load('mean_probe_position', clusters=clusters).shape == - (n_clusters, 2)) - - # Slice spikes. - spikes = slice(None, None, 3) - - # Features. - fet_expected = model.features[spikes].reshape((-1, nc, nf)) - ae(cs.load('features', spikes=spikes), fet_expected) - - # Masks. - masks_expected = model.masks[spikes] - ae(cs.load('masks', spikes=spikes), masks_expected) - - # Waveforms. - waveforms_expected = model.waveforms[spikes] - ae(cs.load('waveforms', spikes=spikes), waveforms_expected) diff --git a/phy/io/sparse.py b/phy/io/sparse.py deleted file mode 100644 index 01fe470f7..000000000 --- a/phy/io/sparse.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Sparse matrix structures.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..utils._types import _as_array - - -#------------------------------------------------------------------------------ -# Sparse CSR -#------------------------------------------------------------------------------ - -def _csr_from_dense(dense): - """Create a CSR structure from a dense NumPy array.""" - raise NotImplementedError(("Creating CSR from dense matrix is not " - "implemented yet.")) - - -def _check_sparse_components(shape=None, data=None, - channels=None, spikes_ptr=None): - """Ensure the components of a sparse matrix are consistent.""" - if not isinstance(shape, (list, tuple, np.ndarray)): - raise ValueError("The shape is required and should be an " - "array-like list ({0} was given).".format(shape)) - if len(shape) != data.ndim + 1: - raise ValueError("'shape' {shape} and {ndim}D-array 'data' are " - "not consistent.".format(shape=shape, - ndim=data.ndim)) - if channels.ndim != 1: - raise ValueError("'channels' should be a 1D array.") - if spikes_ptr.ndim != 1: - raise ValueError("'spikes_ptr' should be a 1D array.") - nitems = data.shape[-1] - if nitems > np.prod(shape): - raise ValueError("'data' is too large (n={0:d}) ".format(nitems) + - " for the specified shape " - "{shape}.".format(shape=shape)) - if len(channels) != (shape[1]): - raise ValueError(("'channels' should have " - "{nexp} elements, " - "not {nact}.").format(nexp=(shape[1]), - nact=len(channels))) - if len(spikes_ptr) != (shape[0] + 1): - raise ValueError(("'spikes_ptr' should have " - "{nexp} elements, " - "not {nact}.").format(nexp=(shape[0] + 1), - nact=len(spikes_ptr))) - if len(data) != len(channels): - raise ValueError("'data' (n={0:d}) and ".format(len(data)) + - "'channels' (n={0:d}) ".format(len(channels)) + - "should have the same length") - return True - - -class SparseCSR(object): - """Sparse CSR matrix data structure.""" - def __init__(self, shape=None, data=None, channels=None, spikes_ptr=None): - # Ensure the arguments are all arrays. - data = _as_array(data) - channels = _as_array(channels) - spikes_ptr = _as_array(spikes_ptr) - # Ensure the arguments are consistent. - assert _check_sparse_components(shape=shape, - data=data, - channels=channels, - spikes_ptr=spikes_ptr) - nitems = data.shape[-1] - # Structure info. - self._nitems = nitems - # Create the structure. - self._shape = shape - self._data = data - self._channels = channels - self._spikes_ptr = spikes_ptr - - @property - def shape(self): - """Shape of the array.""" - return self._shape - - def __eq__(self, other): - return np.all((self._data == other._data) & - (self._channels == other._channels) & - (self._spikes_ptr == other._spikes_ptr)) - - # I/O methods - # ------------------------------------------------------------------------- - - def save_h5(self, f, path): - """Save the array in an HDF5 file.""" - f.write_attr(path, 'sparse_type', 'csr') - f.write_attr(path, 'shape', self._shape) - f.write(path + '/data', self._data) - f.write(path + '/channels', self._channels) - f.write(path + '/spikes_ptr', self._spikes_ptr) - - @staticmethod - def load_h5(f, path): - """Load a SparseCSR array from an HDF5 file.""" - f.read_attr(path, 'sparse_type') == 'csr' - shape = f.read_attr(path, 'shape') - data = f.read(path + '/data')[...] - channels = f.read(path + '/channels')[...] - spikes_ptr = f.read(path + '/spikes_ptr')[...] - return SparseCSR(shape=shape, - data=data, - channels=channels, - spikes_ptr=spikes_ptr) - - -def csr_matrix(dense=None, shape=None, - data=None, channels=None, spikes_ptr=None): - """Create a CSR matrix from a dense matrix, or from sparse data.""" - if dense is not None: - # Ensure 'dense' is a ndarray. - dense = _as_array(dense) - return _csr_from_dense(dense) - if data is None or channels is None or spikes_ptr is None: - raise ValueError("data, channels, and spikes_ptr must be specified.") - return SparseCSR(shape=shape, - data=data, channels=channels, spikes_ptr=spikes_ptr) - - -#------------------------------------------------------------------------------ -# HDF5 functions -#------------------------------------------------------------------------------ - -def save_h5(f, path, arr, overwrite=False): - """Save a sparse array into an HDF5 file.""" - if isinstance(arr, SparseCSR): - arr.save_h5(f, path) - elif isinstance(arr, np.ndarray): - f.write(path, arr, overwrite=overwrite) - else: - raise ValueError("The array should be a SparseCSR or " - "dense NumPy array.") - - -def load_h5(f, path): - """Load a sparse array from an HDF5 file.""" - # Sparse array. - if f.has_attr(path, 'sparse_type'): - if f.read_attr(path, 'sparse_type') == 'csr': - return SparseCSR.load_h5(f, path) - else: - raise NotImplementedError("Only SparseCSR arrays are implemented " - "currently.") - # Regular dense dataset. - else: - return f.read(path)[...] diff --git a/phy/io/store.py b/phy/io/store.py deleted file mode 100644 index 1dd8ab799..000000000 --- a/phy/io/store.py +++ /dev/null @@ -1,752 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Cluster store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from collections import OrderedDict -import os -import os.path as op -import re - -import numpy as np -from six import string_types - -from ..utils._types import _as_int, _is_integer, _is_array_like -from ..utils._misc import _load_json, _save_json -from ..utils.array import (PerClusterData, _spikes_in_clusters, - _subset_spc, _load_ndarray, - _save_arrays, _load_arrays, - ) -from ..utils.event import ProgressReporter -from ..utils.logging import debug, info, warn -from ..utils.settings import _ensure_dir_exists - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _directory_size(path): - """Return the total size in bytes of a directory.""" - total_size = 0 - for dirpath, dirnames, filenames in os.walk(path): - for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) - return total_size - - -def _file_cluster_id(path): - return int(op.splitext(op.basename(path))[0]) - - -def _default_array(shape, value=0, n_spikes=0, dtype=np.float32): - shape = (n_spikes,) + shape[1:] - out = np.empty(shape, dtype=dtype) - out.fill(value) - return out - - -def _assert_per_cluster_data_compatible(d_0, d_1): - n_0 = {k: len(v) for (k, v) in d_0.items()} - n_1 = {k: len(v) for (k, v) in d_1.items()} - if n_0 != n_1: - raise IOError("Inconsistency in the cluster store: please remove " - "`./.phy/cluster_store/`.") - - -#------------------------------------------------------------------------------ -# Base store -#------------------------------------------------------------------------------ - -class BaseStore(object): - def __init__(self, root_dir): - self._root_dir = op.realpath(root_dir) - _ensure_dir_exists(self._root_dir) - - def _location(self, filename): - return op.join(self._root_dir, filename) - - def _offsets_path(self, path): - assert path.endswith('.npy') - return op.splitext(path)[0] + '.offsets.npy' - - def _contains_multiple_arrays(self, path): - return op.exists(path) and op.exists(self._offsets_path(path)) - - def _save(self, filename, data): - """Save an array or list of arrays.""" - path = self._location(filename) - dir_path = op.dirname(path) - if not op.exists(dir_path): - os.makedirs(dir_path) - if isinstance(data, list): - if not data: - return - _save_arrays(path, data) - elif isinstance(data, np.ndarray): - dtype = data.dtype - if not data.size: - return - assert dtype != np.object - np.save(path, data) - - def _open(self, filename): - path = self._location(filename) - if not op.exists(path): - debug("File `{}` doesn't exist.".format(path)) - return - # Multiple arrays: - if self._contains_multiple_arrays(path): - return _load_arrays(path) - else: - return np.load(path) - - def _delete_file(self, filename): - path = self._location(filename) - if op.exists(path): - os.remove(path) - offsets_path = self._offsets_path(path) - if op.exists(offsets_path): - os.remove(offsets_path) - - def describe(self): - raise NotImplementedError() - - -#------------------------------------------------------------------------------ -# Memory store -#------------------------------------------------------------------------------ - -class MemoryStore(object): - """Store cluster-related data in memory.""" - def __init__(self): - self._ds = {} - - def store(self, cluster, **data): - """Store cluster-related data.""" - if cluster not in self._ds: - self._ds[cluster] = {} - self._ds[cluster].update(data) - - def load(self, cluster, keys=None): - """Load cluster-related data.""" - if keys is None: - return self._ds.get(cluster, {}) - else: - if isinstance(keys, string_types): - return self._ds.get(cluster, {}).get(keys, None) - assert isinstance(keys, (list, tuple)) - return {key: self._ds.get(cluster, {}).get(key, None) - for key in keys} - - @property - def cluster_ids(self): - """List of cluster ids in the store.""" - return sorted(self._ds.keys()) - - def erase(self, clusters): - """Delete some clusters from the store.""" - assert isinstance(clusters, list) - for cluster in clusters: - if cluster in self._ds: - del self._ds[cluster] - - def clear(self): - """Clear the store completely by deleting all clusters.""" - self.erase(self.cluster_ids) - - def __contains__(self, item): - return item in self._ds - - -#------------------------------------------------------------------------------ -# Disk store -#------------------------------------------------------------------------------ - -class DiskStore(object): - """Store cluster-related data in HDF5 files.""" - def __init__(self, directory): - assert directory is not None - # White list of extensions, to be sure we don't erase - # the wrong files. - self._allowed_extensions = set() - self._directory = op.realpath(op.expanduser(directory)) - - @property - def path(self): - return self._directory - - # Internal methods - # ------------------------------------------------------------------------- - - def _check_extension(self, file): - """Check that a file extension belongs to the white list of - allowed extensions. This is for safety.""" - _, extension = op.splitext(file) - extension = extension[1:] - if extension not in self._allowed_extensions: - raise RuntimeError("The extension '{0}' ".format(extension) + - "hasn't been registered.") - - def _cluster_path(self, cluster, key): - """Return the absolute path of a cluster in the disk store.""" - # TODO: subfolders - # Example of filename: `123.mykey`. - cluster = _as_int(cluster) - filename = '{0:d}.{1:s}'.format(cluster, key) - return op.realpath(op.join(self._directory, filename)) - - def _cluster_file_exists(self, cluster, key): - """Return whether a cluster file exists.""" - cluster = _as_int(cluster) - return op.exists(self._cluster_path(cluster, key)) - - def _is_cluster_file(self, path): - """Return whether a filename is of the form `xxx.yyy` where xxx is a - numbe and yyy belongs to the set of allowed extensions.""" - filename = op.basename(path) - extensions = '({0})'.format('|'.join(sorted(self._allowed_extensions))) - regex = r'^[0-9]+\.' + extensions + '$' - return re.match(regex, filename) is not None - - # Public methods - # ------------------------------------------------------------------------- - - def register_file_extensions(self, extensions): - """Register file extensions explicitely. This is a security - to make sure that we don't accidentally delete the wrong files.""" - if isinstance(extensions, string_types): - extensions = [extensions] - assert isinstance(extensions, list) - for extension in extensions: - self._allowed_extensions.add(extension) - - def store(self, cluster, append=False, **data): - """Store a NumPy array to disk.""" - # Do not create the file if there's nothing to write. - if not data: - return - mode = 'wb' if not append else 'ab' - for key, value in data.items(): - assert isinstance(value, np.ndarray) - path = self._cluster_path(cluster, key) - self._check_extension(path) - assert self._is_cluster_file(path) - with open(path, mode) as f: - value.tofile(f) - - def _get(self, cluster, key, dtype=None, shape=None): - # The cluster doesn't exist: return None for all keys. - if not self._cluster_file_exists(cluster, key): - return None - else: - return _load_ndarray(self._cluster_path(cluster, key), - dtype=dtype, shape=shape, lazy=False) - - def load(self, cluster, keys, dtype=None, shape=None): - """Load cluster-related data. Return a file handle, to be used - with np.fromfile() once the dtype and shape are known.""" - assert keys is not None - if isinstance(keys, string_types): - return self._get(cluster, keys, dtype=dtype, shape=shape) - assert isinstance(keys, list) - out = {} - for key in keys: - out[key] = self._get(cluster, key, dtype=dtype, shape=shape) - return out - - def save_file(self, filename, data): - path = op.realpath(op.join(self._directory, filename)) - _save_json(path, data) - - def load_file(self, filename): - path = op.realpath(op.join(self._directory, filename)) - if not op.exists(path): - return None - try: - return _load_json(path) - except ValueError as e: - warn("Error when loading `{}`: {}.".format(path, e)) - return None - - @property - def files(self): - """List of files present in the directory.""" - if not op.exists(self._directory): - return [] - return sorted(filter(self._is_cluster_file, - os.listdir(self._directory))) - - @property - def cluster_ids(self): - """List of cluster ids in the store.""" - clusters = set([_file_cluster_id(file) for file in self.files]) - return sorted(clusters) - - def erase(self, clusters): - """Delete some clusters from the store.""" - for cluster in clusters: - for key in self._allowed_extensions: - path = self._cluster_path(cluster, key) - if not op.exists(path): - continue - # Safety first: http://bit.ly/1ITJyF6 - self._check_extension(path) - if self._is_cluster_file(path): - os.remove(path) - else: - raise RuntimeError("The file {0} was about ".format(path) + - "to be removed, but it doesn't appear " - "to be a valid cluster file.") - - def clear(self): - """Clear the store completely by deleting all clusters.""" - self.erase(self.cluster_ids) - - -#------------------------------------------------------------------------------ -# Store item -#------------------------------------------------------------------------------ - -class StoreItem(object): - """A class describing information stored in the cluster store. - - Parameters - ---------- - - name : str - Name of the item. - fields : list - A list of field names. - model : Model - A `Model` instance for the current dataset. - memory_store : MemoryStore - The `MemoryStore` instance for the current dataset. - disk_store : DiskStore - The DiskStore instance for the current dataset. - - """ - name = 'item' - fields = None # list of names - - def __init__(self, cluster_store=None): - self.cluster_store = cluster_store - self.model = cluster_store.model - self.memory_store = cluster_store.memory_store - self.disk_store = cluster_store.disk_store - self._spikes_per_cluster = cluster_store.spikes_per_cluster - self._pr = ProgressReporter() - self._pr.set_progress_message('Initializing ' + self.name + - ': {progress:.1f}%.') - self._pr.set_complete_message(self.name.capitalize() + ' initialized.') - self._shapes = {} - - def empty_values(self, name): - """Return a null array of the right shape for a given field.""" - return _default_array(self._shapes.get(name, (-1,)), value=0.) - - @property - def progress_reporter(self): - """Progress reporter instance.""" - return self._pr - - @property - def spikes_per_cluster(self): - """Spikes per cluster.""" - return self._spikes_per_cluster - - @spikes_per_cluster.setter - def spikes_per_cluster(self, value): - self._spikes_per_cluster = value - - def spikes_in_clusters(self, clusters): - """Return the spikes belonging to clusters.""" - return _spikes_in_clusters(self._spikes_per_cluster, clusters) - - @property - def cluster_ids(self): - """Array of cluster ids.""" - return sorted(self._spikes_per_cluster) - - def is_consistent(self, cluster, spikes): - """Return whether the stored item is consistent. - - To be overriden.""" - return False - - def to_generate(self, mode=None): - """Return the list of clusters that need to be regenerated.""" - if mode in (None, 'default'): - return [cluster for cluster in self.cluster_ids - if not self.is_consistent(cluster, - self.spikes_per_cluster[cluster], - )] - elif mode == 'force': - return self.cluster_ids - elif mode == 'read-only': - return [] - else: - raise ValueError("`mode` should be None, `default`, `force`, " - "or `read-only`.") - - def store(self, cluster): - """Store data for a cluster from the model to the store. - - May be overridden. - - """ - pass - - def store_all(self, mode=None, **kwargs): - """Copy all data for that item from the model to the cluster store.""" - clusters = self.to_generate(mode) - self._pr.value_max = len(clusters) - for cluster in clusters: - self.store(cluster, **kwargs) - self._pr.value += 1 - self._pr.set_complete() - - def load(self, cluster, name): - """Load data for one cluster.""" - raise NotImplementedError() - - def load_multi(self, clusters, name): - """Load data for several clusters.""" - raise NotImplementedError() - - def load_spikes(self, spikes, name): - """Load data from an array of spikes.""" - raise NotImplementedError() - - def on_merge(self, up): - """Called when a new merge occurs. - - May be overriden if there's an efficient way to update the data - after a merge. - - """ - self.on_assign(up) - - def on_assign(self, up): - """Called when a new split occurs. - - May be overriden. - - """ - for cluster in up.added: - self.store(cluster) - - def on_cluster(self, up=None): - """Called when the clusters change. - - Old data is kept on disk and in memory, which is useful for - undo and redo. The `cluster_store.clean()` method can be called to - delete the old files. - - Nothing happens during undo and redo (the data is already there). - - """ - # No need to change anything in the store if this is an undo or - # a redo. - if up is None or up.history is not None: - return - if up.description == 'merge': - self.on_merge(up) - elif up.description == 'assign': - self.on_assign(up) - - -class FixedSizeItem(StoreItem): - """Store data which size doesn't depend on the cluster size.""" - def load_multi(self, clusters, name): - """Load data for several clusters.""" - if not len(clusters): - return self.empty_values(name) - return np.array([self.load(cluster, name) - for cluster in clusters]) - - -class VariableSizeItem(StoreItem): - """Store data which size does depend on the cluster size.""" - def load_multi(self, clusters, name, spikes=None): - """Load data for several clusters. - - A subset of spikes can also be specified. - - """ - if not len(clusters) or (spikes is not None and not len(spikes)): - return self.empty_values(name) - arrays = {cluster: self.load(cluster, name) - for cluster in clusters} - spc = _subset_spc(self._spikes_per_cluster, clusters) - _assert_per_cluster_data_compatible(spc, arrays) - pcd = PerClusterData(spc=spc, - arrays=arrays, - ) - if spikes is not None: - pcd = pcd.subset(spike_ids=spikes) - assert pcd.array.shape[0] == len(spikes) - return pcd.array - - -#------------------------------------------------------------------------------ -# Cluster store -#------------------------------------------------------------------------------ - -class ClusterStore(object): - """Hold per-cluster information on disk and in memory. - - Note - ---- - - Currently, this is used to accelerate access to per-cluster data - and statistics. All data is dynamically updated when clustering - changes occur. - - """ - def __init__(self, - model=None, - spikes_per_cluster=None, - path=None, - ): - self._model = model - self._spikes_per_cluster = spikes_per_cluster - self._memory = MemoryStore() - self._disk = DiskStore(path) if path is not None else None - self._items = OrderedDict() - self._item_per_field = {} - - # Core methods - #-------------------------------------------------------------------------- - - @property - def model(self): - """Model.""" - return self._model - - @property - def memory_store(self): - """Hold some cluster statistics.""" - return self._memory - - @property - def disk_store(self): - """Manage the cache of per-cluster voluminous data.""" - return self._disk - - @property - def spikes_per_cluster(self): - """Dictionary `{cluster_id: spike_ids}`.""" - return self._spikes_per_cluster - - def update_spikes_per_cluster(self, spikes_per_cluster): - self._spikes_per_cluster = spikes_per_cluster - for item in self._items.values(): - try: - item.spikes_per_cluster = spikes_per_cluster - except AttributeError: - debug("Skipping set spikes_per_cluster on " - "store item {}.".format(item.name)) - - @property - def cluster_ids(self): - """All cluster ids appearing in the `spikes_per_cluster` dictionary.""" - return sorted(self._spikes_per_cluster) - - # Store items - #-------------------------------------------------------------------------- - - @property - def items(self): - """Dictionary of registered store items.""" - return self._items - - def register_field(self, name, item_name=None): - """Register a new piece of data to store on memory or on disk. - - Parameters - ---------- - - name : str - The name of the field. - item_name : str - The name of the item. - - - """ - self._item_per_field[name] = self._items[item_name] - if self._disk: - self._disk.register_file_extensions(name) - - # Create the load function. - def _make_func(name): - def load(*args, **kwargs): - return self.load(name, *args, **kwargs) - return load - - load = _make_func(name) - - # We create the `self.()` method for loading. - # We need to ensure that the method name isn't already attributed. - assert not hasattr(self, name) - setattr(self, name, load) - - def register_item(self, item_cls, **kwargs): - """Register a `StoreItem` class in the store. - - A `StoreItem` class is responsible for storing some data to disk - and memory. It must register one or several pieces of data. - - """ - # Instantiate the item. - item = item_cls(cluster_store=self, **kwargs) - assert item.fields is not None - - # Register the StoreItem instance. - self._items[item.name] = item - - # Register all fields declared by the store item. - for field in item.fields: - self.register_field(field, item_name=item.name) - - return item - - # Files - #-------------------------------------------------------------------------- - - @property - def path(self): - """Path to the disk store cache.""" - return self.disk_store.path - - @property - def old_clusters(self): - """Clusters in the disk store that are no longer in the clustering.""" - return sorted(set(self.disk_store.cluster_ids) - - set(self.cluster_ids)) - - @property - def files(self): - """List of files present in the disk store.""" - return self.disk_store.files - - # Status - #-------------------------------------------------------------------------- - - @property - def total_size(self): - """Total size of the disk store.""" - return _directory_size(self.path) - - def is_consistent(self): - """Return whether the cluster store is probably consistent. - - Return true if all cluster stores files exist and have the expected - file size. - - """ - valid = set(self.cluster_ids) - # All store items should be consistent on all valid clusters. - consistent = all(all(item.is_consistent(clu, - self.spikes_per_cluster.get(clu, [])) - for clu in valid) - for item in self._items.values()) - return consistent - - @property - def status(self): - """Return the current status of the cluster store.""" - in_store = set(self.disk_store.cluster_ids) - valid = set(self.cluster_ids) - invalid = in_store - valid - - n_store = len(in_store) - n_old = len(invalid) - size = self.total_size / (1024. ** 2) - consistent = str(self.is_consistent()).rjust(5) - - status = '' - header = "Cluster store status ({0})".format(self.path) - status += header + '\n' - status += '-' * len(header) + '\n' - status += "Number of clusters in the store {0: 4d}\n".format(n_store) - status += "Number of old clusters {0: 4d}\n".format(n_old) - status += "Total size (MB) {0: 7.0f}\n".format(size) - status += "Consistent {0}\n".format(consistent) - return status - - def display_status(self): - """Display the current status of the cluster store.""" - print(self.status) - - # Store management - #-------------------------------------------------------------------------- - - def clear(self): - """Erase all files in the store.""" - self.memory_store.clear() - self.disk_store.clear() - info("Cluster store cleared.") - - def clean(self): - """Erase all old files in the store.""" - to_delete = self.old_clusters - self.memory_store.erase(to_delete) - self.disk_store.erase(to_delete) - n = len(to_delete) - info("{0} clusters deleted from the cluster store.".format(n)) - - def generate(self, mode=None): - """Generate the cluster store. - - Parameters - ---------- - - mode : str (default is None) - How the cluster store should be generated. Options are: - - * None or `default`: only regenerate the missing or inconsistent - clusters - * `force`: fully regenerate the cluster - * `read-only`: just load the existing files, do not write anything - - """ - assert isinstance(self._spikes_per_cluster, dict) - if hasattr(self._model, 'name'): - name = self._model.name - else: - name = 'the current model' - debug("Initializing the cluster store for {0:s}.".format(name)) - for item in self._items.values(): - item.store_all(mode) - - # Load - #-------------------------------------------------------------------------- - - def load(self, name, clusters=None, spikes=None): - """Load some data for a number of clusters and spikes.""" - item = self._item_per_field[name] - # Clusters requested. - if clusters is not None: - if _is_integer(clusters): - # Single cluster case. - return item.load(clusters, name) - clusters = np.unique(clusters) - if spikes is None: - return item.load_multi(clusters, name) - else: - return item.load_multi(clusters, name, spikes=spikes) - # Spikes requested. - elif spikes is not None: - assert clusters is None - if _is_array_like(spikes): - spikes = np.unique(spikes) - out = item.load_spikes(spikes, name) - assert isinstance(out, np.ndarray) - if _is_array_like(spikes): - assert out.shape[0] == len(spikes) - return out diff --git a/phy/io/tests/test_base.py b/phy/io/tests/test_base.py deleted file mode 100644 index c400c0433..000000000 --- a/phy/io/tests/test_base.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of the BaseModel class.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import raises - -from ..base import BaseModel, ClusterMetadata - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_base_cluster_metadata(): - meta = ClusterMetadata() - - @meta.default - def group(cluster): - return 3 - - @meta.default - def color(cluster): - return 0 - - assert meta.group(0) is not None - assert meta.group(2) == 3 - assert meta.group(10) == 3 - - meta.set_color(10, 5) - assert meta.color(10) == 5 - - # Alternative __setitem__ syntax. - meta.set_color([10, 11], 5) - assert meta.color(10) == 5 - assert meta.color(11) == 5 - - meta.set_color([10, 11], 6) - assert meta.color(10) == 6 - assert meta.color(11) == 6 - assert meta.color([10, 11]) == [6, 6] - - meta.set_color(10, 20) - assert meta.color(10) == 20 - - -def test_base(): - model = BaseModel() - - assert model.channel_group is None - - model.channel_group = 1 - assert model.channel_group == 1 - - assert model.channel_groups == [] - assert model.clusterings == [] - - model.clustering = 'original' - assert model.clustering == 'original' - - with raises(NotImplementedError): - model.metadata - with raises(NotImplementedError): - model.traces - with raises(NotImplementedError): - model.spike_samples - with raises(NotImplementedError): - model.spike_clusters - with raises(NotImplementedError): - model.cluster_metadata - with raises(NotImplementedError): - model.features - with raises(NotImplementedError): - model.masks - with raises(NotImplementedError): - model.waveforms - with raises(NotImplementedError): - model.probe - with raises(NotImplementedError): - model.save() - - model.close() diff --git a/phy/io/tests/test_sparse.py b/phy/io/tests/test_sparse.py deleted file mode 100644 index a92502692..000000000 --- a/phy/io/tests/test_sparse.py +++ /dev/null @@ -1,122 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of sparse matrix structures.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ..sparse import csr_matrix, SparseCSR, load_h5, save_h5 -from ..h5 import open_h5 - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def _dense_matrix_example(): - """Sparse matrix example: - - * 1 2 * * - 3 * * 4 * - * * * * * - * * 5 * * - - Return a dense array. - - """ - arr = np.zeros((4, 5)) - arr[0, 1] = 1 - arr[0, 2] = 2 - arr[1, 0] = 3 - arr[1, 3] = 4 - arr[3, 2] = 5 - return arr - - -def _sparse_matrix_example(): - """Return a sparse representation of the sparse matrix example.""" - shape = (4, 5) - data = np.array([1, 2, 3, 4, 5]) - channels = np.array([1, 2, 0, 3, 2]) - spikes_ptr = np.array([0, 2, 4, 4, 5]) - return shape, data, channels, spikes_ptr - - -def test_sparse_csr_check(): - """Test the checks performed when creating a sparse matrix.""" - dense = _dense_matrix_example() - shape, data, channels, spikes_ptr = _sparse_matrix_example() - - # Dense to sparse conversion not implemented yet. - with raises(NotImplementedError): - csr_matrix(dense) - - # Need the three sparse components and the shape. - with raises(ValueError): - csr_matrix(data=data, channels=channels) - with raises(ValueError): - csr_matrix(data=data, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(channels=channels, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(data=data, channels=channels, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, - data=data[:-1], channels=channels, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, channels=[[0]]) - with raises(ValueError): - csr_matrix(shape=shape, spikes_ptr=[0]) - with raises(ValueError): - csr_matrix(shape=(4, 5, 6), data=data, channels=np.zeros((2, 2)), - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=np.zeros((2, 2)), - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=channels, - spikes_ptr=np.zeros((2, 2))) - with raises(ValueError): - csr_matrix(shape=shape, data=np.zeros((100)), channels=channels, - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=np.zeros(100), - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=channels, - spikes_ptr=np.zeros(100)) - - # This one should pass. - sparse = csr_matrix(shape=shape, - data=data, channels=channels, spikes_ptr=spikes_ptr) - assert isinstance(sparse, SparseCSR) - ae(sparse.shape, shape) - ae(sparse._data, data) - ae(sparse._channels, channels) - ae(sparse._spikes_ptr, spikes_ptr) - - -def test_sparse_hdf5(tempdir): - """Test the checks performed when creating a sparse matrix.""" - shape, data, channels, spikes_ptr = _sparse_matrix_example() - sparse = csr_matrix(shape=shape, data=data, channels=channels, - spikes_ptr=spikes_ptr) - dense = _dense_matrix_example() - - with open_h5(op.join(tempdir, 'test.h5'), 'w') as f: - path_sparse = '/my_sparse_array' - save_h5(f, path_sparse, sparse) - sparse_bis = load_h5(f, path_sparse) - assert sparse == sparse_bis - - path_dense = '/my_dense_array' - save_h5(f, path_dense, dense) - dense_bis = load_h5(f, path_dense) - ae(dense, dense_bis) diff --git a/phy/io/tests/test_store.py b/phy/io/tests/test_store.py deleted file mode 100644 index 158d7b495..000000000 --- a/phy/io/tests/test_store.py +++ /dev/null @@ -1,423 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test cluster store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac - -from ...utils._types import Bunch -from ...utils.array import _spikes_per_cluster -from ..store import (MemoryStore, - BaseStore, - DiskStore, - ClusterStore, - VariableSizeItem, - FixedSizeItem, - ) - - -#------------------------------------------------------------------------------ -# Test data stores -#------------------------------------------------------------------------------ - -def test_memory_store(): - ms = MemoryStore() - assert ms.load(2) == {} - - assert ms.load(3).get('key', None) is None - assert ms.load(3) == {} - assert ms.load(3, ['key']) == {'key': None} - assert ms.load(3) == {} - assert ms.cluster_ids == [] - - ms.store(3, key='a') - assert ms.load(3) == {'key': 'a'} - assert ms.load(3, ['key']) == {'key': 'a'} - assert ms.load(3, 'key') == 'a' - assert ms.cluster_ids == [3] - - ms.store(3, key_bis='b') - assert ms.load(3) == {'key': 'a', 'key_bis': 'b'} - assert ms.load(3, ['key']) == {'key': 'a'} - assert ms.load(3, ['key_bis']) == {'key_bis': 'b'} - assert ms.load(3, ['key', 'key_bis']) == {'key': 'a', 'key_bis': 'b'} - assert ms.load(3, 'key_bis') == 'b' - assert ms.cluster_ids == [3] - - ms.erase([2, 3]) - assert ms.load(3) == {} - assert ms.load(3, ['key']) == {'key': None} - assert ms.cluster_ids == [] - - -def test_base_store(tempdir): - store = BaseStore(tempdir) - - store._save('11.npy', np.arange(11)) - store._save('12.npy', np.arange(12)) - store._save('2.npy', np.arange(2)) - store._save('3.npy', None) - store._save('1.npy', [np.arange(3), np.arange(3, 8)]) - - ae(store._open('01.npy'), None) - ae(store._open('11.npy'), np.arange(11)) - ae(store._open('12.npy'), np.arange(12)) - ae(store._open('2.npy'), np.arange(2)) - ae(store._open('3.npy'), None) - ae(store._open('1.npy')[0], np.arange(3)) - ae(store._open('1.npy')[1], np.arange(3, 8)) - - store._delete_file('2.npy') - store._delete_file('12.npy') - ae(store._open('2.npy'), None) - - -def test_disk_store(tempdir): - - dtype = np.float32 - sha = (2, 4) - shb = (3, 5) - a = np.random.rand(*sha).astype(dtype) - b = np.random.rand(*shb).astype(dtype) - - def _assert_equal(d_0, d_1): - """Test the equality of two dictionaries containing NumPy arrays.""" - assert sorted(d_0.keys()) == sorted(d_1.keys()) - for key in d_0.keys(): - ac(d_0[key], d_1[key]) - - ds = DiskStore(tempdir) - - ds.register_file_extensions(['key', 'key_bis']) - assert ds.cluster_ids == [] - - ds.store(3, key=a) - _assert_equal(ds.load(3, - ['key'], - dtype=dtype, - shape=sha, - ), - {'key': a}) - loaded = ds.load(3, 'key', dtype=dtype, shape=sha) - ac(loaded, a) - - # Loading a non-existing key returns None. - assert ds.load(3, 'key_bis') is None - assert ds.cluster_ids == [3] - - ds.store(3, key_bis=b) - _assert_equal(ds.load(3, ['key'], dtype=dtype, shape=sha), {'key': a}) - _assert_equal(ds.load(3, ['key_bis'], - dtype=dtype, - shape=shb, - ), - {'key_bis': b}) - _assert_equal(ds.load(3, - ['key', 'key_bis'], - dtype=dtype, - ), - {'key': a.ravel(), 'key_bis': b.ravel()}) - ac(ds.load(3, 'key_bis', dtype=dtype, shape=shb), b) - assert ds.cluster_ids == [3] - - ds.erase([2, 3]) - assert ds.load(3, ['key']) == {'key': None} - assert ds.cluster_ids == [] - - # Test load/save file. - ds.save_file('test', {'a': a}) - ds = DiskStore(tempdir) - data = ds.load_file('test') - ae(data['a'], a) - assert ds.load_file('test2') is None - - -def test_cluster_store_1(tempdir): - - # We define some data and a model. - n_spikes = 100 - n_clusters = 10 - - spike_ids = np.arange(n_spikes) - spike_clusters = np.random.randint(size=n_spikes, - low=0, high=n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - - model = {'spike_clusters': spike_clusters} - - # We initialize the ClusterStore. - cs = ClusterStore(model=model, - path=tempdir, - spikes_per_cluster=spikes_per_cluster, - ) - - # We create a n_spikes item to be stored in memory, - # and we define how to generate it for a given cluster. - class MyItem(FixedSizeItem): - name = 'my item' - fields = ['n_spikes'] - - def store(self, cluster): - spikes = self.spikes_per_cluster[cluster] - self.memory_store.store(cluster, n_spikes=len(spikes)) - - def load(self, cluster, name): - return self.memory_store.load(cluster, name) - - def on_cluster(self, up): - if up.description == 'merge': - n = sum(len(up.old_spikes_per_cluster[cl]) - for cl in up.deleted) - self.memory_store.store(up.added[0], n_spikes=n) - else: - super(MyItem, self).on_cluster(up) - - item = cs.register_item(MyItem) - item.progress_reporter.set_progress_message("Progress {progress}.\n") - item.progress_reporter.set_complete_message("Finished.\n") - - # Now we generate the store. - cs.generate() - - # We check that the n_spikes field has successfully been created. - for cluster in sorted(spikes_per_cluster): - assert cs.n_spikes(cluster) == len(spikes_per_cluster[cluster]) - - # Merge. - spc = spikes_per_cluster.copy() - spikes = np.sort(np.concatenate([spc[0], spc[1]])) - spc[20] = spikes - del spc[0] - del spc[1] - up = Bunch(description='merge', - added=[20], - deleted=[0, 1], - spike_ids=spikes, - new_spikes_per_cluster=spc, - old_spikes_per_cluster=spikes_per_cluster,) - - cs.items['my item'].on_cluster(up) - - # Check the list of clusters in the store. - ae(cs.memory_store.cluster_ids, list(range(0, n_clusters)) + [20]) - ae(cs.disk_store.cluster_ids, []) - assert cs.n_spikes(20) == len(spikes) - - # Recreate the cluster store. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - cs.register_item(MyItem) - cs.generate() - ae(cs.memory_store.cluster_ids, list(range(n_clusters))) - ae(cs.disk_store.cluster_ids, []) - - -def test_cluster_store_multi(): - """This tests the cluster store when a store item has several fields.""" - - cs = ClusterStore(spikes_per_cluster={0: [0, 2], 1: [1, 3, 4]}) - - class MyItem(FixedSizeItem): - name = 'my item' - fields = ['d', 'm'] - - def store(self, cluster): - spikes = self.spikes_per_cluster[cluster] - self.memory_store.store(cluster, d=len(spikes), m=len(spikes) ** 2) - - def load(self, cluster, name): - return self.memory_store.load(cluster, name) - - cs.register_item(MyItem) - - cs.generate() - - assert cs.memory_store.load(0, ['d', 'm']) == {'d': 2, 'm': 4} - assert cs.d(0) == 2 - assert cs.m(0) == 4 - - assert cs.memory_store.load(1, ['d', 'm']) == {'d': 3, 'm': 9} - assert cs.d(1) == 3 - assert cs.m(1) == 9 - - -def test_cluster_store_load(tempdir): - - # We define some data and a model. - n_spikes = 100 - n_clusters = 10 - - spike_ids = np.arange(n_spikes) - spike_clusters = np.random.randint(size=n_spikes, - low=0, high=n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - model = {'spike_clusters': spike_clusters} - - # We initialize the ClusterStore. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - - # We create a n_spikes item to be stored in memory, - # and we define how to generate it for a given cluster. - class MyItem(VariableSizeItem): - name = 'my item' - fields = ['spikes_square'] - - def store(self, cluster): - spikes = spikes_per_cluster[cluster] - data = (spikes ** 2).astype(np.int32) - self.disk_store.store(cluster, spikes_square=data) - - def load(self, cluster, name): - return self.disk_store.load(cluster, name, np.int32) - - def load_spikes(self, spikes, name): - return (spikes ** 2).astype(np.int32) - - cs.register_item(MyItem) - cs.generate() - - # All spikes in cluster 1. - cluster = 1 - spikes = spikes_per_cluster[cluster] - ae(cs.load('spikes_square', clusters=[cluster]), spikes ** 2) - - # Some spikes in several clusters. - clusters = [2, 3, 5] - spikes = np.concatenate([spikes_per_cluster[cl][::3] - for cl in clusters]) - ae(cs.load('spikes_square', spikes=spikes), np.unique(spikes) ** 2) - - # Empty selection. - assert len(cs.load('spikes_square', clusters=[])) == 0 - assert len(cs.load('spikes_square', spikes=[])) == 0 - - -def test_cluster_store_management(tempdir): - - # We define some data and a model. - n_spikes = 100 - n_clusters = 10 - - spike_ids = np.arange(n_spikes) - spike_clusters = np.random.randint(size=n_spikes, - low=0, high=n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - - model = Bunch({'spike_clusters': spike_clusters, - 'cluster_ids': np.arange(n_clusters), - }) - - # We initialize the ClusterStore. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - - # We create a n_spikes item to be stored in memory, - # and we define how to generate it for a given cluster. - class MyItem(VariableSizeItem): - name = 'my item' - fields = ['spikes_square'] - - def store(self, cluster): - spikes = self.spikes_per_cluster[cluster] - if not self.is_consistent(cluster, spikes): - data = (spikes ** 2).astype(np.int32) - self.disk_store.store(cluster, spikes_square=data) - - def is_consistent(self, cluster, spikes): - data = self.disk_store.load(cluster, - 'spikes_square', - dtype=np.int32, - ) - if data is None: - return False - if len(data) != len(spikes): - return False - expected = (spikes ** 2).astype(np.int32) - return np.all(data == expected) - - cs.register_item(MyItem) - cs.update_spikes_per_cluster(spikes_per_cluster) - - def _check_to_generate(cs, clusters): - item = cs.items['my item'] - ae(item.to_generate(), clusters) - ae(item.to_generate(None), clusters) - ae(item.to_generate('default'), clusters) - ae(item.to_generate('force'), np.arange(n_clusters)) - ae(item.to_generate('read-only'), []) - - # Check the list of clusters to generate. - _check_to_generate(cs, np.arange(n_clusters)) - - # Generate the store. - cs.generate() - - # Check the status. - assert 'True' in cs.status - - # We re-initialize the ClusterStore. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - cs.register_item(MyItem) - cs.update_spikes_per_cluster(spikes_per_cluster) - - # Check the list of clusters to generate. - _check_to_generate(cs, []) - cs.display_status() - - # We erase a file. - path = op.join(cs.path, '1.spikes_square') - os.remove(path) - - # Check the list of clusters to generate. - _check_to_generate(cs, [1]) - assert '9' in cs.status - assert 'False' in cs.status - - cs.generate() - - # Check the status. - assert 'True' in cs.status - - # Now, we make new assignements. - spike_clusters = np.random.randint(size=n_spikes, - low=n_clusters, high=n_clusters + 5) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - cs.update_spikes_per_cluster(spikes_per_cluster) - - # All files are now old and should be removed by clean(). - assert not cs.is_consistent() - item = cs.items['my item'] - ae(item.to_generate(), np.arange(n_clusters, n_clusters + 5)) - - ae(cs.cluster_ids, np.arange(n_clusters, n_clusters + 5)) - ae(cs.old_clusters, np.arange(n_clusters)) - cs.clean() - - ae(cs.cluster_ids, np.arange(n_clusters, n_clusters + 5)) - ae(cs.old_clusters, []) - ae(item.to_generate(), np.arange(n_clusters, n_clusters + 5)) - assert not cs.is_consistent() - cs.generate() - - assert cs.is_consistent() - ae(cs.cluster_ids, np.arange(n_clusters, n_clusters + 5)) - ae(cs.old_clusters, []) - ae(item.to_generate(), []) diff --git a/phy/scripts/__init__.py b/phy/scripts/__init__.py deleted file mode 100644 index 1e350cdd4..000000000 --- a/phy/scripts/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Main CLI tool.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from .phy_script import main diff --git a/phy/scripts/phy_script.py b/phy/scripts/phy_script.py deleted file mode 100644 index 2cc6d7c0c..000000000 --- a/phy/scripts/phy_script.py +++ /dev/null @@ -1,479 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""phy main CLI tool. - -Usage: - - phy --help - -""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import sys -import os.path as op -import argparse -from textwrap import dedent - -import numpy as np -from six import exec_, string_types - - -#------------------------------------------------------------------------------ -# Parser utilities -#------------------------------------------------------------------------------ - -class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, - argparse.RawDescriptionHelpFormatter): - pass - - -class Parser(argparse.ArgumentParser): - def error(self, message): - sys.stderr.write(message + '\n\n') - self.print_help() - sys.exit(2) - - -_examples = dedent(""" - -examples: - phy -v display the version of phy - phy download hybrid_120sec.dat -o data/ - download a sample raw data file in `data/` - phy describe my_file.kwik - display information about a Kwik dataset - phy spikesort my_params.prm - run the whole suite (spike detection and clustering) - phy detect my_params.prm - run spike detection on a parameters file - phy cluster-auto my_file.kwik - run klustakwik on a dataset (after spike detection) - phy cluster-manual my_file.kwik - run the manual clustering GUI - -""") - - -#------------------------------------------------------------------------------ -# Parser creator -#------------------------------------------------------------------------------ - -class ParserCreator(object): - def __init__(self): - self.create_main() - self.create_download() - self.create_traces() - self.create_describe() - self.create_spikesort() - self.create_detect() - self.create_auto() - self.create_manual() - self.create_notebook() - - @property - def parser(self): - return self._parser - - def _add_sub_parser(self, name, desc): - p = self._subparsers.add_parser(name, help=desc, description=desc) - self._add_options(p) - return p - - def _add_options(self, parser): - parser.add_argument('--debug', '-d', - action='store_true', - help='activate debug logging mode') - - parser.add_argument('--hide-traceback', - action='store_true', - help='hide the traceback for cleaner error ' - 'messages') - - parser.add_argument('--profiler', '-p', - action='store_true', - help='activate the profiler') - - parser.add_argument('--line-profiler', '-lp', - dest='line_profiler', - action='store_true', - help='activate the line-profiler -- you ' - 'need to decorate the functions ' - 'to profile with `@profile` ' - 'in the code') - - parser.add_argument('--ipython', '-i', action='store_true', - help='launch the script in an interactive ' - 'IPython console') - - parser.add_argument('--pdb', action='store_true', - help='activate the Python debugger') - - def create_main(self): - import phy - - desc = sys.modules['phy'].__doc__ - self._parser = Parser(description=desc, - epilog=_examples, - formatter_class=CustomFormatter, - ) - self._parser.add_argument('--version', '-v', - action='version', - version=phy.__version_git__, - help='print the version of phy') - self._add_options(self._parser) - self._subparsers = self._parser.add_subparsers(dest='command', - title='subcommand', - ) - - def create_download(self): - desc = 'download a sample dataset' - p = self._add_sub_parser('download', desc) - p.add_argument('file', help='dataset filename') - p.add_argument('--output-dir', '-o', help='output directory') - p.add_argument('--base', - default='cortexlab', - choices=('cortexlab', 'github'), - help='data repository name: `cortexlab` or `github`', - ) - p.set_defaults(func=download) - - def create_describe(self): - desc = 'describe a `.kwik` file' - p = self._add_sub_parser('describe', desc) - p.add_argument('file', help='path to a `.kwik` file') - p.add_argument('--clustering', default='main', - help='name of the clustering to use') - p.set_defaults(func=describe) - - def create_traces(self): - desc = 'show the traces of a raw data file' - p = self._add_sub_parser('traces', desc) - p.add_argument('file', help='path to a `.kwd` or `.dat` file') - p.add_argument('--interval', - help='detection interval in seconds (e.g. `0,10`)') - p.add_argument('--n-channels', '-n', - help='number of channels in the recording ' - '(only required when using a flat binary file)') - p.add_argument('--dtype', - help='NumPy data type ' - '(only required when using a flat binary file)', - default='int16', - ) - p.add_argument('--sample-rate', '-s', - help='sample rate in Hz ' - '(only required when using a flat binary file)') - p.set_defaults(func=traces) - - def create_spikesort(self): - desc = 'launch the whole spike sorting pipeline on a `.prm` file' - p = self._add_sub_parser('spikesort', desc) - p.add_argument('file', help='path to a `.prm` file') - p.add_argument('--kwik-path', help='filename of the `.kwik` file ' - 'to create (by default, `"experiment_name".kwik`)') - p.add_argument('--overwrite', action='store_true', default=False, - help='overwrite the `.kwik` file ') - p.add_argument('--interval', - help='detection interval in seconds (e.g. `0,10`)') - p.set_defaults(func=spikesort) - - def create_detect(self): - desc = 'launch the spike detection algorithm on a `.prm` file' - p = self._add_sub_parser('detect', desc) - p.add_argument('file', help='path to a `.prm` file') - p.add_argument('--kwik-path', help='filename of the `.kwik` file ' - 'to create (by default, `"experiment_name".kwik`)') - p.add_argument('--overwrite', action='store_true', default=False, - help='overwrite the `.kwik` file ') - p.add_argument('--interval', - help='detection interval in seconds (e.g. `0,10`)') - p.set_defaults(func=detect) - - def create_auto(self): - desc = 'launch the automatic clustering algorithm on a `.kwik` file' - p = self._add_sub_parser('cluster-auto', desc) - p.add_argument('file', help='path to a `.kwik` file') - p.add_argument('--clustering', default='main', - help='name of the clustering to use') - p.add_argument('--channel-group', default=None, - help='channel group to cluster') - p.set_defaults(func=cluster_auto) - - def create_manual(self): - desc = 'launch the manual clustering GUI on a `.kwik` file' - p = self._add_sub_parser('cluster-manual', desc) - p.add_argument('file', help='path to a `.kwik` file') - p.add_argument('--clustering', default='main', - help='name of the clustering to use') - p.add_argument('--channel-group', default=None, - help='channel group to manually cluster') - p.add_argument('--cluster-ids', '-c', - help='list of clusters to select initially') - p.add_argument('--no-store', action='store_true', default=False, - help='do not create the store (faster loading time, ' - 'slower GUI)') - p.set_defaults(func=cluster_manual) - - def create_notebook(self): - # TODO - pass - - def parse(self, args): - try: - return self._parser.parse_args(args) - except SystemExit as e: - if e.code != 0: - raise e - - -#------------------------------------------------------------------------------ -# Subcommand functions -#------------------------------------------------------------------------------ - -def _get_kwik_path(args): - kwik_path = args.file - - if not op.exists(kwik_path): - raise IOError("The file `{}` doesn't exist.".format(kwik_path)) - - return kwik_path - - -def _create_session(args, **kwargs): - from phy.session import Session - kwik_path = _get_kwik_path(args) - session = Session(kwik_path, **kwargs) - return session - - -def describe(args): - from phy.io.kwik import KwikModel - path = _get_kwik_path(args) - model = KwikModel(path, clustering=args.clustering) - return 'model.describe()', dict(model=model) - - -def download(args): - from phy import download_sample_data - download_sample_data(args.file, - output_dir=args.output_dir, - base=args.base, - ) - - -def traces(args): - from vispy.app import run - from phy.plot.traces import TraceView - from phy.io.h5 import open_h5 - from phy.io.traces import read_kwd, read_dat - - path = args.file - if path.endswith('.kwd'): - f = open_h5(args.file) - traces = read_kwd(f) - elif path.endswith(('.dat', '.bin')): - if not args.n_channels: - raise ValueError("Please specify `--n-channels`.") - if not args.dtype: - raise ValueError("Please specify `--dtype`.") - if not args.sample_rate: - raise ValueError("Please specify `--sample-rate`.") - n_channels = int(args.n_channels) - dtype = np.dtype(args.dtype) - traces = read_dat(path, dtype=dtype, n_channels=n_channels) - - start, end = map(int, args.interval.split(',')) - sample_rate = float(args.sample_rate) - start = int(sample_rate * start) - end = int(sample_rate * end) - - c = TraceView(keys='interactive') - c.visual.traces = .01 * traces[start:end, ...] - c.show() - run() - - return None, None - - -def detect(args): - from phy.io import create_kwik - - assert args.file.endswith('.prm') - kwik_path = args.kwik_path - kwik_path = create_kwik(args.file, - overwrite=args.overwrite, - kwik_path=kwik_path) - - interval = args.interval - if interval is not None: - interval = list(map(float, interval.split(','))) - - # Create the session with the newly-created .kwik file. - args.file = kwik_path - session = _create_session(args, use_store=False) - return ('session.detect(interval=interval)', - dict(session=session, interval=interval)) - - -def cluster_auto(args): - from phy.utils._misc import _read_python - from phy.session import Session - - assert args.file.endswith('.prm') - - channel_group = (int(args.channel_group) - if args.channel_group is not None else None) - - params = _read_python(args.file) - kwik_path = params['experiment_name'] + '.kwik' - session = Session(kwik_path) - - ns = dict(session=session, - clustering=args.clustering, - channel_group=channel_group, - ) - cmd = ('session.cluster(' - 'clustering=clustering, ' - 'channel_group=channel_group)') - return (cmd, ns) - - -def spikesort(args): - from phy.io import create_kwik - - assert args.file.endswith('.prm') - kwik_path = args.kwik_path - kwik_path = create_kwik(args.file, - overwrite=args.overwrite, - kwik_path=kwik_path, - ) - # Create the session with the newly-created .kwik file. - args.file = kwik_path - session = _create_session(args, use_store=False) - - interval = args.interval - if interval is not None: - interval = list(map(float, interval.split(','))) - - ns = dict(session=session, - interval=interval, - n_s_clusters=100, # TODO: better handling of KK parameters - ) - cmd = ('session.detect(interval=interval); session.cluster();') - return (cmd, ns) - - -def cluster_manual(args): - channel_group = (int(args.channel_group) - if args.channel_group is not None else None) - session = _create_session(args, - clustering=args.clustering, - channel_group=channel_group, - use_store=not(args.no_store), - ) - cluster_ids = (list(map(int, args.cluster_ids.split(','))) - if args.cluster_ids else None) - - session.model.describe() - - from phy.gui import start_qt_app - start_qt_app() - - gui = session.show_gui(cluster_ids=cluster_ids, show=False) - print("\nPress `ctrl+h` to see the list of keyboard shortcuts.\n") - return 'gui.show()', dict(session=session, gui=gui, requires_qt=True) - - -#------------------------------------------------------------------------------ -# Main functions -#------------------------------------------------------------------------------ - -def main(args=None): - p = ParserCreator() - if args is None: - args = sys.argv[1:] - elif isinstance(args, string_types): - args = args.split(' ') - args = p.parse(args) - if args is None: - return - - if args.profiler or args.line_profiler: - from phy.utils.testing import _enable_profiler, _profile - prof = _enable_profiler(args.line_profiler) - else: - prof = None - - import phy - if args.debug: - phy.debug() - - # Hide the traceback. - if args.hide_traceback: - def exception_handler(exception_type, exception, traceback): - print("{}: {}".format(exception_type.__name__, exception)) - - sys.excepthook = exception_handler - - # Activate IPython debugger. - if args.pdb: - from IPython.core import ultratb - sys.excepthook = ultratb.FormattedTB(mode='Verbose', - color_scheme='Linux', - call_pdb=1, - ) - - func = getattr(args, 'func', None) - if func is None: - p.parser.print_help() - return - - out = func(args) - if not out: - return - cmd, ns = out - if not cmd: - return - requires_qt = ns.pop('requires_qt', False) - requires_vispy = ns.pop('requires_vispy', False) - - # Default variables in namespace. - ns.update(phy=phy, path=args.file) - if 'session' in ns: - ns['model'] = ns['session'].model - - # Interactive mode with IPython. - if args.ipython: - print("\nStarting IPython...") - from IPython import start_ipython - args_ipy = ["-i", "-c='{}'".format(cmd)] - if requires_qt or requires_vispy: - # Activate Qt event loop integration with Qt. - args_ipy += ["--gui=qt"] - start_ipython(args_ipy, user_ns=ns) - else: - if not prof: - exec_(cmd, {}, ns) - else: - _profile(prof, cmd, {}, ns) - - if requires_qt: - # Launch the Qt app. - from phy.gui import run_qt_app - run_qt_app() - elif requires_vispy: - # Launch the VisPy Qt app. - from vispy.app import use_app, run - use_app('pyqt4') - run() - - -#------------------------------------------------------------------------------ -# Entry point -#------------------------------------------------------------------------------ - -if __name__ == '__main__': - main() diff --git a/phy/scripts/tests/__init__.py b/phy/scripts/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/scripts/tests/test_phy_script.py b/phy/scripts/tests/test_phy_script.py deleted file mode 100644 index 763e6ec05..000000000 --- a/phy/scripts/tests/test_phy_script.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*-1 - -"""Tests of the script.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ..phy_script import ParserCreator, main - - -#------------------------------------------------------------------------------ -# Script tests -#------------------------------------------------------------------------------ - -def test_parse_version(): - p = ParserCreator() - p.parse(['--version']) - - -def test_parse_cluster_manual(): - p = ParserCreator() - args = p.parse(['cluster-manual', 'test', '-i', '--debug']) - assert args.command == 'cluster-manual' - assert args.ipython - assert args.debug - assert not args.profiler - assert not args.line_profiler - - -def test_parse_cluster_auto(): - p = ParserCreator() - args = p.parse(['cluster-auto', 'test', '-lp']) - assert args.command == 'cluster-auto' - assert not args.ipython - assert not args.debug - assert not args.profiler - assert args.line_profiler - - -def test_download(chdir_tempdir): - main('download hybrid_10sec.prm') diff --git a/phy/session/__init__.py b/phy/session/__init__.py deleted file mode 100644 index 741969821..000000000 --- a/phy/session/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Interactive session around a model.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from .session import Session diff --git a/phy/session/default_settings.py b/phy/session/default_settings.py deleted file mode 100644 index 7e6ddc6cb..000000000 --- a/phy/session/default_settings.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for the session.""" - - -# ----------------------------------------------------------------------------- -# Session settings -# ----------------------------------------------------------------------------- - -def on_open(session): - """You can update the session when a model is opened. - - For example, you can register custom statistics with - `session.register_statistic`. - - """ - pass - - -def on_gui_open(session, gui): - """You can customize a GUI when it is open.""" - pass - - -def on_view_open(gui, view): - """You can customize a view when it is open.""" - pass - - -# ----------------------------------------------------------------------------- -# Misc settings -# ----------------------------------------------------------------------------- - -# Logging level in the log file. -log_file_level = 'debug' diff --git a/phy/session/session.py b/phy/session/session.py deleted file mode 100644 index f3131e474..000000000 --- a/phy/session/session.py +++ /dev/null @@ -1,468 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""Session structure.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -import shutil - -import numpy as np - -from ..utils.array import _concatenate -from ..utils.logging import debug, info, FileLogger, unregister, register -from ..utils.settings import _ensure_dir_exists -from ..io.base import BaseSession -from ..io.kwik.model import KwikModel -from ..io.kwik.store_items import create_store -# HACK: avoid Qt import -try: - from ..cluster.manual.gui import ClusterManualGUI -except ImportError: - class ClusterManualGUI(object): - _vm_classes = {} -from ..cluster.algorithms.klustakwik import KlustaKwik -from ..detect.spikedetekt import SpikeDetekt - - -#------------------------------------------------------------------------------ -# Session class -#------------------------------------------------------------------------------ - -def _process_ups(ups): - """This function processes the UpdateInfo instances of the two - undo stacks (clustering and cluster metadata) and concatenates them - into a single UpdateInfo instance.""" - if len(ups) == 0: - return - elif len(ups) == 1: - return ups[0] - elif len(ups) == 2: - up = ups[0] - up.update(ups[1]) - return up - else: - raise NotImplementedError() - - -class Session(BaseSession): - """A manual clustering session. - - This is the main object used for manual clustering. It implements - all common actions: - - * Loading a dataset (`.kwik` file) - * Listing the clusters - * Changing the current channel group or current clustering - * Showing views (waveforms, features, correlograms, etc.) - * Clustering actions: merge, split, undo, redo - * Wizard: cluster quality, best clusters, most similar clusters - * Save back to .kwik - - """ - - _vm_classes = ClusterManualGUI._vm_classes - _gui_classes = {'cluster_manual': ClusterManualGUI} - - def __init__(self, - kwik_path=None, - clustering=None, - channel_group=None, - model=None, - use_store=True, - phy_user_dir=None, - waveform_filter=True, - ): - self._clustering = clustering - self._channel_group = channel_group - self._use_store = use_store - self._file_logger = None - self._waveform_filter = waveform_filter - if kwik_path: - kwik_path = op.realpath(kwik_path) - super(Session, self).__init__(model=model, - path=kwik_path, - phy_user_dir=phy_user_dir, - vm_classes=self._vm_classes, - gui_classes=self._gui_classes, - ) - - def _backup_kwik(self, kwik_path): - """Save a copy of the Kwik file before opening it.""" - if kwik_path is None: - return - backup_kwik_path = kwik_path + '.bak' - if not op.exists(backup_kwik_path): - info("Saving a backup of the Kwik file " - "in {0}.".format(backup_kwik_path)) - shutil.copyfile(kwik_path, backup_kwik_path) - - def _create_model(self, path): - model = KwikModel(path, - clustering=self._clustering, - channel_group=self._channel_group, - waveform_filter=self._waveform_filter, - ) - self._create_logger(path) - return model - - def _create_logger(self, path): - path = op.splitext(path)[0] + '.log' - level = self.settings['log_file_level'] - if not self._file_logger: - self._file_logger = FileLogger(filename=path, level=level) - register(self._file_logger) - - def _save_model(self): - """Save the spike clusters and cluster groups to the Kwik file.""" - groups = {cluster: self.model.cluster_metadata.group(cluster) - for cluster in self.cluster_ids} - self.model.save(self.model.spike_clusters, - groups, - clustering_metadata=self.model.clustering_metadata, - ) - info("Saved {0:s}.".format(self.model.kwik_path)) - - # File-related actions - # ------------------------------------------------------------------------- - - def open(self, kwik_path=None, model=None): - """Open a `.kwik` file.""" - self._backup_kwik(kwik_path) - return super(Session, self).open(model=model, path=kwik_path) - - @property - def kwik_path(self): - """Path to the `.kwik` file.""" - return self.model.path - - @property - def has_unsaved_changes(self): - """Whether there are unsaved changes in the model. - - If true, a prompt message for saving will be displayed when closing - the GUI. - - """ - # TODO - pass - - # Properties - # ------------------------------------------------------------------------- - - @property - def n_spikes(self): - """Number of spikes in the current channel group.""" - return self.model.n_spikes - - @property - def cluster_ids(self): - """Array of all cluster ids used in the current clustering.""" - return self.model.cluster_ids - - @property - def n_clusters(self): - """Number of clusters in the current clustering.""" - return self.model.n_clusters - - # Event callbacks - # ------------------------------------------------------------------------- - - def _create_cluster_store(self): - # Do not create the store if there is only one cluster. - if self.model.n_clusters <= 1 or not self._use_store: - # Just use a mock store. - self.store = create_store(self.model, - self.model.spikes_per_cluster, - ) - return - - # Kwik store in experiment_dir/name.phy/1/main/cluster_store. - store_path = op.join(self.settings.exp_settings_dir, - 'cluster_store', - str(self.model.channel_group), - self.model.clustering - ) - _ensure_dir_exists(store_path) - - # Instantiate the store. - spc = self.model.spikes_per_cluster - cs = self.settings['features_masks_chunk_size'] - wns = self.settings['waveforms_n_spikes_max'] - wes = self.settings['waveforms_excerpt_size'] - self.store = create_store(self.model, - path=store_path, - spikes_per_cluster=spc, - features_masks_chunk_size=cs, - waveforms_n_spikes_max=wns, - waveforms_excerpt_size=wes, - ) - - # Generate the cluster store if it doesn't exist or is invalid. - # If the cluster store already exists and is consistent - # with the data, it is not recreated. - self.store.generate() - - def change_channel_group(self, channel_group): - """Change the current channel group.""" - self._channel_group = channel_group - self.model.channel_group = channel_group - info("Switched to channel group {}.".format(channel_group)) - self.emit('open') - - def change_clustering(self, clustering): - """Change the current clustering.""" - self._clustering = clustering - self.model.clustering = clustering - info("Switched to `{}` clustering.".format(clustering)) - self.emit('open') - - def on_open(self): - self._create_cluster_store() - - def on_close(self): - if self._file_logger: - unregister(self._file_logger) - self._file_logger = None - - def register_statistic(self, func=None, shape=(-1,)): - """Decorator registering a custom cluster statistic. - - Parameters - ---------- - - func : function - A function that takes a cluster index as argument, and returns - some statistics (generally a NumPy array). - - Notes - ----- - - This function will be called on every cluster when a dataset is opened. - It is also automatically called on new clusters when clusters change. - You can access the data from the model and from the cluster store. - - """ - if func is not None: - return self.register_statistic()(func) - - def decorator(func): - - name = func.__name__ - - def _wrapper(cluster): - out = func(cluster) - self.store.memory_store.store(cluster, **{name: out}) - - # Add the statistics. - stats = self.store.items['statistics'] - stats.add(name, _wrapper, shape) - # Register it in the global cluster store. - self.store.register_field(name, 'statistics') - # Compute it on all existing clusters. - stats.store_all(name=name, mode='force') - info("Registered statistic `{}`.".format(name)) - - return decorator - - # Spike sorting - # ------------------------------------------------------------------------- - - def detect(self, traces=None, - interval=None, - algorithm='spikedetekt', - **kwargs): - """Detect spikes in traces. - - Parameters - ---------- - - traces : array - An `(n_samples, n_channels)` array. If unspecified, the Kwik - file's raw data is used. - interval : tuple (optional) - A tuple `(start, end)` (in seconds) where to detect spikes. - algorithm : str - The algorithm name. Only `spikedetekt` currently. - **kwargs : dictionary - Algorithm parameters. - - Returns - ------- - - result : dict - A `{channel_group: tuple}` mapping, where the tuple is: - - * `spike_times` : the spike times (in seconds). - * `masks`: the masks of the spikes `(n_spikes, n_channels)`. - - """ - assert algorithm == 'spikedetekt' - # Create `.phy/spikedetekt/` directory for temporary files. - sd_dir = op.join(self.settings.exp_settings_dir, 'spikedetekt') - _ensure_dir_exists(sd_dir) - # Default interval. - if interval is not None: - (start_sec, end_sec) = interval - sr = self.model.sample_rate - interval_samples = (int(start_sec * sr), - int(end_sec * sr)) - else: - interval_samples = None - # Find the raw traces. - traces = traces if traces is not None else self.model.traces - # Take the parameters in the Kwik file, coming from the PRM file. - params = self.model.metadata - params.update(kwargs) - # Probe parameters required by SpikeDetekt. - params['probe_channels'] = self.model.probe.channels_per_group - params['probe_adjacency_list'] = self.model.probe.adjacency - # Start the spike detection. - debug("Running SpikeDetekt with the following parameters: " - "{}.".format(params)) - sd = SpikeDetekt(tempdir=sd_dir, **params) - out = sd.run_serial(traces, interval_samples=interval_samples) - n_features = params['n_features_per_channel'] - - # Add the spikes in the `.kwik` and `.kwx` files. - for group in out.groups: - spike_samples = _concatenate(out.spike_samples[group]) - n_spikes = len(spike_samples) if spike_samples is not None else 0 - n_channels = sd._n_channels_per_group[group] - self.model.creator.add_spikes(group=group, - spike_samples=spike_samples, - spike_recordings=None, # TODO - masks=out.masks[group], - features=out.features[group], - n_channels=n_channels, - n_features=n_features, - ) - sc = np.zeros(n_spikes, dtype=np.int32) - self.model.creator.add_clustering(group=group, - name='main', - spike_clusters=sc) - self.emit('open') - - if out.groups: - self.change_channel_group(out.groups[0]) - - def cluster(self, - clustering=None, - algorithm='klustakwik', - spike_ids=None, - channel_group=None, - **kwargs): - """Run an automatic clustering algorithm on all or some of the spikes. - - Parameters - ---------- - - clustering : str - The name of the clustering in which to save the results. - algorithm : str - The algorithm name. Only `klustakwik` currently. - spike_ids : array-like - Array of spikes to cluster. - - Returns - ------- - - spike_clusters : array - The spike_clusters assignements returned by the algorithm. - - """ - if clustering is None: - clustering = 'main' - if channel_group is not None: - self.change_channel_group(channel_group) - - kk2_dir = op.join(self.settings.exp_settings_dir, 'klustakwik2') - _ensure_dir_exists(kk2_dir) - - # Take KK2's default parameters. - from klustakwik2.default_parameters import default_parameters - params = default_parameters.copy() - # Update the PRM ones, by filtering them. - params.update({k: v for k, v in self.model.metadata.items() - if k in default_parameters}) - # Update the ones passed to the function. - params.update(kwargs) - - # Original spike_clusters array. - if self.model.spike_clusters is None: - n_spikes = (len(spike_ids) if spike_ids is not None - else self.model.n_spikes) - spike_clusters_orig = np.zeros(n_spikes, dtype=np.int32) - else: - spike_clusters_orig = self.model.spike_clusters.copy() - - # HACK: there needs to be one clustering. - if 'empty' not in self.model.clusterings: - self.model.add_clustering('empty', spike_clusters_orig) - - # Instantiate the KlustaKwik instance. - kk = KlustaKwik(**params) - - # Save the current clustering in the Kwik file. - @kk.connect - def on_iter(sc): - # Update the original spike clusters. - spike_clusters = spike_clusters_orig.copy() - spike_clusters[spike_ids] = sc - # Save to a text file. - path = op.join(kk2_dir, 'spike_clusters.txt') - # Backup. - if op.exists(path): - shutil.copy(path, path + '~') - np.savetxt(path, spike_clusters, fmt='%d') - - info("Running {}...".format(algorithm)) - # Run KK. - sc = kk.cluster(model=self.model, spike_ids=spike_ids) - info("The automatic clustering process has finished.") - - # Save the results in the Kwik file. - spike_clusters = spike_clusters_orig.copy() - spike_clusters[spike_ids] = sc - - # Add a new clustering and switch to it. - if clustering in self.model.clusterings: - self.change_clustering('empty') - self.model.delete_clustering(clustering) - self.model.add_clustering(clustering, spike_clusters) - - # Copy the main clustering to original (only if this is the very - # first run of the clustering algorithm). - if clustering == 'main': - self.model.copy_clustering('main', 'original') - self.change_clustering(clustering) - - # Set the new clustering metadata. - params = kk.params - params['version'] = kk.version - metadata = {'{}_{}'.format(algorithm, name): value - for name, value in params.items()} - self.model.clustering_metadata.update(metadata) - self.save() - info("The clustering has been saved in the " - "`{}` clustering in the `.kwik` file.".format(clustering)) - self.model.delete_clustering('empty') - return sc - - # GUI - # ------------------------------------------------------------------------- - - def show_gui(self, **kwargs): - """Show a GUI.""" - gui = super(Session, self).show_gui(store=self.store, - **kwargs) - - @gui.connect - def on_request_save(): - self.save() - - return gui diff --git a/phy/session/tests/__init__.py b/phy/session/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/session/tests/test_session.py b/phy/session/tests/test_session.py deleted file mode 100644 index db44052ad..000000000 --- a/phy/session/tests/test_session.py +++ /dev/null @@ -1,401 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of session structure.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np -from numpy.testing import assert_allclose as ac -from numpy.testing import assert_equal as ae -from pytest import raises, yield_fixture, mark - -from ..session import Session -from ...utils.array import _spikes_in_clusters -from ...utils.logging import set_level -from ...io.mock import MockModel, artificial_traces -from ...io.kwik.mock import create_mock_kwik -from ...io.kwik.creator import create_kwik - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - -n_clusters = 5 -n_spikes = 50 -n_channels = 28 -n_fets = 2 -n_samples_traces = 3000 - - -def _start_manual_clustering(kwik_path=None, - model=None, - tempdir=None, - chunk_size=None, - ): - session = Session(phy_user_dir=tempdir) - session.open(kwik_path=kwik_path, model=model) - session.settings['waveforms_scale_factor'] = 1. - session.settings['features_scale_factor'] = 1. - session.settings['traces_scale_factor'] = 1. - session.settings['prompt_save_on_exit'] = False - if chunk_size is not None: - session.settings['features_masks_chunk_size'] = chunk_size - return session - - -@yield_fixture -def session(tempdir): - # Create the test HDF5 file in the temporary directory. - kwik_path = create_mock_kwik(tempdir, - n_clusters=n_clusters, - n_spikes=n_spikes, - n_channels=n_channels, - n_features_per_channel=n_fets, - n_samples_traces=n_samples_traces) - - session = _start_manual_clustering(kwik_path=kwik_path, - tempdir=tempdir) - session.tempdir = tempdir - - yield session - - session.close() - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_store_corruption(tempdir): - # Create the test HDF5 file in the temporary directory. - kwik_path = create_mock_kwik(tempdir, - n_clusters=n_clusters, - n_spikes=n_spikes, - n_channels=n_channels, - n_features_per_channel=n_fets, - n_samples_traces=n_samples_traces) - - session = Session(kwik_path, phy_user_dir=tempdir) - store_path = session.store.path - session.close() - - # Corrupt a file in the store. - fn = op.join(store_path, 'waveforms_spikes') - with open(fn, 'rb') as f: - contents = f.read() - with open(fn, 'wb') as f: - f.write(contents[1:-1]) - - session = Session(kwik_path, phy_user_dir=tempdir) - session.close() - - -def test_session_one_cluster(tempdir): - session = Session(phy_user_dir=tempdir) - # The disk store is not created if there is only one cluster. - session.open(model=MockModel(n_clusters=1)) - assert session.store.disk_store is None - - -def test_session_store_features(tempdir): - """Check that the cluster store works for features and masks.""" - - model = MockModel(n_spikes=50, n_clusters=3) - s0 = np.nonzero(model.spike_clusters == 0)[0] - s1 = np.nonzero(model.spike_clusters == 1)[0] - - session = _start_manual_clustering(model=model, - tempdir=tempdir, - chunk_size=4, - ) - - f = session.store.features(0) - m = session.store.masks(1) - w = session.store.waveforms(1) - - assert f.shape == (len(s0), 28, 2) - assert m.shape == (len(s1), 28,) - assert w.shape == (len(s1), model.n_samples_waveforms, 28,) - - ac(f, model.features[s0].reshape((f.shape[0], -1, 2)), 1e-3) - ac(m, model.masks[s1], 1e-3) - - -def test_session_gui_clustering(qtbot, session): - - cs = session.store - spike_clusters = session.model.spike_clusters.copy() - - f = session.model.features - m = session.model.masks - - def _check_arrays(cluster, clusters_for_sc=None, spikes=None): - """Check the features and masks in the cluster store - of a given custer.""" - if spikes is None: - if clusters_for_sc is None: - clusters_for_sc = [cluster] - spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc) - shape = (len(spikes), - len(session.model.channel_order), - session.model.n_features_per_channel) - ac(cs.features(cluster), f[spikes, :].reshape(shape)) - ac(cs.masks(cluster), m[spikes]) - - _check_arrays(0) - _check_arrays(2) - - gui = session.show_gui() - qtbot.addWidget(gui.main_window) - - # Merge two clusters. - clusters = [0, 2] - gui.merge(clusters) # Create cluster 5. - _check_arrays(5, clusters) - - # Split some spikes. - spikes = [2, 3, 5, 7, 11, 13] - # clusters = np.unique(spike_clusters[spikes]) - gui.split(spikes) # Create cluster 6 and more. - _check_arrays(6, spikes=spikes) - - # Undo. - gui.undo() - _check_arrays(5, clusters) - - # Undo. - gui.undo() - _check_arrays(0) - _check_arrays(2) - - # Redo. - gui.redo() - _check_arrays(5, clusters) - - # Split some spikes. - spikes = [5, 7, 11, 13, 17, 19] - # clusters = np.unique(spike_clusters[spikes]) - gui.split(spikes) # Create cluster 6 and more. - _check_arrays(6, spikes=spikes) - - # Test merge-undo-different-merge combo. - spc = gui.clustering.spikes_per_cluster.copy() - clusters = gui.cluster_ids[:3] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - # Undo. - gui.undo() - for cluster in clusters: - _check_arrays(cluster, spikes=spc[cluster]) - # Another merge. - clusters = gui.cluster_ids[1:5] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - - # Move a cluster to a group. - cluster = gui.cluster_ids[0] - gui.move([cluster], 2) - assert len(gui.store.mean_probe_position(cluster)) == 2 - - # Save. - spike_clusters_new = gui.model.spike_clusters.copy() - # Check that the spike clusters have changed. - assert not np.all(spike_clusters_new == spike_clusters) - ac(session.model.spike_clusters, gui.clustering.spike_clusters) - session.save() - - # Re-open the file and check that the spike clusters and - # cluster groups have correctly been saved. - session = _start_manual_clustering(kwik_path=session.model.path, - tempdir=session.tempdir) - ac(session.model.spike_clusters, gui.clustering.spike_clusters) - ac(session.model.spike_clusters, spike_clusters_new) - #  Check the cluster groups. - clusters = gui.clustering.cluster_ids - groups = session.model.cluster_groups - assert groups[cluster] == 2 - - gui.close() - - -def test_session_gui_multiple_clusterings(qtbot, session): - - gui = session.show_gui() - qtbot.addWidget(gui.main_window) - - assert session.model.n_spikes == n_spikes - assert session.model.n_clusters == n_clusters - assert len(session.model.cluster_ids) == n_clusters - assert gui.clustering.n_clusters == n_clusters - assert session.model.cluster_metadata.group(1) == 1 - - # Change clustering. - with raises(ValueError): - session.change_clustering('automat') - session.change_clustering('original') - - n_clusters_2 = session.model.n_clusters - assert session.model.n_spikes == n_spikes - assert session.model.n_clusters == n_clusters_2 - assert len(session.model.cluster_ids) == n_clusters_2 - assert gui.clustering.n_clusters == n_clusters_2 - assert session.model.cluster_metadata.group(2) == 2 - - # Merge the clusters and save, for the current clustering. - gui.clustering.merge(gui.clustering.cluster_ids) - session.save() - - # Re-open the session. - session = _start_manual_clustering(kwik_path=session.model.path, - tempdir=session.tempdir) - - # The default clustering is the main one: nothing should have - # changed here. - assert session.model.n_clusters == n_clusters - - session.change_clustering('original') - assert session.model.n_spikes == n_spikes - assert session.model.n_clusters == 1 - assert session.model.cluster_ids == n_clusters_2 - - gui.close() - - -def test_session_kwik(session): - - # Check backup. - assert op.exists(op.join(session.tempdir, session.kwik_path + '.bak')) - - cs = session.store - nc = n_channels - 2 - - # Check the stored items. - for cluster in range(n_clusters): - n_spikes = len(session.model.spikes_per_cluster[cluster]) - n_unmasked_channels = cs.n_unmasked_channels(cluster) - - assert cs.features(cluster).shape == (n_spikes, nc, n_fets) - assert cs.masks(cluster).shape == (n_spikes, nc) - assert cs.mean_masks(cluster).shape == (nc,) - assert n_unmasked_channels <= nc - assert cs.mean_probe_position(cluster).shape == (2,) - assert cs.main_channels(cluster).shape == (n_unmasked_channels,) - - -def test_session_gui_statistics(qtbot, session): - """Test registration of new statistic.""" - - gui = session.show_gui() - qtbot.addWidget(gui.main_window) - - @session.register_statistic - def n_spikes_2(cluster): - return gui.clustering.cluster_counts.get(cluster, 0) ** 2 - - store = gui.store - stats = store.items['statistics'] - - def _check(): - for clu in gui.cluster_ids: - assert (store.n_spikes_2(clu) == - store.features(clu).shape[0] ** 2) - - assert 'n_spikes_2' in stats.fields - _check() - - # Merge the clusters and check that the statistics has been - # recomputed for the new cluster. - clusters = gui.cluster_ids - gui.merge(clusters) - _check() - assert gui.cluster_ids == [max(clusters) + 1] - - gui.undo() - _check() - - gui.merge(gui.cluster_ids[::2]) - _check() - - gui.close() - - -@mark.parametrize('spike_ids', [None, np.arange(20)]) -def test_session_auto(session, spike_ids): - set_level('info') - sc = session.cluster(num_starting_clusters=10, - spike_ids=spike_ids, - clustering='test', - ) - assert session.model.clustering == 'test' - assert session.model.clusterings == ['main', - 'original', 'test'] - - # Re-open the dataset and check that the clustering has been saved. - session = _start_manual_clustering(kwik_path=session.model.path, - tempdir=session.tempdir) - if spike_ids is None: - spike_ids = slice(None, None, None) - assert len(session.model.spike_clusters) == n_spikes - assert not np.all(session.model.spike_clusters[spike_ids] == sc) - - assert 'klustakwik_version' not in session.model.clustering_metadata - kk_dir = op.join(session.settings.exp_settings_dir, 'klustakwik2') - - # Check temporary files with latest clustering. - files = os.listdir(kk_dir) - assert 'spike_clusters.txt' in files - if len(files) >= 2: - assert 'spike_clusters.txt~' in files - sc_txt = np.loadtxt(op.join(kk_dir, 'spike_clusters.txt')) - ae(sc, sc_txt[spike_ids]) - - for clustering in ('test',): - session.change_clustering(clustering) - assert len(session.model.spike_clusters) == n_spikes - assert np.all(session.model.spike_clusters[spike_ids] == sc) - - # Test clustering metadata. - session.change_clustering('test') - metadata = session.model.clustering_metadata - assert 'klustakwik_version' in metadata - assert metadata['klustakwik_num_starting_clusters'] == 10 - - -def test_session_detect(tempdir): - channels = range(n_channels) - graph = [[i, i + 1] for i in range(n_channels - 1)] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - }}} - sample_rate = 10000 - n_samples_traces = 10000 - traces = artificial_traces(n_samples_traces, n_channels) - assert traces is not None - - kwik_path = op.join(tempdir, 'test.kwik') - create_kwik(kwik_path=kwik_path, probe=probe, sample_rate=sample_rate) - session = Session(kwik_path, phy_user_dir=tempdir) - session.detect(traces=traces) - m = session.model - if m.n_spikes > 0: - shape = (m.n_spikes, n_channels * m.n_features_per_channel) - assert m.features.shape == shape From 0b0a75341d5f770cd17fed269a1e49cfd2329d5c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 10:03:02 +0200 Subject: [PATCH 0003/1059] phy.utils tests pass --- conftest.py | 47 ------------ phy/io/__init__.py | 4 - phy/io/mock.py | 128 ------------------------------- phy/utils/tests/test_settings.py | 10 +-- 4 files changed, 5 insertions(+), 184 deletions(-) diff --git a/conftest.py b/conftest.py index dc101ae01..b1acaf10d 100644 --- a/conftest.py +++ b/conftest.py @@ -8,15 +8,9 @@ import os -import numpy as np from pytest import yield_fixture -from phy.electrode.mea import load_probe -from phy.io.mock import artificial_traces -from phy.utils._types import Bunch from phy.utils.tempdir import TemporaryDirectory -from phy.utils.settings import _load_default_settings -from phy.utils.datasets import download_test_data #------------------------------------------------------------------------------ @@ -42,44 +36,3 @@ def chdir_tempdir(): def tempdir_bis(): with TemporaryDirectory() as tempdir: yield tempdir - - -@yield_fixture(params=['null', 'artificial', 'real']) -def raw_dataset(request): - sample_rate = 20000 - params = _load_default_settings()['spikedetekt'] - data_type = request.param - if data_type == 'real': - path = download_test_data('test-32ch-10s.dat') - traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) - traces = traces[:45000] - n_samples, n_channels = traces.shape - params['use_single_threshold'] = False - probe = load_probe('1x32_buzsaki') - else: - probe = {'channel_groups': { - 0: {'channels': [0, 1, 2], - 'graph': [[0, 1], [0, 2], [1, 2]], - }, - 1: {'channels': [3], - 'graph': [], - 'geometry': {3: [0., 0.]}, - } - }} - if data_type == 'null': - n_samples, n_channels = 25000, 4 - traces = np.zeros((n_samples, n_channels)) - elif data_type == 'artificial': - n_samples, n_channels = 25000, 4 - traces = artificial_traces(n_samples, n_channels) - traces[5000:5010, 1] *= 5 - traces[15000:15010, 3] *= 5 - n_samples_w = params['extract_s_before'] + params['extract_s_after'] - yield Bunch(n_channels=n_channels, - n_samples=n_samples, - sample_rate=sample_rate, - n_samples_waveforms=n_samples_w, - traces=traces, - params=params, - probe=probe, - ) diff --git a/phy/io/__init__.py b/phy/io/__init__.py index 7c7cb105b..17b52558c 100644 --- a/phy/io/__init__.py +++ b/phy/io/__init__.py @@ -3,9 +3,5 @@ """Input/output.""" -from .base import BaseModel, BaseSession from .h5 import File, open_h5 -from .store import ClusterStore, StoreItem from .traces import read_dat, read_kwd -from .kwik.creator import KwikCreator, create_kwik -from .kwik.model import KwikModel diff --git a/phy/io/mock.py b/phy/io/mock.py index c24260fbd..ecfcbea60 100644 --- a/phy/io/mock.py +++ b/phy/io/mock.py @@ -9,11 +9,6 @@ import numpy as np import numpy.random as nr -from ..utils._color import _random_color -from ..utils.array import _unique, _spikes_per_cluster -from .base import BaseModel, ClusterMetadata -from ..electrode.mea import MEA, staggered_positions - #------------------------------------------------------------------------------ # Artificial data @@ -49,126 +44,3 @@ def artificial_spike_samples(n_spikes, max_isi=50): def artificial_correlograms(n_clusters, n_samples): return nr.uniform(size=(n_clusters, n_clusters, n_samples)) - - -#------------------------------------------------------------------------------ -# Artificial Model -#------------------------------------------------------------------------------ - -class MockModel(BaseModel): - n_channels = 28 - n_features_per_channel = 2 - n_features = 28 * n_features_per_channel - n_spikes = 1000 - n_samples_traces = 20000 - n_samples_waveforms = 40 - n_clusters = 10 - sample_rate = 20000. - - def __init__(self, n_spikes=None, n_clusters=None): - super(MockModel, self).__init__() - if n_spikes is not None: - self.n_spikes = n_spikes - if n_clusters is not None: - self.n_clusters = n_clusters - self.name = 'mock' - self._clustering = 'main' - nfpc = self.n_features_per_channel - self._metadata = {'description': 'A mock model.', - 'n_features_per_channel': nfpc} - self._cluster_metadata = ClusterMetadata() - - @self._cluster_metadata.default - def group(cluster): - if cluster <= 2: - return cluster - # Default group is unsorted. - return 3 - - @self._cluster_metadata.default - def color(cluster): - return _random_color() - - positions = staggered_positions(self.n_channels) - self._probe = MEA(channels=self.channels, positions=positions) - self._traces = artificial_traces(self.n_samples_traces, - self.n_channels) - self._spike_clusters = artificial_spike_clusters(self.n_spikes, - self.n_clusters) - self._spike_ids = np.arange(self.n_spikes).astype(np.int64) - self._spikes_per_cluster = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) - self._spike_samples = artificial_spike_samples(self.n_spikes, 30) - assert self._spike_samples[-1] < self.n_samples_traces - self._features = artificial_features(self.n_spikes, self.n_features) - self._masks = artificial_masks(self.n_spikes, self.n_channels) - self._features_masks = np.dstack((self._features, - np.repeat(self._masks, - nfpc, - axis=1))) - self._waveforms = artificial_waveforms(self.n_spikes, - self.n_samples_waveforms, - self.n_channels) - - @property - def channels(self): - return np.arange(self.n_channels) - - @property - def channel_order(self): - return self.channels - - @property - def metadata(self): - return self._metadata - - @property - def traces(self): - return self._traces - - @property - def spike_samples(self): - return self._spike_samples - - @property - def spike_clusters(self): - return self._spike_clusters - - @property - def spikes_per_cluster(self): - return self._spikes_per_cluster - - def update_spikes_per_cluster(self, spc): - self._spikes_per_cluster = spc - - @property - def spike_ids(self): - return self._spike_ids - - @property - def cluster_ids(self): - return _unique(self._spike_clusters) - - @property - def cluster_metadata(self): - return self._cluster_metadata - - @property - def features(self): - return self._features - - @property - def masks(self): - return self._masks - - @property - def features_masks(self): - return self._features_masks - - @property - def waveforms(self): - return self._waveforms - - @property - def probe(self): - return self._probe diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index f1dc7f2d0..997d2e57c 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -34,12 +34,12 @@ def test_recursive_dirs(): def test_load_default_settings(): settings = _load_default_settings() keys = settings.keys() - assert 'log_file_level' in keys - assert 'on_open' in keys - assert 'spikedetekt' in keys - assert 'klustakwik2' in keys + # assert 'log_file_level' in keys + # assert 'on_open' in keys + # assert 'spikedetekt' in keys + # assert 'klustakwik2' in keys assert 'traces' in keys - assert 'cluster_manual_config' in keys + # assert 'cluster_manual_config' in keys def test_base_settings(): From 0ae0070a437a258cb30b5a303b4084b1ad951b0e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 10:05:12 +0200 Subject: [PATCH 0004/1059] Bump version --- phy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/__init__.py b/phy/__init__.py index 8846b78d8..1a72b8313 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -22,7 +22,7 @@ __author__ = 'Kwik team' __email__ = 'cyrille.rossant at gmail.com' -__version__ = '0.2.2' +__version__ = '0.3.0.dev0' __version_git__ = __version__ + _git_version() From 131dec9c11724ecee2ccaf637f14c3ba46350d54 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 10:25:37 +0200 Subject: [PATCH 0005/1059] WIP: improving VisPy testing --- phy/utils/testing.py | 55 +++++++++++++++----------------------------- 1 file changed, 18 insertions(+), 37 deletions(-) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index e5bd2b3d4..292bb1834 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -156,47 +156,28 @@ def _profile(prof, statement, glob, loc): # Testing VisPy canvas #------------------------------------------------------------------------------ -def _frame(canvas): - canvas.update() - canvas.app.process_events() - time.sleep(1. / 60.) - - -def show_test(canvas, n_frames=2): +def show_test(canvas): """Show a VisPy canvas for a fraction of second.""" - with canvas as c: - show_test_run(c, n_frames) - - -def show_test_start(canvas): - """This is the __enter__ of with canvas.""" - canvas.show() - canvas._backend._vispy_warmup() - - -def show_test_run(canvas, n_frames=2): - """Display frames of a canvas.""" - if n_frames == 0: - while not canvas._closed: - _frame(canvas) - else: - for _ in range(n_frames): - _frame(canvas) - if canvas._closed: - return + with canvas: + # Interactive mode for tests. + if '-i' in sys.argv: + while not canvas._closed: + canvas.update() + canvas.app.process_events() + time.sleep(1. / 60) + else: + canvas.update() + canvas.app.process_events() -def show_test_stop(canvas): - """This is the __exit__ of with canvas.""" - # ensure all GL calls are complete - if not canvas._closed: - canvas._backend._vispy_set_current() - canvas.context.finish() - canvas.close() - time.sleep(0.025) # ensure window is really closed/destroyed +# TODO +# def test_1(guibot): +# c = Canvas() +# guibot.add(c) +# guibot.wait(c) -def show_colored_canvas(color, n_frames=5): +def show_colored_canvas(color): """Show a transient VisPy canvas with a uniform background color.""" from vispy import app, gloo c = app.Canvas() @@ -205,4 +186,4 @@ def show_colored_canvas(color, n_frames=5): def on_draw(e): gloo.clear(color) - show_test(c, n_frames=n_frames) + show_test(c) From 745be8e852efcf813a445e2ae5acedc36b93d864 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:04:20 +0200 Subject: [PATCH 0006/1059] Improve logging in phy.utils --- conftest.py | 6 + phy/__init__.py | 53 ++++++-- phy/io/h5.py | 14 +-- phy/utils/__init__.py | 1 - phy/utils/array.py | 16 +-- phy/utils/datasets.py | 20 +-- phy/utils/logging.py | 210 ------------------------------- phy/utils/settings.py | 28 ++--- phy/utils/testing.py | 14 ++- phy/utils/tests/test_datasets.py | 20 ++- phy/utils/tests/test_logging.py | 72 ----------- 11 files changed, 105 insertions(+), 349 deletions(-) delete mode 100644 phy/utils/logging.py delete mode 100644 phy/utils/tests/test_logging.py diff --git a/conftest.py b/conftest.py index b1acaf10d..54053af9d 100644 --- a/conftest.py +++ b/conftest.py @@ -6,10 +6,12 @@ # Imports #------------------------------------------------------------------------------ +import logging import os from pytest import yield_fixture +from phy import add_default_handler from phy.utils.tempdir import TemporaryDirectory @@ -17,6 +19,10 @@ # Common fixtures #------------------------------------------------------------------------------ +logging.getLogger().setLevel(logging.DEBUG) +add_default_handler('DEBUG') + + @yield_fixture def tempdir(): with TemporaryDirectory() as tempdir: diff --git a/phy/__init__.py b/phy/__init__.py index 1a72b8313..94e0a8fe0 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -8,10 +8,12 @@ # Imports #------------------------------------------------------------------------------ +import logging import os.path as op -from pkg_resources import get_distribution, DistributionNotFound +import sys + +from six import StringIO -from .utils.logging import _default_logger, set_level from .utils.datasets import download_sample_data from .utils._misc import _git_version @@ -26,19 +28,48 @@ __version_git__ = __version__ + _git_version() -__all__ = ['debug', 'set_level'] +# Set a null handler on the root logger +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.NullHandler()) + + +_logger_fmt = '%(asctime)s [%(levelname)s] %(caller)s %(message)s' +_logger_date_fmt = '%H:%M:%S' + + +class _Formatter(logging.Formatter): + def format(self, record): + # Only keep the first character in the level name. + record.levelname = record.levelname[0] + filename = op.splitext(op.basename(record.pathname))[0] + record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(16) + return super(_Formatter, self).format(record) + + +def add_default_handler(level='INFO'): + handler = logging.StreamHandler() + handler.setLevel(level) + + formatter = _Formatter(fmt=_logger_fmt, + datefmt=_logger_date_fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) -# Set up the default logger. -_default_logger() +def string_handler(level='INFO'): + buffer = StringIO() + for handler in logger.handlers: + logger.removeHandler(handler) + handler = logging.StreamHandler(buffer) + logger.addHandler(handler) + return buffer -def debug(enable=True): - """Enable debug logging mode.""" - if enable: - set_level('debug') - else: - set_level('info') +if '--debug' in sys.argv: # pragma: no cover + add_default_handler('DEBUG') + logger.info("Activate DEBUG level.") def test(): diff --git a/phy/io/h5.py b/phy/io/h5.py index cdebc29ba..c42d5d0f7 100644 --- a/phy/io/h5.py +++ b/phy/io/h5.py @@ -6,11 +6,13 @@ # Imports #------------------------------------------------------------------------------ +import logging + import numpy as np import h5py from six import string_types -from ..utils.logging import debug, warn +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ @@ -180,8 +182,8 @@ def read_attr(self, path, attr_name): out = out[0].decode('UTF-8') return out except (TypeError, IOError): - debug("Unable to read attribute `{}` at `{}`.".format( - attr_name, path)) + logger.debug("Unable to read attribute `%s` at `%s`.", + attr_name, path) return else: raise KeyError("The attribute '{0:s}' ".format(attr_name) + @@ -210,11 +212,9 @@ def write_attr(self, path, attr_name, value): self._h5py_file.create_group(path) try: self._h5py_file[path].attrs[attr_name] = value - # debug("Write `{}={}` at `{}`.".format(attr_name, - # str(value), path)) except TypeError: - warn("Unable to write attribute `{}={}` at `{}`.".format( - attr_name, value, path)) + logger.warn("Unable to write attribute `%s=%s` at `%s`.", + attr_name, value, path) def attrs(self, path='/'): """Return the list of attributes at the given path.""" diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 66243eac6..e13bbe23b 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -6,5 +6,4 @@ from ._types import _is_array_like, _as_array, _as_tuple, _as_list, Bunch from .datasets import download_file, download_sample_data from .event import EventEmitter, ProgressReporter -from .logging import debug, info, warn, register, unregister, set_level from .settings import Settings, _ensure_dir_exists diff --git a/phy/utils/array.py b/phy/utils/array.py index c0a8a5a64..f4094774f 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -6,19 +6,21 @@ # Imports #------------------------------------------------------------------------------ -import os -import os.path as op -from math import floor -from operator import mul from functools import reduce +import logging import math +from math import floor +from operator import mul +import os +import os.path as op import numpy as np from six import integer_types, string_types -from .logging import warn from ._types import _as_tuple, _as_array +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Utility functions @@ -811,8 +813,8 @@ def __getitem__(self, item): # Concatenate all chunks. l = [chunk_start] if rec_stop - rec_start >= 2: - warn("Loading a full virtual array: this might be slow " - "and something might be wrong.") + logger.warn("Loading a full virtual array: this might be slow " + "and something might be wrong.") l += [self.arrs[r][...] for r in range(rec_start + 1, rec_stop)] l += [chunk_stop] diff --git a/phy/utils/datasets.py b/phy/utils/datasets.py index 9480af76a..c3dc1a49d 100644 --- a/phy/utils/datasets.py +++ b/phy/utils/datasets.py @@ -7,13 +7,15 @@ #------------------------------------------------------------------------------ import hashlib +import logging import os import os.path as op -from .logging import debug, warn from .settings import _phy_user_dir, _ensure_dir_exists from .event import ProgressReporter +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Utility functions @@ -61,7 +63,7 @@ def _download(url, stream=None): from requests import get r = get(url, stream=stream) if r.status_code != 200: - debug("Error while downloading `{}`.".format(url)) + logger.debug("Error while downloading %s.", url) r.raise_for_status() return r @@ -136,16 +138,16 @@ def download_file(url, output_path=None): if op.exists(output_path): checked = _check_md5_of_url(output_path, url) if checked is False: - debug("The file `{}` already exists ".format(output_path) + - "but is invalid: redownloading.") + logger.debug("The file `%s` already exists " + "but is invalid: redownloading.", output_path) elif checked is True: - debug("The file `{}` already exists: ".format(output_path) + - "skipping.") + logger.debug("The file `%s` already exists: skipping.", + output_path) return r = _download(url, stream=True) _save_stream(r, output_path) if _check_md5_of_url(output_path, url) is False: - debug("The checksum doesn't match: retrying the download.") + logger.debug("The checksum doesn't match: retrying the download.") r = _download(url, stream=True) _save_stream(r, output_path) if _check_md5_of_url(output_path, url) is False: @@ -187,5 +189,5 @@ def download_sample_data(filename, output_dir=None, base='cortexlab'): try: download_file(url, output_path=output_path) except Exception as e: - warn("An error occurred while downloading `{}` to `{}`: {}".format( - url, output_path, str(e))) + logger.warn("An error occurred while downloading `%s` to `%s`: %s", + url, output_path, str(e)) diff --git a/phy/utils/logging.py b/phy/utils/logging.py deleted file mode 100644 index 2608de6ce..000000000 --- a/phy/utils/logging.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import absolute_import - -"""Logger utility classes and functions.""" - - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -import os.path as op -import sys -import logging - -from six import iteritems, string_types - - -# ----------------------------------------------------------------------------- -# Stream classes -# ----------------------------------------------------------------------------- - -class StringStream(object): - """Logger stream used to store all logs in a string.""" - def __init__(self): - self.string = "" - - def write(self, line): - self.string += line - - def flush(self): - pass - - def __repr__(self): - return self.string - - -# ----------------------------------------------------------------------------- -# Custom formatter -# ----------------------------------------------------------------------------- - -_LONG_FORMAT = ('%(asctime)s [%(levelname)s] ' - '%(caller)s %(message)s', - '%Y-%m-%d %H:%M:%S') -_SHORT_FORMAT = ('%(asctime)s [%(levelname)s] %(message)s', - '%H:%M:%S') - - -class Formatter(logging.Formatter): - def format(self, record): - # Only keep the first character in the level name. - record.levelname = record.levelname[0] - filename = op.splitext(op.basename(record.pathname))[0] - record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(16) - return super(Formatter, self).format(record) - - -# ----------------------------------------------------------------------------- -# Logger classes -# ----------------------------------------------------------------------------- - -class Logger(object): - """Save logging information to a stream.""" - def __init__(self, - fmt=None, - stream=None, - level=None, - name=None, - handler=None, - ): - if stream is None: - stream = sys.stdout - if name is None: - name = self.__class__.__name__ - self.name = name - if handler is None: - self.stream = stream - self.handler = logging.StreamHandler(self.stream) - else: - self.handler = handler - self.level = level - self.fmt = fmt - # Set the level and corresponding formatter. - self.set_level(level, fmt) - - def set_level(self, level=None, fmt=None): - # Default level and format. - if level is None: - level = self.level or logging.INFO - if isinstance(level, string_types): - level = getattr(logging, level.upper()) - fmt = fmt or self.fmt - # Create the Logger object. - self._logger = logging.getLogger(self.name) - # Create the formatter. - if fmt is None: - fmt, datefmt = (_LONG_FORMAT if level == logging.DEBUG - else _SHORT_FORMAT) - else: - datefmt = None - formatter = Formatter(fmt=fmt, datefmt=datefmt) - self.handler.setFormatter(formatter) - # Configure the logger. - self._logger.setLevel(level) - self._logger.propagate = False - self._logger.addHandler(self.handler) - - def close(self): - pass - - def debug(self, msg): - self._logger.debug(msg) - - def info(self, msg): - self._logger.info(msg) - - def warn(self, msg): - self._logger.warn(msg) - - -class StringLogger(Logger): - """Log to a string.""" - def __init__(self, **kwargs): - kwargs['stream'] = StringStream() - super(StringLogger, self).__init__(**kwargs) - - def __repr__(self): - return self.stream.__repr__() - - -class ConsoleLogger(Logger): - """Log to the standard output.""" - def __init__(self, **kwargs): - kwargs['stream'] = sys.stdout - super(ConsoleLogger, self).__init__(**kwargs) - - -class FileLogger(Logger): - """Log to a file.""" - def __init__(self, filename=None, **kwargs): - kwargs['handler'] = logging.FileHandler(filename) - super(FileLogger, self).__init__(**kwargs) - - def close(self): - self.handler.close() - self._logger.removeHandler(self.handler) - del self.handler - del self._logger - - -# ----------------------------------------------------------------------------- -# Global variables -# ----------------------------------------------------------------------------- - -LOGGERS = {} - - -def register(logger): - """Register a logger.""" - name = logger.name - if name not in LOGGERS: - LOGGERS[name] = logger - - -def unregister(logger): - """Unregister a logger.""" - name = logger.name - if name in LOGGERS: - LOGGERS[name].close() - del LOGGERS[name] - - -def _log(level, *msg): - if isinstance(msg, tuple): - msg = ' '.join(str(_) for _ in msg) - for name, logger in iteritems(LOGGERS): - getattr(logger, level)(msg) - - -def debug(*msg): - """Generate a debug message.""" - _log('debug', *msg) - - -def info(*msg): - """Generate an info message.""" - _log('info', *msg) - - -def warn(*msg): - """Generate a warning.""" - _log('warn', *msg) - - -def set_level(level): - """Set the level of all registered loggers. - - Parameters - ---------- - - level : str - Can be `warn`, `info`, or `debug`. - - """ - for name, logger in iteritems(LOGGERS): - logger.set_level(level) - - -def _default_logger(level='info'): - """Create a default logger in `info` mode by default.""" - register(ConsoleLogger()) - set_level(level) diff --git a/phy/utils/settings.py b/phy/utils/settings.py index dad69b24f..92d799275 100644 --- a/phy/utils/settings.py +++ b/phy/utils/settings.py @@ -6,12 +6,14 @@ # Imports #------------------------------------------------------------------------------ +import logging import os import os.path as op -from .logging import debug, warn from ._misc import _load_json, _save_json, _read_python +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Settings @@ -22,7 +24,7 @@ def _create_empty_settings(path): # If the file exists of if it is an internal settings file, skip. if op.exists(path) or op.splitext(path)[1] == '': return - debug("Creating empty settings file: {}.".format(path)) + logger.debug("Creating empty settings file: %s.", path) with open(path, 'a') as f: f.write("# Settings file. Refer to phy's documentation " "for more details.\n") @@ -90,8 +92,6 @@ def _update(self, d): def _try_load_json(self, path): try: self._update(_load_json(path)) - # debug("Loaded internal settings file " - # "from `{}`.".format(path)) return True except Exception: return False @@ -99,8 +99,6 @@ def _try_load_json(self, path): def _try_load_python(self, path): try: self._update(_read_python(path)) - # debug("Loaded internal settings file " - # "from `{}`.".format(path)) return True except Exception: return False @@ -109,7 +107,7 @@ def load(self, path): """Load a settings file.""" path = op.realpath(path) if not op.exists(path): - debug("Settings file `{}` doesn't exist.".format(path)) + logger.debug("Settings file `{}` doesn't exist.".format(path)) return # Try JSON first, then Python. has_ext = op.splitext(path)[1] != '' @@ -117,19 +115,19 @@ def load(self, path): if self._try_load_json(path): return if not self._try_load_python(path): - warn("Unable to read '{}'. ".format(path) + - "Please try to delete this file.") + logger.warn("Unable to read '%s'. " + "Please try to delete this file.", path) def save(self, path): """Save the settings to a JSON file.""" path = op.realpath(path) try: _save_json(path, self._to_save) - debug("Saved internal settings file " - "to `{}`.".format(path)) + logger.debug("Saved internal settings file " + "to `%s`.", path) except Exception as e: - warn("Unable to save the internal settings file " - "to `{}`:\n{}".format(path, str(e))) + logger.warn("Unable to save the internal settings file " + "to `%s`:\n%s", path, str(e)) self._to_save = {} @@ -171,8 +169,8 @@ def _load_user_settings(self): def on_open(self, path): """Initialize settings when loading an experiment.""" if path is None: - debug("Unable to initialize the settings for unspecified " - "model path.") + logger.debug("Unable to initialize the settings for unspecified " + "model path.") return # Get the experiment settings path. path = op.realpath(op.expanduser(path)) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 292bb1834..7dd0fa12e 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -6,13 +6,14 @@ # Imports #------------------------------------------------------------------------------ -import sys -import time from contextlib import contextmanager -from timeit import default_timer from cProfile import Profile -import os.path as op import functools +import logging +import os.path as op +import sys +import time +from timeit import default_timer from numpy.testing import assert_array_equal as ae from numpy.testing import assert_allclose as ac @@ -20,9 +21,10 @@ from six.moves import builtins from ._types import _is_array_like -from .logging import info from .settings import _ensure_dir_exists +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Utility functions @@ -67,7 +69,7 @@ def benchmark(name='', repeats=1): start = default_timer() yield duration = (default_timer() - start) * 1000. - info("{} took {:.6f}ms.".format(name, duration / repeats)) + logger.info("%s took %.6fms.", name, duration / repeats) class ContextualProfile(Profile): diff --git a/phy/utils/tests/test_datasets.py b/phy/utils/tests/test_datasets.py index 77f356743..90dbe7ece 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/utils/tests/test_datasets.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +import logging import os.path as op from itertools import product @@ -13,24 +14,23 @@ from numpy.testing import assert_array_equal as ae import responses from pytest import raises, yield_fixture +from six import StringIO +from phy import string_handler from ..datasets import (download_file, download_test_data, download_sample_data, _check_md5_of_url, _BASE_URL, ) -from ..logging import register, StringLogger, set_level + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ # Fixtures #------------------------------------------------------------------------------ -def setup(): - set_level('debug') - - # Test URL and data _URL = 'http://test/data' _DATA = np.linspace(0., 1., 100000).astype(np.float32) @@ -116,25 +116,23 @@ def test_download_not_found(tempdir): @responses.activate def test_download_already_exists_invalid(tempdir, mock_url): - logger = StringLogger(level='debug') - register(logger) + buffer = string_handler() path = op.join(tempdir, 'test') # Create empty file. open(path, 'a').close() _check(_dl(path)) - assert 'redownload' in str(logger) + assert 'redownload' in buffer.getvalue() @responses.activate def test_download_already_exists_valid(tempdir, mock_url): - logger = StringLogger(level='debug') - register(logger) + buffer = string_handler() path = op.join(tempdir, 'test') # Create valid file. with open(path, 'ab') as f: f.write(_DATA.tostring()) _check(_dl(path)) - assert 'skip' in str(logger) + assert 'skip' in buffer.getvalue() @responses.activate diff --git a/phy/utils/tests/test_logging.py b/phy/utils/tests/test_logging.py deleted file mode 100644 index 49c427809..000000000 --- a/phy/utils/tests/test_logging.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Unit tests for logging module.""" - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -import os - -from ..logging import (StringLogger, ConsoleLogger, debug, info, warn, - FileLogger, register, unregister, - set_level) - - -# ----------------------------------------------------------------------------- -# Tests -# ----------------------------------------------------------------------------- -def test_string_logger(): - l = StringLogger(fmt='') - l.info("test 1") - l.info("test 2") - - log = str(l) - logs = log.split('\n') - - assert "test 1" in logs[0] - assert "test 2" in logs[1] - - -def test_console_logger(): - l = ConsoleLogger(fmt='') - l.info("test 1") - l.info("test 2") - - l = ConsoleLogger(level='debug') - l.info("test 1") - l.info("test 2") - - -def test_file_logger(): - logfile = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'log.txt') - l = FileLogger(logfile, fmt='', level='debug') - l.debug("test file 1") - l.debug("test file 2") - l.info("test file info") - l.warn("test file warn") - l.close() - - with open(logfile, 'r') as f: - contents = f.read() - - assert contents.strip().startswith("test file 1\ntest file 2") - - os.remove(logfile) - - -def test_register(): - l = StringLogger(fmt='') - register(l) - - set_level('info') - debug("test D1") - info("test I1") - warn("test W1") - - set_level('Debug') - debug("test D2") - info("test I2") - warn("test W2") - assert len(str(l).strip().split('\n')) == 5 - - unregister(l) From b605cce21e3d8e9e267b5fd443ab3668fb21d63b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:06:17 +0200 Subject: [PATCH 0007/1059] Update phy.traces --- phy/traces/waveform.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 4231bcee4..79faf13b4 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -6,12 +6,15 @@ # Imports #------------------------------------------------------------------------------ +import logging + import numpy as np from scipy.interpolate import interp1d from ..utils._types import _as_array, Bunch from ..utils.array import _pad -from ..utils.logging import warn + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ @@ -150,7 +153,7 @@ def align(self, waveform, s_aligned): f = interp1d(old_s, waveform, bounds_error=True, kind='cubic', axis=0) except ValueError: - warn("Interpolation error at time {0:d}".format(s)) + logger.warn("Interpolation error at time %d", s) return waveform return f(new_s) From e75c68ef59bf4b87e532751630982acf4c4530f5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:10:11 +0200 Subject: [PATCH 0008/1059] Update phy.plot --- phy/plot/_vispy_utils.py | 10 ++++++---- phy/plot/tests/test_utils.py | 19 ++++--------------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/phy/plot/_vispy_utils.py b/phy/plot/_vispy_utils.py index 19103f5df..fab6230a0 100644 --- a/phy/plot/_vispy_utils.py +++ b/phy/plot/_vispy_utils.py @@ -7,8 +7,9 @@ # Imports #------------------------------------------------------------------------------ -import os.path as op from functools import wraps +import logging +import os.path as op import numpy as np @@ -17,9 +18,10 @@ from ..utils._types import _as_array, _as_list from ..utils.array import _unique, _in_polygon -from ..utils.logging import debug from ._panzoom import PanZoom +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Misc @@ -447,13 +449,13 @@ def add(self, xy): """Add a new point.""" self._points.append((xy)) self._update_points() - debug("Add lasso point.") + logger.debug("Add lasso point.") def clear(self): """Remove all points.""" self._points = [] self._update_points() - debug("Clear lasso.") + logger.debug("Clear lasso.") def in_lasso(self, points): """Find points within the lasso. diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 2268c6f1b..f6fe6c985 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -10,10 +10,7 @@ from vispy import app -from ...utils.testing import (show_test_start, - show_test_run, - show_test_stop, - ) +from ...utils.testing import show_test from .._vispy_utils import LassoVisual from .._panzoom import PanZoom, PanZoomGrid @@ -26,9 +23,6 @@ # Tests VisPy #------------------------------------------------------------------------------ -_N_FRAMES = 2 - - class TestCanvas(app.Canvas): _pz = None @@ -57,12 +51,9 @@ def on_resize(self, event): self.context.set_viewport(0, 0, event.size[0], event.size[1]) -def _show_visual(visual, grid=False, stop=True): +def _show_visual(visual, grid=False): view = TestCanvas(visual, grid=grid) - show_test_start(view) - show_test_run(view, _N_FRAMES) - if stop: - show_test_stop(view) + show_test(view) return view @@ -74,7 +65,5 @@ def test_lasso(): [-.8, +.8], [-.8, -.8], ] - view = _show_visual(lasso, grid=True, stop=False) + view = _show_visual(lasso, grid=True) view.visual.add([+.8, -.8]) - show_test_run(view, _N_FRAMES) - show_test_stop(view) From 7e0a04621c6762bf846495d9f201d4cf757935c9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:10:59 +0200 Subject: [PATCH 0009/1059] Update phy.io --- phy/io/tests/test_mock.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/phy/io/tests/test_mock.py b/phy/io/tests/test_mock.py index 1097a93b9..124b5da65 100644 --- a/phy/io/tests/test_mock.py +++ b/phy/io/tests/test_mock.py @@ -16,7 +16,7 @@ artificial_spike_clusters, artificial_features, artificial_masks, - MockModel) + ) #------------------------------------------------------------------------------ @@ -61,19 +61,3 @@ def _test_artificial(n_spikes=None, n_clusters=None): def test_artificial(): _test_artificial(n_spikes=100, n_clusters=10) _test_artificial(n_spikes=0, n_clusters=0) - - -def test_mock_model(): - model = MockModel() - - assert model.metadata['description'] == 'A mock model.' - assert model.traces.ndim == 2 - assert model.spike_samples.ndim == 1 - assert model.spike_clusters.ndim == 1 - assert model.features.ndim == 2 - assert model.masks.ndim == 2 - assert model.waveforms.ndim == 3 - - assert isinstance(model.probe, MEA) - with raises(NotImplementedError): - model.save() From 1fe6bf06f712368d107860b58e5bd816c77a1b3f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:18:46 +0200 Subject: [PATCH 0010/1059] Update phy.gui --- phy/gui/base.py | 24 +++++----- phy/gui/qt.py | 12 +++-- phy/gui/tests/test_base.py | 97 -------------------------------------- phy/gui/tests/test_dock.py | 9 ---- phy/gui/tests/test_qt.py | 9 ---- 5 files changed, 20 insertions(+), 131 deletions(-) diff --git a/phy/gui/base.py b/phy/gui/base.py index 4df3d8cbe..3cc3e6ab6 100644 --- a/phy/gui/base.py +++ b/phy/gui/base.py @@ -9,14 +9,17 @@ from collections import Counter import inspect +import logging from six import string_types, PY3 from ..utils._misc import _show_shortcuts -from ..utils import debug, info, warn, EventEmitter +from ..utils import EventEmitter from ._utils import _read from .dock import DockWindow +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # BaseViewModel @@ -263,10 +266,10 @@ def on_close(e=None): def remove(self, widget): """Remove a widget.""" if widget in self._widgets: - debug("Remove widget {}.".format(widget)) + logger.debug("Remove widget %s.", widget) self._widgets.remove(widget) else: - debug("Unable to remove widget {}.".format(widget)) + logger.debug("Unable to remove widget %s.", widget) #------------------------------------------------------------------------------ @@ -428,7 +431,7 @@ def _load_config(self, config=None, # Add the right number of views of each type. if current_count.get(name, 0) >= requested_count.get(name, 0): continue - debug("Adding {} view in GUI.".format(name)) + logger.debug("Adding %s view in GUI.", name) # GUI-specific keyword arguments position, size, maximized self.add_view(name, **kwargs) if name not in current_count: @@ -505,14 +508,14 @@ def process_snippet(self, snippet): snippet = snippet[len(cmd):].strip() func = self._snippets.get(cmd, None) if func is None: - info("The snippet `{}` could not be found.".format(cmd)) + logger.info("The snippet `%s` could not be found.", cmd) return try: - info("Processing snippet `{}`.".format(cmd)) + logger.info("Processing snippet `%s`.", cmd) func(self, snippet) except Exception as e: - warn("Error when executing snippet `{}`: {}.".format( - cmd, str(e))) + logger.warn("Error when executing snippet `%s`: %s.", + cmd, str(e)) def _snippet_action_name(self, char): return self._snippet_chars.index(char) @@ -556,7 +559,7 @@ def enter(): ) def enable_snippet_mode(self): - info("Snippet mode enabled, press `escape` to leave this mode.") + logger.info("Snippet mode enabled, press `escape` to leave this mode.") self._remove_actions() self._create_snippet_actions() self._snippet_message = ':' @@ -567,7 +570,7 @@ def disable_snippet_mode(self): self._remove_actions() self._set_default_shortcuts() self._create_actions() - info("Snippet mode disabled.") + logger.info("Snippet mode disabled.") #-------------------------------------------------------------------------- # Public methods @@ -593,7 +596,6 @@ def add_view(self, item, title=None, **kwargs): # View model parameters from settings. vm_class = self._view_creator._widget_classes[name] kwargs.update(vm_class.get_params(self.settings)) - # debug("Create {} with {}.".format(name, kwargs)) item = self._view_creator.add(item, **kwargs) # Set the view name if necessary. if not item._view_name: diff --git a/phy/gui/qt.py b/phy/gui/qt.py index dfb46e0d7..f82db6f39 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -6,12 +6,14 @@ # Imports # ----------------------------------------------------------------------------- +import contextlib +import logging import os import sys -import contextlib from ..utils._misc import _is_interactive -from ..utils.logging import info, warn + +logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- @@ -35,7 +37,7 @@ def _check_qt(): if not _PYQT: - warn("PyQt is not available.") + logger.warn("PyQt is not available.") return False return True @@ -119,9 +121,9 @@ def enable_qt(): ip.enable_gui('qt') global _APP_RUNNING _APP_RUNNING = True - info("Qt event loop activated.") + logger.info("Qt event loop activated.") except: - warn("Qt event loop not activated.") + logger.warn("Qt event loop not activated.") # ----------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_base.py b/phy/gui/tests/test_base.py index ea606eedc..492fc30a5 100644 --- a/phy/gui/tests/test_base.py +++ b/phy/gui/tests/test_base.py @@ -7,8 +7,6 @@ # Imports #------------------------------------------------------------------------------ -import os.path as op - from pytest import raises, mark from ..base import (BaseViewModel, @@ -20,8 +18,6 @@ _set_qt_widget_position_size, ) from ...utils.event import EventEmitter -from ...utils.logging import set_level -from ...io.base import BaseModel, BaseSession # Skip these tests in "make test-quick". @@ -32,14 +28,6 @@ # Base tests #------------------------------------------------------------------------------ -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - def test_base_view_model(qtbot): class MyViewModel(BaseViewModel): _view_name = 'main_window' @@ -202,88 +190,3 @@ def _keystroke(char=None): gui.reset_gui() gui.close() - - -def test_base_session(tempdir, qtbot): - - phy_dir = op.join(tempdir, 'test.phy') - - model = BaseModel() - - class V1(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 1' - - class V2(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 2' - - vm_classes = {'v1': V1, 'v2': V2} - - config = [('v1', {'position': 'right'}), - ('v2', {'position': 'left'}), - ('v2', {'position': 'bottom'}), - ] - - shortcuts = {'test': 't', 'exit': 'ctrl+q'} - - class TestGUI(BaseGUI): - def __init__(self, **kwargs): - super(TestGUI, self).__init__(vm_classes=vm_classes, - **kwargs) - self.on_open() - - def _create_actions(self): - self._add_gui_shortcut('test') - self._add_gui_shortcut('exit') - - def test(self): - self.show_shortcuts() - self.reset_gui() - - gui_classes = {'gui': TestGUI} - - default_settings_path = op.join(tempdir, 'default_settings.py') - - with open(default_settings_path, 'w') as f: - f.write("gui_config = {}\n".format(str(config)) + - "gui_shortcuts = {}".format(str(shortcuts))) - - session = BaseSession(model=model, - phy_user_dir=phy_dir, - default_settings_paths=[default_settings_path], - vm_classes=vm_classes, - gui_classes=gui_classes, - ) - - # New GUI. - gui = session.show_gui('gui') - qtbot.addWidget(gui.main_window) - qtbot.waitForWindowShown(gui.main_window) - - # Remove a v2 view. - v2 = gui.get_views('v2') - assert len(v2) == 2 - v2[0].close() - gui.close() - - # Reopen and check that the v2 is gone. - gui = session.show_gui('gui') - qtbot.addWidget(gui.main_window) - qtbot.waitForWindowShown(gui.main_window) - - v2 = gui.get_views('v2') - assert len(v2) == 1 - - gui.reset_gui() - v2 = gui.get_views('v2') - assert len(v2) == 2 - gui.close() - - gui = session.show_gui('gui') - qtbot.addWidget(gui.main_window) - qtbot.waitForWindowShown(gui.main_window) - - v2 = gui.get_views('v2') - assert len(v2) == 2 - gui.close() diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 8ee2ce02f..0a2916972 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -12,7 +12,6 @@ from ..dock import DockWindow from ...utils._color import _random_color -from ...utils.logging import set_level # Skip these tests in "make test-quick". @@ -23,14 +22,6 @@ # Tests #------------------------------------------------------------------------------ -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - def _create_canvas(): """Create a VisPy canvas with a color background.""" c = app.Canvas() diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index d147a610e..fa4540402 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -13,7 +13,6 @@ _set_qt_widget_position_size, _prompt, ) -from ...utils.logging import set_level # Skip these tests in "make test-quick". @@ -24,14 +23,6 @@ # Tests #------------------------------------------------------------------------------ -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - def test_wrap(qtbot): view = QtWebKit.QWebView() From 073e0afaa6bace8b17ad2b700db92c3de9dd4f71 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:23:49 +0200 Subject: [PATCH 0011/1059] Update phy.cluster.algorithms --- phy/cluster/algorithms/klustakwik.py | 85 ++++++++++++++++--- .../algorithms/tests/test_klustakwik.py | 37 ++------ 2 files changed, 79 insertions(+), 43 deletions(-) diff --git a/phy/cluster/algorithms/klustakwik.py b/phy/cluster/algorithms/klustakwik.py index 46144f431..4233e8e12 100644 --- a/phy/cluster/algorithms/klustakwik.py +++ b/phy/cluster/algorithms/klustakwik.py @@ -6,9 +6,74 @@ # Imports #------------------------------------------------------------------------------ -from ...utils.array import PartialArray +import numpy as np +import six + +from ...utils.array import chunk_bounds from ...utils.event import EventEmitter -from ...io.kwik.sparse_kk2 import sparsify_features_masks + + +#------------------------------------------------------------------------------ +# Sparse structures +#------------------------------------------------------------------------------ + +def sparsify_features_masks(features, masks, chunk_size=10000): + from klustakwik2 import RawSparseData + + assert features.ndim == 2 + assert masks.ndim == 2 + assert features.shape == masks.shape + + n_spikes, num_features = features.shape + + # Stage 1: read min/max of fet values for normalization + # and count total number of unmasked features. + vmin = np.ones(num_features) * np.inf + vmax = np.ones(num_features) * (-np.inf) + total_unmasked_features = 0 + for _, _, i, j in chunk_bounds(n_spikes, chunk_size): + f, m = features[i:j], masks[i:j] + inds = m > 0 + # Replace the masked values by NaN. + vmin = np.minimum(np.min(f, axis=0), vmin) + vmax = np.maximum(np.max(f, axis=0), vmax) + total_unmasked_features += inds.sum() + # Stage 2: read data line by line, normalising + vdiff = vmax - vmin + vdiff[vdiff == 0] = 1 + fetsum = np.zeros(num_features) + fet2sum = np.zeros(num_features) + nsum = np.zeros(num_features) + all_features = np.zeros(total_unmasked_features) + all_fmasks = np.zeros(total_unmasked_features) + all_unmasked = np.zeros(total_unmasked_features, dtype=int) + offsets = np.zeros(n_spikes + 1, dtype=int) + curoff = 0 + for i in six.moves.range(n_spikes): + fetvals, fmaskvals = (features[i] - vmin) / vdiff, masks[i] + inds = (fmaskvals > 0).nonzero()[0] + masked_inds = (fmaskvals == 0).nonzero()[0] + all_features[curoff:curoff + len(inds)] = fetvals[inds] + all_fmasks[curoff:curoff + len(inds)] = fmaskvals[inds] + all_unmasked[curoff:curoff + len(inds)] = inds + offsets[i] = curoff + curoff += len(inds) + fetsum[masked_inds] += fetvals[masked_inds] + fet2sum[masked_inds] += fetvals[masked_inds] ** 2 + nsum[masked_inds] += 1 + offsets[-1] = curoff + + nsum[nsum == 0] = 1 + noise_mean = fetsum / nsum + noise_variance = fet2sum / nsum - noise_mean ** 2 + + return RawSparseData(noise_mean, + noise_variance, + all_features, + all_fmasks, + all_unmasked, + offsets, + ) #------------------------------------------------------------------------------ @@ -26,7 +91,6 @@ def __init__(self, **kwargs): self.version = __version__ def cluster(self, - model=None, spike_ids=None, features=None, masks=None, @@ -39,12 +103,6 @@ def cluster(self, Emit the `iter` event at every KlustaKwik iteration. """ - # Get the features and masks. - if model is not None: - if features is None: - features = PartialArray(model.features_masks, 0) - if masks is None: - masks = PartialArray(model.features_masks, 1) # Select some spikes if needed. if spike_ids is not None: features = features[spike_ids] @@ -70,14 +128,15 @@ def f(_): return spike_clusters -def cluster(model, algorithm='klustakwik', spike_ids=None, **kwargs): +def cluster(features=None, masks=None, algorithm='klustakwik', + spike_ids=None, **kwargs): """Launch an automatic clustering algorithm on the model. Parameters ---------- - model : BaseModel - A model. + features : ndarray + masks : ndarray algorithm : str Only 'klustakwik' is supported currently. **kwargs @@ -86,4 +145,4 @@ def cluster(model, algorithm='klustakwik', spike_ids=None, **kwargs): """ assert algorithm == 'klustakwik' kk = KlustaKwik(**kwargs) - return kk.cluster(model=model, spike_ids=spike_ids) + return kk.cluster(features=features, masks=masks, spike_ids=spike_ids) diff --git a/phy/cluster/algorithms/tests/test_klustakwik.py b/phy/cluster/algorithms/tests/test_klustakwik.py index c0315b6b0..276d87309 100644 --- a/phy/cluster/algorithms/tests/test_klustakwik.py +++ b/phy/cluster/algorithms/tests/test_klustakwik.py @@ -6,46 +6,23 @@ # Imports #------------------------------------------------------------------------------ -from ....utils.logging import set_level -from ....io.kwik import KwikModel -from ....io.kwik.mock import create_mock_kwik +from phy.io.mock import artificial_features, artificial_masks from ..klustakwik import cluster -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -def setup(): - set_level('info') - - -def teardown(): - set_level('info') - - -sample_rate = 10000 -n_samples = 25000 -n_channels = 4 - - #------------------------------------------------------------------------------ # Tests clustering #------------------------------------------------------------------------------ def test_cluster(tempdir): + n_channels = 4 n_spikes = 100 - filename = create_mock_kwik(tempdir, - n_clusters=1, - n_spikes=n_spikes, - n_channels=8, - n_features_per_channel=3, - n_samples_traces=5000) - model = KwikModel(filename) - - spike_clusters = cluster(model, num_starting_clusters=10) + features = artificial_features(n_spikes, n_channels * 3) + masks = artificial_masks(n_spikes, n_channels * 3) + + spike_clusters = cluster(features, masks, num_starting_clusters=10) assert len(spike_clusters) == n_spikes - spike_clusters = cluster(model, num_starting_clusters=10, + spike_clusters = cluster(features, masks, num_starting_clusters=10, spike_ids=range(100)) assert len(spike_clusters) == 100 From 9af8859d412b14839abe525647ee46bb9f591538 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 11:29:05 +0200 Subject: [PATCH 0012/1059] Update phy.cluster.manual --- phy/cluster/manual/__init__.py | 9 --- phy/cluster/manual/_utils.py | 81 +++++++++++++++++++++++++- phy/cluster/manual/tests/test_utils.py | 24 +++----- phy/cluster/manual/wizard.py | 21 ------- phy/utils/__init__.py | 3 +- 5 files changed, 91 insertions(+), 47 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 8d7154907..c6fa9b085 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -3,14 +3,5 @@ """Manual clustering facilities.""" -from .view_models import (BaseClusterViewModel, - HTMLClusterViewModel, - StatsViewModel, - FeatureViewModel, - WaveformViewModel, - TraceViewModel, - CorrelogramViewModel, - ) from .clustering import Clustering from .wizard import Wizard -from .gui import ClusterManualGUI diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 95887a607..c7af20a59 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -7,9 +7,10 @@ #------------------------------------------------------------------------------ from copy import deepcopy +from collections import defaultdict from ._history import History -from ...utils import Bunch, _as_list +from phy.utils import Bunch, _as_list, _is_list #------------------------------------------------------------------------------ @@ -79,6 +80,84 @@ def __repr__(self): # ClusterMetadataUpdater class #------------------------------------------------------------------------------ +class ClusterMetadata(object): + """Hold cluster metadata. + + Features + -------- + + * New metadata fields can be easily registered + * Arbitrary functions can be used for default values + + Notes + ---- + + If a metadata field `group` is registered, then two methods are + dynamically created: + + * `group(cluster)` returns the group of a cluster, or the default value + if the cluster doesn't exist. + * `set_group(cluster, value)` sets a value for the `group` metadata field. + + """ + def __init__(self, data=None): + self._fields = {} + self._data = defaultdict(dict) + # Fill the existing values. + if data is not None: + self._data.update(data) + + @property + def data(self): + return self._data + + def _get_one(self, cluster, field): + """Return the field value for a cluster, or the default value if it + doesn't exist.""" + if cluster in self._data: + if field in self._data[cluster]: + return self._data[cluster][field] + elif field in self._fields: + # Call the default field function. + return self._fields[field](cluster) + else: + return None + else: + if field in self._fields: + return self._fields[field](cluster) + else: + return None + + def _get(self, clusters, field): + if _is_list(clusters): + return [self._get_one(cluster, field) + for cluster in _as_list(clusters)] + else: + return self._get_one(clusters, field) + + def _set_one(self, cluster, field, value): + """Set a field value for a cluster.""" + self._data[cluster][field] = value + + def _set(self, clusters, field, value): + clusters = _as_list(clusters) + for cluster in clusters: + self._set_one(cluster, field, value) + + def default(self, func): + """Register a new metadata field with a function + returning the default value of a cluster.""" + field = func.__name__ + # Register the decorated function as the default field function. + self._fields[field] = func + # Create self.(clusters). + setattr(self, field, lambda clusters: self._get(clusters, field)) + # Create self.set_(clusters, value). + setattr(self, 'set_{0:s}'.format(field), + lambda clusters, value: self._set(clusters, field, value)) + return func + + class ClusterMetadataUpdater(object): """Handle cluster metadata changes.""" def __init__(self, cluster_metadata): diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 4b707080b..9c0f4e3bb 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -6,23 +6,17 @@ # Imports #------------------------------------------------------------------------------ -from ....utils.logging import set_level, debug -from .._utils import ClusterMetadataUpdater, UpdateInfo -from ....io.kwik.model import ClusterMetadata +import logging + +from .._utils import ClusterMetadata, ClusterMetadataUpdater, UpdateInfo + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - def test_metadata_history(): """Test ClusterMetadataUpdater history.""" @@ -114,7 +108,7 @@ def color(cluster): def test_update_info(): - debug(UpdateInfo(deleted=range(5), added=[5], description='merge')) - debug(UpdateInfo(deleted=range(5), added=[5], description='assign')) - debug(UpdateInfo(deleted=range(5), added=[5], - description='assign', history='undo')) + logger.debug(UpdateInfo(deleted=range(5), added=[5], description='merge')) + logger.debug(UpdateInfo(deleted=range(5), added=[5], description='assign')) + logger.debug(UpdateInfo(deleted=range(5), added=[5], + description='assign', history='undo')) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index b0c3c05d0..fe30b54d5 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -6,12 +6,9 @@ # Imports #------------------------------------------------------------------------------ -import os.path as op from operator import itemgetter from ...utils import _is_array_like -from .view_models import HTMLClusterViewModel -from ...gui._utils import _read #------------------------------------------------------------------------------ @@ -482,21 +479,3 @@ def get_panel_params(self): best_group=self._group(self.best) or 'unsorted', match_group=self._group(self.match) or 'unsorted', ) - - -#------------------------------------------------------------------------------ -# Wizard view model -#------------------------------------------------------------------------------ - -class WizardViewModel(HTMLClusterViewModel): - def get_html(self, **kwargs): - static_path = op.join(op.dirname(op.realpath(__file__)), 'static') - params = self._wizard.get_panel_params() - html = _read('wizard.html', static_path=static_path) - return html.format(**params) - - def get_css(self, **kwargs): - css = super(WizardViewModel, self).get_css(**kwargs) - static_path = op.join(op.dirname(op.realpath(__file__)), 'static') - css += _read('styles.css', static_path=static_path) - return css diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index e13bbe23b..6ea22eb9d 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -3,7 +3,8 @@ """Utilities.""" -from ._types import _is_array_like, _as_array, _as_tuple, _as_list, Bunch +from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, + Bunch, _is_list) from .datasets import download_file, download_sample_data from .event import EventEmitter, ProgressReporter from .settings import Settings, _ensure_dir_exists From 77a10683da94490093f9275d31eeb4fb8a55eb91 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 13:11:00 +0200 Subject: [PATCH 0013/1059] Update tests --- phy/io/tests/test_mock.py | 2 -- phy/traces/waveform.py | 2 +- phy/utils/testing.py | 3 ++- phy/utils/tests/test_datasets.py | 1 - phy/utils/tests/test_testing.py | 18 ++++++++++++- setup.py | 5 ++-- tests/__init__.py | 2 -- tests/scripts/test_phy_spikesort.py | 39 ----------------------------- 8 files changed, 23 insertions(+), 49 deletions(-) delete mode 100644 tests/__init__.py delete mode 100644 tests/scripts/test_phy_spikesort.py diff --git a/phy/io/tests/test_mock.py b/phy/io/tests/test_mock.py index 124b5da65..3476608ca 100644 --- a/phy/io/tests/test_mock.py +++ b/phy/io/tests/test_mock.py @@ -8,9 +8,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import raises -from ...electrode.mea import MEA from ..mock import (artificial_waveforms, artificial_traces, artificial_spike_clusters, diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 79faf13b4..db96e0329 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -354,7 +354,7 @@ def __getitem__(self, item): try: waveforms[i, ...] = self._load_at(time) except ValueError as e: - warn("Error while loading waveform: {0}".format(str(e))) + logger.warn("Error while loading waveform: %s", str(e)) if self._dc_offset: waveforms -= self._dc_offset if self._scale_factor: diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 7dd0fa12e..710b8101d 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -10,6 +10,7 @@ from cProfile import Profile import functools import logging +import os import os.path as op import sys import time @@ -162,7 +163,7 @@ def show_test(canvas): """Show a VisPy canvas for a fraction of second.""" with canvas: # Interactive mode for tests. - if '-i' in sys.argv: + if 'PYTEST_INTERACT' in os.environ: while not canvas._closed: canvas.update() canvas.app.process_events() diff --git a/phy/utils/tests/test_datasets.py b/phy/utils/tests/test_datasets.py index 90dbe7ece..9e5d0b1c4 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/utils/tests/test_datasets.py @@ -14,7 +14,6 @@ from numpy.testing import assert_array_equal as ae import responses from pytest import raises, yield_fixture -from six import StringIO from phy import string_handler from ..datasets import (download_file, diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index c443a3135..28e2c99f0 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -6,7 +6,12 @@ # Imports #------------------------------------------------------------------------------ -from ..testing import captured_output +import time + +from vispy.app import Canvas + +from ..testing import (benchmark, captured_output, show_test, + ) #------------------------------------------------------------------------------ @@ -17,3 +22,14 @@ def test_captured_output(): with captured_output() as (out, err): print('Hello world!') assert out.getvalue().strip() == 'Hello world!' + + +def test_benchmark(): + with benchmark(): + time.sleep(.002) + + +def test_canvas(): + c = Canvas(keys='interactive') + with benchmark(): + show_test(c) diff --git a/setup.py b/setup.py index 35b0c26c3..44a65d897 100644 --- a/setup.py +++ b/setup.py @@ -22,11 +22,12 @@ #------------------------------------------------------------------------------ class PyTest(TestCommand): - user_options = [('pytest-args=', 'a', "String of arguments to pass to py.test")] + user_options = [('pytest-args=', 'a', + "String of arguments to pass to py.test")] def initialize_options(self): TestCommand.initialize_options(self) - self.pytest_args = '--cov-report term-missing --cov=phy phy tests' + self.pytest_args = '--cov-report term-missing --cov=phy phy' def finalize_options(self): TestCommand.finalize_options(self) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 5788a4a03..000000000 --- a/tests/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- coding: utf-8 -*- -"""Integration and functional tests for phy.""" diff --git a/tests/scripts/test_phy_spikesort.py b/tests/scripts/test_phy_spikesort.py deleted file mode 100644 index df4c5520f..000000000 --- a/tests/scripts/test_phy_spikesort.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*-1 - -"""Tests of phy spike sorting commands.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from phy.scripts import main - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_version(): - main('-v') - - -def test_cluster_auto_prm(chdir_tempdir): - main('download hybrid_10sec.dat') - main('download hybrid_10sec.prm') - main('detect hybrid_10sec.prm') - main('cluster-auto hybrid_10sec.prm --channel-group=0') - - -def test_quick_start(chdir_tempdir): - main('download hybrid_10sec.dat') - main('download hybrid_10sec.prm') - main('spikesort hybrid_10sec.prm') - # TODO: implement auto-close - # main('cluster-manual hybrid_10sec.kwik') - - -# def test_traces(chdir_tempdir): - # TODO: implement auto-close - # main('download hybrid_10sec.dat') - # main('traces --n-channels=32 --dtype=int16 ' - # '--sample-rate=20000 --interval=0,3 hybrid_10sec.dat') From 6e4bc4edb6871dd70828b4ccbde5c67faad860d2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 13:11:45 +0200 Subject: [PATCH 0014/1059] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4da4f02fe..c6acc7356 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ wiki .*fuse* *.orig .eggs +.profile __pycache__ _old *.py[cod] From 5a5ba8bbd484f427260f699101e5e754e4a6c5d1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 13:45:36 +0200 Subject: [PATCH 0015/1059] Increase coverage in color module --- phy/utils/tests/test_color.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/phy/utils/tests/test_color.py b/phy/utils/tests/test_color.py index 94eb4ccf8..ddbf88fc2 100644 --- a/phy/utils/tests/test_color.py +++ b/phy/utils/tests/test_color.py @@ -8,7 +8,9 @@ from pytest import mark -from .._color import _random_color, _is_bright, _random_bright_color +from .._color import (_random_color, _is_bright, _random_bright_color, + _selected_clusters_colors, + ) from ..testing import show_colored_canvas @@ -24,4 +26,11 @@ def test_random_color(): color = _random_color() show_colored_canvas(color) - assert _is_bright(_random_bright_color()) + for _ in range(10): + assert _is_bright(_random_bright_color()) + + +def test_selected_clusters_colors(): + assert _selected_clusters_colors().ndim == 2 + assert len(_selected_clusters_colors(3)) == 3 + assert len(_selected_clusters_colors(10)) == 10 From fedff54e15c0e5b347d96c202d5e827da28477f8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 13:49:12 +0200 Subject: [PATCH 0016/1059] Fix random seed in tests --- conftest.py | 4 ++++ phy/utils/_color.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index 54053af9d..1e56705d9 100644 --- a/conftest.py +++ b/conftest.py @@ -7,6 +7,7 @@ #------------------------------------------------------------------------------ import logging +import numpy as np import os from pytest import yield_fixture @@ -22,6 +23,9 @@ logging.getLogger().setLevel(logging.DEBUG) add_default_handler('DEBUG') +# Fix the random seed in the tests. +np.random.seed(2015) + @yield_fixture def tempdir(): diff --git a/phy/utils/_color.py b/phy/utils/_color.py index 6edc3cc0a..da2dd1e39 100644 --- a/phy/utils/_color.py +++ b/phy/utils/_color.py @@ -8,7 +8,7 @@ import numpy as np -from random import uniform +from numpy.random import uniform from colorsys import hsv_to_rgb From 100cbb896857ed5953063ce808989da202da1dc1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 13:54:30 +0200 Subject: [PATCH 0017/1059] Add QByteArray tests --- phy/utils/_misc.py | 19 +------------------ phy/utils/tests/test_misc.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index db70effbf..e7420cc96 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -17,28 +17,11 @@ import numpy as np from six import string_types, exec_ -from six.moves import builtins, cPickle +from six.moves import builtins from ._types import _is_integer -#------------------------------------------------------------------------------ -# Pickle utility functions -#------------------------------------------------------------------------------ - -def _load_pickle(path): - path = op.realpath(op.expanduser(path)) - assert op.exists(path) - with open(path, 'rb') as f: - return cPickle.load(f) - - -def _save_pickle(path, data): - path = op.realpath(op.expanduser(path)) - with open(path, 'wb') as f: - cPickle.dump(data, f, protocol=2) - - #------------------------------------------------------------------------------ # JSON utility functions #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index a40488136..b77775d6a 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -13,8 +13,11 @@ import numpy as np from numpy.testing import assert_array_equal as ae from pytest import raises +from six import string_types -from .._misc import _git_version, _load_json, _save_json +from .._misc import (_git_version, _load_json, _save_json, + _encode_qbytearray, _decode_qbytearray, + ) #------------------------------------------------------------------------------ @@ -53,6 +56,20 @@ def test_json_numpy(tempdir): assert d['b'] == d_bis['b'] +def test_qbytearray(): + + from phy.gui.qt import QtCore + arr = QtCore.QByteArray() + arr.append('1') + arr.append('2') + arr.append('3') + + encoded = _encode_qbytearray(arr) + assert isinstance(encoded, string_types) + decoded = _decode_qbytearray(encoded) + assert arr == decoded + + def test_git_version(): v = _git_version() From 9fe6ecd8aeefc9334cf4b5fa76f2dad687b3c999 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 13:59:21 +0200 Subject: [PATCH 0018/1059] WIP: increase coverage of utils._misc --- phy/utils/_misc.py | 2 +- phy/utils/tests/test_misc.py | 37 +++++++++++++++++++++--------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index e7420cc96..d99c54bfe 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -51,7 +51,7 @@ def default(self, obj): return {'__qbytearray__': _encode_qbytearray(obj)} elif isinstance(obj, np.generic): return np.asscalar(obj) - return super(_CustomEncoder, self).default(obj) + return super(_CustomEncoder, self).default(obj) # pragma: no cover def _json_custom_hook(d): diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index b77775d6a..1c3cb2242 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -24,8 +24,29 @@ # Tests #------------------------------------------------------------------------------ +def test_qbytearray(tempdir): + + from phy.gui.qt import QtCore + arr = QtCore.QByteArray() + arr.append('1') + arr.append('2') + arr.append('3') + + encoded = _encode_qbytearray(arr) + assert isinstance(encoded, string_types) + decoded = _decode_qbytearray(encoded) + assert arr == decoded + + # Test JSON serialization of QByteArray. + d = {'arr': arr} + path = op.join(tempdir, 'test') + _save_json(path, d) + d_bis = _load_json(path) + assert d == d_bis + + def test_json_simple(tempdir): - d = {'a': 1, 'b': 'bb', 3: '33'} + d = {'a': 1, 'b': 'bb', 3: '33', 'mock': {'mock': True}} path = op.join(tempdir, 'test') _save_json(path, d) @@ -56,20 +77,6 @@ def test_json_numpy(tempdir): assert d['b'] == d_bis['b'] -def test_qbytearray(): - - from phy.gui.qt import QtCore - arr = QtCore.QByteArray() - arr.append('1') - arr.append('2') - arr.append('3') - - encoded = _encode_qbytearray(arr) - assert isinstance(encoded, string_types) - decoded = _decode_qbytearray(encoded) - assert arr == decoded - - def test_git_version(): v = _git_version() From 4b1a708456a764fa6778d9ee8e0c48dfca6ad940 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 14:03:39 +0200 Subject: [PATCH 0019/1059] WIP: increase coverage --- phy/gui/base.py | 18 +++++++++++++++++- phy/utils/_misc.py | 38 ++------------------------------------ 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/phy/gui/base.py b/phy/gui/base.py index 3cc3e6ab6..cba29ddab 100644 --- a/phy/gui/base.py +++ b/phy/gui/base.py @@ -13,7 +13,6 @@ from six import string_types, PY3 -from ..utils._misc import _show_shortcuts from ..utils import EventEmitter from ._utils import _read from .dock import DockWindow @@ -291,6 +290,23 @@ def _assert_counters_equal(c_0, c_1): assert c_0 == c_1 +def _show_shortcut(shortcut): + if isinstance(shortcut, string_types): + return shortcut + elif isinstance(shortcut, tuple): + return ', '.join(shortcut) + + +def _show_shortcuts(shortcuts, name=''): + print() + if name: + name = ' for ' + name + print('Keyboard shortcuts' + name) + for name in sorted(shortcuts): + print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) + print() + + class BaseGUI(EventEmitter): """Base GUI. diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index d99c54bfe..a2c4f075c 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -13,7 +13,6 @@ import os import sys import subprocess -from inspect import getargspec import numpy as np from six import string_types, exec_ @@ -118,23 +117,7 @@ def _read_python(path): return metadata -def _fun_arg_count(f): - """Return the number of arguments of a function. - - WARNING: with methods, only works if the first argument is named 'self'. - - """ - args = getargspec(f).args - if args and args[0] == 'self': - args = args[1:] - return len(args) - - -def _is_in_ipython(): - return '__IPYTHON__' in dir(builtins) - - -def _is_interactive(): +def _is_interactive(): # pragma: no cover """Determine whether the user has requested interactive mode.""" # The Python interpreter sets sys.flags correctly, so use them! if sys.flags.interactive: @@ -154,23 +137,6 @@ def _is_interactive(): return False -def _show_shortcut(shortcut): - if isinstance(shortcut, string_types): - return shortcut - elif isinstance(shortcut, tuple): - return ', '.join(shortcut) - - -def _show_shortcuts(shortcuts, name=''): - print() - if name: - name = ' for ' + name - print('Keyboard shortcuts' + name) - for name in sorted(shortcuts): - print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) - print() - - def _git_version(): curdir = os.getcwd() filedir, _ = op.split(__file__) @@ -182,7 +148,7 @@ def _git_version(): '--always', '--tags'], stderr=fnull).strip().decode('ascii')) return version - except (OSError, subprocess.CalledProcessError): + except (OSError, subprocess.CalledProcessError): # pragma: no cover return "" finally: os.chdir(curdir) From 7e74a5ddd12eb8877ebd5fe277ad61775fe40a10 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 14:17:46 +0200 Subject: [PATCH 0020/1059] WIP: increase coverage in utils --- phy/utils/_types.py | 17 +++++----- phy/utils/tests/test_array.py | 12 +------ phy/utils/tests/test_types.py | 59 +++++++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/phy/utils/_types.py b/phy/utils/_types.py index 8b1f11fb1..f68432b39 100644 --- a/phy/utils/_types.py +++ b/phy/utils/_types.py @@ -43,17 +43,14 @@ def _is_float(x): return isinstance(x, (float, np.float32, np.float64)) -def _as_int(x): - if isinstance(x, integer_types): - return x - x = np.asscalar(x) - return x - - def _as_list(obj): """Ensure an object is a list.""" - if isinstance(obj, string_types): + if obj is None: + return None + elif isinstance(obj, string_types): return [obj] + elif isinstance(obj, tuple): + return list(obj) elif not hasattr(obj, '__len__'): return [obj] else: @@ -70,6 +67,8 @@ def _as_array(arr, dtype=None): Avoid a copy if possible. """ + if arr is None: + return None if isinstance(arr, np.ndarray) and dtype is None: return arr if isinstance(arr, integer_types + (float,)): @@ -88,8 +87,6 @@ def _as_tuple(item): """Ensure an item is a tuple.""" if item is None: return None - # elif hasattr(item, '__len__'): - # return tuple(item) elif not isinstance(item, tuple): return (item,) else: diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index c7688ed75..52c042dd7 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -12,7 +12,7 @@ import numpy as np from pytest import raises, mark -from .._types import _as_array, _as_tuple +from .._types import _as_array from ..array import (_unique, _normalize, _index_of, @@ -167,16 +167,6 @@ def test_index_of(): ae(_index_of(arr, lookup), [1, 2, 2, 1, 1, 0, 2]) -def test_as_tuple(): - assert _as_tuple(3) == (3,) - assert _as_tuple((3,)) == (3,) - assert _as_tuple(None) is None - assert _as_tuple((None,)) == (None,) - assert _as_tuple((3, 4)) == (3, 4) - assert _as_tuple([3]) == ([3], ) - assert _as_tuple([3, 4]) == ([3, 4], ) - - def test_len_index(): arr = np.arange(10) diff --git a/phy/utils/tests/test_types.py b/phy/utils/tests/test_types.py index bbb7fdf3d..517e7a211 100644 --- a/phy/utils/tests/test_types.py +++ b/phy/utils/tests/test_types.py @@ -8,7 +8,9 @@ import numpy as np -from .._types import Bunch, _is_integer +from .._types import (Bunch, _is_integer, _is_list, _is_float, + _as_list, _is_array_like, _as_array, _as_tuple, + ) #------------------------------------------------------------------------------ @@ -21,9 +23,60 @@ def test_bunch(): assert obj.a == 1 obj.b = 2 assert obj['b'] == 2 + assert obj.copy() == obj -def test_integer(): +def test_number(): + assert not _is_integer(None) + assert not _is_integer(3.) assert _is_integer(3) assert _is_integer(np.arange(1)[0]) - assert not _is_integer(3.) + + assert not _is_float(None) + assert not _is_float(3) + assert not _is_float(np.array([3])[0]) + assert _is_float(3.) + assert _is_float(np.array([3.])[0]) + + +def test_list(): + assert not _is_list(None) + assert not _is_list(()) + assert _is_list([]) + + assert _as_list(None) is None + assert _as_list(3) == [3] + assert _as_list([3]) == [3] + assert _as_list((3,)) == [3] + assert _as_list('3') == ['3'] + assert np.all(_as_list(np.array([3])) == np.array([3])) + + +def test_as_tuple(): + assert _as_tuple(3) == (3,) + assert _as_tuple((3,)) == (3,) + assert _as_tuple(None) is None + assert _as_tuple((None,)) == (None,) + assert _as_tuple((3, 4)) == (3, 4) + assert _as_tuple([3]) == ([3], ) + assert _as_tuple([3, 4]) == ([3, 4], ) + + +def test_array(): + def _check(arr): + assert isinstance(arr, np.ndarray) + assert np.all(arr == [3]) + + _check(_as_array(3)) + _check(_as_array(3.)) + _check(_as_array([3])) + + _check(_as_array(3, np.float)) + _check(_as_array(3., np.float)) + _check(_as_array([3], np.float)) + + assert _as_array(None) is None + assert not _is_array_like(None) + assert not _is_array_like(3) + assert _is_array_like([3]) + assert _is_array_like(np.array([3])) From 86cd53a3b311d4355ebf161db17cfcd686609897 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 14:34:57 +0200 Subject: [PATCH 0021/1059] Remove virtual array functions --- phy/io/traces.py | 19 +++-- phy/utils/array.py | 130 ----------------------------- phy/utils/tests/test_array.py | 148 ---------------------------------- 3 files changed, 13 insertions(+), 284 deletions(-) diff --git a/phy/io/traces.py b/phy/io/traces.py index e61651c90..990221834 100644 --- a/phy/io/traces.py +++ b/phy/io/traces.py @@ -10,8 +10,6 @@ import numpy as np -from ..utils.array import _concatenate_virtual_arrays - #------------------------------------------------------------------------------ # Raw data readers @@ -23,14 +21,23 @@ def read_kwd(kwd_handle): The output is a memory-mapped file. """ + import dask + f = kwd_handle if '/recordings' not in f: return recordings = f.children('/recordings') - traces = [] - for recording in recordings: - traces.append(f.read('/recordings/{}/data'.format(recording))) - return _concatenate_virtual_arrays(traces) + + def _read(idx): + return f.read('/recordings/{}/data'.format(recordings[idx])) + + dsk = {('data', idx): (_read, idx) for idx in range(len(recordings))} + + chunks = (tuple(_read(idx).shape[0] for idx in range(len(recordings))), + tuple(_read(idx).shape[1] for idx in range(len(recordings))), + ) + + return dask.Array(dsk, 'data', chunks) def read_dat(filename, dtype=None, shape=None, offset=0, n_channels=None): diff --git a/phy/utils/array.py b/phy/utils/array.py index f4094774f..8111ebeef 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -724,133 +724,3 @@ def subset(self, spike_ids=None, clusters=None, spc=None): _as_array(self._spc[cluster])) arrays_s[cluster] = _as_array(self._arrays[cluster])[spk_rel] return PerClusterData(spc=spc, arrays=arrays_s) - - -# ----------------------------------------------------------------------------- -# PartialArray -# ----------------------------------------------------------------------------- - -class PartialArray(object): - """Proxy to a view of an array, allowing selection along the first - dimensions and fixing the trailing dimensions.""" - def __init__(self, arr, trailing_index=None): - self._arr = arr - self._trailing_index = _as_tuple(trailing_index) - self.shape = _partial_shape(arr.shape, self._trailing_index) - self.dtype = arr.dtype - self.ndim = len(self.shape) - - def __getitem__(self, item): - if self._trailing_index is None: - return self._arr[item] - else: - item = _as_tuple(item) - k = len(item) - n = len(self._arr.shape) - t = len(self._trailing_index) - if k < (n - t): - item += (slice(None, None, None),) * (n - k - t) - item += self._trailing_index - if len(item) != n: - raise ValueError("The array selection is invalid: " - "{0}".format(str(item))) - return self._arr[item] - - def __len__(self): - return self.shape[0] - - -class ConcatenatedArrays(object): - """This object represents a concatenation of several memory-mapped - arrays.""" - def __init__(self, arrs): - assert isinstance(arrs, list) - self.arrs = arrs - self.offsets = np.concatenate([[0], np.cumsum([arr.shape[0] - for arr in arrs])], - axis=0) - self.dtype = arrs[0].dtype if arrs else None - self.shape = (self.offsets[-1],) + arrs[0].shape[1:] - - def _get_recording(self, index): - """Return the recording that contains a given index.""" - assert index >= 0 - recs = np.nonzero((index - self.offsets[:-1]) >= 0)[0] - if len(recs) == 0: - # If the index is greater than the total size, - # return the last recording. - return len(self.arrs) - 1 - # Return the last recording such that the index is greater than - # its offset. - return recs[-1] - - def __getitem__(self, item): - # Get the start and stop indices of the requested item. - start, stop = _start_stop(item) - # Return the concatenation of all arrays. - if start is None and stop is None: - return np.concatenate(self.arrs, axis=0) - if start is None: - start = 0 - if stop is None: - stop = self.offsets[-1] - if stop < 0: - stop = self.offsets[-1] + stop - # Get the recording indices of the first and last item. - rec_start = self._get_recording(start) - rec_stop = self._get_recording(stop) - assert 0 <= rec_start <= rec_stop < len(self.arrs) - # Find the start and stop relative to the arrays. - start_rel = start - self.offsets[rec_start] - stop_rel = stop - self.offsets[rec_stop] - # Single array case. - if rec_start == rec_stop: - # Apply the rest of the index. - return _fill_index(self.arrs[rec_start][start_rel:stop_rel], - item) - chunk_start = self.arrs[rec_start][start_rel:] - chunk_stop = self.arrs[rec_stop][:stop_rel] - # Concatenate all chunks. - l = [chunk_start] - if rec_stop - rec_start >= 2: - logger.warn("Loading a full virtual array: this might be slow " - "and something might be wrong.") - l += [self.arrs[r][...] for r in range(rec_start + 1, - rec_stop)] - l += [chunk_stop] - # Apply the rest of the index. - return _fill_index(np.concatenate(l, axis=0), item) - - def __len__(self): - return self.shape[0] - - -class VirtualMappedArray(object): - """A virtual mapped array that yields null arrays to any selection.""" - def __init__(self, shape, dtype, fill=0): - self.shape = shape - self.dtype = dtype - self.ndim = len(self.shape) - self._fill = fill - - def __getitem__(self, item): - if isinstance(item, integer_types): - return self._fill * np.ones(self.shape[1:], dtype=self.dtype) - else: - assert not isinstance(item, tuple) - n = _len_index(item, max_len=self.shape[0]) - return self._fill * np.ones((n,) + self.shape[1:], - dtype=self.dtype) - - def __len__(self): - return self.shape[0] - - -def _concatenate_virtual_arrays(arrs): - """Return a virtual concatenate of several NumPy arrays.""" - n = len(arrs) - if n == 0: - return None - elif n == 1: - return arrs[0] - return ConcatenatedArrays(arrs) diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index 52c042dd7..f32ba97aa 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -27,13 +27,9 @@ excerpts, data_chunk, get_excerpts, - PartialArray, - VirtualMappedArray, PerClusterData, - _partial_shape, _range_from_slice, _pad, - _concatenate_virtual_arrays, _load_arrays, _save_arrays, ) @@ -195,31 +191,6 @@ def __getitem__(self, item): _check[start::3, 10] -def test_virtual_mapped_array(): - shape = (10, 2) - dtype = np.float32 - arr = VirtualMappedArray(shape, dtype, 1) - arr_actual = np.ones(shape, dtype=dtype) - - class _Check(object): - def __getitem__(self, item): - ae(arr[item], arr_actual[item]) - - _check = _Check() - - for start in (0, 1, 2): - _check[start] - _check[start:1] - _check[start:2] - _check[start:3] - _check[start:3:2] - _check[start:5] - _check[start:5:2] - _check[start:] - _check[start::2] - _check[start::3] - - def test_as_array(): ae(_as_array(3), [3]) ae(_as_array([3]), [3]) @@ -230,60 +201,6 @@ def test_as_array(): _as_array(map) -def test_concatenate_virtual_arrays(): - arr1 = np.random.rand(5, 2) - arr2 = np.random.rand(4, 2) - - def _concat(*arrs): - return np.concatenate(arrs, axis=0) - - # Single array. - concat = _concatenate_virtual_arrays([arr1]) - ae(concat[:], arr1) - ae(concat[1:], arr1[1:]) - ae(concat[:3], arr1[:3]) - ae(concat[1:4], arr1[1:4]) - - # Two arrays. - concat = _concatenate_virtual_arrays([arr1, arr2]) - # First array. - ae(concat[1:], _concat(arr1[1:], arr2)) - ae(concat[:3], arr1[:3]) - ae(concat[1:4], arr1[1:4]) - # Second array. - ae(concat[5:], arr2) - ae(concat[6:], arr2[1:]) - ae(concat[5:8], arr2[:3]) - ae(concat[7:9], arr2[2:]) - ae(concat[7:12], arr2[2:]) - ae(concat[5:-1], arr2[:-1]) - # Both arrays. - ae(concat[:], _concat(arr1, arr2)) - ae(concat[1:], _concat(arr1[1:], arr2)) - ae(concat[:-1], _concat(arr1, arr2[:-1])) - ae(concat[:9], _concat(arr1, arr2)) - ae(concat[:10], _concat(arr1, arr2)) - ae(concat[:8], _concat(arr1, arr2[:-1])) - ae(concat[1:7], _concat(arr1[1:], arr2[:-2])) - ae(concat[4:7], _concat(arr1[4:], arr2[:-2])) - - # Check second axis. - for idx in (slice(None, None, None), - 0, - 1, - [0], - [1], - [0, 1], - [1, 0], - ): - # First array. - ae(concat[1:4, idx], arr1[1:4, idx]) - # Second array. - ae(concat[6:, idx], arr2[1:, idx]) - # Both arrays. - ae(concat[1:7, idx], _concat(arr1[1:, idx], arr2[:-2, idx])) - - def test_in_polygon(): polygon = [[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]] points = np.random.uniform(size=(100, 2), low=-1, high=1) @@ -533,68 +450,3 @@ def _check(pcd): # From dicts. pcd = PerClusterData(spc=spc, arrays=arrays) _check(pcd) - - -#------------------------------------------------------------------------------ -# Test PartialArray -#------------------------------------------------------------------------------ - -def test_partial_shape(): - - _partial_shape(None, ()) - _partial_shape((), None) - _partial_shape((), ()) - _partial_shape(None, None) - - assert _partial_shape((5, 3), 1) == (5,) - assert _partial_shape((5, 3), (1,)) == (5,) - assert _partial_shape((5, 10, 2), 1) == (5, 10) - with raises(ValueError): - _partial_shape((5, 10, 2), (1, 2)) - assert _partial_shape((5, 10, 3), (1, 2)) == (5,) - assert _partial_shape((5, 10, 3), (slice(None, None, None), 2)) == (5, 10) - assert _partial_shape((5, 10, 3), (slice(1, None, None), 2)) == (5, 9) - assert _partial_shape((5, 10, 3), (slice(1, 5, None), 2)) == (5, 4) - assert _partial_shape((5, 10, 3), (slice(4, None, 3), 2)) == (5, 2) - - -def test_partial_array(): - # 2D array. - arr = np.random.rand(5, 2) - - ae(PartialArray(arr)[:], arr) - - pa = PartialArray(arr, 1) - assert pa.shape == (5,) - ae(pa[0], arr[0, 1]) - ae(pa[0:2], arr[0:2, 1]) - ae(pa[[1, 2]], arr[[1, 2], 1]) - with raises(ValueError): - pa[[1, 2], 0] - - # 3D array. - arr = np.random.rand(5, 3, 2) - - pa = PartialArray(arr, (2, 1)) - assert pa.shape == (5,) - ae(pa[0], arr[0, 2, 1]) - ae(pa[0:2], arr[0:2, 2, 1]) - ae(pa[[1, 2]], arr[[1, 2], 2, 1]) - with raises(ValueError): - pa[[1, 2], 0] - - pa = PartialArray(arr, (1,)) - assert pa.shape == (5, 3) - ae(pa[0, 2], arr[0, 2, 1]) - ae(pa[0:2, 1], arr[0:2, 1, 1]) - ae(pa[[1, 2], 0], arr[[1, 2], 0, 1]) - ae(pa[[1, 2]], arr[[1, 2], :, 1]) - - # Slice and 3D. - arr = np.random.rand(5, 10, 2) - - pa = PartialArray(arr, (slice(1, None, 3), 1)) - assert pa.shape == (5, 3) - ae(pa[0], arr[0, 1::3, 1]) - ae(pa[0:2], arr[0:2, 1::3, 1]) - ae(pa[[1, 2]], arr[[1, 2], 1::3, 1]) From 882e9394963e7aedaa1a9c9f22d48b16ccc980d1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 14:48:57 +0200 Subject: [PATCH 0022/1059] Remove some array functions --- phy/utils/array.py | 164 +--------------------------------- phy/utils/tests/test_array.py | 52 +---------- 2 files changed, 7 insertions(+), 209 deletions(-) diff --git a/phy/utils/array.py b/phy/utils/array.py index 8111ebeef..02fc921a3 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -6,18 +6,14 @@ # Imports #------------------------------------------------------------------------------ -from functools import reduce import logging import math from math import floor -from operator import mul -import os import os.path as op import numpy as np -from six import integer_types, string_types -from ._types import _as_tuple, _as_array +from ._types import _as_array logger = logging.getLogger(__name__) @@ -118,6 +114,8 @@ def _index_of(arr, lookup): """ # Equivalent of np.digitize(arr, lookup) - 1, but much faster. # TODO: assertions to disable in production for performance reasons. + # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large + # values m = (lookup.max() if len(lookup) else 0) + 1 tmp = np.zeros(m + 1, dtype=np.int) # Ensure that -1 values are kept. @@ -127,29 +125,6 @@ def _index_of(arr, lookup): return tmp[arr] -def index_of_small(arr, lookup): - """Faster on small arrays with large values.""" - return np.searchsorted(lookup, arr) - - -def _partial_shape(shape, trailing_index): - """Return the shape of a partial array.""" - if shape is None: - shape = () - if trailing_index is None: - trailing_index = () - trailing_index = _as_tuple(trailing_index) - # Length of the selection items for the partial array. - len_item = len(shape) - len(trailing_index) - # Array for the trailing dimensions. - _arr = np.empty(shape=shape[len_item:]) - try: - trailing_arr = _arr[trailing_index] - except IndexError: - raise ValueError("The partial shape index is invalid.") - return shape[:len_item] + trailing_arr.shape - - def _pad(arr, n, dir='right'): """Pad an array with zeros along the first axis. @@ -189,55 +164,6 @@ def _pad(arr, n, dir='right'): return out -def _start_stop(item): - """Find the start and stop indices of a __getitem__ item. - - This is used only by ConcatenatedArrays. - - Only two cases are supported currently: - - * Single integer. - * Contiguous slice in the first dimension only. - - """ - if isinstance(item, tuple): - item = item[0] - if isinstance(item, slice): - # Slice. - if item.step not in (None, 1): - return NotImplementedError() - return item.start, item.stop - elif isinstance(item, (list, np.ndarray)): - # List or array of indices. - return np.min(item), np.max(item) - else: - # Integer. - return item, item + 1 - - -def _len_index(item, max_len=0): - """Return the expected length of the output of __getitem__(item).""" - if isinstance(item, (list, np.ndarray)): - return len(item) - elif isinstance(item, slice): - stop = item.stop or max_len - start = item.start or 0 - step = item.step or 1 - start = np.clip(start, 0, stop) - assert 0 <= start <= stop - return 1 + ((stop - 1 - start) // step) - else: - return 1 - - -def _fill_index(arr, item): - if isinstance(item, tuple): - item = (slice(None, None, None),) + item[1:] - return arr[item] - else: - return arr - - def _in_polygon(points, polygon): """Return the points that are inside a polygon.""" from matplotlib.path import Path @@ -249,94 +175,10 @@ def _in_polygon(points, polygon): return path.contains_points(points) -def _concatenate(arrs): - if arrs is None: - return - arrs = [_as_array(arr) for arr in arrs if arr is not None] - if not arrs: - return - return np.concatenate(arrs, axis=0) - - # ----------------------------------------------------------------------------- # I/O functions # ----------------------------------------------------------------------------- -def _prod(l): - return reduce(mul, l, 1) - - -class LazyMemmap(object): - """A memmapped array that only opens the file handle when required.""" - def __init__(self, path, dtype=None, shape=None, mode='r'): - assert isinstance(path, string_types) - assert dtype - self._path = path - self._f = None - self.mode = mode - self.dtype = dtype - self.shape = shape - self.ndim = len(shape) if shape else None - - def _open_file_if_needed(self): - if self._f is None: - self._f = np.memmap(self._path, - dtype=self.dtype, - shape=self.shape, - mode=self.mode, - ) - self.shape = self._f.shape - self.ndim = self._f.ndim - - def __getitem__(self, item): - self._open_file_if_needed() - return self._f[item] - - def __len__(self): - self._open_file_if_needed() - return len(self._f) - - def __del__(self): - if self._f is not None: - del self._f - - -def _memmap(f_or_path, dtype=None, shape=None, lazy=True): - if not lazy: - return np.memmap(f_or_path, dtype=dtype, shape=shape, mode='r') - else: - return LazyMemmap(f_or_path, dtype=dtype, shape=shape, mode='r') - - -def _file_size(f_or_path): - if isinstance(f_or_path, string_types): - with open(f_or_path, 'rb') as f: - return _file_size(f) - else: - return os.fstat(f_or_path.fileno()).st_size - - -def _load_ndarray(f_or_path, dtype=None, shape=None, memmap=False, lazy=False): - if dtype is None: - return f_or_path - else: - if not memmap: - arr = np.fromfile(f_or_path, dtype=dtype) - if shape is not None: - arr = arr.reshape(shape) - else: - # memmap doesn't accept -1 in shapes, but we can compute - # the missing dimension from the file size. - if shape and shape[0] == -1: - n_bytes = _file_size(f_or_path) - n_items = n_bytes // np.dtype(dtype).itemsize - n_rows = n_items // _prod(shape[1:]) - shape = (n_rows,) + shape[1:] - assert _prod(shape) == n_items - arr = _memmap(f_or_path, dtype=dtype, shape=shape, lazy=lazy) - return arr - - def _save_arrays(path, arrays): """Save multiple arrays in a single file by concatenating them along the first axis. diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index f32ba97aa..af6f790ef 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -7,7 +7,6 @@ #------------------------------------------------------------------------------ import os.path as op -from itertools import product import numpy as np from pytest import raises, mark @@ -17,8 +16,6 @@ _normalize, _index_of, _in_polygon, - _load_ndarray, - _len_index, _spikes_in_clusters, _spikes_per_cluster, _flatten_spikes_per_cluster, @@ -163,34 +160,6 @@ def test_index_of(): ae(_index_of(arr, lookup), [1, 2, 2, 1, 1, 0, 2]) -def test_len_index(): - arr = np.arange(10) - - class _Check(object): - def __getitem__(self, item): - if isinstance(item, tuple): - item, max_len = item - else: - max_len = None - assert _len_index(item, max_len) == (len(arr[item]) - if hasattr(arr[item], - '__len__') else 1) - - _check = _Check() - - for start in (0, 1, 2): - _check[start] - _check[start:1] - _check[start:2] - _check[start:3] - _check[start:3:2] - _check[start:5] - _check[start:5:2] - _check[start:, 10] - _check[start::2, 10] - _check[start::3, 10] - - def test_as_array(): ae(_as_array(3), [3]) ae(_as_array([3]), [3]) @@ -216,23 +185,6 @@ def test_in_polygon(): # Test I/O functions #------------------------------------------------------------------------------ -@mark.parametrize('memmap,lazy', product([False, True], [False, True])) -def test_load_ndarray(tempdir, memmap, lazy): - n, m = 10000, 100 - dtype = np.float32 - arr = np.random.randn(n, m).astype(dtype) - path = op.join(tempdir, 'test') - with open(path, 'wb') as f: - arr.tofile(f) - arr_m = _load_ndarray(path, - dtype=dtype, - shape=(n, m), - memmap=memmap, - lazy=lazy, - ) - ae(arr, arr_m[...]) - - @mark.parametrize('n', [20, 0]) def test_load_save_arrays(tempdir, n): path = op.join(tempdir, 'test.npy') @@ -317,6 +269,10 @@ def test_get_excerpts(): subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) ae(subdata, data) + data = np.random.rand(10, 2) + subdata = get_excerpts(data, n_excerpts=1, excerpt_size=10) + ae(subdata, data) + #------------------------------------------------------------------------------ # Test spike clusters functions From 97b3ee19cdd3ba8b2065e53bdc6420c1421a1b1b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 15:10:00 +0200 Subject: [PATCH 0023/1059] Increase coverage in datasets --- phy/utils/datasets.py | 18 ++++++------------ phy/utils/tests/test_datasets.py | 8 +++++++- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/phy/utils/datasets.py b/phy/utils/datasets.py index c3dc1a49d..392cf28b7 100644 --- a/phy/utils/datasets.py +++ b/phy/utils/datasets.py @@ -112,7 +112,7 @@ def _validate_output_dir(output_dir): return output_dir -def download_file(url, output_path=None): +def download_file(url, output_path): """Download a binary file from an URL. The checksum will be downloaded from `URL + .md5`. If this download @@ -123,18 +123,12 @@ def download_file(url, output_path=None): url : str The file's URL. - output_path : str or None - The path where the file is to be saved. - - Returns - ------- - output_path : str - The path where the file was downloaded. + The path where the file is to be saved. """ - if output_path is None: - output_path = url.split('/')[-1] + output_path = op.realpath(output_path) + assert output_path is not None if op.exists(output_path): checked = _check_md5_of_url(output_path, url) if checked is False: @@ -143,7 +137,7 @@ def download_file(url, output_path=None): elif checked is True: logger.debug("The file `%s` already exists: skipping.", output_path) - return + return output_path r = _download(url, stream=True) _save_stream(r, output_path) if _check_md5_of_url(output_path, url) is False: @@ -153,7 +147,7 @@ def download_file(url, output_path=None): if _check_md5_of_url(output_path, url) is False: raise RuntimeError("The checksum of the downloaded file " "doesn't match the provided checksum.") - return output_path + return def download_test_data(name, phy_user_dir=None, force=False): diff --git a/phy/utils/tests/test_datasets.py b/phy/utils/tests/test_datasets.py index 9e5d0b1c4..0f5d6a7fa 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/utils/tests/test_datasets.py @@ -50,7 +50,7 @@ def _add_mock_response(url, body, file_type='binary'): def mock_url(): _add_mock_response(_URL, _DATA.tostring()) _add_mock_response(_URL + '.md5', _CHECKSUM + ' ' + op.basename(_URL)) - yield + yield _URL responses.reset() @@ -81,6 +81,7 @@ def mock_urls(request): def _dl(path): + assert path download_file(_URL, path) with open(path, 'rb') as f: data = f.read() @@ -170,6 +171,9 @@ def test_download_sample_data(tempdir): data = f.read() ae(np.fromstring(data, np.float32), _DATA) + # Warning. + download_sample_data(name + '_', tempdir) + responses.reset() @@ -186,4 +190,6 @@ def test_dat_file(tempdir): arr = np.fromfile(f, dtype=np.int16).reshape((-1, 4)) assert arr.shape == (20000, 4) + assert download_test_data(fn, tempdir) == path + responses.reset() From 2727dcd1c18d9e79c79844f80d83d38603902ab5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 15:15:04 +0200 Subject: [PATCH 0024/1059] Increase coverage in datasets --- phy/utils/datasets.py | 13 +++++-------- phy/utils/tests/test_datasets.py | 7 +++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/phy/utils/datasets.py b/phy/utils/datasets.py index 392cf28b7..d5c577b97 100644 --- a/phy/utils/datasets.py +++ b/phy/utils/datasets.py @@ -30,7 +30,7 @@ def _remote_file_size(path): import requests - try: + try: # pragma: no cover response = requests.head(path) return int(response.headers.get('content-length', 0)) except Exception: @@ -54,15 +54,14 @@ def _save_stream(r, path): downloaded += len(chunk) if i % 100 == 0: pr.value = downloaded - if size: - assert size == downloaded + assert ((size == downloaded) if size else True) pr.set_complete() def _download(url, stream=None): from requests import get r = get(url, stream=stream) - if r.status_code != 200: + if r.status_code != 200: # pragma: no cover logger.debug("Error while downloading %s.", url) r.raise_for_status() return r @@ -86,9 +85,7 @@ def _md5(path, blocksize=2 ** 20): def _check_md5(path, checksum): - if checksum is None: - return - return _md5(path) == checksum + return (_md5(path) == checksum) if checksum else None def _check_md5_of_url(output_path, url): @@ -108,7 +105,7 @@ def _validate_output_dir(output_dir): output_dir = output_dir + '/' output_dir = op.realpath(op.dirname(output_dir)) if not op.exists(output_dir): - os.mkdir(output_dir) + os.makedirs(output_dir) return output_dir diff --git a/phy/utils/tests/test_datasets.py b/phy/utils/tests/test_datasets.py index 0f5d6a7fa..7e86ab552 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/utils/tests/test_datasets.py @@ -21,6 +21,7 @@ download_sample_data, _check_md5_of_url, _BASE_URL, + _validate_output_dir, ) logger = logging.getLogger(__name__) @@ -96,6 +97,12 @@ def _check(data): # Test utility functions #------------------------------------------------------------------------------ +def test_validate_output_dir(chdir_tempdir): + _validate_output_dir(None) + _validate_output_dir(op.join(chdir_tempdir, 'a/b/c')) + assert op.exists(op.join(chdir_tempdir, 'a/b/c/')) + + @responses.activate def test_check_md5_of_url(tempdir, mock_url): output_path = op.join(tempdir, 'data') From cef4c1a3305ba7e98c2212f096986cd8d2425bc0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 15:21:27 +0200 Subject: [PATCH 0025/1059] Increase coverage in event --- phy/utils/event.py | 2 +- phy/utils/tests/test_event.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/phy/utils/event.py b/phy/utils/event.py index 47250a561..04e799ac9 100644 --- a/phy/utils/event.py +++ b/phy/utils/event.py @@ -153,7 +153,7 @@ def format_field(self, value, spec): def _default_on_progress(message, value, value_max, end='\r', **kwargs): - if value_max == 0: + if value_max == 0: # pragma: no cover return if value <= value_max: progress = 100 * value / float(value_max) diff --git a/phy/utils/tests/test_event.py b/phy/utils/tests/test_event.py index 1b0d19d2f..8a2d19aa2 100644 --- a/phy/utils/tests/test_event.py +++ b/phy/utils/tests/test_event.py @@ -20,6 +20,9 @@ def test_event_system(): _list = [] + with raises(ValueError): + ev.connect(lambda x: x) + @ev.connect(set_method=True) def on_my_event(arg, kwarg=None): _list.append((arg, kwarg)) @@ -62,6 +65,7 @@ def on_complete(): pr.value_max = 10 pr.value = 0 pr.value = 5 + assert pr.value == 5 assert pr.progress == .5 assert not pr.is_complete() pr.value = 10 @@ -89,7 +93,8 @@ def on_complete(): def test_progress_message(): """Test messages with the progress reporter.""" pr = ProgressReporter() - pr.set_progress_message("The progress is {progress}%. ({hello})") + pr.reset(5) + pr.set_progress_message("The progress is {progress}%. ({hello:d})") pr.set_complete_message("Finished {hello}.") pr.value_max = 10 @@ -97,6 +102,10 @@ def test_progress_message(): print() pr.value = 5 print() - pr.increment(hello='hello world') + pr.increment() + print() + pr.increment(hello='hello') + print() + pr.increment(hello=3) print() pr.value = 10 From 6db0d48acc6cb542d28ad0e627e4cea4d9037d1b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 16:06:12 +0200 Subject: [PATCH 0026/1059] WIP: improve settings --- phy/utils/settings.py | 52 +++++++++++---------------- phy/utils/tests/test_settings.py | 62 +++++++++++++++++++++++++------- 2 files changed, 71 insertions(+), 43 deletions(-) diff --git a/phy/utils/settings.py b/phy/utils/settings.py index 92d799275..112ed397e 100644 --- a/phy/utils/settings.py +++ b/phy/utils/settings.py @@ -10,6 +10,8 @@ import os import os.path as op +from six import string_types + from ._misc import _load_json, _save_json, _read_python logger = logging.getLogger(__name__) @@ -89,43 +91,34 @@ def _update(self, d): else: self._store[k] = v - def _try_load_json(self, path): - try: - self._update(_load_json(path)) - return True - except Exception: - return False - - def _try_load_python(self, path): - try: - self._update(_read_python(path)) - return True - except Exception: - return False - def load(self, path): """Load a settings file.""" + if not isinstance(path, string_types): + logger.warn("The settings file `%s` is invalid.", path) + return path = op.realpath(path) if not op.exists(path): - logger.debug("Settings file `{}` doesn't exist.".format(path)) + logger.debug("The settings file `%s` doesn't exist.", path) return - # Try JSON first, then Python. - has_ext = op.splitext(path)[1] != '' - if not has_ext: - if self._try_load_json(path): - return - if not self._try_load_python(path): - logger.warn("Unable to read '%s'. " - "Please try to delete this file.", path) + try: + if op.splitext(path)[1] == '.py': + self._update(_read_python(path)) + elif op.splitext(path)[1] == '.json': + self._update(_load_json(path)) + else: + logger.warn("The settings file %s must have the extension " + "'.py' or '.json'.", path) + except Exception as e: + logger.warn("Unable to read %s. " + "Please try to delete this file. %s", path, str(e)) def save(self, path): """Save the settings to a JSON file.""" path = op.realpath(path) try: _save_json(path, self._to_save) - logger.debug("Saved internal settings file " - "to `%s`.", path) - except Exception as e: + logger.debug("Saved internal settings file to `%s`.", path) + except Exception as e: # pragma: no cover logger.warn("Unable to save the internal settings file " "to `%s`:\n%s", path, str(e)) self._to_save = {} @@ -163,15 +156,12 @@ def _load_user_settings(self): # Load the user's internal settings. self.internal_settings_path = op.join(self.phy_user_dir, - 'internal_settings') + 'internal_settings.json') self._bs.load(self.internal_settings_path) def on_open(self, path): """Initialize settings when loading an experiment.""" - if path is None: - logger.debug("Unable to initialize the settings for unspecified " - "model path.") - return + assert path is not None # Get the experiment settings path. path = op.realpath(op.expanduser(path)) self.exp_path = path diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 997d2e57c..e1bd07fbc 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -8,19 +8,36 @@ import os.path as op -from pytest import raises +from pytest import raises, yield_fixture, mark from ..settings import (BaseSettings, Settings, _load_default_settings, _recursive_dirs, + _phy_user_dir, ) +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture(params=['py', 'json']) +def settings(request): + if request.param == 'py': + yield ('py', '''a = 4\nb = 5\nd = {'k1': 2, 'k2': '3'}''') + elif request.param == 'json': + yield ('json', '''{"a": 4, "b": 5, "d": {"k1": 2, "k2": "3"}}''') + + #------------------------------------------------------------------------------ # Test settings #------------------------------------------------------------------------------ +def test_phy_user_dir(): + assert op.exists(_phy_user_dir()) + + def test_recursive_dirs(): dirs = list(_recursive_dirs()) assert len(dirs) >= 5 @@ -53,13 +70,11 @@ def test_base_settings(): assert s['a'] == 3 -def test_user_settings(tempdir): - path = op.join(tempdir, 'test.py') - - # Create a simple settings file. - contents = '''a = 4\nb = 5\nd = {'k1': 2, 'k2': 3}\n''' +def test_base_settings_file(tempdir, settings): + ext, settings = settings + path = op.join(tempdir, 'test.' + ext) with open(path, 'w') as f: - f.write(contents) + f.write(settings) s = BaseSettings() @@ -67,21 +82,35 @@ def test_user_settings(tempdir): s['c'] = 6 assert s['a'] == 3 - # Now, set the settings file. + s.load(path=None) + + # Now, load the settings file. s.load(path=path) assert s['a'] == 4 assert s['b'] == 5 assert s['c'] == 6 - assert s['d'] == {'k1': 2, 'k2': 3} + assert s['d'] == {'k1': 2, 'k2': '3'} s = BaseSettings() s['d'] = {'k2': 30, 'k3': 40} s.load(path=path) - assert s['d'] == {'k1': 2, 'k2': 3, 'k3': 40} + assert s['d'] == {'k1': 2, 'k2': '3', 'k3': 40} + + +def test_base_settings_invalid(tempdir, settings): + ext, settings = settings + settings = settings[:-2] + path = op.join(tempdir, 'test.' + ext) + with open(path, 'w') as f: + f.write(settings) + + s = BaseSettings() + s.load(path) + assert 'a' not in s def test_internal_settings(tempdir): - path = op.join(tempdir, 'test') + path = op.join(tempdir, 'test.json') s = BaseSettings() @@ -104,6 +133,10 @@ def test_internal_settings(tempdir): assert s['c'] == 6 +def test_settings_nodir(): + Settings() + + def test_settings_manager(tempdir, tempdir_bis): tempdir_exp = tempdir_bis sm = Settings(tempdir) @@ -111,15 +144,17 @@ def test_settings_manager(tempdir, tempdir_bis): # Check paths. assert sm.phy_user_dir == tempdir assert sm.internal_settings_path == op.join(tempdir, - 'internal_settings') + 'internal_settings.json') assert sm.user_settings_path == op.join(tempdir, 'user_settings.py') # User settings. with raises(KeyError): sm['a'] + assert sm.get('a', None) is None # Artificially populate the user settings. sm._bs._store['a'] = 3 assert sm['a'] == 3 + assert sm.get('a') == 3 # Internal settings. sm['c'] = 5 @@ -151,3 +186,6 @@ def test_settings_manager(tempdir, tempdir_bis): sm.on_open(path) assert sm['c'] == 50 assert 'a' not in sm + + assert len(sm.keys()) >= 10 + assert str(sm).startswith(' Date: Tue, 22 Sep 2015 16:08:50 +0200 Subject: [PATCH 0027/1059] Increase coverage in settings --- phy/utils/tests/test_settings.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index e1bd07fbc..44f6bc351 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -70,6 +70,14 @@ def test_base_settings(): assert s['a'] == 3 +def test_base_settings_wrong_extension(tempdir): + path = op.join(tempdir, 'test') + with open(path, 'w'): + pass + s = BaseSettings() + s.load(path=path) + + def test_base_settings_file(tempdir, settings): ext, settings = settings path = op.join(tempdir, 'test.' + ext) @@ -82,6 +90,7 @@ def test_base_settings_file(tempdir, settings): s['c'] = 6 assert s['a'] == 3 + # Warning: wrong path. s.load(path=None) # Now, load the settings file. From b97b2679fa13d4eba727300f28e0e08df728b45c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 16:38:50 +0200 Subject: [PATCH 0028/1059] WIP: increase coverage --- phy/utils/tests/test_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 1c3cb2242..72cdec267 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -88,5 +88,5 @@ def test_git_version(): stderr=fnull) assert v is not "", "git_version failed to return" assert v[:6] == "-git-v", "Git version does not begin in -git-v" - except (OSError, subprocess.CalledProcessError): + except (OSError, subprocess.CalledProcessError): # pragma: no cover assert v == "" From 1cdc33a36328a42e55df5a4ca27497e932f2f181 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 16:41:40 +0200 Subject: [PATCH 0029/1059] WIP: increase coverage in testing --- phy/utils/testing.py | 7 +++---- phy/utils/tests/test_testing.py | 7 +++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 710b8101d..19d0a14b8 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -52,10 +52,9 @@ def _assert_equal(d_0, d_1): ac(d_0, d_1) # Compare dicts recursively. elif isinstance(d_0, dict): - assert sorted(d_0) == sorted(d_1) - for (k_0, k_1) in zip(sorted(d_0), sorted(d_1)): - assert k_0 == k_1 - _assert_equal(d_0[k_0], d_1[k_1]) + assert set(d_0) == set(d_1) + for k_0 in d_0: + _assert_equal(d_0[k_0], d_1[k_0]) else: # General comparison. assert d_0 == d_1 diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index 28e2c99f0..61f43846f 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -8,9 +8,11 @@ import time +import numpy as np from vispy.app import Canvas from ..testing import (benchmark, captured_output, show_test, + _assert_equal, ) @@ -24,6 +26,11 @@ def test_captured_output(): assert out.getvalue().strip() == 'Hello world!' +def test_assert_equal(): + d = {'a': {'b': np.random.rand(5), 3: 'c'}, 'b': 2.} + _assert_equal(d, d.copy()) + + def test_benchmark(): with benchmark(): time.sleep(.002) From 0644a1dfc68455dd1f787088bad6322227e5d591 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 16:46:42 +0200 Subject: [PATCH 0030/1059] WIP: increase testing coverage --- phy/utils/tests/test_testing.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index 61f43846f..ae9bcfd2f 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -6,13 +6,16 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op import time import numpy as np +from pytest import mark from vispy.app import Canvas from ..testing import (benchmark, captured_output, show_test, - _assert_equal, + _assert_equal, _enable_profiler, _profile, + show_colored_canvas, ) @@ -36,7 +39,17 @@ def test_benchmark(): time.sleep(.002) +@mark.parametrize('line_by_line', [False, True]) +def test_profile(chdir_tempdir, line_by_line): + prof = _enable_profiler(line_by_line=line_by_line) + _profile(prof, 'import time; time.sleep(.001)', {}, {}) + assert op.exists(op.join(chdir_tempdir, '.profile', 'stats.txt')) + + def test_canvas(): c = Canvas(keys='interactive') - with benchmark(): - show_test(c) + show_test(c) + + +def test_show_colored_canvas(): + show_colored_canvas((.6, 0, .8)) From a3e5af9c26ad7e7db71be82912ff9aa713eda37f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 17:02:48 +0200 Subject: [PATCH 0031/1059] Increase testing coverage --- phy/utils/testing.py | 10 +++++----- phy/utils/tests/test_testing.py | 21 +++++++++++++-------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 19d0a14b8..5f2c41729 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -72,7 +72,7 @@ def benchmark(name='', repeats=1): logger.info("%s took %.6fms.", name, duration / repeats) -class ContextualProfile(Profile): +class ContextualProfile(Profile): # pragma: no cover def __init__(self, *args, **kwds): super(ContextualProfile, self).__init__(*args, **kwds) self.enable_count = 0 @@ -118,7 +118,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.disable_by_count() -def _enable_profiler(line_by_line=False): +def _enable_profiler(line_by_line=False): # pragma: no cover if 'profile' in builtins.__dict__: return builtins.__dict__['profile'] if line_by_line: @@ -138,13 +138,13 @@ def _profile(prof, statement, glob, loc): # Capture stdout. old_stdout = sys.stdout sys.stdout = output = StringIO() - try: + try: # pragma: no cover from line_profiler import LineProfiler if isinstance(prof, LineProfiler): prof.print_stats() else: prof.print_stats('cumulative') - except ImportError: + except ImportError: # pragma: no cover prof.print_stats('cumulative') sys.stdout = old_stdout stats = output.getvalue() @@ -162,7 +162,7 @@ def show_test(canvas): """Show a VisPy canvas for a fraction of second.""" with canvas: # Interactive mode for tests. - if 'PYTEST_INTERACT' in os.environ: + if 'PYTEST_INTERACT' in os.environ: # pragma: no cover while not canvas._closed: canvas.update() canvas.app.process_events() diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index ae9bcfd2f..d4175f0e3 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -6,11 +6,13 @@ # Imports #------------------------------------------------------------------------------ +from copy import deepcopy import os.path as op import time import numpy as np from pytest import mark +from six.moves import builtins from vispy.app import Canvas from ..testing import (benchmark, captured_output, show_test, @@ -31,7 +33,9 @@ def test_captured_output(): def test_assert_equal(): d = {'a': {'b': np.random.rand(5), 3: 'c'}, 'b': 2.} - _assert_equal(d, d.copy()) + d_bis = deepcopy(d) + d_bis['a']['b'] = d_bis['a']['b'] + 1e-8 + _assert_equal(d, d_bis) def test_benchmark(): @@ -39,13 +43,6 @@ def test_benchmark(): time.sleep(.002) -@mark.parametrize('line_by_line', [False, True]) -def test_profile(chdir_tempdir, line_by_line): - prof = _enable_profiler(line_by_line=line_by_line) - _profile(prof, 'import time; time.sleep(.001)', {}, {}) - assert op.exists(op.join(chdir_tempdir, '.profile', 'stats.txt')) - - def test_canvas(): c = Canvas(keys='interactive') show_test(c) @@ -53,3 +50,11 @@ def test_canvas(): def test_show_colored_canvas(): show_colored_canvas((.6, 0, .8)) + + +@mark.parametrize('line_by_line', [False, True]) +def test_profile(chdir_tempdir, line_by_line): + # Remove the profile from the builtins. + prof = _enable_profiler(line_by_line=line_by_line) + _profile(prof, 'import time; time.sleep(.001)', {}, {}) + assert op.exists(op.join(chdir_tempdir, '.profile', 'stats.txt')) From a42f1c4c226dc1ff1946e7392548d44b8777fafb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 17:09:35 +0200 Subject: [PATCH 0032/1059] WIP: increase coverage in phy.traces --- phy/traces/detect.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/phy/traces/detect.py b/phy/traces/detect.py index 44af233b0..5f0970abd 100644 --- a/phy/traces/detect.py +++ b/phy/traces/detect.py @@ -101,8 +101,7 @@ def __init__(self, assert mode in ('positive', 'negative', 'both') if isinstance(thresholds, (float, int, np.ndarray)): thresholds = {'default': thresholds} - if thresholds is None: - thresholds = {} + thresholds = thresholds if thresholds is not None else {} assert isinstance(thresholds, dict) self._mode = mode self._thresholds = thresholds @@ -144,14 +143,6 @@ def __call__(self, data, threshold=None): # Connected components # ----------------------------------------------------------------------------- -def _to_tuples(x): - return ((i, j) for (i, j) in x) - - -def _to_list(x): - return [(i, j) for (i, j) in x] - - def connected_components(weak_crossings=None, strong_crossings=None, probe_adjacency_list=None, @@ -187,8 +178,7 @@ def connected_components(weak_crossings=None, """ - if probe_adjacency_list is None: - probe_adjacency_list = {} + probe_adjacency_list = probe_adjacency_list or {} # Make sure the values are sets. probe_adjacency_list = {c: set(cs) @@ -242,7 +232,7 @@ def connected_components(weak_crossings=None, for j_s in range(i_s - join_size, i_s + 1): # Allow us to leave out a channel from the graph to exclude bad # channels - if i_ch not in mgraph: + if i_ch not in mgraph: # pragma: no cover continue for j_ch in mgraph[i_ch]: # Label of the adjacent element. From 707a83d58134ba038b7dc31e309061c2ad235d9e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 17:54:45 +0200 Subject: [PATCH 0033/1059] WIP --- phy/traces/pca.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/traces/pca.py b/phy/traces/pca.py index e9ff85093..f7e962dd5 100644 --- a/phy/traces/pca.py +++ b/phy/traces/pca.py @@ -58,7 +58,7 @@ def _compute_pcs(x, n_pcs=None, masks=None): # Don't compute the cov matrix if there are no unmasked spikes # on that channel. alpha = 1. / n_spikes - if x_channel.shape[0] <= 1: + if x_channel.shape[0] <= 1: # pragma: no cover cov = alpha * cov_reg else: cov_channel = np.cov(x_channel, rowvar=0) From 93857ca3dafd4d210d0f4ca748f5d13ffb8a10cd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 17:55:20 +0200 Subject: [PATCH 0034/1059] Add filter getter --- phy/traces/filter.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/phy/traces/filter.py b/phy/traces/filter.py index 89b748be3..5f38546cb 100644 --- a/phy/traces/filter.py +++ b/phy/traces/filter.py @@ -59,6 +59,18 @@ def __call__(self, data): return apply_filter(data, filter=self._filter) +def _filter_and_margin(**kwargs): + + b_filter = bandpass_filter(**kwargs) + + def filter(x): + return apply_filter(x, b_filter) + + filter_margin = kwargs['order'] * 3 + + return filter, filter_margin + + #------------------------------------------------------------------------------ # Whitening #------------------------------------------------------------------------------ From 121b51acd684bc7b9bf5050e1a9374dd8d7b0dbd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 19:20:53 +0200 Subject: [PATCH 0035/1059] Arbitrary axis in filter --- phy/traces/filter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/traces/filter.py b/phy/traces/filter.py index 5f38546cb..cfa7ef16e 100644 --- a/phy/traces/filter.py +++ b/phy/traces/filter.py @@ -25,13 +25,13 @@ def bandpass_filter(rate=None, low=None, high=None, order=None): 'pass') -def apply_filter(x, filter=None): +def apply_filter(x, filter=None, axis=0): """Apply a filter to an array.""" x = _as_array(x) - if x.shape[0] == 0: + if x.shape[axis] == 0: return x b, a = filter - return signal.filtfilt(b, a, x, axis=0) + return signal.filtfilt(b, a, x, axis=axis) class Filter(object): @@ -63,8 +63,8 @@ def _filter_and_margin(**kwargs): b_filter = bandpass_filter(**kwargs) - def filter(x): - return apply_filter(x, b_filter) + def filter(x, axis=0): + return apply_filter(x, b_filter, axis=axis) filter_margin = kwargs['order'] * 3 From 8f17102b04d60a1816ae1bed52ded372ea57b748 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 20:30:46 +0200 Subject: [PATCH 0036/1059] Update waveform loader --- phy/traces/tests/test_waveform.py | 133 ++++++++++++++++++------------ phy/traces/waveform.py | 78 +++++++++++------- 2 files changed, 125 insertions(+), 86 deletions(-) diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index 7c2735046..f9155e45d 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -6,13 +6,18 @@ # Imports #------------------------------------------------------------------------------ +from itertools import product + import numpy as np from numpy.testing import assert_array_equal as ae import numpy.random as npr -from pytest import raises +from pytest import raises, yield_fixture -from ...io.mock import artificial_traces -from ..waveform import _slice, WaveformLoader, WaveformExtractor +from phy.io.mock import artificial_traces, artificial_spike_samples +from phy.utils import Bunch +from ..waveform import (_slice, WaveformLoader, WaveformExtractor, + SpikeLoader, + ) from ..filter import bandpass_filter, apply_filter @@ -112,47 +117,74 @@ def test_extract_simple(): #------------------------------------------------------------------------------ -# Tests loader +# Tests utility functions #------------------------------------------------------------------------------ def test_slice(): assert _slice(0, (20, 20)) == slice(0, 20, None) -def test_loader(): - n_samples_trace, n_channels = 10000, 100 - n_samples = 40 - n_spikes = n_samples_trace // (2 * n_samples) +#------------------------------------------------------------------------------ +# Tests loader +#------------------------------------------------------------------------------ + +@yield_fixture(params=[(None, None), (-1, 0)]) +def waveform_loader(request): + scale_factor, dc_offset = request.param + + n_samples_trace, n_channels = 1000, 5 + h = 10 + n_samples_waveforms = 2 * h + n_spikes = n_samples_trace // (2 * n_samples_waveforms) traces = artificial_traces(n_samples_trace, n_channels) - spike_samples = np.cumsum(npr.randint(low=0, high=2 * n_samples, - size=n_spikes)) + spike_samples = artificial_spike_samples(n_spikes, + max_isi=2 * n_samples_waveforms) with raises(ValueError): WaveformLoader(traces) - # Create a loader. - loader = WaveformLoader(traces, n_samples=n_samples) - assert id(loader.traces) == id(traces) - loader.traces = traces - - # Extract a waveform. - t = spike_samples[10] - waveform = loader._load_at(t) - assert waveform.shape == (n_samples, n_channels) - ae(waveform, traces[t - 20:t + 20, :]) + loader = WaveformLoader(traces=traces, + n_samples_waveforms=n_samples_waveforms, + scale_factor=scale_factor, + dc_offset=dc_offset, + ) + b = Bunch(loader=loader, + n_samples_waveforms=n_samples_waveforms, + n_spikes=n_spikes, + spike_samples=spike_samples, + ) + yield b + + +def test_loader_simple(waveform_loader): + b = waveform_loader + spike_samples = b.spike_samples + loader = b.loader + traces = loader.traces + dc_offset = loader.dc_offset or 0. + scale_factor = loader.scale_factor or 1 + n_samples_traces, n_channels = traces.shape + n_samples_waveforms = b.n_samples_waveforms + h = n_samples_waveforms // 2 + + def _transform(arr): + return (arr - dc_offset) * scale_factor waveforms = loader[spike_samples[10:20]] - assert waveforms.shape == (10, n_samples, n_channels) + assert waveforms.shape == (10, n_samples_waveforms, n_channels) t = spike_samples[15] w1 = waveforms[5, ...] - w2 = traces[t - 20:t + 20, :] + w2 = _transform(traces[t - h:t + h, :]) assert np.allclose(w1, w2) + sl = SpikeLoader(loader, spike_samples) + assert np.allclose(sl[15], w2) + def test_edges(): - n_samples_trace, n_channels = 1000, 10 - n_samples = 40 + n_samples_trace, n_channels = 100, 10 + n_samples_waveforms = 20 traces = artificial_traces(n_samples_trace, n_channels) @@ -165,7 +197,7 @@ def test_edges(): # Create a loader. loader = WaveformLoader(traces, - n_samples=n_samples, + n_samples_waveforms=n_samples_waveforms, filter=lambda x: apply_filter(x, b_filter), filter_margin=filter_margin) @@ -173,57 +205,50 @@ def test_edges(): with raises(ValueError): loader._load_at(200000) - assert loader._load_at(0).shape == (n_samples, n_channels) - assert loader._load_at(5).shape == (n_samples, n_channels) - assert loader._load_at(n_samples_trace - 5).shape == (n_samples, - n_channels) - assert loader._load_at(n_samples_trace - 1).shape == (n_samples, - n_channels) + ns = n_samples_waveforms + filter_margin + assert loader._load_at(0).shape == (ns, n_channels) + assert loader._load_at(5).shape == (ns, n_channels) + assert loader._load_at(n_samples_trace - 5).shape == (ns, n_channels) + assert loader._load_at(n_samples_trace - 1).shape == (ns, n_channels) def test_loader_channels(): - n_samples_trace, n_channels = 1000, 50 - n_samples = 40 + n_samples_trace, n_channels = 1000, 10 + n_samples_waveforms = 20 traces = artificial_traces(n_samples_trace, n_channels) # Create a loader. - loader = WaveformLoader(traces, n_samples=n_samples) + loader = WaveformLoader(traces, n_samples_waveforms=n_samples_waveforms) loader.traces = traces - channels = [10, 20, 30] + channels = [2, 5, 7] loader.channels = channels assert loader.channels == channels - assert loader[500].shape == (1, n_samples, 3) - assert loader[[500, 501, 600, 300]].shape == (4, n_samples, 3) + assert loader[500].shape == (1, n_samples_waveforms, 3) + assert loader[[500, 501, 600, 300]].shape == (4, n_samples_waveforms, 3) # Test edge effects. - assert loader[3].shape == (1, n_samples, 3) - assert loader[995].shape == (1, n_samples, 3) + assert loader[3].shape == (1, n_samples_waveforms, 3) + assert loader[995].shape == (1, n_samples_waveforms, 3) with raises(NotImplementedError): loader[500:510] def test_loader_filter(): - n_samples_trace, n_channels = 1000, 100 - n_samples = 40 - n_spikes = n_samples_trace // (2 * n_samples) - - traces = artificial_traces(n_samples_trace, n_channels) - spike_samples = np.cumsum(npr.randint(low=0, high=2 * n_samples, - size=n_spikes)) + traces = np.c_[np.arange(20), np.arange(20, 40)].astype(np.int32) + n_samples_trace, n_channels = traces.shape + h = 3 - # With filter. - def my_filter(x): + def my_filter(x, axis=0): return x * x loader = WaveformLoader(traces, - n_samples=(n_samples // 2, n_samples // 2), + n_samples_waveforms=(h, h), filter=my_filter, - filter_margin=5) + filter_margin=2) - t = spike_samples[5] - waveform_filtered = loader._load_at(t) + t = 10 + waveform_filtered = loader[t] traces_filtered = my_filter(traces) - traces_filtered[t - 20:t + 20, :] - assert np.allclose(waveform_filtered, traces_filtered[t - 20:t + 20, :]) + assert np.allclose(waveform_filtered, traces_filtered[t - h:t + h, :]) diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index db96e0329..73d9d4ab8 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -227,17 +227,17 @@ def __init__(self, offset=0, filter=None, filter_margin=0, - n_samples=None, + n_samples_waveforms=None, channels=None, scale_factor=None, dc_offset=None, + dtype=None, ): - # A (possibly memmapped) array-like structure with traces. if traces is not None: self.traces = traces else: self._traces = None - self.dtype = np.float32 + self.dtype = dtype or traces.dtype # Scale factor for the loaded waveforms. self._scale_factor = scale_factor self._dc_offset = dc_offset @@ -245,14 +245,14 @@ def __init__(self, self._offset = int(offset) # List of channels to use when loading the waveforms. self._channels = channels - # A filter function that takes a (n_samples, n_channels) array as - # input. + # A filter function that takes a (n_samples_waveforms, n_channels) + # array as input. self._filter = filter # Number of samples to return, can be an int or a # tuple (before, after). - if n_samples is None: - raise ValueError("'n_samples' must be specified.") - self.n_samples_before_after = _before_after(n_samples) + if n_samples_waveforms is None: + raise ValueError("'n_samples_waveforms' must be specified.") + self.n_samples_before_after = _before_after(n_samples_waveforms) self.n_samples_waveforms = sum(self.n_samples_before_after) # Number of additional samples to use for filtering. self._filter_margin = _before_after(filter_margin) @@ -260,6 +260,18 @@ def __init__(self, self._n_samples_extract = (self.n_samples_waveforms + sum(self._filter_margin)) + @property + def offset(self): + return self._offset + + @property + def dc_offset(self): + return self._dc_offset + + @property + def scale_factor(self): + return self._scale_factor + @property def traces(self): """Raw traces.""" @@ -306,30 +318,13 @@ def _load_at(self, time): elif slice_extract.stop >= ns - 1: extract = _pad(extract, self._n_samples_extract, 'right') - assert extract.shape[0] == self._n_samples_extract - - # Filter the waveforms. - # TODO: do the filtering in a vectorized way for higher performance. - if self._filter is not None: - waveforms = self._filter(extract) - else: - waveforms = extract - - # Remove the margin. - margin_before, margin_after = self._filter_margin - if margin_after > 0: - assert margin_before >= 0 - waveforms = waveforms[margin_before:-margin_after, :] - # Make a subselection with the specified channels. if self._channels is not None: - out = waveforms[..., self._channels] - else: - out = waveforms + extract = extract[..., self._channels] - assert out.shape == (self.n_samples_waveforms, - self.n_channels_waveforms) - return out + assert extract.shape == (self._n_samples_extract, + self.n_channels_waveforms) + return extract def __getitem__(self, item): """Load waveforms.""" @@ -338,27 +333,46 @@ def __getitem__(self, item): "implemented yet.") if not hasattr(item, '__len__'): item = [item] + # Ensure a list of time samples are being requested. spikes = _as_array(item) n_spikes = len(spikes) + # Initialize the array. # TODO: int16 - shape = (n_spikes, self.n_samples_waveforms, - self.n_channels_waveforms) + shape = (n_spikes, self._n_samples_extract, self.n_channels_waveforms) + # No traces: return null arrays. if self.n_samples_trace == 0: return np.zeros(shape, dtype=self.dtype) - waveforms = np.empty(shape, dtype=self.dtype) + waveforms = np.zeros(shape, dtype=self.dtype) + # Load all spikes. for i, time in enumerate(spikes): try: waveforms[i, ...] = self._load_at(time) except ValueError as e: logger.warn("Error while loading waveform: %s", str(e)) + + # Filter the waveforms. + if self._filter is not None: + waveforms = self._filter(waveforms, axis=1) + + # Remove the margin. + margin_before, margin_after = self._filter_margin + if margin_after > 0: + assert margin_before >= 0 + waveforms = waveforms[:, margin_before:-margin_after, :] + + # Transform. if self._dc_offset: waveforms -= self._dc_offset if self._scale_factor: waveforms *= self._scale_factor + + assert waveforms.shape == (n_spikes, self.n_samples_waveforms, + self.n_channels_waveforms) + return waveforms From b3b448f8438c5909ea0cc7fd53edf0ab0a23dc11 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 20:37:18 +0200 Subject: [PATCH 0037/1059] WIP: increase waveform coverage --- phy/traces/tests/test_waveform.py | 17 +++++++++++++---- phy/traces/waveform.py | 4 ++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index f9155e45d..fef65f25f 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -54,10 +54,9 @@ def test_extract_simple(): we = WaveformExtractor(extract_before=3, extract_after=5, - thresholds={'weak': weak, - 'strong': strong}, channels_per_group=cpg, ) + we.set_thresholds(weak=weak, strong=strong) # _component() comp = we._component(component, n_samples=ns) @@ -128,7 +127,7 @@ def test_slice(): # Tests loader #------------------------------------------------------------------------------ -@yield_fixture(params=[(None, None), (-1, 0)]) +@yield_fixture(params=[(None, None), (-1, 2)]) def waveform_loader(request): scale_factor, dc_offset = request.param @@ -157,17 +156,27 @@ def waveform_loader(request): yield b +def test_loader_edge_case(): + wl = WaveformLoader(n_samples_waveforms=3) + wl.traces = np.random.rand(0, 2) + wl[0] + + def test_loader_simple(waveform_loader): b = waveform_loader spike_samples = b.spike_samples loader = b.loader traces = loader.traces - dc_offset = loader.dc_offset or 0. + dc_offset = loader.dc_offset or 0 scale_factor = loader.scale_factor or 1 n_samples_traces, n_channels = traces.shape n_samples_waveforms = b.n_samples_waveforms h = n_samples_waveforms // 2 + assert loader.offset == 0 + assert loader.dc_offset in (dc_offset, None) + assert loader.scale_factor in (scale_factor, None) + def _transform(arr): return (arr - dc_offset) * scale_factor diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 73d9d4ab8..348ac75bb 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -237,7 +237,7 @@ def __init__(self, self.traces = traces else: self._traces = None - self.dtype = dtype or traces.dtype + self.dtype = dtype or (traces.dtype if traces is not None else None) # Scale factor for the loaded waveforms. self._scale_factor = scale_factor self._dc_offset = dc_offset @@ -351,7 +351,7 @@ def __getitem__(self, item): for i, time in enumerate(spikes): try: waveforms[i, ...] = self._load_at(time) - except ValueError as e: + except ValueError as e: # pragma: no cover logger.warn("Error while loading waveform: %s", str(e)) # Filter the waveforms. From 91240bb54546d18a28a56c54db6f9426e86fd47b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 20:54:22 +0200 Subject: [PATCH 0038/1059] Test get_padded() --- phy/traces/tests/test_waveform.py | 15 +++++++++++---- phy/traces/waveform.py | 6 +++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index fef65f25f..09819aad2 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -6,17 +6,14 @@ # Imports #------------------------------------------------------------------------------ -from itertools import product - import numpy as np from numpy.testing import assert_array_equal as ae -import numpy.random as npr from pytest import raises, yield_fixture from phy.io.mock import artificial_traces, artificial_spike_samples from phy.utils import Bunch from ..waveform import (_slice, WaveformLoader, WaveformExtractor, - SpikeLoader, + SpikeLoader, _get_padded, ) from ..filter import bandpass_filter, apply_filter @@ -115,6 +112,16 @@ def test_extract_simple(): ae(masks_f_o, [1., 0.5, 0.]) +def test_get_padded(): + arr = np.array([1, 2, 3])[:, np.newaxis] + + with raises(RuntimeError): + ae(_get_padded(arr, -2, 5).ravel(), [1, 2, 3, 0, 0]) + ae(_get_padded(arr, 1, 2).ravel(), [2]) + ae(_get_padded(arr, 0, 5).ravel(), [1, 2, 3, 0, 0]) + ae(_get_padded(arr, -2, 3).ravel(), [0, 0, 1, 2, 3]) + + #------------------------------------------------------------------------------ # Tests utility functions #------------------------------------------------------------------------------ diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 348ac75bb..d94c0d52b 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -27,7 +27,7 @@ def _get_padded(data, start, end): Assumes that either `start<0` or `end>len(data)` but not both. """ - if start < 0 and end >= data.shape[0]: + if start < 0 and end > data.shape[0]: raise RuntimeError() if start < 0: start_zeros = np.zeros((-start, data.shape[1]), @@ -67,7 +67,7 @@ def _component(self, component, data=None, n_samples=None): comp_s = component[:, 0] # shape: (component_size,) comp_ch = component[:, 1] # shape: (component_size,) channel = comp_ch[0] - if channel not in self._dep_channels: + if channel not in self._dep_channels: # pragma: no cover raise RuntimeError("Channel `{}` appears to be dead and should " "have been excluded from the threshold " "crossings.".format(channel)) @@ -152,7 +152,7 @@ def align(self, waveform, s_aligned): try: f = interp1d(old_s, waveform, bounds_error=True, kind='cubic', axis=0) - except ValueError: + except ValueError: # pragma: no cover logger.warn("Interpolation error at time %d", s) return waveform return f(new_s) From 9f445361c2542d84075b13f3beb5b8461584078d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 21:03:09 +0200 Subject: [PATCH 0039/1059] Remove function --- phy/traces/filter.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/phy/traces/filter.py b/phy/traces/filter.py index cfa7ef16e..cba7a7881 100644 --- a/phy/traces/filter.py +++ b/phy/traces/filter.py @@ -59,18 +59,6 @@ def __call__(self, data): return apply_filter(data, filter=self._filter) -def _filter_and_margin(**kwargs): - - b_filter = bandpass_filter(**kwargs) - - def filter(x, axis=0): - return apply_filter(x, b_filter, axis=axis) - - filter_margin = kwargs['order'] * 3 - - return filter, filter_margin - - #------------------------------------------------------------------------------ # Whitening #------------------------------------------------------------------------------ From 58320649e853353bd11132c289c1797c3d855591 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 21:07:33 +0200 Subject: [PATCH 0040/1059] Increase coverage in phy.stats --- phy/stats/__init__.py | 2 -- phy/stats/ccg.py | 19 ------------------- phy/stats/tests/test_ccg.py | 3 ++- 3 files changed, 2 insertions(+), 22 deletions(-) diff --git a/phy/stats/__init__.py b/phy/stats/__init__.py index 53deb63f3..6c9d646f2 100644 --- a/phy/stats/__init__.py +++ b/phy/stats/__init__.py @@ -2,5 +2,3 @@ # flake8: noqa """Statistics functions.""" - -from .ccg import pairwise_correlograms diff --git a/phy/stats/ccg.py b/phy/stats/ccg.py index 3121df132..28baed6df 100644 --- a/phy/stats/ccg.py +++ b/phy/stats/ccg.py @@ -171,22 +171,3 @@ def _symmetrize_correlograms(correlograms): sym = np.transpose(sym, (1, 0, 2)) return np.dstack((sym, correlograms)) - - -def pairwise_correlograms(spike_samples, - spike_clusters, - binsize=None, - winsize_bins=None, - ): - """Compute all pairwise correlograms in a set of neurons. - - TODO: improve interface and documentation. - - """ - ccgs = correlograms(spike_samples, - spike_clusters, - binsize=binsize, - winsize_bins=winsize_bins, - ) - ccgs = _symmetrize_correlograms(ccgs) - return ccgs diff --git a/phy/stats/tests/test_ccg.py b/phy/stats/tests/test_ccg.py index 3f7eb2348..e5eb6eeb2 100644 --- a/phy/stats/tests/test_ccg.py +++ b/phy/stats/tests/test_ccg.py @@ -78,7 +78,8 @@ def test_ccg_0(): c_expected[0, 1, 0] = 0 # This is a peculiarity of the algorithm. c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + binsize=binsize, winsize_bins=winsize_bins, + cluster_order=[0, 1]) ae(c, c_expected) From b902b1f05bbfc603cb8f2110f974a2d1e6dc5fb0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 21:49:36 +0200 Subject: [PATCH 0041/1059] Fixed bug in read_kwd() --- phy/io/tests/test_traces.py | 11 ++++++++--- phy/io/traces.py | 12 +++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/phy/io/tests/test_traces.py b/phy/io/tests/test_traces.py index bbc635f28..b20772863 100644 --- a/phy/io/tests/test_traces.py +++ b/phy/io/tests/test_traces.py @@ -11,9 +11,10 @@ import numpy as np from numpy.testing import assert_array_equal as ae from numpy.testing import assert_allclose as ac +from pytest import raises from ..h5 import open_h5 -from ..traces import read_dat, _dat_n_samples, read_kwd +from ..traces import read_dat, _dat_n_samples, read_kwd, read_ns5 from ..mock import artificial_traces @@ -52,6 +53,10 @@ def test_read_kwd(tempdir): arr[n_samples // 2:, ...].astype(np.float32)) with open_h5(path, 'r') as f: - data = read_kwd(f)[:] + data = read_kwd(f)[...] + ac(arr, data) - ac(arr, data) + +def test_read_ns5(): + with raises(NotImplementedError): + read_ns5('') diff --git a/phy/io/traces.py b/phy/io/traces.py index 990221834..24ddfa5e5 100644 --- a/phy/io/traces.py +++ b/phy/io/traces.py @@ -21,23 +21,25 @@ def read_kwd(kwd_handle): The output is a memory-mapped file. """ - import dask + import dask.array f = kwd_handle + if '/recordings' not in f: return recordings = f.children('/recordings') def _read(idx): + # The file needs to be open. + assert f.is_open() return f.read('/recordings/{}/data'.format(recordings[idx])) - dsk = {('data', idx): (_read, idx) for idx in range(len(recordings))} + dsk = {('data', idx, 0): (_read, idx) for idx in range(len(recordings))} chunks = (tuple(_read(idx).shape[0] for idx in range(len(recordings))), - tuple(_read(idx).shape[1] for idx in range(len(recordings))), + (_read(0).shape[1],) ) - - return dask.Array(dsk, 'data', chunks) + return dask.array.Array(dsk, 'data', chunks) def read_dat(filename, dtype=None, shape=None, offset=0, n_channels=None): From 5ac43a4b5fc7485cd5a9b1e252e2213c0a6eee14 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 21:57:41 +0200 Subject: [PATCH 0042/1059] WIP: increase coverage in mock module --- phy/io/tests/test_mock.py | 10 ++++++++++ phy/io/traces.py | 30 +++++++++--------------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/phy/io/tests/test_mock.py b/phy/io/tests/test_mock.py index 3476608ca..e53e18063 100644 --- a/phy/io/tests/test_mock.py +++ b/phy/io/tests/test_mock.py @@ -14,6 +14,8 @@ artificial_spike_clusters, artificial_features, artificial_masks, + artificial_spike_samples, + artificial_correlograms, ) @@ -55,6 +57,14 @@ def _test_artificial(n_spikes=None, n_clusters=None): masks = artificial_masks(n_spikes, n_channels) assert masks.shape == (n_spikes, n_channels) + # Spikes. + spikes = artificial_spike_samples(n_spikes) + assert spikes.shape == (n_spikes,) + + # CCG. + ccg = artificial_correlograms(n_clusters, 10) + assert ccg.shape == (n_clusters, n_clusters, 10) + def test_artificial(): _test_artificial(n_spikes=100, n_clusters=10) diff --git a/phy/io/traces.py b/phy/io/traces.py index 24ddfa5e5..7db9a239b 100644 --- a/phy/io/traces.py +++ b/phy/io/traces.py @@ -25,7 +25,7 @@ def read_kwd(kwd_handle): f = kwd_handle - if '/recordings' not in f: + if '/recordings' not in f: # pragma: no cover return recordings = f.children('/recordings') @@ -42,6 +42,14 @@ def _read(idx): return dask.array.Array(dsk, 'data', chunks) +def _dat_n_samples(filename, dtype=None, n_channels=None): + assert dtype is not None + item_size = np.dtype(dtype).itemsize + n_samples = op.getsize(filename) // (item_size * n_channels) + assert n_samples >= 0 + return n_samples + + def read_dat(filename, dtype=None, shape=None, offset=0, n_channels=None): """Read traces from a flat binary `.dat` file. @@ -73,26 +81,6 @@ def read_dat(filename, dtype=None, shape=None, offset=0, n_channels=None): mode='r', offset=offset) -def _dat_to_traces(dat_path, n_channels, dtype): - assert dtype is not None - assert n_channels is not None - n_samples = _dat_n_samples(dat_path, - n_channels=n_channels, - dtype=dtype, - ) - return read_dat(dat_path, - dtype=dtype, - shape=(n_samples, n_channels)) - - -def _dat_n_samples(filename, dtype=None, n_channels=None): - assert dtype is not None - item_size = np.dtype(dtype).itemsize - n_samples = op.getsize(filename) // (item_size * n_channels) - assert n_samples >= 0 - return n_samples - - def read_ns5(filename): # TODO raise NotImplementedError() From c291f9e97a2919a72d1297f7d3c16baf2801c746 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 22:03:11 +0200 Subject: [PATCH 0043/1059] Increase coverage in io --- phy/io/h5.py | 54 ++--------------------------------------- phy/io/tests/test_h5.py | 26 -------------------- 2 files changed, 2 insertions(+), 78 deletions(-) diff --git a/phy/io/h5.py b/phy/io/h5.py index c42d5d0f7..95b357dfc 100644 --- a/phy/io/h5.py +++ b/phy/io/h5.py @@ -75,8 +75,7 @@ def is_open(self): return self._h5py_file is not None def open(self, mode=None): - if mode is not None: - self.mode = mode + self.mode = mode if mode is not None else None if not self.is_open(): self._h5py_file = h5py.File(self.filename, self.mode) @@ -140,7 +139,7 @@ def write(self, path, array=None, dtype=None, shape=None, overwrite=False): # Copy and rename #-------------------------------------------------------------------------- - def _check_move_copy(self, path, new_path): + def _check_move_copy(self, path, new_path): # pragma: no cover if not self.exists(path): raise ValueError("'{0}' doesn't exist.".format(path)) if self.exists(new_path): @@ -167,55 +166,6 @@ def delete(self, path): parent = self.read(path) del parent[name] - # Attributes - #-------------------------------------------------------------------------- - - def read_attr(self, path, attr_name): - """Read an attribute of an HDF5 group.""" - _check_hdf5_path(self._h5py_file, path) - attrs = self._h5py_file[path].attrs - if attr_name in attrs: - try: - out = attrs[attr_name] - if isinstance(out, np.ndarray) and out.dtype.kind == 'S': - if len(out) == 1: - out = out[0].decode('UTF-8') - return out - except (TypeError, IOError): - logger.debug("Unable to read attribute `%s` at `%s`.", - attr_name, path) - return - else: - raise KeyError("The attribute '{0:s}' ".format(attr_name) + - "at `{}` doesn't exist.".format(path)) - - def write_attr(self, path, attr_name, value): - """Write an attribute of an HDF5 group.""" - assert isinstance(path, string_types) - assert isinstance(attr_name, string_types) - if value is None: - value = '' - # Ensure lists of strings are converted to ASCII arrays. - if isinstance(value, list): - if not value: - value = None - if value and isinstance(value[0], string_types): - value = np.array(value, dtype='S') - # Use string arrays instead of vlen arrays (crash in h5py 2.5.0 win64). - if isinstance(value, string_types): - value = np.array([value], dtype='S') - # Idem: fix crash with boolean attributes on win64. - if isinstance(value, bool): - value = int(value) - # If the parent group doesn't already exist, create it. - if path not in self._h5py_file: - self._h5py_file.create_group(path) - try: - self._h5py_file[path].attrs[attr_name] = value - except TypeError: - logger.warn("Unable to write attribute `%s=%s` at `%s`.", - attr_name, value, path) - def attrs(self, path='/'): """Return the list of attributes at the given path.""" if path in self._h5py_file: diff --git a/phy/io/tests/test_h5.py b/phy/io/tests/test_h5.py index 72038dd0e..7696790dd 100644 --- a/phy/io/tests/test_h5.py +++ b/phy/io/tests/test_h5.py @@ -105,8 +105,6 @@ def test_h5_read(tempdir): assert f.has_attr('/mygroup', 'myattr') assert not f.has_attr('/mygroup', 'myattr_bis') assert not f.has_attr('/mygroup_bis', 'myattr_bis') - value = f.read_attr('/mygroup', 'myattr') - assert value == 123 # Check that errors are raised when the paths are invalid. with raises(Exception): @@ -168,30 +166,6 @@ def test_h5_write(tempdir): f.write('/ds3/ds4/ds5', temp_array) ae(f.read('/ds3/ds4/ds5'), temp_array) - # Write an existing attribute. - f.write_attr('/ds1', 'myattr', 456) - - with raises(KeyError): - f.read_attr('/ds1', 'nonexistingattr') - - assert f.read_attr('/ds1', 'myattr') == 456 - - # Write a new attribute in a dataset. - f.write_attr('/ds1', 'mynewattr', 789) - assert f.read_attr('/ds1', 'mynewattr') == 789 - - # Write a new attribute in a group. - f.write_attr('/mygroup', 'mynewattr', 890) - assert f.read_attr('/mygroup', 'mynewattr') == 890 - - # Write a new attribute in a nonexisting group. - f.write_attr('/nonexistinggroup', 'mynewattr', 2) - assert f.read_attr('/nonexistinggroup', 'mynewattr') == 2 - - # Write a new attribute two levels into a nonexisting group. - f.write_attr('/nonexistinggroup2/group3', 'mynewattr', 2) - assert f.read_attr('/nonexistinggroup2/group3', 'mynewattr') == 2 - def test_h5_describe(tempdir): # Create the test HDF5 file in the temporary directory. From 526c264c7112bc65334551ef3f310de71f44109f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 22:04:43 +0200 Subject: [PATCH 0044/1059] Flakify --- phy/io/h5.py | 2 -- phy/utils/tests/test_settings.py | 2 +- phy/utils/tests/test_testing.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/phy/io/h5.py b/phy/io/h5.py index 95b357dfc..24835d0c6 100644 --- a/phy/io/h5.py +++ b/phy/io/h5.py @@ -8,9 +8,7 @@ import logging -import numpy as np import h5py -from six import string_types logger = logging.getLogger(__name__) diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 44f6bc351..38b13cb2e 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -8,7 +8,7 @@ import os.path as op -from pytest import raises, yield_fixture, mark +from pytest import raises, yield_fixture from ..settings import (BaseSettings, Settings, diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index d4175f0e3..00e642ad8 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -12,7 +12,6 @@ import numpy as np from pytest import mark -from six.moves import builtins from vispy.app import Canvas from ..testing import (benchmark, captured_output, show_test, From b2e3891254b8e7fc04fba7d6cf21d5f4a24f058c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Sep 2015 22:14:13 +0200 Subject: [PATCH 0045/1059] Increase coverage in phy.electrode --- phy/electrode/mea.py | 16 ++-------------- phy/electrode/tests/test_mea.py | 29 +++++++++++++++++------------ 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/phy/electrode/mea.py b/phy/electrode/mea.py index 714cf7322..1837bac5e 100644 --- a/phy/electrode/mea.py +++ b/phy/electrode/mea.py @@ -25,7 +25,7 @@ def _edges_to_adjacency_list(edges): """Convert a list of edges into an adjacency list.""" adj = {} for i, j in edges: - if i in adj: + if i in adj: # pragma: no cover ni = adj[i] else: ni = adj[i] = set() @@ -120,8 +120,7 @@ def __init__(self, ): self._probe = probe self._channels = channels - if positions is not None: - assert self.n_channels == positions.shape[0] + self._check_positions(positions) self._positions = positions # This is a mapping {channel: list of neighbors}. if adjacency is None and probe is not None: @@ -133,8 +132,6 @@ def _check_positions(self, positions): if positions is None: return positions = _as_array(positions) - if self.n_channels is None: - self.n_channels = positions.shape[0] if positions.shape[0] != self.n_channels: raise ValueError("'positions' " "(shape {0:s})".format(str(positions.shape)) + @@ -147,11 +144,6 @@ def positions(self): """Channel positions in the current channel group.""" return self._positions - @positions.setter - def positions(self, value): - self._check_positions(value) - self._positions = value - @property def channels(self): """Channel ids in the current channel group.""" @@ -167,10 +159,6 @@ def adjacency(self): """Adjacency graph in the current channel group.""" return self._adjacency - @adjacency.setter - def adjacency(self, value): - self._adjacency = value - def change_channel_group(self, group): """Change the current channel group.""" assert self._probe is not None diff --git a/phy/electrode/tests/test_mea.py b/phy/electrode/tests/test_mea.py index b974d6e3c..ab04246c8 100644 --- a/phy/electrode/tests/test_mea.py +++ b/phy/electrode/tests/test_mea.py @@ -6,11 +6,14 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op + from pytest import raises import numpy as np from numpy.testing import assert_array_equal as ae -from ..mea import (_probe_channels, _probe_positions, _probe_adjacency_list, +from ..mea import (_probe_channels, _probe_all_channels, + _probe_positions, _probe_adjacency_list, MEA, linear_positions, staggered_positions, load_probe, list_probes ) @@ -57,8 +60,7 @@ def test_mea(): channels = np.arange(n_channels) positions = np.random.randn(n_channels, 2) - mea = MEA(channels) - mea.positions = positions + mea = MEA(channels, positions=positions) ae(mea.positions, positions) assert mea.adjacency is None @@ -68,17 +70,11 @@ def test_mea(): mea = MEA(channels, positions=positions) assert mea.n_channels == n_channels - with raises(AssertionError): + with raises(ValueError): MEA(channels=np.arange(n_channels + 1), positions=positions) - with raises(AssertionError): - MEA(channels=channels, positions=positions[:-1, :]) - - mea = MEA(channels=channels) - assert mea.n_channels == n_channels - mea.positions = positions with raises(ValueError): - mea.positions = positions[:-1, :] + MEA(channels=channels, positions=positions[:-1, :]) def test_positions(): @@ -90,9 +86,18 @@ def test_positions(): assert probe.shape == (29, 2) -def test_library(): +def test_library(tempdir): probe = load_probe('1x32_buzsaki') assert probe assert probe['channel_groups'][0]['channels'] == list(range(32)) + assert _probe_all_channels(probe) == list(range(32)) assert '1x32_buzsaki' in list_probes() + + path = op.join(tempdir, 'test.prb') + with raises(IOError): + load_probe(path) + + with open(path, 'w') as f: + f.write('') + load_probe(path) From a615f8c5af4e99a3b79add9feb1e5260dd2f4327 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 09:30:25 +0200 Subject: [PATCH 0046/1059] WIP: increase coverage in phy.cluster.manual --- phy/cluster/manual/tests/test_clustering.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 59a9b34c2..3b4b9aaaf 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -351,9 +351,6 @@ def _checkpoint(index=None): def _assert_is_checkpoint(index): ae(clustering.spike_clusters, checkpoints[index]) - def _assert_spikes(spikes): - ae(info.spike_ids, spikes) - # Checkpoint 0. _checkpoint() _assert_is_checkpoint(0) From aa5d8db969d27d9ecda8052349759c3e3de4610a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 09:33:34 +0200 Subject: [PATCH 0047/1059] Increase coverage in phy.cluster.manual.clustering --- phy/cluster/manual/tests/test_clustering.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 3b4b9aaaf..e0ad61c31 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -114,6 +114,9 @@ def test_clustering_split(): # Instantiate a Clustering instance. clustering = Clustering(spike_clusters) ae(clustering.spike_clusters, spike_clusters) + n_spikes = len(spike_clusters) + assert clustering.n_spikes == n_spikes + ae(clustering.spike_ids, np.arange(n_spikes)) splits = [[0], [1], @@ -360,6 +363,11 @@ def _assert_is_checkpoint(index): my_spikes_3 = np.unique(np.random.randint(low=0, high=n_spikes, size=1000)) my_spikes_4 = np.arange(n_spikes - 5) + # Edge cases. + clustering.assign([]) + with raises(ValueError): + clustering.merge([], 1) + # Checkpoint 1. info = clustering.split(my_spikes_1) _checkpoint() From e0a9112eb9258283ea8d0fa58fd650cac4414da4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 09:35:12 +0200 Subject: [PATCH 0048/1059] WIP: increase coverage in phy.cluster.manual --- phy/cluster/manual/tests/test_history.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/cluster/manual/tests/test_history.py b/phy/cluster/manual/tests/test_history.py index 5fadc274d..4ccb04044 100644 --- a/phy/cluster/manual/tests/test_history.py +++ b/phy/cluster/manual/tests/test_history.py @@ -26,6 +26,9 @@ def _assert_current(item): item1 = np.ones(4) item2 = 2 * np.ones(5) + assert not history.is_first() + assert history.is_last() + history.add(item0) _assert_current(item0) From b879d2355677fb18b50c69a7585fa78e2dc8b2eb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 09:41:24 +0200 Subject: [PATCH 0049/1059] Increase coverage in phy.cluster.manual --- phy/cluster/manual/_utils.py | 4 ---- phy/cluster/manual/tests/test_utils.py | 16 +++++++++++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index c7af20a59..dccfdb85c 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -120,13 +120,9 @@ def _get_one(self, cluster, field): elif field in self._fields: # Call the default field function. return self._fields[field](cluster) - else: - return None else: if field in self._fields: return self._fields[field](cluster) - else: - return None def _get(self, clusters, field): if _is_list(clusters): diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 9c0f4e3bb..8cf6da8d1 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -8,7 +8,9 @@ import logging -from .._utils import ClusterMetadata, ClusterMetadataUpdater, UpdateInfo +from .._utils import (ClusterMetadata, ClusterMetadataUpdater, UpdateInfo, + _update_cluster_selection, + ) logger = logging.getLogger(__name__) @@ -32,6 +34,9 @@ def group(cluster): def color(cluster): return 0 + assert base_meta.group(2) == 2 + assert base_meta.group([4, 2]) == [5, 2] + meta = ClusterMetadataUpdater(base_meta) # Values set in 'data'. @@ -107,8 +112,17 @@ def color(cluster): assert info is None +def test_update_cluster_selection(): + clusters = [1, 2, 3] + up = UpdateInfo(deleted=[2], added=[4, 0]) + assert _update_cluster_selection(clusters, up) == [1, 3, 4, 0] + + def test_update_info(): + logger.debug(UpdateInfo()) + logger.debug(UpdateInfo(description='hello')) logger.debug(UpdateInfo(deleted=range(5), added=[5], description='merge')) logger.debug(UpdateInfo(deleted=range(5), added=[5], description='assign')) logger.debug(UpdateInfo(deleted=range(5), added=[5], description='assign', history='undo')) + logger.debug(UpdateInfo(metadata_changed=[2, 3], description='metadata')) From 608c2dde1c102e0c6e42f7f51a31d505107c43f9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 09:58:30 +0200 Subject: [PATCH 0050/1059] WIP: increase coverage in wizard --- phy/cluster/manual/tests/test_wizard.py | 48 ++++++++++++++++--------- phy/cluster/manual/wizard.py | 8 ++--- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 4a7943637..a0d81ea6b 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from pytest import raises +from pytest import raises, yield_fixture from ..wizard import (_previous, _next, @@ -18,6 +18,26 @@ # Test wizard #------------------------------------------------------------------------------ +@yield_fixture +def wizard(): + groups = {2: None, 3: None, 5: 'ignored', 7: 'good'} + wizard = Wizard(groups) + + @wizard.set_quality_function + def quality(cluster): + return {2: .2, + 3: .3, + 5: .5, + 7: .7, + }[cluster] + + @wizard.set_similarity_function + def similarity(cluster, other): + return 1. + quality(cluster) - quality(other) + + yield wizard + + def test_utils(): l = [2, 3, 5, 7, 11] @@ -80,22 +100,7 @@ def similarity(cluster, other): assert wizard.most_similar_clusters(n_max=1) == [5] -def test_wizard_nav(): - - groups = {2: None, 3: None, 5: 'ignored', 7: 'good'} - wizard = Wizard(groups) - - @wizard.set_quality_function - def quality(cluster): - return {2: .2, - 3: .3, - 5: .5, - 7: .7, - }[cluster] - - @wizard.set_similarity_function - def similarity(cluster, other): - return 1. + quality(cluster) - quality(other) +def test_wizard_nav(wizard): # Loop over the best clusters. wizard.start() @@ -150,6 +155,15 @@ def similarity(cluster, other): assert wizard.best == 3 assert wizard.match == 2 + wizard.first() + assert wizard.selection == (3, 2) + wizard.last() + assert wizard.selection == (3, 5) + wizard.unpin() assert wizard.best == 3 assert wizard.match is None + + +def test_wizard_update(wizard): + pass diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index fe30b54d5..40be0dc50 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -250,14 +250,14 @@ def selection(self): """Return the current best/match cluster selection.""" b, m = self.best, self.match if b is None: - return [] + return () elif m is None: - return [b] + return (b,) else: if b == m: - return [b] + return (b,) else: - return [b, m] + return (b, m) @match.setter def match(self, value): From 873d7a98fcea090ab01181351f594f03f57ce0d4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 10:19:22 +0200 Subject: [PATCH 0051/1059] WIP: improve wizard --- phy/cluster/manual/tests/test_wizard.py | 46 ++++++++++++++------- phy/cluster/manual/wizard.py | 55 ++++++++----------------- 2 files changed, 49 insertions(+), 52 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index a0d81ea6b..df467fcff 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -8,6 +8,7 @@ from pytest import raises, yield_fixture +from ..clustering import Clustering from ..wizard import (_previous, _next, Wizard, @@ -25,11 +26,7 @@ def wizard(): @wizard.set_quality_function def quality(cluster): - return {2: .2, - 3: .3, - 5: .5, - 7: .7, - }[cluster] + return cluster * .1 @wizard.set_similarity_function def similarity(cluster, other): @@ -44,10 +41,9 @@ def test_utils(): def func(x): return x in (2, 5) - with raises(RuntimeError): - _previous(l, 1, func) - with raises(RuntimeError): - _previous(l, 15, func) + # Error: log and do nothing. + _previous(l, 1, func) + _previous(l, 15, func) assert _previous(l, 2, func) == 2 assert _previous(l, 3, func) == 2 @@ -55,10 +51,9 @@ def func(x): assert _previous(l, 7, func) == 5 assert _previous(l, 11, func) == 5 - with raises(RuntimeError): - _next(l, 1, func) - with raises(RuntimeError): - _next(l, 15, func) + # Error: log and do nothing. + _next(l, 1, func) + _next(l, 15, func) assert _next(l, 2, func) == 5 assert _next(l, 3, func) == 5 @@ -104,6 +99,10 @@ def test_wizard_nav(wizard): # Loop over the best clusters. wizard.start() + + assert wizard.n_clusters == 4 + assert wizard.best_list == [3, 2, 7, 5] + assert wizard.best == 3 assert wizard.match is None @@ -138,6 +137,7 @@ def test_wizard_nav(wizard): wizard.pin() assert wizard.best == 3 assert wizard.match == 2 + assert wizard.match_list == [2, 7, 5] wizard.next() assert wizard.best == 3 @@ -164,6 +164,24 @@ def test_wizard_nav(wizard): assert wizard.best == 3 assert wizard.match is None + assert wizard.n_processed == 2 + def test_wizard_update(wizard): - pass + # 2: none, 3: none, 5: unknown, 7: good + wizard.start() + clustering = Clustering([2, 3, 5, 7]) + + assert wizard.best_list == [3, 2, 7, 5] + wizard.next() + wizard.pin() + assert wizard.selection == (2, 3) + + wizard.on_cluster(clustering.merge([2, 3])) + assert wizard.best_list == [8, 7, 5] + assert wizard.selection == (8, 7) + + wizard.on_cluster(clustering.undo()) + print(wizard.selection) + print(wizard.best_list) + print(wizard.match_list) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 40be0dc50..72daaa5a6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -6,10 +6,13 @@ # Imports #------------------------------------------------------------------------------ +import logging from operator import itemgetter from ...utils import _is_array_like +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Utility functions @@ -42,7 +45,8 @@ def _find_first(items, filter=None): def _previous(items, current, filter=None): if current not in items: - raise RuntimeError("{0} is not in {1}.".format(current, items)) + logger.debug("%d is not in %s.", current, items) + return i = items.index(current) if i == 0: return current @@ -56,7 +60,8 @@ def _next(items, current, filter=None): if not items: return current if current not in items: - raise RuntimeError("{0} is not in {1}.".format(current, items)) + logger.debug("%d is not in %s.", current, items) + return i = items.index(current) if i == len(items) - 1: return current @@ -310,9 +315,10 @@ def previous_best(self): """Select the previous best in cluster.""" if self._has_finished: return - self.best = _previous(self._best_list, - self._best, - ) + if self._best_list: + self.best = _previous(self._best_list, + self._best, + ) if self.match is not None: self._set_match_list() @@ -321,16 +327,17 @@ def next_match(self): # Handle the case where we arrive at the end of the match list. if self.match is not None and len(self._match_list) <= 1: self.next_best() - else: + elif self._match_list: self.match = _next(self._match_list, self._match, ) def previous_match(self): """Select the previous match.""" - self.match = _previous(self._match_list, - self._match, - ) + if self._match_list: + self.match = _previous(self._match_list, + self._match, + ) def next(self): """Next cluster proposition.""" @@ -444,38 +451,10 @@ def _update_state(self, up): self._delete(up.deleted) # Select the last added cluster. if self.best is not None and up.added: - self.best = up.added[-1] + self.pin(up.added[-1]) def on_cluster(self, up): if self._has_finished: return if self._best_list or self._match_list: self._update_state(up) - - # Panel - #-------------------------------------------------------------------------- - - @property - def _best_progress(self): - """Progress in the best clusters.""" - value = (self.best_list.index(self.best) - if self.best in self.best_list else 0) - maximum = len(self.best_list) - return _progress(value, maximum) - - @property - def _match_progress(self): - """Progress in the processed clusters.""" - value = self.n_processed - maximum = self.n_clusters - return _progress(value, maximum) - - def get_panel_params(self): - """Return the parameters for the HTML panel.""" - return dict(best=self.best if self.best is not None else '', - match=self.match if self.match is not None else '', - best_progress=self._best_progress, - match_progress=self._match_progress, - best_group=self._group(self.best) or 'unsorted', - match_group=self._group(self.match) or 'unsorted', - ) From aa5485f29f7e8f4a969fd8a4cf84d21c071e4181 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 10:50:04 +0200 Subject: [PATCH 0052/1059] WIP: improve wizard --- phy/cluster/manual/tests/test_wizard.py | 69 +++++++++++++++++--- phy/cluster/manual/wizard.py | 86 ++++++++++--------------- 2 files changed, 94 insertions(+), 61 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index df467fcff..fdb216b30 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -6,9 +6,10 @@ # Imports #------------------------------------------------------------------------------ -from pytest import raises, yield_fixture +from pytest import yield_fixture from ..clustering import Clustering +from .._utils import ClusterMetadata, ClusterMetadataUpdater from ..wizard import (_previous, _next, Wizard, @@ -16,13 +17,20 @@ #------------------------------------------------------------------------------ -# Test wizard +# Fixtures #------------------------------------------------------------------------------ @yield_fixture def wizard(): - groups = {2: None, 3: None, 5: 'ignored', 7: 'good'} - wizard = Wizard(groups) + + def get_cluster_ids(): + return [2, 3, 5, 7] + + wizard = Wizard(get_cluster_ids) + + @wizard.set_status_function + def cluster_status(cluster): + return {2: None, 3: None, 5: 'ignored', 7: 'good'}.get(cluster, None) @wizard.set_quality_function def quality(cluster): @@ -35,6 +43,33 @@ def similarity(cluster, other): yield wizard +@yield_fixture +def cluster_metadata(): + data = {2: {'group': 3}, + 3: {'group': 3}, + 5: {'group': 1}, + 7: {'group': 2}, + } + + base_meta = ClusterMetadata(data=data) + + @base_meta.default + def group(cluster): + return 3 + + meta = ClusterMetadataUpdater(base_meta) + yield meta + + +@yield_fixture +def clustering(): + yield Clustering([2, 3, 5, 7]) + + +#------------------------------------------------------------------------------ +# Test wizard +#------------------------------------------------------------------------------ + def test_utils(): l = [2, 3, 5, 7, 11] @@ -64,7 +99,7 @@ def func(x): def test_wizard_core(): - wizard = Wizard([2, 3, 5]) + wizard = Wizard(lambda: [2, 3, 5]) @wizard.set_quality_function def quality(cluster): @@ -167,21 +202,37 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 -def test_wizard_update(wizard): +def test_wizard_update(wizard, clustering, cluster_metadata): # 2: none, 3: none, 5: unknown, 7: good + + # The wizard gets the cluster ids from the Clustering instance + # and the status from ClusterMetadataUpdater. + wizard._get_cluster_ids = lambda: clustering.cluster_ids + + @wizard.set_status_function + def status(cluster): + group = cluster_metadata.group(cluster) + if group <= 1: + return 'ignored' + elif group == 2: + return 'good' + elif group == 3: + return None + raise NotImplementedError() # pragma: no cover + wizard.start() - clustering = Clustering([2, 3, 5, 7]) assert wizard.best_list == [3, 2, 7, 5] wizard.next() wizard.pin() assert wizard.selection == (2, 3) + print(wizard.selection) wizard.on_cluster(clustering.merge([2, 3])) assert wizard.best_list == [8, 7, 5] assert wizard.selection == (8, 7) wizard.on_cluster(clustering.undo()) print(wizard.selection) - print(wizard.best_list) - print(wizard.match_list) + # print(wizard.best_list) + # print(wizard.match_list) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 72daaa5a6..71e1f3f75 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -9,8 +9,6 @@ import logging from operator import itemgetter -from ...utils import _is_array_like - logger = logging.getLogger(__name__) @@ -83,15 +81,16 @@ def _progress(value, maximum): class Wizard(object): """Propose a selection of high-quality clusters and merge candidates.""" - def __init__(self, cluster_groups=None): - self.cluster_groups = cluster_groups + def __init__(self, get_cluster_ids): + self._get_cluster_ids = get_cluster_ids + self._similarity = None + self._quality = None + self._cluster_status = lambda cluster: None self.reset() def reset(self): self._best_list = [] # This list is fixed (modulo clustering actions). self._match_list = [] # This list may often change. - self._similarity = None - self._quality = None self._best = None self._match = None @@ -99,9 +98,19 @@ def reset(self): def has_started(self): return len(self._best_list) > 0 - # Quality functions + # Quality and status functions #-------------------------------------------------------------------------- + def set_status_function(self, func): + """Register a function returning the status of a cluster: None, + 'ignored', or 'good'. + + Can be used as a decorator. + + """ + self._cluster_status = func + return func + def set_similarity_function(self, func): """Register a function returning the similarity between two clusters. @@ -123,17 +132,14 @@ def set_quality_function(self, func): # Internal methods #-------------------------------------------------------------------------- - def _group(self, cluster): - return self._cluster_groups.get(cluster, None) - - def _in_groups(self, items, groups): + def _with_status(self, items, status): """Filter out ignored clusters or pairs of clusters.""" - if not isinstance(groups, (list, tuple)): - groups = [groups] - return [item for item in items if self._group(item) in groups] + if not isinstance(status, (list, tuple)): + status = [status] + return [item for item in items if self._cluster_status(item) in status] def _is_not_ignored(self, cluster): - return self._in_groups([cluster], (None, 'good')) + return self._with_status([cluster], (None, 'good')) def _check(self): clusters = set(self.cluster_ids) @@ -147,15 +153,15 @@ def _check(self): assert self.best != self.match def _sort(self, items, mix_good_unsorted=False): - """Sort clusters according to their groups: + """Sort clusters according to their status: unsorted, good, and ignored.""" if mix_good_unsorted: - return (self._in_groups(items, (None, 'good')) + - self._in_groups(items, 'ignored')) + return (self._with_status(items, (None, 'good')) + + self._with_status(items, 'ignored')) else: - return (self._in_groups(items, None) + - self._in_groups(items, 'good') + - self._in_groups(items, 'ignored')) + return (self._with_status(items, None) + + self._with_status(items, 'good') + + self._with_status(items, 'ignored')) # Properties #-------------------------------------------------------------------------- @@ -163,24 +169,7 @@ def _sort(self, items, mix_good_unsorted=False): @property def cluster_ids(self): """Array of cluster ids in the current clustering.""" - return sorted(self._cluster_groups) - - @property - def cluster_groups(self): - """Dictionary with the groups of each cluster. - - The groups are: `None` (corresponds to unsorted), `good`, or `ignored`. - - """ - return self._cluster_groups - - @cluster_groups.setter - def cluster_groups(self, cluster_groups): - # cluster_groups is a dictionary or is converted to one. - if _is_array_like(cluster_groups): - # A group can be None (unsorted), `good`, or `ignored`. - cluster_groups = {clu: None for clu in cluster_groups} - self._cluster_groups = cluster_groups + return sorted(self._get_cluster_ids()) # Core methods #-------------------------------------------------------------------------- @@ -284,10 +273,10 @@ def match_list(self): def n_processed(self): """Numbered of processed clusters so far. - A cluster is considered processed if its group is not `None`. + A cluster is considered processed if its status is not `None`. """ - return len(self._in_groups(self._best_list, ('good', 'ignored'))) + return len(self._with_status(self._best_list, ('good', 'ignored'))) @property def n_clusters(self): @@ -397,8 +386,6 @@ def unpin(self): def _delete(self, clusters): for clu in clusters: - if clu in self._cluster_groups: - del self._cluster_groups[clu] if clu in self._best_list: self._best_list.remove(clu) if clu in self._match_list: @@ -408,12 +395,10 @@ def _delete(self, clusters): if clu == self._match: self._match = None - def _add(self, clusters, group, position=None): + def _add(self, clusters, position=None): for clu in clusters: - assert clu not in self._cluster_groups assert clu not in self._best_list assert clu not in self._match_list - self._cluster_groups[clu] = group if self.best is not None: if position is not None: self._best_list.insert(position, clu) @@ -423,13 +408,11 @@ def _add(self, clusters, group, position=None): self._match_list.append(clu) def _update_state(self, up): - # Update the cluster group. + # Update the cluster status. if up.description == 'metadata_group': cluster = up.metadata_changed[0] - group = up.metadata_value - self._cluster_groups[cluster] = group # Reorder the best list, so that the clusters moved in different - # groups go to their right place in the best list. + # status go to their right place in the best list. if (self._best is not None and self._best_list and cluster == self._best): # Find the next best after the cluster has been moved. @@ -443,10 +426,9 @@ def _update_state(self, up): # Add the child at the parent's position. parents = [x for (x, y) in up.descendants if y == clu] parent = parents[0] - group = self._group(parent) position = (self._best_list.index(parent) if self._best_list else None) - self._add([clu], group, position) + self._add([clu], position) # Delete old clusters. self._delete(up.deleted) # Select the last added cluster. From 9e54dfce4d9d7cd6567f571000f762a622db18cf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 11:22:16 +0200 Subject: [PATCH 0053/1059] WIP: store extra object in undo stack --- phy/cluster/manual/_utils.py | 2 +- phy/cluster/manual/clustering.py | 26 ++++++++++++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index dccfdb85c..52e32ede2 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -48,7 +48,7 @@ def __init__(self, **kwargs): metadata_value=None, # new metadata value old_spikes_per_cluster={}, # only for the affected clusters new_spikes_per_cluster={}, # only for the affected clusters - selection=[], # clusters selected before the action + extra=None, # extra object ) d.update(kwargs) super(UpdateInfo, self).__init__(d) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index f7c261a38..7b029936c 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -154,7 +154,7 @@ class Clustering(object): """ def __init__(self, spike_clusters): - self._undo_stack = History(base_item=(None, None)) + self._undo_stack = History(base_item=(None, None, None)) # Spike -> cluster mapping. self._spike_clusters = _as_array(spike_clusters) self._n_spikes = len(self._spike_clusters) @@ -225,7 +225,7 @@ def spikes_in_clusters(self, clusters): # Actions #-------------------------------------------------------------------------- - def merge(self, cluster_ids, to=None): + def merge(self, cluster_ids, to=None, extra=None): """Merge several clusters to a new cluster. Parameters @@ -235,6 +235,8 @@ def merge(self, cluster_ids, to=None): List of clusters to merge. to : integer or None The id of the new cluster. By default, this is `new_cluster_id()`. + extra : object + An object to store in the undo stack. Returns ------- @@ -288,7 +290,7 @@ def merge(self, cluster_ids, to=None): self.spike_clusters[spike_ids] = to # Add to stack. - self._undo_stack.add((spike_ids, [to])) + self._undo_stack.add((spike_ids, [to], extra)) return up @@ -331,7 +333,7 @@ def _do_assign(self, spike_ids, new_spike_clusters): return up - def assign(self, spike_ids, spike_clusters_rel=0): + def assign(self, spike_ids, spike_clusters_rel=0, extra=None): """Make new spike cluster assignements. Parameters @@ -342,6 +344,8 @@ def assign(self, spike_ids, spike_clusters_rel=0): spike_clusters_rel : array-like Relative cluster ids of the spikes in `spike_ids`. This must have the same size as `spike_ids`. + extra : object + An object to store in the undo stack. Returns ------- @@ -401,7 +405,7 @@ def assign(self, spike_ids, spike_clusters_rel=0): up = self._do_assign(spike_ids, cluster_ids) # Add the assignement to the undo stack. - self._undo_stack.add((spike_ids, cluster_ids)) + self._undo_stack.add((spike_ids, cluster_ids, extra)) return up @@ -441,13 +445,13 @@ def undo(self): up : UpdateInfo instance of the changes done by this operation. """ - self._undo_stack.back() + _, _, extra = self._undo_stack.back() # Retrieve the initial spike_cluster structure. spike_clusters_new = self._spike_clusters_base.copy() # Loop over the history (except the last item because we undo). - for spike_ids, cluster_ids in self._undo_stack: + for spike_ids, cluster_ids, _ in self._undo_stack: # We update the spike clusters accordingly. if spike_ids is not None: spike_clusters_new[spike_ids] = cluster_ids @@ -460,6 +464,8 @@ def undo(self): up = self._do_assign(changed, clusters_changed) up.history = 'undo' + # Add the extra object from the undone object. + up.extra = extra return up def redo(self): @@ -477,7 +483,11 @@ def redo(self): # No redo has been performed: abort. return - spike_ids, cluster_ids = item + # NOTE: the extra object of the redone action may not be useful, + # for example when it represents data associated to the state + # *before* the action. What might be more useful would be the + # extra object of the next item in the list (if it exists). + spike_ids, cluster_ids, extra = item assert spike_ids is not None # We apply the new assignement. From 5cb2fe44d4566fd51634e6214f0aed168bb1547d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 11:32:16 +0200 Subject: [PATCH 0054/1059] Rename EventEmitter.reset() to _reset() --- phy/utils/event.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/phy/utils/event.py b/phy/utils/event.py index 04e799ac9..05f492b10 100644 --- a/phy/utils/event.py +++ b/phy/utils/event.py @@ -44,9 +44,9 @@ def on_my_event(arg, key=None): """ def __init__(self): - self.reset() + self._reset() - def reset(self): + def _reset(self): """Remove all registered callbacks.""" self._callbacks = defaultdict(list) @@ -243,7 +243,6 @@ def increment(self, **kwargs): def reset(self, value_max=None): """Reset the value to 0 and the value max to a given value.""" - super(ProgressReporter, self).reset() self._value = 0 if value_max is not None: self._value_max = value_max From 2108e16d074d08c79c09ad98dd1fd69051ce8d5b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 11:40:13 +0200 Subject: [PATCH 0055/1059] WIP: add undo_state in history and clustering --- phy/cluster/manual/_utils.py | 3 +- phy/cluster/manual/clustering.py | 53 ++++++++++++--------- phy/cluster/manual/tests/test_clustering.py | 5 +- 3 files changed, 37 insertions(+), 24 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 52e32ede2..fa284b24a 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -48,7 +48,8 @@ def __init__(self, **kwargs): metadata_value=None, # new metadata value old_spikes_per_cluster={}, # only for the affected clusters new_spikes_per_cluster={}, # only for the affected clusters - extra=None, # extra object + undo_state=None, # returned during an undo: it contains + # information about the undone action ) d.update(kwargs) super(UpdateInfo, self).__init__(d) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 7b029936c..d23e71885 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -15,6 +15,7 @@ ) from ._utils import UpdateInfo from ._history import History +from phy.utils.event import EventEmitter #------------------------------------------------------------------------------ @@ -98,7 +99,7 @@ def _assign_update_info(spike_ids, return update_info -class Clustering(object): +class Clustering(EventEmitter): """Handle cluster changes in a set of spikes. Features @@ -154,6 +155,7 @@ class Clustering(object): """ def __init__(self, spike_clusters): + super(Clustering, self).__init__() self._undo_stack = History(base_item=(None, None, None)) # Spike -> cluster mapping. self._spike_clusters = _as_array(spike_clusters) @@ -225,7 +227,7 @@ def spikes_in_clusters(self, clusters): # Actions #-------------------------------------------------------------------------- - def merge(self, cluster_ids, to=None, extra=None): + def merge(self, cluster_ids, to=None, undo_state=None): """Merge several clusters to a new cluster. Parameters @@ -235,8 +237,9 @@ def merge(self, cluster_ids, to=None, extra=None): List of clusters to merge. to : integer or None The id of the new cluster. By default, this is `new_cluster_id()`. - extra : object - An object to store in the undo stack. + undo_state : object + An object to store in the undo stack with this action. It will be + returned when that action is undone. Returns ------- @@ -290,7 +293,9 @@ def merge(self, cluster_ids, to=None, extra=None): self.spike_clusters[spike_ids] = to # Add to stack. - self._undo_stack.add((spike_ids, [to], extra)) + self._undo_stack.add((spike_ids, [to], undo_state)) + + self.emit('cluster', up) return up @@ -331,9 +336,11 @@ def _do_assign(self, spike_ids, new_spike_clusters): # We make the assignements. self._spike_clusters[spike_ids] = new_spike_clusters + self.emit('cluster', up) + return up - def assign(self, spike_ids, spike_clusters_rel=0, extra=None): + def assign(self, spike_ids, spike_clusters_rel=0, undo_state=None): """Make new spike cluster assignements. Parameters @@ -344,8 +351,9 @@ def assign(self, spike_ids, spike_clusters_rel=0, extra=None): spike_clusters_rel : array-like Relative cluster ids of the spikes in `spike_ids`. This must have the same size as `spike_ids`. - extra : object - An object to store in the undo stack. + undo_state : object + An object to store in the undo stack with this action. It will be + returned when that action is undone. Returns ------- @@ -405,11 +413,11 @@ def assign(self, spike_ids, spike_clusters_rel=0, extra=None): up = self._do_assign(spike_ids, cluster_ids) # Add the assignement to the undo stack. - self._undo_stack.add((spike_ids, cluster_ids, extra)) + self._undo_stack.add((spike_ids, cluster_ids, undo_state)) return up - def split(self, spike_ids): + def split(self, spike_ids, undo_state=None): """Split a number of spikes into a new cluster. This is equivalent to an `assign()` to a single new cluster. @@ -419,6 +427,9 @@ def split(self, spike_ids): spike_ids : array-like Array of spike ids to merge. + undo_state : object + An object to store in the undo stack with this action. It will be + returned when that action is undone. Returns ------- @@ -434,7 +445,7 @@ def split(self, spike_ids): """ # self.assign() accepts relative numbers as second argument. - return self.assign(spike_ids, 0) + return self.assign(spike_ids, 0, undo_state=undo_state) def undo(self): """Undo the last cluster assignement operation. @@ -445,7 +456,7 @@ def undo(self): up : UpdateInfo instance of the changes done by this operation. """ - _, _, extra = self._undo_stack.back() + _, _, undo_state = self._undo_stack.back() # Retrieve the initial spike_cluster structure. spike_clusters_new = self._spike_clusters_base.copy() @@ -461,11 +472,10 @@ def undo(self): spike_clusters_new)[0] clusters_changed = spike_clusters_new[changed] - up = self._do_assign(changed, - clusters_changed) + up = self._do_assign(changed, clusters_changed) up.history = 'undo' - # Add the extra object from the undone object. - up.extra = extra + # Add the undo_state object from the undone object. + up.undo_state = undo_state return up def redo(self): @@ -483,15 +493,14 @@ def redo(self): # No redo has been performed: abort. return - # NOTE: the extra object of the redone action may not be useful, - # for example when it represents data associated to the state + # NOTE: the undo_state object is only returned when undoing. + # It represents data associated to the state # *before* the action. What might be more useful would be the - # extra object of the next item in the list (if it exists). - spike_ids, cluster_ids, extra = item + # undo_state object of the next item in the list (if it exists). + spike_ids, cluster_ids, undo_state = item assert spike_ids is not None # We apply the new assignement. - up = self._do_assign(spike_ids, - cluster_ids) + up = self._do_assign(spike_ids, cluster_ids) up.history = 'redo' return up diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index e0ad61c31..1829949bc 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -258,12 +258,13 @@ def _assert_spikes(clusters): _assert_is_checkpoint(1) # Checkpoint 2. - info = clustering.merge([2, 3], 12) + info = clustering.merge([2, 3], 12, undo_state='hello') _checkpoint() _assert_spikes([12]) assert info.added == [12] assert info.deleted == [2, 3] assert info.history is None + assert info.undo_state is None # undo_state is only returned when undoing. _assert_is_checkpoint(2) # Undo once. @@ -271,6 +272,7 @@ def _assert_spikes(clusters): assert info.added == [2, 3] assert info.deleted == [12] assert info.history == 'undo' + assert info.undo_state == 'hello' _assert_is_checkpoint(1) # Redo. @@ -279,6 +281,7 @@ def _assert_spikes(clusters): assert info.added == [12] assert info.deleted == [2, 3] assert info.history == 'redo' + assert info.undo_state is None _assert_is_checkpoint(2) # No redo. From 664c74ac0c06499c52465c6d20e44a328115a59f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 11:57:44 +0200 Subject: [PATCH 0056/1059] WIP: undo_state --- phy/cluster/manual/clustering.py | 106 +++++++++----------- phy/cluster/manual/tests/test_clustering.py | 14 ++- phy/cluster/manual/tests/test_wizard.py | 13 ++- 3 files changed, 73 insertions(+), 60 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index d23e71885..5d6169cb1 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -227,7 +227,45 @@ def spikes_in_clusters(self, clusters): # Actions #-------------------------------------------------------------------------- - def merge(self, cluster_ids, to=None, undo_state=None): + def _update_all_spikes_per_cluster(self): + self._spikes_per_cluster = _spikes_per_cluster(self._spike_ids, + self._spike_clusters) + + def _do_assign(self, spike_ids, new_spike_clusters): + """Make spike-cluster assignements after the spike selection has + been extended to full clusters.""" + + # Ensure spike_clusters has the right shape. + spike_ids = _as_array(spike_ids) + if len(new_spike_clusters) == 1 and len(spike_ids) > 1: + new_spike_clusters = (np.ones(len(spike_ids), dtype=np.int64) * + new_spike_clusters[0]) + old_spike_clusters = self._spike_clusters[spike_ids] + + assert len(spike_ids) == len(old_spike_clusters) + assert len(new_spike_clusters) == len(spike_ids) + + # Update the spikes per cluster structure. + clusters = _unique(old_spike_clusters) + old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] + for cluster in clusters} + new_spikes_per_cluster = _spikes_per_cluster(spike_ids, + new_spike_clusters) + self._spikes_per_cluster.update(new_spikes_per_cluster) + # All old clusters are deleted. + for cluster in clusters: + del self._spikes_per_cluster[cluster] + + # We return the UpdateInfo structure. + up = _assign_update_info(spike_ids, + old_spike_clusters, old_spikes_per_cluster, + new_spike_clusters, new_spikes_per_cluster) + + # We make the assignements. + self._spike_clusters[spike_ids] = new_spike_clusters + return up + + def merge(self, cluster_ids, to=None): """Merge several clusters to a new cluster. Parameters @@ -237,9 +275,6 @@ def merge(self, cluster_ids, to=None, undo_state=None): List of clusters to merge. to : integer or None The id of the new cluster. By default, this is `new_cluster_id()`. - undo_state : object - An object to store in the undo stack with this action. It will be - returned when that action is undone. Returns ------- @@ -283,6 +318,7 @@ def merge(self, cluster_ids, to=None, undo_state=None): old_spikes_per_cluster=old_spc, new_spikes_per_cluster=new_spc, ) + undo_state = self.emit('request_undo_state', up) # Update the spikes_per_cluster structure directly. self._spikes_per_cluster[to] = spike_ids @@ -296,51 +332,9 @@ def merge(self, cluster_ids, to=None, undo_state=None): self._undo_stack.add((spike_ids, [to], undo_state)) self.emit('cluster', up) - return up - def _update_all_spikes_per_cluster(self): - self._spikes_per_cluster = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) - - def _do_assign(self, spike_ids, new_spike_clusters): - """Make spike-cluster assignements after the spike selection has - been extended to full clusters.""" - - # Ensure spike_clusters has the right shape. - spike_ids = _as_array(spike_ids) - if len(new_spike_clusters) == 1 and len(spike_ids) > 1: - new_spike_clusters = (np.ones(len(spike_ids), dtype=np.int64) * - new_spike_clusters[0]) - old_spike_clusters = self._spike_clusters[spike_ids] - - assert len(spike_ids) == len(old_spike_clusters) - assert len(new_spike_clusters) == len(spike_ids) - - # Update the spikes per cluster structure. - clusters = _unique(old_spike_clusters) - old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] - for cluster in clusters} - new_spikes_per_cluster = _spikes_per_cluster(spike_ids, - new_spike_clusters) - self._spikes_per_cluster.update(new_spikes_per_cluster) - # All old clusters are deleted. - for cluster in clusters: - del self._spikes_per_cluster[cluster] - - # We return the UpdateInfo structure. - up = _assign_update_info(spike_ids, - old_spike_clusters, old_spikes_per_cluster, - new_spike_clusters, new_spikes_per_cluster) - - # We make the assignements. - self._spike_clusters[spike_ids] = new_spike_clusters - - self.emit('cluster', up) - - return up - - def assign(self, spike_ids, spike_clusters_rel=0, undo_state=None): + def assign(self, spike_ids, spike_clusters_rel=0): """Make new spike cluster assignements. Parameters @@ -351,9 +345,6 @@ def assign(self, spike_ids, spike_clusters_rel=0, undo_state=None): spike_clusters_rel : array-like Relative cluster ids of the spikes in `spike_ids`. This must have the same size as `spike_ids`. - undo_state : object - An object to store in the undo stack with this action. It will be - returned when that action is undone. Returns ------- @@ -411,13 +402,15 @@ def assign(self, spike_ids, spike_clusters_rel=0, undo_state=None): spike_clusters_rel) up = self._do_assign(spike_ids, cluster_ids) + undo_state = self.emit('request_undo_state', up) # Add the assignement to the undo stack. self._undo_stack.add((spike_ids, cluster_ids, undo_state)) + self.emit('cluster', up) return up - def split(self, spike_ids, undo_state=None): + def split(self, spike_ids): """Split a number of spikes into a new cluster. This is equivalent to an `assign()` to a single new cluster. @@ -427,9 +420,6 @@ def split(self, spike_ids, undo_state=None): spike_ids : array-like Array of spike ids to merge. - undo_state : object - An object to store in the undo stack with this action. It will be - returned when that action is undone. Returns ------- @@ -445,7 +435,7 @@ def split(self, spike_ids, undo_state=None): """ # self.assign() accepts relative numbers as second argument. - return self.assign(spike_ids, 0, undo_state=undo_state) + return self.assign(spike_ids, 0) def undo(self): """Undo the last cluster assignement operation. @@ -476,6 +466,8 @@ def undo(self): up.history = 'undo' # Add the undo_state object from the undone object. up.undo_state = undo_state + + self.emit('cluster', up) return up def redo(self): @@ -500,7 +492,9 @@ def redo(self): spike_ids, cluster_ids, undo_state = item assert spike_ids is not None - # We apply the new assignement. + # We apply the new assignment. up = self._do_assign(spike_ids, cluster_ids) up.history = 'redo' + + self.emit('cluster', up) return up diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 1829949bc..abfa99ea6 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -245,6 +245,10 @@ def _assert_is_checkpoint(index): def _assert_spikes(clusters): ae(info.spike_ids, _spikes_in_clusters(spike_clusters, clusters)) + @clustering.connect + def on_request_undo_state(up): + return 'hello' + # Checkpoint 0. _checkpoint() _assert_is_checkpoint(0) @@ -258,7 +262,7 @@ def _assert_spikes(clusters): _assert_is_checkpoint(1) # Checkpoint 2. - info = clustering.merge([2, 3], 12, undo_state='hello') + info = clustering.merge([2, 3], 12) _checkpoint() _assert_spikes([12]) assert info.added == [12] @@ -272,7 +276,7 @@ def _assert_spikes(clusters): assert info.added == [2, 3] assert info.deleted == [12] assert info.history == 'undo' - assert info.undo_state == 'hello' + assert info.undo_state == ['hello'] _assert_is_checkpoint(1) # Redo. @@ -357,6 +361,10 @@ def _checkpoint(index=None): def _assert_is_checkpoint(index): ae(clustering.spike_clusters, checkpoints[index]) + @clustering.connect + def on_request_undo_state(up): + return 'hello' + # Checkpoint 0. _checkpoint() _assert_is_checkpoint(0) @@ -387,12 +395,14 @@ def _assert_is_checkpoint(index): # Checkpoint 3. info = clustering.assign(my_spikes_3) assert info.history is None + assert info.undo_state is None _checkpoint() _assert_is_checkpoint(3) # Undo checkpoint 3. info = clustering.undo() assert info.history == 'undo' + assert info.undo_state == ['hello'] _checkpoint() _assert_is_checkpoint(2) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index fdb216b30..d52efd89a 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -220,6 +220,14 @@ def status(cluster): return None raise NotImplementedError() # pragma: no cover + @clustering.connect + def on_request_undo_state(up): + return {'selection': wizard.selection} + + @clustering.connect + def on_cluster(up): + wizard.on_cluster(up) + wizard.start() assert wizard.best_list == [3, 2, 7, 5] @@ -228,11 +236,12 @@ def status(cluster): assert wizard.selection == (2, 3) print(wizard.selection) - wizard.on_cluster(clustering.merge([2, 3])) + # Save the selection before the merge in the undo stack. + clustering.merge([2, 3]) assert wizard.best_list == [8, 7, 5] assert wizard.selection == (8, 7) - wizard.on_cluster(clustering.undo()) + clustering.undo() print(wizard.selection) # print(wizard.best_list) # print(wizard.match_list) From a3dbeeed99b15e4a566dc96897145b31ba736738 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:24:18 +0200 Subject: [PATCH 0057/1059] WIP: improve wizard --- phy/cluster/manual/tests/test_wizard.py | 38 ++------- phy/cluster/manual/wizard.py | 107 ++++++++++++++++++------ 2 files changed, 89 insertions(+), 56 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index d52efd89a..229ec9fea 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -26,7 +26,8 @@ def wizard(): def get_cluster_ids(): return [2, 3, 5, 7] - wizard = Wizard(get_cluster_ids) + wizard = Wizard() + wizard.set_cluster_ids_function(get_cluster_ids) @wizard.set_status_function def cluster_status(cluster): @@ -99,7 +100,8 @@ def func(x): def test_wizard_core(): - wizard = Wizard(lambda: [2, 3, 5]) + wizard = Wizard() + wizard.set_cluster_ids_function(lambda: [2, 3, 5]) @wizard.set_quality_function def quality(cluster): @@ -191,9 +193,9 @@ def test_wizard_nav(wizard): assert wizard.match == 2 wizard.first() - assert wizard.selection == (3, 2) + assert wizard.selection == [3, 2] wizard.last() - assert wizard.selection == (3, 5) + assert wizard.selection == [3, 5] wizard.unpin() assert wizard.best == 3 @@ -204,42 +206,20 @@ def test_wizard_nav(wizard): def test_wizard_update(wizard, clustering, cluster_metadata): # 2: none, 3: none, 5: unknown, 7: good - - # The wizard gets the cluster ids from the Clustering instance - # and the status from ClusterMetadataUpdater. - wizard._get_cluster_ids = lambda: clustering.cluster_ids - - @wizard.set_status_function - def status(cluster): - group = cluster_metadata.group(cluster) - if group <= 1: - return 'ignored' - elif group == 2: - return 'good' - elif group == 3: - return None - raise NotImplementedError() # pragma: no cover - - @clustering.connect - def on_request_undo_state(up): - return {'selection': wizard.selection} - - @clustering.connect - def on_cluster(up): - wizard.on_cluster(up) + wizard.attach(clustering, cluster_metadata) wizard.start() assert wizard.best_list == [3, 2, 7, 5] wizard.next() wizard.pin() - assert wizard.selection == (2, 3) + assert wizard.selection == [2, 3] print(wizard.selection) # Save the selection before the merge in the undo stack. clustering.merge([2, 3]) assert wizard.best_list == [8, 7, 5] - assert wizard.selection == (8, 7) + assert wizard.selection == [8, 7] clustering.undo() print(wizard.selection) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 71e1f3f75..70335295f 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -81,14 +81,15 @@ def _progress(value, maximum): class Wizard(object): """Propose a selection of high-quality clusters and merge candidates.""" - def __init__(self, get_cluster_ids): - self._get_cluster_ids = get_cluster_ids + def __init__(self): self._similarity = None self._quality = None + self._get_cluster_ids = None self._cluster_status = lambda cluster: None self.reset() def reset(self): + self._selection = [] self._best_list = [] # This list is fixed (modulo clustering actions). self._match_list = [] # This list may often change. self._best = None @@ -101,6 +102,10 @@ def has_started(self): # Quality and status functions #-------------------------------------------------------------------------- + def set_cluster_ids_function(self, func): + """Register a function giving the list of cluster ids.""" + self._get_cluster_ids = func + def set_status_function(self, func): """Register a function returning the status of a cluster: None, 'ignored', or 'good'. @@ -233,31 +238,39 @@ def best(self): def best(self, value): assert value in self._best_list self._best = value + self._selection = [value] @property def match(self): """Currently-selected closest match.""" return self._match - @property - def selection(self): - """Return the current best/match cluster selection.""" - b, m = self.best, self.match - if b is None: - return () - elif m is None: - return (b,) - else: - if b == m: - return (b,) - else: - return (b, m) - @match.setter def match(self, value): if value is not None: assert value in self._match_list self._match = value + if len(self._selection) == 1: + self._selection += [value] + elif len(self._selection) == 2: + self._selection[1] = value + + @property + def selection(self): + """Return the current best/match cluster selection.""" + return self._selection + + @selection.setter + def selection(self, value): + """Return the current best/match cluster selection.""" + assert isinstance(value, (tuple, list)) + clusters = self.cluster_ids + value = [cluster for cluster in value if cluster in clusters] + self._selection = value + if len(self._selection) >= 1: + self._best = self._selection[0] + if len(self._selection) >= 2: + self._match = self._selection[1] @property def best_list(self): @@ -415,12 +428,12 @@ def _update_state(self, up): # status go to their right place in the best list. if (self._best is not None and self._best_list and cluster == self._best): - # Find the next best after the cluster has been moved. - next_best = _next(self._best_list, self._best) + # # Find the next best after the cluster has been moved. + # next_best = _next(self._best_list, self._best) # Reorder the list. self._best_list = self._sort(self._best_list) - # Select the next best. - self._best = next_best + # # Select the next best. + # self._best = next_best # Update the wizard with new and old clusters. for clu in up.added: # Add the child at the parent's position. @@ -431,12 +444,52 @@ def _update_state(self, up): self._add([clu], position) # Delete old clusters. self._delete(up.deleted) - # Select the last added cluster. - if self.best is not None and up.added: - self.pin(up.added[-1]) + # # Select the last added cluster. + # if self.best is not None and up.added: + # self._best = up.added[-1] - def on_cluster(self, up): - if self._has_finished: + def _select_after_update(self, up): + if up.history == 'undo': + self.selection = up.undo_state[0]['selection'] return - if self._best_list or self._match_list: - self._update_state(up) + # Make as few updates as possible in the views after clustering + # actions. This allows for better before/after comparisons. + if up.added: + self.selection = up.added + elif up.description == 'metadata_group': + cluster = up.metadata_changed[0] + if cluster == self.best: + self.next_best() + elif cluster == self.match: + self.next_match() + + def attach(self, clustering, cluster_metadata): + # TODO: might be better in an independent function in another module + + # The wizard gets the cluster ids from the Clustering instance + # and the status from ClusterMetadataUpdater. + self.set_cluster_ids_function(lambda: clustering.cluster_ids) + + @self.set_status_function + def status(cluster): + group = cluster_metadata.group(cluster) + if group <= 1: + return 'ignored' + elif group == 2: + return 'good' + elif group == 3: + return None + raise NotImplementedError() # pragma: no cover + + @clustering.connect + def on_request_undo_state(up): + return {'selection': self.selection} + + @clustering.connect + def on_cluster(up): + if self._has_finished: + return + if self._best_list or self._match_list: + self._update_state(up) + if self._best is not None or self._match is not None: + self._select_after_update(up) From ec3cb648529e296b829c5904e5de96dfcd400092 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:28:41 +0200 Subject: [PATCH 0058/1059] cluster.manual tests pass --- phy/cluster/manual/tests/test_wizard.py | 9 +++++++-- phy/cluster/manual/wizard.py | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 229ec9fea..468edfd40 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -213,15 +213,20 @@ def test_wizard_update(wizard, clustering, cluster_metadata): assert wizard.best_list == [3, 2, 7, 5] wizard.next() wizard.pin() + assert wizard.selection == [2, 3] + assert wizard.best == 2 + assert wizard.match == 3 - print(wizard.selection) # Save the selection before the merge in the undo stack. clustering.merge([2, 3]) assert wizard.best_list == [8, 7, 5] assert wizard.selection == [8, 7] + # Undo. clustering.undo() - print(wizard.selection) + assert wizard.selection == [2, 3] + assert wizard.best == 2 + assert wizard.match == 3 # print(wizard.best_list) # print(wizard.match_list) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 70335295f..31b53e44e 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -456,7 +456,9 @@ def _select_after_update(self, up): # actions. This allows for better before/after comparisons. if up.added: self.selection = up.added - elif up.description == 'metadata_group': + if up.description == 'merge': + self.pin(up.added[0]) + if up.description == 'metadata_group': cluster = up.metadata_changed[0] if cluster == self.best: self.next_best() From 5a8ff72c1d2fc94cdaeb7a68de95cbbde24585ba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:36:02 +0200 Subject: [PATCH 0059/1059] WIP: improve wizard --- phy/cluster/manual/tests/test_wizard.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 468edfd40..4216ee0ae 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -210,23 +210,29 @@ def test_wizard_update(wizard, clustering, cluster_metadata): wizard.start() + def _check_best_match(b, m): + assert wizard.selection == [b, m] + assert wizard.best == b + assert wizard.match == m + assert wizard.best_list == [3, 2, 7, 5] wizard.next() wizard.pin() - assert wizard.selection == [2, 3] - assert wizard.best == 2 - assert wizard.match == 3 + _check_best_match(2, 3) # Save the selection before the merge in the undo stack. clustering.merge([2, 3]) assert wizard.best_list == [8, 7, 5] - assert wizard.selection == [8, 7] + _check_best_match(8, 7) # Undo. clustering.undo() - assert wizard.selection == [2, 3] - assert wizard.best == 2 - assert wizard.match == 3 - # print(wizard.best_list) - # print(wizard.match_list) + _check_best_match(2, 3) + + wizard.selection = [1, 5, 7, 8] + _check_best_match(5, 7) + + # Redo. + clustering.redo() + # _check_best_match(8, 7) From 7d0bf3819560584d2e8f920ac5d2adcc2ee66803 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:40:23 +0200 Subject: [PATCH 0060/1059] Fix assignment typo --- phy/cluster/algorithms/klustakwik.py | 2 +- phy/cluster/manual/clustering.py | 24 ++++++++++----------- phy/cluster/manual/tests/test_clustering.py | 10 ++++----- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/phy/cluster/algorithms/klustakwik.py b/phy/cluster/algorithms/klustakwik.py index 4233e8e12..c61a13897 100644 --- a/phy/cluster/algorithms/klustakwik.py +++ b/phy/cluster/algorithms/klustakwik.py @@ -98,7 +98,7 @@ def cluster(self, """Run the clustering algorithm on the model, or on any features and masks. - Return the `spike_clusters` assignements. + Return the `spike_clusters` assignments. Emit the `iter` event at every KlustaKwik iteration. diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 5d6169cb1..cf88b37e2 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -26,7 +26,7 @@ def _extend_spikes(spike_ids, spike_clusters): """Return all spikes belonging to the clusters containing the specified spikes.""" # We find the spikes belonging to modified clusters. - # What are the old clusters that are modified by the assignement? + # What are the old clusters that are modified by the assignment? old_spike_clusters = spike_clusters[spike_ids] unique_clusters = _unique(old_spike_clusters) # Now we take all spikes from these clusters. @@ -47,7 +47,7 @@ def _concatenate_spike_clusters(*pairs): return concat[:, 0].astype(np.int64), concat[:, 1].astype(np.int64) -def _extend_assignement(spike_ids, old_spike_clusters, spike_clusters_rel): +def _extend_assignment(spike_ids, old_spike_clusters, spike_clusters_rel): # 1. Add spikes that belong to modified clusters. # 2. Find new cluster ids for all changed clusters. @@ -163,7 +163,7 @@ def __init__(self, spike_clusters): self._spike_ids = np.arange(self._n_spikes).astype(np.int64) # Create the spikes per cluster structure. self._update_all_spikes_per_cluster() - # Keep a copy of the original spike clusters assignement. + # Keep a copy of the original spike clusters assignment. self._spike_clusters_base = self._spike_clusters.copy() def reset(self): @@ -232,7 +232,7 @@ def _update_all_spikes_per_cluster(self): self._spike_clusters) def _do_assign(self, spike_ids, new_spike_clusters): - """Make spike-cluster assignements after the spike selection has + """Make spike-cluster assignments after the spike selection has been extended to full clusters.""" # Ensure spike_clusters has the right shape. @@ -261,7 +261,7 @@ def _do_assign(self, spike_ids, new_spike_clusters): old_spike_clusters, old_spikes_per_cluster, new_spike_clusters, new_spikes_per_cluster) - # We make the assignements. + # We make the assignments. self._spike_clusters[spike_ids] = new_spike_clusters return up @@ -335,7 +335,7 @@ def merge(self, cluster_ids, to=None): return up def assign(self, spike_ids, spike_clusters_rel=0): - """Make new spike cluster assignements. + """Make new spike cluster assignments. Parameters ---------- @@ -392,19 +392,19 @@ def assign(self, spike_ids, spike_clusters_rel=0): assert spike_ids.min() >= 0 assert spike_ids.max() < self._n_spikes - # Normalize the spike-cluster assignement such that + # Normalize the spike-cluster assignment such that # there are only new or dead clusters, not modified clusters. # This implies that spikes not explicitely selected, but that # belong to clusters affected by the operation, will be assigned # to brand new clusters. - spike_ids, cluster_ids = _extend_assignement(spike_ids, + spike_ids, cluster_ids = _extend_assignment(spike_ids, self._spike_clusters, spike_clusters_rel) up = self._do_assign(spike_ids, cluster_ids) undo_state = self.emit('request_undo_state', up) - # Add the assignement to the undo stack. + # Add the assignment to the undo stack. self._undo_stack.add((spike_ids, cluster_ids, undo_state)) self.emit('cluster', up) @@ -438,7 +438,7 @@ def split(self, spike_ids): return self.assign(spike_ids, 0) def undo(self): - """Undo the last cluster assignement operation. + """Undo the last cluster assignment operation. Returns ------- @@ -471,7 +471,7 @@ def undo(self): return up def redo(self): - """Redo the last cluster assignement operation. + """Redo the last cluster assignment operation. Returns ------- @@ -479,7 +479,7 @@ def redo(self): up : UpdateInfo instance of the changes done by this operation. """ - # Go forward in the stack, and retrieve the new assignement. + # Go forward in the stack, and retrieve the new assignment. item = self._undo_stack.forward() if item is None: # No redo has been performed: abort. diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index abfa99ea6..e96bcb6e6 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -17,12 +17,12 @@ ) from ..clustering import (_extend_spikes, _concatenate_spike_clusters, - _extend_assignement, + _extend_assignment, Clustering) #------------------------------------------------------------------------------ -# Test assignements +# Test assignments #------------------------------------------------------------------------------ def test_extend_spikes_simple(): @@ -72,7 +72,7 @@ def test_concatenate_spike_clusters(): ae(clusters, np.arange(0, 60 + 1, 10)) -def test_extend_assignement(): +def test_extend_assignment(): spike_clusters = np.array([3, 5, 2, 9, 5, 5, 2]) spike_ids = np.array([0, 2]) @@ -84,7 +84,7 @@ def test_extend_assignement(): # This should not depend on the index chosen. for to in (123, 0, 1, 2, 3): clusters_rel = [123] * len(spike_ids) - new_spike_ids, new_cluster_ids = _extend_assignement(spike_ids, + new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, spike_clusters, clusters_rel) ae(new_spike_ids, [0, 2, 6]) @@ -92,7 +92,7 @@ def test_extend_assignement(): # Second case: we assign the spikes to different clusters. clusters_rel = [0, 1] - new_spike_ids, new_cluster_ids = _extend_assignement(spike_ids, + new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, spike_clusters, clusters_rel) ae(new_spike_ids, [0, 2, 6]) From 5e5fdb0dc9fc6c45cee58a386e3b7e1116ffe55c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:46:09 +0200 Subject: [PATCH 0061/1059] Refactor merge in Clustering --- phy/cluster/manual/clustering.py | 45 ++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index cf88b37e2..234686c4d 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -265,6 +265,30 @@ def _do_assign(self, spike_ids, new_spike_clusters): self._spike_clusters[spike_ids] = new_spike_clusters return up + def _do_merge(self, spike_ids, cluster_ids, to): + + # Create the UpdateInfo instance here. + descendants = [(cluster, to) for cluster in cluster_ids] + old_spc = {k: self._spikes_per_cluster[k] for k in cluster_ids} + new_spc = {to: spike_ids} + up = UpdateInfo(description='merge', + spike_ids=spike_ids, + added=[to], + deleted=cluster_ids, + descendants=descendants, + old_spikes_per_cluster=old_spc, + new_spikes_per_cluster=new_spc, + ) + + # Update the spikes_per_cluster structure directly. + self._spikes_per_cluster[to] = spike_ids + for cluster in cluster_ids: + del self._spikes_per_cluster[cluster] + + # Assign the clusters. + self.spike_clusters[spike_ids] = to + return up + def merge(self, cluster_ids, to=None): """Merge several clusters to a new cluster. @@ -306,28 +330,9 @@ def merge(self, cluster_ids, to=None): # Find all spikes in the specified clusters. spike_ids = _spikes_in_clusters(self.spike_clusters, cluster_ids) - # Create the UpdateInfo instance here. - descendants = [(cluster, to) for cluster in cluster_ids] - old_spc = {k: self._spikes_per_cluster[k] for k in cluster_ids} - new_spc = {to: spike_ids} - up = UpdateInfo(description='merge', - spike_ids=spike_ids, - added=[to], - deleted=cluster_ids, - descendants=descendants, - old_spikes_per_cluster=old_spc, - new_spikes_per_cluster=new_spc, - ) + up = self._do_merge(spike_ids, cluster_ids, to) undo_state = self.emit('request_undo_state', up) - # Update the spikes_per_cluster structure directly. - self._spikes_per_cluster[to] = spike_ids - for cluster in cluster_ids: - del self._spikes_per_cluster[cluster] - - # Assign the clusters. - self.spike_clusters[spike_ids] = to - # Add to stack. self._undo_stack.add((spike_ids, [to], undo_state)) From f7591821ee6726ef598e919816e9b015ac69510b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:47:48 +0200 Subject: [PATCH 0062/1059] WIP: refactor --- phy/cluster/manual/clustering.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 234686c4d..a010d4409 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -246,14 +246,14 @@ def _do_assign(self, spike_ids, new_spike_clusters): assert len(new_spike_clusters) == len(spike_ids) # Update the spikes per cluster structure. - clusters = _unique(old_spike_clusters) + old_clusters = _unique(old_spike_clusters) old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] - for cluster in clusters} + for cluster in old_clusters} new_spikes_per_cluster = _spikes_per_cluster(spike_ids, new_spike_clusters) self._spikes_per_cluster.update(new_spikes_per_cluster) # All old clusters are deleted. - for cluster in clusters: + for cluster in old_clusters: del self._spikes_per_cluster[cluster] # We return the UpdateInfo structure. @@ -403,8 +403,8 @@ def assign(self, spike_ids, spike_clusters_rel=0): # belong to clusters affected by the operation, will be assigned # to brand new clusters. spike_ids, cluster_ids = _extend_assignment(spike_ids, - self._spike_clusters, - spike_clusters_rel) + self._spike_clusters, + spike_clusters_rel) up = self._do_assign(spike_ids, cluster_ids) undo_state = self.emit('request_undo_state', up) From f58d23773932658e36ee5f47daeb8b173b6f324a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:52:43 +0200 Subject: [PATCH 0063/1059] Shortcut to merge in assign --- phy/cluster/manual/clustering.py | 11 ++++++++++- phy/cluster/manual/tests/test_clustering.py | 5 +++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index a010d4409..c3fc4fb60 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -247,6 +247,15 @@ def _do_assign(self, spike_ids, new_spike_clusters): # Update the spikes per cluster structure. old_clusters = _unique(old_spike_clusters) + + # NOTE: shortcut to a merge if this assignment is effectively a merge + # i.e. if all spikes are assigned to a single cluster. + # The fact that spike selection has been previously extended to + # whole clusters is critical here. + new_clusters = _unique(new_spike_clusters) + if len(new_clusters) == 1: + return self._do_merge(spike_ids, old_clusters, new_clusters[0]) + old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] for cluster in old_clusters} new_spikes_per_cluster = _spikes_per_cluster(spike_ids, @@ -274,7 +283,7 @@ def _do_merge(self, spike_ids, cluster_ids, to): up = UpdateInfo(description='merge', spike_ids=spike_ids, added=[to], - deleted=cluster_ids, + deleted=list(cluster_ids), descendants=descendants, old_spikes_per_cluster=old_spc, new_spikes_per_cluster=new_spc, diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index e96bcb6e6..0b4be635a 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -382,18 +382,21 @@ def on_request_undo_state(up): # Checkpoint 1. info = clustering.split(my_spikes_1) _checkpoint() + assert info.description == 'assign' assert 10 in info.added assert info.history is None _assert_is_checkpoint(1) # Checkpoint 2. info = clustering.split(my_spikes_2) + assert info.description == 'assign' assert info.history is None _checkpoint() _assert_is_checkpoint(2) # Checkpoint 3. info = clustering.assign(my_spikes_3) + assert info.description == 'assign' assert info.history is None assert info.undo_state is None _checkpoint() @@ -401,6 +404,7 @@ def on_request_undo_state(up): # Undo checkpoint 3. info = clustering.undo() + assert info.description == 'assign' assert info.history == 'undo' assert info.undo_state == ['hello'] _checkpoint() @@ -408,6 +412,7 @@ def on_request_undo_state(up): # Checkpoint 4. info = clustering.assign(my_spikes_4) + assert info.description == 'assign' assert info.history is None _checkpoint(4) assert len(info.deleted) >= 2 From 77ec1bf8990a4319b4ff5ff010316857d2bf3911 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 13:57:26 +0200 Subject: [PATCH 0064/1059] Fix --- phy/cluster/manual/tests/test_wizard.py | 2 +- phy/cluster/manual/wizard.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 4216ee0ae..fce8afbc3 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -235,4 +235,4 @@ def _check_best_match(b, m): # Redo. clustering.redo() - # _check_best_match(8, 7) + _check_best_match(8, 7) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 31b53e44e..2b60b1266 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -382,8 +382,7 @@ def pin(self, cluster=None): return if cluster is None: cluster = self.best - if self.match is not None and self.best == cluster: - return + logger.debug("Pin %d.", cluster) self.best = cluster self._set_match_list(cluster) self._check() @@ -391,6 +390,7 @@ def pin(self, cluster=None): def unpin(self): """Unpin the current cluster.""" if self.match is not None: + logger.debug("Unpin.") self.match = None self._match_list = [] From 4c0fae6798957765bab0ea02f7b58e930e1101a9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 14:26:33 +0200 Subject: [PATCH 0065/1059] More wizard tests --- phy/cluster/manual/tests/test_wizard.py | 49 ++++++++++++++++++++++++- phy/cluster/manual/wizard.py | 3 ++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index fce8afbc3..af685b7b4 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -7,6 +7,7 @@ #------------------------------------------------------------------------------ from pytest import yield_fixture +from numpy.testing import assert_array_equal as ae from ..clustering import Clustering from .._utils import ClusterMetadata, ClusterMetadataUpdater @@ -221,18 +222,62 @@ def _check_best_match(b, m): _check_best_match(2, 3) + ################################ + # Save the selection before the merge in the undo stack. clustering.merge([2, 3]) assert wizard.best_list == [8, 7, 5] _check_best_match(8, 7) - # Undo. + # Undo merge. clustering.undo() _check_best_match(2, 3) wizard.selection = [1, 5, 7, 8] _check_best_match(5, 7) - # Redo. + # Redo merge. clustering.redo() _check_best_match(8, 7) + + ################################ + + # Split. + ae(clustering.spike_clusters, [8, 8, 5, 7]) + clustering.split([1, 2]) + ae(clustering.spike_clusters, [10, 9, 9, 7]) + _check_best_match(9, 10) + + # Undo split. + up = clustering.undo() + _check_best_match(8, 7) + assert up.description == 'assign' + assert up.history == 'undo' + + # Redo split. + up = clustering.redo() + _check_best_match(9, 10) + assert up.description == 'assign' + assert up.history == 'redo' + + ################################ + + # Split (=merge). + ae(clustering.spike_clusters, [10, 9, 9, 7]) + up = clustering.split([1, 2]) + ae(clustering.spike_clusters, [10, 11, 11, 7]) + _check_best_match(11, 7) + assert up.description == 'merge' + assert up.history is None + + # Undo split (=merge). + up = clustering.undo() + _check_best_match(9, 10) + assert up.description == 'merge' + assert up.history == 'undo' + + # Redo split (=merge). + up = clustering.redo() + _check_best_match(11, 7) + assert up.description == 'merge' + assert up.history == 'redo' diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 2b60b1266..0080658cd 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -179,6 +179,9 @@ def cluster_ids(self): # Core methods #-------------------------------------------------------------------------- + def cluster_status(self, cluster): + return self._cluster_status(cluster) + def best_clusters(self, n_max=None, quality=None): """Return the list of best clusters sorted by decreasing quality. From ae1ce6c8b71d85b1c9d75d7569f8e1d00f52ca01 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 14:30:59 +0200 Subject: [PATCH 0066/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index af685b7b4..422231cd0 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -206,7 +206,7 @@ def test_wizard_nav(wizard): def test_wizard_update(wizard, clustering, cluster_metadata): - # 2: none, 3: none, 5: unknown, 7: good + # 2: none, 3: none, 5: ignored, 7: good wizard.attach(clustering, cluster_metadata) wizard.start() @@ -224,21 +224,29 @@ def _check_best_match(b, m): ################################ - # Save the selection before the merge in the undo stack. + assert wizard.cluster_status(2) is None + assert wizard.cluster_status(3) is None clustering.merge([2, 3]) - assert wizard.best_list == [8, 7, 5] _check_best_match(8, 7) + assert wizard.best_list == [8, 7, 5] + assert wizard.cluster_status(8) is None + assert wizard.cluster_status(7) == 'good' # Undo merge. clustering.undo() _check_best_match(2, 3) + assert wizard.cluster_status(2) is None + assert wizard.cluster_status(3) is None + # Make a selection. wizard.selection = [1, 5, 7, 8] _check_best_match(5, 7) # Redo merge. clustering.redo() _check_best_match(8, 7) + assert wizard.cluster_status(8) is None + assert wizard.cluster_status(7) == 'good' ################################ @@ -247,6 +255,8 @@ def _check_best_match(b, m): clustering.split([1, 2]) ae(clustering.spike_clusters, [10, 9, 9, 7]) _check_best_match(9, 10) + assert wizard.cluster_status(10) is None + assert wizard.cluster_status(9) is None # Undo split. up = clustering.undo() @@ -269,6 +279,7 @@ def _check_best_match(b, m): _check_best_match(11, 7) assert up.description == 'merge' assert up.history is None + assert wizard.cluster_status(11) is None # Undo split (=merge). up = clustering.undo() From cb6d978635159c2fcdfab02cee909f4a97644143 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 14:41:37 +0200 Subject: [PATCH 0067/1059] ClusterMetadataUpdater can now emit events --- phy/cluster/manual/_utils.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index fa284b24a..7dc28035e 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -10,7 +10,7 @@ from collections import defaultdict from ._history import History -from phy.utils import Bunch, _as_list, _is_list +from phy.utils import Bunch, _as_list, _is_list, EventEmitter #------------------------------------------------------------------------------ @@ -155,9 +155,10 @@ def default(self, func): return func -class ClusterMetadataUpdater(object): +class ClusterMetadataUpdater(EventEmitter): """Handle cluster metadata changes.""" def __init__(self, cluster_metadata): + super(ClusterMetadataUpdater, self).__init__() self._cluster_metadata = cluster_metadata # Keep a deep copy of the original structure for the undo stack. self._data_base = deepcopy(cluster_metadata.data) @@ -183,13 +184,15 @@ def f(clusters, value): def _set(self, clusters, field, value, add_to_stack=True): self._cluster_metadata._set(clusters, field, value) clusters = _as_list(clusters) - info = UpdateInfo(description='metadata_' + field, - metadata_changed=clusters, - metadata_value=value, - ) + up = UpdateInfo(description='metadata_' + field, + metadata_changed=clusters, + metadata_value=value, + ) if add_to_stack: - self._undo_stack.add((clusters, field, value, info)) - return info + self._undo_stack.add((clusters, field, value, up)) + + self.emit('cluster', up) + return up def undo(self): """Undo the last metadata change. From 084f61c2870c33a8be8d7e09061b146446941f57 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 14:42:09 +0200 Subject: [PATCH 0068/1059] Check groups updates in wizard --- phy/cluster/manual/tests/test_wizard.py | 10 ++++++++-- phy/cluster/manual/wizard.py | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 422231cd0..eabee6823 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -221,10 +221,12 @@ def _check_best_match(b, m): wizard.pin() _check_best_match(2, 3) + cluster_metadata.set_group(2, 2) + wizard.selection = [2, 3] ################################ - assert wizard.cluster_status(2) is None + assert wizard.cluster_status(2) == 'good' assert wizard.cluster_status(3) is None clustering.merge([2, 3]) _check_best_match(8, 7) @@ -235,7 +237,7 @@ def _check_best_match(b, m): # Undo merge. clustering.undo() _check_best_match(2, 3) - assert wizard.cluster_status(2) is None + assert wizard.cluster_status(2) == 'good' assert wizard.cluster_status(3) is None # Make a selection. @@ -258,6 +260,10 @@ def _check_best_match(b, m): assert wizard.cluster_status(10) is None assert wizard.cluster_status(9) is None + # Ignore a cluster. + cluster_metadata.set_group(9, 1) + assert wizard.cluster_status(9) == 'ignored' + # Undo split. up = clustering.undo() _check_best_match(8, 7) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 0080658cd..8fa8162a6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -490,7 +490,6 @@ def status(cluster): def on_request_undo_state(up): return {'selection': self.selection} - @clustering.connect def on_cluster(up): if self._has_finished: return @@ -498,3 +497,6 @@ def on_cluster(up): self._update_state(up) if self._best is not None or self._match is not None: self._select_after_update(up) + + clustering.connect(on_cluster) + cluster_metadata.connect(on_cluster) From a2c6a1a12d846bfa96aaaf5f687e5d25dafa489e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 14:56:35 +0200 Subject: [PATCH 0069/1059] Add set_from_descendants() method --- phy/cluster/manual/_utils.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 7dc28035e..465a71cd0 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -112,6 +112,10 @@ def __init__(self, data=None): def data(self): return self._data + @property + def fields(self): + return sorted(self._fields) + def _get_one(self, cluster, field): """Return the field value for a cluster, or the default value if it doesn't exist.""" @@ -154,6 +158,17 @@ def default(self, func): lambda clusters, value: self._set(clusters, field, value)) return func + def set_from_descendants(self, descendants): + """Update metadata of some clusters given the metadata of their + ascendants.""" + fields = self.fields + for old, new in descendants: + for field in fields: + old_val = self._data[old].get(field, None) + new_val = self._data[new].get(field, None) + if old_val is not None and new_val is None: + self._set_one(new, field, new_val) + class ClusterMetadataUpdater(EventEmitter): """Handle cluster metadata changes.""" @@ -194,6 +209,11 @@ def _set(self, clusters, field, value, add_to_stack=True): self.emit('cluster', up) return up + def set_from_descendants(self, descendants): + """Update metadata of some clusters given the metadata of their + ascendants.""" + self._cluster_metadata.set_from_descendants(descendants) + def undo(self): """Undo the last metadata change. From 6273c46cd1c009e4ed2a06ea887c8e7b797d6e1e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 15:18:31 +0200 Subject: [PATCH 0070/1059] Test set_from_descendants in ClusterMetadata --- phy/cluster/manual/_utils.py | 17 +++++++---- phy/cluster/manual/tests/test_utils.py | 39 ++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 465a71cd0..70d04a656 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -162,11 +162,18 @@ def set_from_descendants(self, descendants): """Update metadata of some clusters given the metadata of their ascendants.""" fields = self.fields - for old, new in descendants: - for field in fields: - old_val = self._data[old].get(field, None) - new_val = self._data[new].get(field, None) - if old_val is not None and new_val is None: + for field in fields: + # For each new cluster, a set of metadata values of their + # ascendants. + candidates = defaultdict(set) + for old, new in descendants: + candidates[new].add(old) + for new, vals in candidates.items(): + # Ask the field the value of the new cluster, + # as a function of the values of its ascendants. This is + # encoded in the specified default function. + new_val = self._fields[field](new, list(vals)) + if new_val is not None: self._set_one(new, field, new_val) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 8cf6da8d1..f37c78b8b 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -112,6 +112,45 @@ def color(cluster): assert info is None +def test_metadata_descendants(): + """Test ClusterMetadataUpdater history.""" + + data = {0: {'group': 0}, + 1: {'group': 1}, + 2: {'group': 2}, + 3: {'group': 3}, + } + + meta = ClusterMetadata(data=data) + + @meta.default + def group(cluster, ascendant_values=None): + if not ascendant_values: + return 3 + s = list(set(ascendant_values) - set([None, 3])) + # Return the default value if all ascendant values are the default. + if not s: + return 3 + # Otherwise, return good (2) if it is present, or the largest value + # among those present. + return max(s) + + meta.set_from_descendants([]) + assert meta.group(4) == 3 + + meta.set_from_descendants([(0, 4)]) + assert meta.group(4) == 0 + + meta.set_from_descendants([(1, 4)]) + assert meta.group(4) == 1 + + meta.set_from_descendants([(1, 5), (2, 5)]) + assert meta.group(5) == 2 + + meta.set_from_descendants([(2, 6), (3, 6), (10, 10)]) + assert meta.group(6) == 2 + + def test_update_cluster_selection(): clusters = [1, 2, 3] up = UpdateInfo(deleted=[2], added=[4, 0]) From f136ba55ee52b2cffde5fed5b554be1fc1c158a1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 15:25:49 +0200 Subject: [PATCH 0071/1059] Fix set_from_descendants --- phy/cluster/manual/_utils.py | 7 +++++++ phy/cluster/manual/tests/test_utils.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 70d04a656..70ac1c92c 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -169,6 +169,13 @@ def set_from_descendants(self, descendants): for old, new in descendants: candidates[new].add(old) for new, vals in candidates.items(): + + # Skip that new cluster if its value is already non-default. + current_val = self._get_one(new, field) + default_val = self._fields[field](new) + if current_val != default_val: + continue + # Ask the field the value of the new cluster, # as a function of the values of its ascendants. This is # encoded in the specified default function. diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index f37c78b8b..ac9fb8c23 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -141,6 +141,8 @@ def group(cluster, ascendant_values=None): meta.set_from_descendants([(0, 4)]) assert meta.group(4) == 0 + # Reset to default. + meta.set_group(4, 3) meta.set_from_descendants([(1, 4)]) assert meta.group(4) == 1 @@ -150,6 +152,11 @@ def group(cluster, ascendant_values=None): meta.set_from_descendants([(2, 6), (3, 6), (10, 10)]) assert meta.group(6) == 2 + # If the value of the new cluster is non-default, it should not + # be changed by set_from_descendants. + meta.set_from_descendants([(3, 2)]) + assert meta.group(2) == 2 + def test_update_cluster_selection(): clusters = [1, 2, 3] From 348db15fad4b732f40a81e5653b806e3a637d152 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 15:26:28 +0200 Subject: [PATCH 0072/1059] Update wizard tests --- phy/cluster/manual/tests/test_wizard.py | 21 +++++++++++++++------ phy/cluster/manual/wizard.py | 10 +++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index eabee6823..bcaec7868 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -56,8 +56,16 @@ def cluster_metadata(): base_meta = ClusterMetadata(data=data) @base_meta.default - def group(cluster): - return 3 + def group(cluster, ascendant_values=None): + if not ascendant_values: + return 3 + s = list(set(ascendant_values) - set([None, 3])) + # Return the default value if all ascendant values are the default. + if not s: + return 3 + # Otherwise, return good (2) if it is present, or the largest value + # among those present. + return max(s) meta = ClusterMetadataUpdater(base_meta) yield meta @@ -228,10 +236,10 @@ def _check_best_match(b, m): assert wizard.cluster_status(2) == 'good' assert wizard.cluster_status(3) is None - clustering.merge([2, 3]) + clustering.merge([2, 3]) # => 8 _check_best_match(8, 7) assert wizard.best_list == [8, 7, 5] - assert wizard.cluster_status(8) is None + assert wizard.cluster_status(8) == 'good' assert wizard.cluster_status(7) == 'good' # Undo merge. @@ -247,14 +255,14 @@ def _check_best_match(b, m): # Redo merge. clustering.redo() _check_best_match(8, 7) - assert wizard.cluster_status(8) is None + assert wizard.cluster_status(8) == 'good' assert wizard.cluster_status(7) == 'good' ################################ # Split. ae(clustering.spike_clusters, [8, 8, 5, 7]) - clustering.split([1, 2]) + clustering.split([1, 2]) # ==> 9, 10 ae(clustering.spike_clusters, [10, 9, 9, 7]) _check_best_match(9, 10) assert wizard.cluster_status(10) is None @@ -275,6 +283,7 @@ def _check_best_match(b, m): _check_best_match(9, 10) assert up.description == 'assign' assert up.history == 'redo' + assert wizard.cluster_status(9) == 'ignored' ################################ diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 8fa8162a6..390e3834e 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -478,13 +478,12 @@ def attach(self, clustering, cluster_metadata): @self.set_status_function def status(cluster): group = cluster_metadata.group(cluster) + if group is None: + return None if group <= 1: return 'ignored' elif group == 2: return 'good' - elif group == 3: - return None - raise NotImplementedError() # pragma: no cover @clustering.connect def on_request_undo_state(up): @@ -493,8 +492,13 @@ def on_request_undo_state(up): def on_cluster(up): if self._has_finished: return + # Set the cluster metadata of new clusters. + if up.added: + cluster_metadata.set_from_descendants(up.descendants) + # Update the wizard state. if self._best_list or self._match_list: self._update_state(up) + # Make a new selection. if self._best is not None or self._match is not None: self._select_after_update(up) From d2434b84c3b636fd85bdff3bc23fcfd2b3004e65 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 15:29:26 +0200 Subject: [PATCH 0073/1059] WIP: test_wizard_update_group --- phy/cluster/manual/tests/test_wizard.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index bcaec7868..1e5a3cc9f 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -213,10 +213,27 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 -def test_wizard_update(wizard, clustering, cluster_metadata): +def test_wizard_update_group(wizard, clustering, cluster_metadata): # 2: none, 3: none, 5: ignored, 7: good wizard.attach(clustering, cluster_metadata) + wizard.start() + + def _check_best_match(b, m): + assert wizard.selection == [b, m] + assert wizard.best == b + assert wizard.match == m + wizard.pin() + _check_best_match(3, 2) + + # Ignore the currently-pinned cluster. + cluster_metadata.set_group(3, 0) + _check_best_match(5, 2) + + +def test_wizard_update_clustering(wizard, clustering, cluster_metadata): + # 2: none, 3: none, 5: ignored, 7: good + wizard.attach(clustering, cluster_metadata) wizard.start() def _check_best_match(b, m): From 22e0dabebcf2385adfb4e9da0706eeaad18d734e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 15:46:50 +0200 Subject: [PATCH 0074/1059] WIP: add undo_state to ClusterMetadataUpdater --- phy/cluster/manual/_utils.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 70ac1c92c..3086bea41 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -191,8 +191,9 @@ def __init__(self, cluster_metadata): self._cluster_metadata = cluster_metadata # Keep a deep copy of the original structure for the undo stack. self._data_base = deepcopy(cluster_metadata.data) - # The stack contains (clusters, field, value, update_info) tuples. - self._undo_stack = History((None, None, None, None)) + # The stack contains (clusters, field, value, update_info, undo_state) + # tuples. + self._undo_stack = History((None, None, None, None, None)) for field, func in self._cluster_metadata._fields.items(): @@ -211,14 +212,17 @@ def f(clusters, value): setattr(self, 'set_{0:s}'.format(field), _make_set(field)) def _set(self, clusters, field, value, add_to_stack=True): + self._cluster_metadata._set(clusters, field, value) clusters = _as_list(clusters) up = UpdateInfo(description='metadata_' + field, metadata_changed=clusters, metadata_value=value, ) + undo_state = self.emit('request_undo_state', up) + if add_to_stack: - self._undo_stack.add((clusters, field, value, up)) + self._undo_stack.add((clusters, field, value, up, undo_state)) self.emit('cluster', up) return up @@ -241,13 +245,14 @@ def undo(self): if args is None: return self._cluster_metadata._data = deepcopy(self._data_base) - for clusters, field, value, _ in self._undo_stack: + for clusters, field, value, up, undo_state in self._undo_stack: if clusters is not None: self._set(clusters, field, value, add_to_stack=False) # Return the UpdateInfo instance of the undo action. - info = args[-1] - info.history = 'undo' - return info + up, undo_state = args[-2:] + up.history = 'undo' + up.undo_state = undo_state + return up def redo(self): """Redo the next metadata change. @@ -260,8 +265,8 @@ def redo(self): args = self._undo_stack.forward() if args is None: return - clusters, field, value, info = args + clusters, field, value, up, undo_state = args self._set(clusters, field, value, add_to_stack=False) # Return the UpdateInfo instance of the redo action. - info.history = 'redo' - return info + up.history = 'redo' + return up From 781e3ec0e73c8ae2b5ac325e272948724ac1feb2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 16:08:21 +0200 Subject: [PATCH 0075/1059] Fix bugs --- phy/cluster/manual/_utils.py | 8 +++++++- phy/cluster/manual/tests/test_wizard.py | 9 +++++++++ phy/cluster/manual/wizard.py | 24 +++++++++++++++++------- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 3086bea41..c93742286 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -223,8 +223,8 @@ def _set(self, clusters, field, value, add_to_stack=True): if add_to_stack: self._undo_stack.add((clusters, field, value, up, undo_state)) + self.emit('cluster', up) - self.emit('cluster', up) return up def set_from_descendants(self, descendants): @@ -248,10 +248,13 @@ def undo(self): for clusters, field, value, up, undo_state in self._undo_stack: if clusters is not None: self._set(clusters, field, value, add_to_stack=False) + # Return the UpdateInfo instance of the undo action. up, undo_state = args[-2:] up.history = 'undo' up.undo_state = undo_state + + self.emit('cluster', up) return up def redo(self): @@ -267,6 +270,9 @@ def redo(self): return clusters, field, value, up, undo_state = args self._set(clusters, field, value, add_to_stack=False) + # Return the UpdateInfo instance of the redo action. up.history = 'redo' + + self.emit('cluster', up) return up diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 1e5a3cc9f..332622c7e 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -229,6 +229,15 @@ def _check_best_match(b, m): # Ignore the currently-pinned cluster. cluster_metadata.set_group(3, 0) _check_best_match(5, 2) + # 2: none, 3: ignored, 5: ignored, 7: good + + # Ignore the current match and move to next. + cluster_metadata.set_group(2, 1) + _check_best_match(5, 7) + # 2: ignored, 3: ignored, 5: ignored, 7: good + + cluster_metadata.undo() + _check_best_match(5, 2) def test_wizard_update_clustering(wizard, clustering, cluster_metadata): diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 390e3834e..5487c6f85 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -9,6 +9,8 @@ import logging from operator import itemgetter +from phy.utils import EventEmitter + logger = logging.getLogger(__name__) @@ -79,9 +81,10 @@ def _progress(value, maximum): # Wizard #------------------------------------------------------------------------------ -class Wizard(object): +class Wizard(EventEmitter): """Propose a selection of high-quality clusters and merge candidates.""" def __init__(self): + super(Wizard, self).__init__() self._similarity = None self._quality = None self._get_cluster_ids = None @@ -240,8 +243,7 @@ def best(self): @best.setter def best(self, value): assert value in self._best_list - self._best = value - self._selection = [value] + self.selection = [value] @property def match(self): @@ -252,11 +254,10 @@ def match(self): def match(self, value): if value is not None: assert value in self._match_list - self._match = value if len(self._selection) == 1: - self._selection += [value] + self.selection = self.selection + [value] elif len(self._selection) == 2: - self._selection[1] = value + self.selection = [self.selection[0], value] @property def selection(self): @@ -270,10 +271,13 @@ def selection(self, value): clusters = self.cluster_ids value = [cluster for cluster in value if cluster in clusters] self._selection = value + if len(self._selection) == 1: + self._match = None if len(self._selection) >= 1: self._best = self._selection[0] if len(self._selection) >= 2: self._match = self._selection[1] + self.emit('select', self._selection) @property def best_list(self): @@ -464,7 +468,11 @@ def _select_after_update(self, up): if up.description == 'metadata_group': cluster = up.metadata_changed[0] if cluster == self.best: + # Pin the next best if there was a match before. + match_before = self.match is not None self.next_best() + if match_before: + self.pin() elif cluster == self.match: self.next_match() @@ -485,7 +493,6 @@ def status(cluster): elif group == 2: return 'good' - @clustering.connect def on_request_undo_state(up): return {'selection': self.selection} @@ -502,5 +509,8 @@ def on_cluster(up): if self._best is not None or self._match is not None: self._select_after_update(up) + clustering.connect(on_request_undo_state) + cluster_metadata.connect(on_request_undo_state) + clustering.connect(on_cluster) cluster_metadata.connect(on_cluster) From 759ec846c5d513638077f41319d3098b2d08559d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 16:11:27 +0200 Subject: [PATCH 0076/1059] More tests --- phy/cluster/manual/tests/test_wizard.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 332622c7e..0282b1a2d 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -239,6 +239,9 @@ def _check_best_match(b, m): cluster_metadata.undo() _check_best_match(5, 2) + cluster_metadata.redo() + _check_best_match(5, 7) + def test_wizard_update_clustering(wizard, clustering, cluster_metadata): # 2: none, 3: none, 5: ignored, 7: good From 19d78d5ea295c5d253de5fcfee7b802c4dc408c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 16:12:52 +0200 Subject: [PATCH 0077/1059] WIP: increase coverage in phy.cluster.manual --- phy/cluster/manual/tests/test_utils.py | 2 +- phy/cluster/manual/tests/test_wizard.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index ac9fb8c23..6c20c55a3 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -129,7 +129,7 @@ def group(cluster, ascendant_values=None): return 3 s = list(set(ascendant_values) - set([None, 3])) # Return the default value if all ascendant values are the default. - if not s: + if not s: # pragma: no cover return 3 # Otherwise, return good (2) if it is present, or the largest value # among those present. diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 0282b1a2d..80b4578eb 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -61,7 +61,7 @@ def group(cluster, ascendant_values=None): return 3 s = list(set(ascendant_values) - set([None, 3])) # Return the default value if all ascendant values are the default. - if not s: + if not s: # pragma: no cover return 3 # Otherwise, return good (2) if it is present, or the largest value # among those present. From 7d10774c1bd4999f26cab3ac58ad42133f1ce698 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 16:21:50 +0200 Subject: [PATCH 0078/1059] WIP: increase coverage in phy.cluster.manual.wizard --- phy/cluster/manual/tests/test_wizard.py | 2 +- phy/cluster/manual/wizard.py | 16 ++-------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 80b4578eb..d26c87fe4 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -214,8 +214,8 @@ def test_wizard_nav(wizard): def test_wizard_update_group(wizard, clustering, cluster_metadata): - # 2: none, 3: none, 5: ignored, 7: good wizard.attach(clustering, cluster_metadata) + wizard.start() def _check_best_match(b, m): diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 5487c6f85..0142a7443 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -306,14 +306,8 @@ def n_clusters(self): # Navigation #-------------------------------------------------------------------------- - @property - def _has_finished(self): - return self.best is not None and len(self._best_list) <= 1 - def next_best(self): """Select the next best cluster.""" - if self._has_finished: - return self.best = _next(self._best_list, self._best, ) @@ -322,8 +316,6 @@ def next_best(self): def previous_best(self): """Select the previous best in cluster.""" - if self._has_finished: - return if self._best_list: self.best = _previous(self._best_list, self._best, @@ -385,8 +377,6 @@ def start(self): def pin(self, cluster=None): """Pin the current best cluster and set the list of closest matches.""" - if self._has_finished: - return if cluster is None: cluster = self.best logger.debug("Pin %d.", cluster) @@ -422,7 +412,7 @@ def _add(self, clusters, position=None): if self.best is not None: if position is not None: self._best_list.insert(position, clu) - else: + else: # pragma: no cover self._best_list.append(clu) if self.match is not None: self._match_list.append(clu) @@ -486,7 +476,7 @@ def attach(self, clustering, cluster_metadata): @self.set_status_function def status(cluster): group = cluster_metadata.group(cluster) - if group is None: + if group is None: # pragma: no cover return None if group <= 1: return 'ignored' @@ -497,8 +487,6 @@ def on_request_undo_state(up): return {'selection': self.selection} def on_cluster(up): - if self._has_finished: - return # Set the cluster metadata of new clusters. if up.added: cluster_metadata.set_from_descendants(up.descendants) From 7b35f5779fd9074cb2434f8c050fbdde137e6ecf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 16:27:59 +0200 Subject: [PATCH 0079/1059] WIP: increase coverage in phy.cluster.manual.wizard --- phy/cluster/manual/tests/test_wizard.py | 22 ++++++++++++++++ phy/cluster/manual/wizard.py | 35 +++++++++---------------- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index d26c87fe4..1404a27fb 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -213,6 +213,28 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 +def test_wizard_update_simple(wizard, clustering, cluster_metadata): + # 2: none, 3: none, 5: ignored, 7: good + wizard.attach(clustering, cluster_metadata) + + wizard.first() + wizard.last() + + wizard.start() + + wizard.first() + wizard.last() + + wizard.pin() + + wizard.first() + wizard.last() + + wizard.pin() + wizard.previous_best() + wizard.next_best() + + def test_wizard_update_group(wizard, clustering, cluster_metadata): wizard.attach(clustering, cluster_metadata) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 0142a7443..5097f3b40 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -308,37 +308,28 @@ def n_clusters(self): def next_best(self): """Select the next best cluster.""" - self.best = _next(self._best_list, - self._best, - ) - if self.match is not None: + boo_match = self.match is not None + self.best = _next(self._best_list, self._best) + if boo_match: self._set_match_list() def previous_best(self): """Select the previous best in cluster.""" + boo_match = self.match is not None if self._best_list: - self.best = _previous(self._best_list, - self._best, - ) - if self.match is not None: + self.best = _previous(self._best_list, self._best) + if boo_match: self._set_match_list() def next_match(self): """Select the next match.""" - # Handle the case where we arrive at the end of the match list. - if self.match is not None and len(self._match_list) <= 1: - self.next_best() - elif self._match_list: - self.match = _next(self._match_list, - self._match, - ) + if self._match_list: + self.match = _next(self._match_list, self._match) def previous_match(self): """Select the previous match.""" if self._match_list: - self.match = _previous(self._match_list, - self._match, - ) + self.match = _previous(self._match_list, self._match) def next(self): """Next cluster proposition.""" @@ -356,16 +347,16 @@ def previous(self): def first(self): """First match or first best.""" - if self.match is None: + if self.match is None and self._best_list: self.best = self._best_list[0] - else: + elif self._match_list: self.match = self._match_list[0] def last(self): """Last match or last best.""" - if self.match is None: + if self.match is None and self._best_list: self.best = self._best_list[-1] - else: + elif self.match_list: self.match = self._match_list[-1] # Control From fc3ae180576c1f0c92f1b3d6e7e38f4e4df397f0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 16:32:31 +0200 Subject: [PATCH 0080/1059] Increase coverage in wizard --- phy/cluster/manual/tests/test_wizard.py | 6 ++++++ phy/cluster/manual/wizard.py | 13 ------------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 1404a27fb..54c751293 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -13,6 +13,7 @@ from .._utils import ClusterMetadata, ClusterMetadataUpdater from ..wizard import (_previous, _next, + _find_first, Wizard, ) @@ -86,6 +87,10 @@ def test_utils(): def func(x): return x in (2, 5) + _find_first([], None) + + _previous([], None) + _previous([0, 1], 1, lambda x: x > 0) # Error: log and do nothing. _previous(l, 1, func) _previous(l, 15, func) @@ -96,6 +101,7 @@ def func(x): assert _previous(l, 7, func) == 5 assert _previous(l, 11, func) == 5 + _next([], None) # Error: log and do nothing. _next(l, 1, func) _next(l, 15, func) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 5097f3b40..8b064d3fa 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -71,12 +71,6 @@ def _next(items, current, filter=None): return current -def _progress(value, maximum): - if maximum <= 1: - return 1 - return int(100 * value / float(maximum - 1)) - - #------------------------------------------------------------------------------ # Wizard #------------------------------------------------------------------------------ @@ -98,10 +92,6 @@ def reset(self): self._best = None self._match = None - @property - def has_started(self): - return len(self._best_list) > 0 - # Quality and status functions #-------------------------------------------------------------------------- @@ -146,9 +136,6 @@ def _with_status(self, items, status): status = [status] return [item for item in items if self._cluster_status(item) in status] - def _is_not_ignored(self, cluster): - return self._with_status([cluster], (None, 'good')) - def _check(self): clusters = set(self.cluster_ids) assert set(self._best_list) <= clusters From 21b1aba6d120ab9f6ff02558bcf1721d89757d63 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 20:37:08 +0200 Subject: [PATCH 0081/1059] Update prompt --- phy/gui/qt.py | 35 +++++++++++++++++++++++------------ phy/gui/tests/test_base.py | 2 ++ phy/gui/tests/test_qt.py | 31 +++++++++++++------------------ 3 files changed, 38 insertions(+), 30 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index f82db6f39..5d44044a4 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -50,21 +50,32 @@ def _check_qt(): # Utility functions # ----------------------------------------------------------------------------- -def _prompt(parent, message, buttons=('yes', 'no'), title='Question'): - buttons = [(button, getattr(QtGui.QMessageBox, button.capitalize())) - for button in buttons] +def _button_enum_from_name(name): + return getattr(QtGui.QMessageBox, name.capitalize()) + + +def _button_name_from_enum(enum): + names = dir(QtGui.QMessageBox) + for name in names: + if getattr(QtGui.QMessageBox, name) == enum: + return name.lower() + + +def _prompt(message, buttons=('yes', 'no'), title='Question'): + buttons = [(button, _button_enum_from_name(button)) for button in buttons] arg_buttons = 0 for (_, button) in buttons: arg_buttons |= button - reply = QtGui.QMessageBox.question(parent, - title, - message, - arg_buttons, - buttons[0][1], - ) - for name, button in buttons: - if reply == button: - return name + box = QtGui.QMessageBox() + box.setWindowTitle(title) + box.setText(message) + box.setStandardButtons(arg_buttons) + box.setDefaultButton(buttons[0][1]) + return box + + +def _show_box(box): # pragma: no cover + return _button_name_from_enum(box.exec_()) def _set_qt_widget_position_size(widget, position=None, size=None): diff --git a/phy/gui/tests/test_base.py b/phy/gui/tests/test_base.py index 492fc30a5..7cc348169 100644 --- a/phy/gui/tests/test_base.py +++ b/phy/gui/tests/test_base.py @@ -164,6 +164,8 @@ def test(self): qtbot.addWidget(gui.main_window) gui.show() + gui.test() + # Test snippet mode. gui.enable_snippet_mode() diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index fa4540402..192b312ed 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -6,19 +6,14 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - -from ..qt import (QtWebKit, QtGui, - qt_app, +from ..qt import (QtCore, QtGui, QtWebKit, _set_qt_widget_position_size, + _button_name_from_enum, + _button_enum_from_name, _prompt, ) -# Skip these tests in "make test-quick". -pytestmark = mark.long - - #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ @@ -52,13 +47,13 @@ def _assert(text): view.close() -@mark.skipif() -def test_prompt(): - with qt_app(): - w = QtGui.QWidget() - w.show() - result = _prompt(w, - "How are you doing?", - buttons=['save', 'cancel', 'close'], - ) - print(result) +def test_prompt(qtbot): + + assert _button_name_from_enum(QtGui.QMessageBox.Save) == 'save' + assert _button_enum_from_name('save') == QtGui.QMessageBox.Save + + box = _prompt("How are you doing?", + buttons=['save', 'cancel', 'close'], + ) + qtbot.mouseClick(box.buttons()[0], QtCore.Qt.LeftButton) + assert 'save' in box.clickedButton().text().lower() From 742a3ee1d8dbe427c26d9fd19ff8f18e14812d7f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 20:47:35 +0200 Subject: [PATCH 0082/1059] Fix dock tests --- phy/gui/tests/test_dock.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 0a2916972..40d57f042 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -65,7 +65,7 @@ def test_dock_status_message(qtbot): def test_dock_state(qtbot): - _gs = None + _gs = [] gui = DockWindow() qtbot.addWidget(gui) @@ -79,8 +79,7 @@ def press(): @gui.connect_ def on_close_gui(): - global _gs - _gs = gui.save_geometry_state() + _gs.append(gui.save_geometry_state()) gui.show() @@ -100,9 +99,16 @@ def on_close_gui(): gui.add_view(_create_canvas(), 'view2') @gui.connect_ - def on_show(): - print(_gs) - gui.restore_geometry_state(_gs) + def on_show_gui(): + gui.restore_geometry_state(_gs[0]) + qtbot.addWidget(gui) gui.show() + + assert len(gui.list_views('view')) == 3 + assert gui.view_count() == { + 'view1': 1, + 'view2': 2, + } + gui.close() From 225eb2acd9857e56611d20f3bb2708188068b066 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 20:54:15 +0200 Subject: [PATCH 0083/1059] Add dock tests --- phy/gui/tests/test_dock.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 40d57f042..26cc2511b 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -10,8 +10,9 @@ from vispy import app +from ..qt import Qt from ..dock import DockWindow -from ...utils._color import _random_color +from phy.utils._color import _random_color # Skip these tests in "make test-quick". @@ -66,12 +67,14 @@ def test_dock_status_message(qtbot): def test_dock_state(qtbot): _gs = [] - gui = DockWindow() + gui = DockWindow(size=(100, 100)) qtbot.addWidget(gui) + _press = [] + @gui.shortcut('press', 'ctrl+g') def press(): - pass + _press.append(0) gui.add_view(_create_canvas(), 'view1') gui.add_view(_create_canvas(), 'view2') @@ -82,6 +85,10 @@ def on_close_gui(): _gs.append(gui.save_geometry_state()) gui.show() + qtbot.waitForWindowShown(gui) + + qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) + assert _press == [0] assert len(gui.list_views('view')) == 3 assert gui.view_count() == { From 70cee576a1fff04906ea43db327ca7f5d0cd54d9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 20:55:41 +0200 Subject: [PATCH 0084/1059] Increase coverage in phy.gui --- phy/gui/tests/test_dock.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 26cc2511b..a3ccacb6e 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -32,11 +32,6 @@ def _create_canvas(): def on_draw(e): c.context.clear(c.color) - @c.connect - def on_key_press(e): - c.color = _random_color() - c.update() - return c @@ -45,13 +40,10 @@ def test_dock_1(qtbot): gui = DockWindow() qtbot.addWidget(gui) - @gui.shortcut('quit', 'ctrl+q') - def quit(): - gui.close() - gui.add_view(_create_canvas(), 'view1') gui.add_view(_create_canvas(), 'view2') gui.show() + qtbot.waitForWindowShown(gui) assert len(gui.list_views('view')) == 2 gui.close() From 699b57434dfc8a34c8d66413753ef085f7e5d734 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:00:02 +0200 Subject: [PATCH 0085/1059] WIP: increase coverage in phy.gui.dock --- phy/gui/dock.py | 2 +- phy/gui/tests/test_dock.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 99a95b8ff..9e9fbb67f 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -170,7 +170,7 @@ def add_view(self, from vispy.app import Canvas if isinstance(view, Canvas): view = view.native - except ImportError: + except ImportError: # pragma: no cover pass class DockWidget(QtGui.QDockWidget): diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index a3ccacb6e..3e47eac78 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -29,7 +29,7 @@ def _create_canvas(): c.color = _random_color() @c.connect - def on_draw(e): + def on_draw(e): # pragma: no cover c.context.clear(c.color) return c @@ -40,10 +40,14 @@ def test_dock_1(qtbot): gui = DockWindow() qtbot.addWidget(gui) - gui.add_view(_create_canvas(), 'view1') + gui.add_action('test', lambda: None) + # Adding an action twice has no effect. + gui.add_action('test', lambda: None) + + gui.add_view(_create_canvas(), 'view1', floating=True) gui.add_view(_create_canvas(), 'view2') gui.show() - qtbot.waitForWindowShown(gui) + # qtbot.waitForWindowShown(gui) assert len(gui.list_views('view')) == 2 gui.close() From 4d804dd285812991eecfebed9394766267127367 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:03:13 +0200 Subject: [PATCH 0086/1059] Increase coverage in phy.gui.dock --- phy/gui/dock.py | 11 +---------- phy/gui/tests/test_dock.py | 10 +++++++++- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 9e9fbb67f..f2a07ac4c 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -20,15 +20,6 @@ def _title(widget): return str(widget.windowTitle()).lower() -def _widget(dock_widget): - """Return a Qt or VisPy widget from a dock widget.""" - widget = dock_widget.widget() - if hasattr(widget, '_vispy_canvas'): - return widget._vispy_canvas - else: - return widget - - # ----------------------------------------------------------------------------- # Qt windows # ----------------------------------------------------------------------------- @@ -96,7 +87,7 @@ def closeEvent(self, e): res = self.emit('close_gui') # Discard the close event if False is returned by one of the callback # functions. - if False in res: + if False in res: # pragma: no cover e.ignore() return super(DockWindow, self).closeEvent(e) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 3e47eac78..ac6164c88 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -37,9 +37,17 @@ def on_draw(e): # pragma: no cover def test_dock_1(qtbot): - gui = DockWindow() + gui = DockWindow(position=(200, 100), size=(100, 100)) qtbot.addWidget(gui) + # Increase coverage. + @gui.connect_ + def on_show_gui(): + pass + gui.unconnect_(on_show_gui) + qtbot.keyPress(gui, Qt.Key_Control) + qtbot.keyRelease(gui, Qt.Key_Control) + gui.add_action('test', lambda: None) # Adding an action twice has no effect. gui.add_action('test', lambda: None) From bb9a46d8a6c472a46aa39009d150e9a7b1cf7384 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:11:31 +0200 Subject: [PATCH 0087/1059] Increase coverage in phy.gui.qt --- phy/gui/__init__.py | 2 +- phy/gui/qt.py | 46 +++++++------------------------------- phy/gui/tests/test_base.py | 11 +++++---- phy/gui/tests/test_qt.py | 5 ++--- 4 files changed, 16 insertions(+), 48 deletions(-) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index 0ed1b6ebb..a4f62e14b 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -3,7 +3,7 @@ """GUI routines.""" -from .qt import start_qt_app, run_qt_app, qt_app, enable_qt +from .qt import start_qt_app, run_qt_app, enable_qt from .dock import DockWindow from .base import (BaseViewModel, diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 5d44044a4..a6ca3479b 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -23,29 +23,23 @@ _PYQT = False try: from PyQt4 import QtCore, QtGui, QtWebKit # noqa - from PyQt4.QtGui import QMainWindow Qt = QtCore.Qt _PYQT = True -except ImportError: +except ImportError: # pragma: no cover try: from PyQt5 import QtCore, QtGui, QtWebKit # noqa - from PyQt5.QtGui import QMainWindow _PYQT = True except ImportError: pass -def _check_qt(): +def _check_qt(): # pragma: no cover if not _PYQT: logger.warn("PyQt is not available.") return False return True -if not _check_qt(): - QMainWindow = object # noqa - - # ----------------------------------------------------------------------------- # Utility functions # ----------------------------------------------------------------------------- @@ -74,17 +68,10 @@ def _prompt(message, buttons=('yes', 'no'), title='Question'): return box -def _show_box(box): # pragma: no cover +def _show_box(box): # pragma: no cover return _button_name_from_enum(box.exec_()) -def _set_qt_widget_position_size(widget, position=None, size=None): - if position is not None: - widget.moveTo(*position) - if size is not None: - widget.resize(*size) - - # ----------------------------------------------------------------------------- # Event loop integration with IPython # ----------------------------------------------------------------------------- @@ -93,7 +80,7 @@ def _set_qt_widget_position_size(widget, position=None, size=None): _APP_RUNNING = False -def _try_enable_ipython_qt(): +def _try_enable_ipython_qt(): # pragma: no cover """Try to enable IPython Qt event loop integration. Returns True in the following cases: @@ -123,7 +110,7 @@ def _try_enable_ipython_qt(): return False -def enable_qt(): +def enable_qt(): # pragma: no cover if not _check_qt(): return try: @@ -141,7 +128,7 @@ def enable_qt(): # Qt app # ----------------------------------------------------------------------------- -def start_qt_app(): +def start_qt_app(): # pragma: no cover """Start a Qt application if necessary. If a new Qt application is created, this function returns it. @@ -169,7 +156,7 @@ def start_qt_app(): return _APP -def run_qt_app(): +def run_qt_app(): # pragma: no cover """Start the Qt application's event loop.""" global _APP_RUNNING if not _check_qt(): @@ -181,32 +168,15 @@ def run_qt_app(): _APP_RUNNING = False -@contextlib.contextmanager -def qt_app(): - """Context manager to ensure that a Qt app is running.""" - if not _check_qt(): - return - app = start_qt_app() - yield app - run_qt_app() - - # ----------------------------------------------------------------------------- # Testing utilities # ----------------------------------------------------------------------------- -def _close_qt_after(window, duration): - """Close a Qt window after a given duration.""" - def callback(): - window.close() - QtCore.QTimer.singleShot(int(1000 * duration), callback) - - _MAX_ITER = 100 _DELAY = max(0, float(os.environ.get('PHY_EVENT_LOOP_DELAY', .1))) -def _debug_trace(): +def _debug_trace(): # pragma: no cover """Set a tracepoint in the Python debugger that works with Qt.""" from PyQt4.QtCore import pyqtRemoveInputHook from pdb import set_trace diff --git a/phy/gui/tests/test_base.py b/phy/gui/tests/test_base.py index 7cc348169..742d8a5e8 100644 --- a/phy/gui/tests/test_base.py +++ b/phy/gui/tests/test_base.py @@ -15,7 +15,6 @@ BaseGUI, ) from ..qt import (QtGui, - _set_qt_widget_position_size, ) from ...utils.event import EventEmitter @@ -36,16 +35,16 @@ class MyViewModel(BaseViewModel): def _create_view(self, text='', position=None, size=None): view = QtGui.QMainWindow() view.setWindowTitle(text) - _set_qt_widget_position_size(view, - position=position, - size=size, - ) + if size: + view.resize(*size) + if position: + view.move(*position) return view size = (400, 100) text = 'hello' - vm = MyViewModel(text=text, size=size) + vm = MyViewModel(text=text, size=size, position=(100, 100)) qtbot.addWidget(vm.view) vm.show() diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index 192b312ed..9c76ecf8f 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -7,7 +7,6 @@ #------------------------------------------------------------------------------ from ..qt import (QtCore, QtGui, QtWebKit, - _set_qt_widget_position_size, _button_name_from_enum, _button_enum_from_name, _prompt, @@ -26,7 +25,7 @@ def _assert(text): html = view.page().mainFrame().toHtml() assert html == '' + text + '' - _set_qt_widget_position_size(view, size=(100, 100)) + view.resize(100, 100) view.setHtml("hello") qtbot.addWidget(view) qtbot.waitForWindowShown(view) @@ -38,7 +37,7 @@ def _assert(text): view.close() view = QtWebKit.QWebView() - _set_qt_widget_position_size(view, size=(100, 100)) + view.resize(100, 100) view.show() qtbot.addWidget(view) From 86aee2a052605442cbf95aa7eb40505448196f96 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:22:05 +0200 Subject: [PATCH 0088/1059] Remove GUI base module --- phy/gui/__init__.py | 6 - phy/gui/base.py | 690 ------------------------------------ phy/gui/tests/test_base.py | 193 ---------- phy/gui/tests/test_utils.py | 17 + 4 files changed, 17 insertions(+), 889 deletions(-) delete mode 100644 phy/gui/base.py delete mode 100644 phy/gui/tests/test_base.py create mode 100644 phy/gui/tests/test_utils.py diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index a4f62e14b..5dbd6c7b7 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -5,9 +5,3 @@ from .qt import start_qt_app, run_qt_app, enable_qt from .dock import DockWindow - -from .base import (BaseViewModel, - HTMLViewModel, - WidgetCreator, - BaseGUI, - ) diff --git a/phy/gui/base.py b/phy/gui/base.py deleted file mode 100644 index cba29ddab..000000000 --- a/phy/gui/base.py +++ /dev/null @@ -1,690 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Base classes for GUIs.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from collections import Counter -import inspect -import logging - -from six import string_types, PY3 - -from ..utils import EventEmitter -from ._utils import _read -from .dock import DockWindow - -logger = logging.getLogger(__name__) - - -#------------------------------------------------------------------------------ -# BaseViewModel -#------------------------------------------------------------------------------ - -class BaseViewModel(object): - """Interface between a view and a model. - - Events - ------ - - show_view - close_view - - """ - _view_name = '' - _imported_params = ('position', 'size',) - - def __init__(self, model=None, **kwargs): - self._model = model - self._event = EventEmitter() - - # Instantiate the underlying view. - self._view = self._create_view(**kwargs) - - # Set passed keyword arguments as attributes. - for param in self.imported_params(): - value = kwargs.get(param, None) - if value is not None: - setattr(self, param, value) - - self.on_open() - - def emit(self, *args, **kwargs): - """Emit an event.""" - return self._event.emit(*args, **kwargs) - - def connect(self, *args, **kwargs): - """Connect a callback function.""" - self._event.connect(*args, **kwargs) - - # Methods to override - #-------------------------------------------------------------------------- - - def _create_view(self, **kwargs): - """Create the view with the parameters passed to the constructor. - - Must be overriden.""" - return None - - def on_open(self): - """Initialize the view after the model has been loaded. - - May be overriden.""" - - def on_close(self): - """Called when the model is closed. - - May be overriden.""" - - # Parameters - #-------------------------------------------------------------------------- - - @classmethod - def imported_params(cls): - """All parameter names to be imported on object creation.""" - out = () - for base_class in inspect.getmro(cls): - if base_class == object: - continue - out += base_class._imported_params - return out - - def exported_params(self, save_size_pos=True): - """Return a dictionary of variables to save when the view is closed.""" - if save_size_pos and hasattr(self._view, 'pos'): - return { - 'position': (self._view.x(), self._view.y()), - 'size': (self._view.width(), self._view.height()), - } - else: - return {} - - @classmethod - def get_params(cls, settings): - """Return the parameter values for the creation of the view.""" - name = cls._view_name - param_names = cls.imported_params() - params = {key: settings[name + '_' + key] - for key in param_names - if (name + '_' + key) in settings} - return params - - # Properties - #-------------------------------------------------------------------------- - - @property - def model(self): - """The model.""" - return self._model - - @property - def name(self): - """The view model's name.""" - return self._view_name - - @property - def view(self): - """The underlying view.""" - return self._view - - # Public methods - #-------------------------------------------------------------------------- - - def close(self): - """Close the view.""" - self._view.close() - self.emit('close_view') - - def show(self): - """Show the view.""" - self._view.show() - self.emit('show_view') - - -#------------------------------------------------------------------------------ -# HTMLViewModel -#------------------------------------------------------------------------------ - -class HTMLViewModel(BaseViewModel): - """Widget with custom HTML code. - - To create a new HTML view, derive from `HTMLViewModel`, and implement - `get_html()` which returns HTML code. - - """ - - def _update(self, view, **kwargs): - html = self.get_html(**kwargs) - css = self.get_css(**kwargs) - wrapped = _read('wrap_qt.html') - html_wrapped = wrapped.replace('%CSS%', css).replace('%HTML%', html) - view.setHtml(html_wrapped) - - def _create_view(self, **kwargs): - from PyQt4.QtWebKit import QWebView - view = QWebView() - self._update(view, **kwargs) - return view - - def get_html(self, **kwargs): - """Return the non-formatted HTML contents of the view.""" - return '' - - def get_css(self, **kwargs): - """Return the view's CSS styles.""" - return '' - - def update(self, **kwargs): - """Update the widget's HTML contents.""" - self._update(self._view, **kwargs) - - def isVisible(self): - return self._view.isVisible() - - -#------------------------------------------------------------------------------ -# Widget creator (used to create views and GUIs) -#------------------------------------------------------------------------------ - -class WidgetCreator(EventEmitter): - """Manage the creation of widgets. - - A widget must implement: - - * `name` - * `show()` - * `connect` (for `close` event) - - Events - ------ - - add(widget): when a widget is added. - close(widget): when a widget is closed. - - """ - def __init__(self, widget_classes=None): - super(WidgetCreator, self).__init__() - self._widget_classes = widget_classes or {} - self._widgets = [] - - def _create_widget(self, widget_class, **kwargs): - """Create a new widget of a given class. - - May be overriden. - - """ - return widget_class(**kwargs) - - @property - def widget_classes(self): - """The registered widget classes that can be created.""" - return self._widget_classes - - def _widget_name(self, widget): - if widget.name: - return widget.name - # Fallback to the name given in widget_classes. - for name, cls in self._widget_classes.items(): - if cls == widget.__class__: - return name - - def get(self, *names): - """Return the list of widgets of a given type.""" - if not names: - return self._widgets - return [widget for widget in self._widgets - if self._widget_name(widget) in names] - - def add(self, widget_class, show=False, **kwargs): - """Add a new widget.""" - # widget_class can also be a name, but in this case it must be - # registered in self._widget_classes. - if isinstance(widget_class, string_types): - if widget_class not in self.widget_classes: - raise ValueError("Unknown widget class " - "`{}`.".format(widget_class)) - widget_class = self.widget_classes[widget_class] - widget = self._create_widget(widget_class, **kwargs) - if widget not in self._widgets: - self._widgets.append(widget) - self.emit('add', widget) - - @widget.connect - def on_close(e=None): - self.emit('close', widget) - self.remove(widget) - - if show: - widget.show() - - return widget - - def remove(self, widget): - """Remove a widget.""" - if widget in self._widgets: - logger.debug("Remove widget %s.", widget) - self._widgets.remove(widget) - else: - logger.debug("Unable to remove widget %s.", widget) - - -#------------------------------------------------------------------------------ -# Base GUI -#------------------------------------------------------------------------------ - -def _title(item): - """Default view model title.""" - if hasattr(item, 'name'): - name = item.name.capitalize() - else: - name = item.__class__.__name__.capitalize() - return name - - -def _assert_counters_equal(c_0, c_1): - c_0 = {(k, v) for (k, v) in c_0.items() if v > 0} - c_1 = {(k, v) for (k, v) in c_1.items() if v > 0} - assert c_0 == c_1 - - -def _show_shortcut(shortcut): - if isinstance(shortcut, string_types): - return shortcut - elif isinstance(shortcut, tuple): - return ', '.join(shortcut) - - -def _show_shortcuts(shortcuts, name=''): - print() - if name: - name = ' for ' + name - print('Keyboard shortcuts' + name) - for name in sorted(shortcuts): - print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) - print() - - -class BaseGUI(EventEmitter): - """Base GUI. - - This object represents a main window with: - - * multiple dockable views - * user-exposed actions - * keyboard shortcuts - - Parameters - ---------- - - config : list - List of pairs `(name, kwargs)` to create default views. - vm_classes : dict - Dictionary `{name: view_model_class}`. - state : object - Default Qt GUI state. - shortcuts : dict - Dictionary `{function_name: keyboard_shortcut}`. - - Events - ------ - - add_view - close_view - reset_gui - - """ - - _default_shortcuts = { - 'exit': 'ctrl+q', - 'enable_snippet_mode': ':', - } - - def __init__(self, - model=None, - vm_classes=None, - state=None, - shortcuts=None, - snippets=None, - config=None, - settings=None, - ): - super(BaseGUI, self).__init__() - self.settings = settings or {} - if state is None: - state = {} - self.model = model - # Shortcuts. - s = self._default_shortcuts.copy() - s.update(shortcuts or {}) - self._shortcuts = s - self._snippets = snippets or {} - # GUI state and config. - self._state = state - if config is None: - config = [(name, {}) for name in (vm_classes or {})] - self._config = config - # Create the dock window. - self._dock = DockWindow(title=self.title) - self._view_creator = WidgetCreator(widget_classes=vm_classes) - self._initialize_views() - self._load_geometry_state(state) - self._set_default_shortcuts() - self._create_actions() - self._set_default_view_connections() - - def _initialize_views(self): - self._load_config(self._config, - requested_count=self._state.get('view_count', None), - ) - - #-------------------------------------------------------------------------- - # Methods to override - #-------------------------------------------------------------------------- - - @property - def title(self): - """Title of the main window. - - May be overriden. - - """ - return 'Base GUI' - - def _set_default_view_connections(self): - """Set view connections. - - May be overriden. - - Example: - - ```python - @self.main_window.connect_views('view_1', 'view_2') - def f(view_1, view_2): - # Called for every pair of views of type view_1 and view_2. - pass - ``` - - """ - pass - - def _create_actions(self): - """Create default actions in the GUI. - - The `_add_gui_shortcut()` method can be used. - - Must be overriden. - - """ - pass - - def _view_model_kwargs(self, name): - return {} - - def on_open(self): - """Callback function when the model is opened. - - Must be overriden. - - """ - pass - - #-------------------------------------------------------------------------- - # Internal methods - #-------------------------------------------------------------------------- - - def _load_config(self, config=None, - current_count=None, - requested_count=None): - """Load a GUI configuration dictionary.""" - config = config or [] - current_count = current_count or {} - requested_count = requested_count or Counter([name - for name, _ in config]) - for name, kwargs in config: - # Add the right number of views of each type. - if current_count.get(name, 0) >= requested_count.get(name, 0): - continue - logger.debug("Adding %s view in GUI.", name) - # GUI-specific keyword arguments position, size, maximized - self.add_view(name, **kwargs) - if name not in current_count: - current_count[name] = 0 - current_count[name] += 1 - _assert_counters_equal(current_count, requested_count) - - def _load_geometry_state(self, gui_state): - if gui_state: - self._dock.restore_geometry_state(gui_state) - - def _remove_actions(self): - self._dock.remove_actions() - - def _set_default_shortcuts(self): - for name, shortcut in self._default_shortcuts.items(): - self._add_gui_shortcut(name) - - def _add_gui_shortcut(self, method_name): - """Helper function to add a GUI action with a keyboard shortcut.""" - # Get the keyboard shortcut for this method. - shortcut = self._shortcuts.get(method_name, None) - - def callback(): - return getattr(self, method_name)() - - # Bind the shortcut to the method. - self._dock.add_action(method_name, - callback, - shortcut=shortcut, - ) - - #-------------------------------------------------------------------------- - # Snippet methods - #-------------------------------------------------------------------------- - - @property - def status_message(self): - """Message in the status bar.""" - return str(self._dock.status_message) - - @status_message.setter - def status_message(self, value): - self._dock.status_message = str(value) - - # HACK: Unicode characters do not appear to work on Python 2 - _snippet_message_cursor = '\u200A\u258C' if PY3 else '' - - @property - def _snippet_message(self): - """This is used to write a snippet message in the status bar. - - A cursor is appended at the end. - - """ - n = len(self.status_message) - n_cur = len(self._snippet_message_cursor) - return self.status_message[:n - n_cur] - - @_snippet_message.setter - def _snippet_message(self, value): - self.status_message = value + self._snippet_message_cursor - - def process_snippet(self, snippet): - """Processes a snippet. - - May be overriden. - - """ - assert snippet[0] == ':' - snippet = snippet[1:] - split = snippet.split(' ') - cmd = split[0] - snippet = snippet[len(cmd):].strip() - func = self._snippets.get(cmd, None) - if func is None: - logger.info("The snippet `%s` could not be found.", cmd) - return - try: - logger.info("Processing snippet `%s`.", cmd) - func(self, snippet) - except Exception as e: - logger.warn("Error when executing snippet `%s`: %s.", - cmd, str(e)) - - def _snippet_action_name(self, char): - return self._snippet_chars.index(char) - - _snippet_chars = 'abcdefghijklmnopqrstuvwxyz0123456789 ._,+*-=:()' - - def _create_snippet_actions(self): - # One action per allowed character. - for i, char in enumerate(self._snippet_chars): - - def _make_func(char): - def callback(): - self._snippet_message += char - return callback - - self._dock.add_action('snippet_{}'.format(i), - shortcut=char, - callback=_make_func(char), - ) - - def backspace(): - if self._snippet_message == ':': - return - self._snippet_message = self._snippet_message[:-1] - - def enter(): - self.process_snippet(self._snippet_message) - self.disable_snippet_mode() - - self._dock.add_action('snippet_backspace', - shortcut='backspace', - callback=backspace, - ) - self._dock.add_action('snippet_activate', - shortcut=('enter', 'return'), - callback=enter, - ) - self._dock.add_action('snippet_disable', - shortcut='escape', - callback=self.disable_snippet_mode, - ) - - def enable_snippet_mode(self): - logger.info("Snippet mode enabled, press `escape` to leave this mode.") - self._remove_actions() - self._create_snippet_actions() - self._snippet_message = ':' - - def disable_snippet_mode(self): - self.status_message = '' - # Reestablishes the shortcuts. - self._remove_actions() - self._set_default_shortcuts() - self._create_actions() - logger.info("Snippet mode disabled.") - - #-------------------------------------------------------------------------- - # Public methods - #-------------------------------------------------------------------------- - - def show(self): - """Show the GUI""" - self._dock.show() - - @property - def main_window(self): - """Main Qt window.""" - return self._dock - - def add_view(self, item, title=None, **kwargs): - """Add a new view instance to the GUI.""" - position = kwargs.pop('position', None) - # Item may be a string. - if isinstance(item, string_types): - name = item - # Default view model kwargs. - kwargs.update(self._view_model_kwargs(name)) - # View model parameters from settings. - vm_class = self._view_creator._widget_classes[name] - kwargs.update(vm_class.get_params(self.settings)) - item = self._view_creator.add(item, **kwargs) - # Set the view name if necessary. - if not item._view_name: - item._view_name = name - # Default dock title. - if title is None: - title = _title(item) - # Get the underlying view. - view = item.view if isinstance(item, BaseViewModel) else item - # Add the view to the main window. - dw = self._dock.add_view(view, title=title, position=position) - - # Dock widget close event. - @dw.connect_ - def on_close_widget(): - self._view_creator.remove(item) - self.emit('close_view', item) - - # Make sure the callback above is called when the dock widget - # is closed directly. - # View model close event. - @item.connect - def on_close_view(e=None): - dw.close() - - # Call the user callback function. - if 'on_view_open' in self.settings: - self.settings['on_view_open'](self, item) - - self.emit('add_view', item) - - def get_views(self, *names): - """Return the list of views of a given type.""" - return self._view_creator.get(*names) - - def connect_views(self, name_0, name_1): - """Decorator for a function called on every pair of views of a - given type.""" - def wrap(func): - for view_0 in self.get_views(name_0): - for view_1 in self.get_views(name_1): - func(view_0, view_1) - return wrap - - @property - def views(self): - """List of all open views.""" - return self.get_views() - - def view_count(self): - """Number of views of each type.""" - return {name: len(self.get_views(name)) - for name in self._view_creator.widget_classes.keys()} - - def reset_gui(self): - """Reset the GUI configuration.""" - count = self.view_count() - self._load_config(self._config, - current_count=count, - ) - self.emit('reset_gui') - - def show_shortcuts(self): - """Show the list of all keyboard shortcuts.""" - _show_shortcuts(self._shortcuts, name=self.__class__.__name__) - - def isVisible(self): - return self._dock.isVisible() - - def close(self): - """Close the GUI.""" - self._dock.close() - - def exit(self): - """Close the GUI.""" - self.close() diff --git a/phy/gui/tests/test_base.py b/phy/gui/tests/test_base.py deleted file mode 100644 index 742d8a5e8..000000000 --- a/phy/gui/tests/test_base.py +++ /dev/null @@ -1,193 +0,0 @@ -# -*- coding: utf-8 -*-1 -from __future__ import print_function - -"""Tests of base classes.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import raises, mark - -from ..base import (BaseViewModel, - HTMLViewModel, - WidgetCreator, - BaseGUI, - ) -from ..qt import (QtGui, - ) -from ...utils.event import EventEmitter - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Base tests -#------------------------------------------------------------------------------ - -def test_base_view_model(qtbot): - class MyViewModel(BaseViewModel): - _view_name = 'main_window' - _imported_params = ('text',) - - def _create_view(self, text='', position=None, size=None): - view = QtGui.QMainWindow() - view.setWindowTitle(text) - if size: - view.resize(*size) - if position: - view.move(*position) - return view - - size = (400, 100) - text = 'hello' - - vm = MyViewModel(text=text, size=size, position=(100, 100)) - qtbot.addWidget(vm.view) - vm.show() - - assert vm.view.windowTitle() == text - assert vm.text == text - assert vm.size == size - assert (vm.view.width(), vm.view.height()) == size - - vm.close() - - -def test_html_view_model(qtbot): - - class MyHTMLViewModel(HTMLViewModel): - def get_html(self, **kwargs): - return 'hello world!' - - vm = MyHTMLViewModel() - vm.show() - qtbot.addWidget(vm.view) - vm.close() - - -def test_widget_creator(): - - class MyWidget(EventEmitter): - """Mock widget.""" - def __init__(self, param=None): - super(MyWidget, self).__init__() - self.name = 'my_widget' - self._shown = False - self.param = param - - @property - def shown(self): - return self._shown - - def close(self, e=None): - self.emit('close', e) - self._shown = False - - def show(self): - self._shown = True - - widget_classes = {'my_widget': MyWidget} - - wc = WidgetCreator(widget_classes=widget_classes) - assert not wc.get() - assert not wc.get('my_widget') - - with raises(ValueError): - wc.add('my_widget_bis') - - for show in (False, True): - w = wc.add('my_widget', show=show, param=show) - assert len(wc.get()) == 1 - assert len(wc.get('my_widget')) == 1 - - assert w.shown is show - assert w.param is show - w.show() - assert w.shown - - w.close() - assert not wc.get() - assert not wc.get('my_widget') - - -def test_base_gui(qtbot): - - class V1(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 1' - - class V2(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 2' - - class V3(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 3' - - vm_classes = {'v1': V1, 'v2': V2, 'v3': V3} - - config = [('v1', {'position': 'right'}), - ('v2', {'position': 'left'}), - ('v2', {'position': 'bottom'}), - ('v3', {'position': 'left'}), - ] - - shortcuts = {'test': 't'} - - _message = [] - - def _snippet(gui, args): - _message.append(args) - - snippets = {'hello': _snippet} - - class TestGUI(BaseGUI): - def __init__(self): - super(TestGUI, self).__init__(vm_classes=vm_classes, - config=config, - shortcuts=shortcuts, - snippets=snippets, - ) - - def _create_actions(self): - self._add_gui_shortcut('test') - - def test(self): - self.show_shortcuts() - self.reset_gui() - - gui = TestGUI() - qtbot.addWidget(gui.main_window) - gui.show() - - gui.test() - - # Test snippet mode. - gui.enable_snippet_mode() - - def _keystroke(char=None): - """Simulate a keystroke.""" - i = gui._snippet_action_name(char) - getattr(gui.main_window, 'snippet_{}'.format(i))() - - gui.enable_snippet_mode() - for c in 'hello world': - _keystroke(c) - assert gui.status_message == ':hello world' + gui._snippet_message_cursor - gui.main_window.snippet_activate() - assert _message == ['world'] - - # Test views. - v2 = gui.get_views('v2') - assert len(v2) == 2 - v2[1].close() - - v3 = gui.get_views('v3') - v3[0].close() - - gui.reset_gui() - - gui.close() diff --git a/phy/gui/tests/test_utils.py b/phy/gui/tests/test_utils.py new file mode 100644 index 000000000..63db87902 --- /dev/null +++ b/phy/gui/tests/test_utils.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +"""Test HTML/CSS utilities.""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +from .._utils import _read + + +# ----------------------------------------------------------------------------- +# Utilities +# ----------------------------------------------------------------------------- + +def test_read(): + assert _read('wrap_qt.html') From 403d156dbe1dff9009b0ba71a3d832f8fef59a57 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:32:23 +0200 Subject: [PATCH 0089/1059] WIP: increase coverage in phy.gui --- phy/gui/dock.py | 4 ---- phy/gui/tests/test_dock.py | 14 +++++++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index f2a07ac4c..d922437dd 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -66,10 +66,6 @@ def __init__(self, self._status_bar = QtGui.QStatusBar() self.setStatusBar(self._status_bar) - def keyReleaseEvent(self, e): - self.emit('keystroke', e.key(), e.text()) - return super(DockWindow, self).keyReleaseEvent(e) - # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index ac6164c88..f5420c59a 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -52,12 +52,24 @@ def on_show_gui(): # Adding an action twice has no effect. gui.add_action('test', lambda: None) - gui.add_view(_create_canvas(), 'view1', floating=True) + view = gui.add_view(_create_canvas(), 'view1', floating=True) gui.add_view(_create_canvas(), 'view2') + view.setFloating(False) gui.show() # qtbot.waitForWindowShown(gui) assert len(gui.list_views('view')) == 2 + + # Check that the close_widget event is fired when the dock widget is + # closed. + _close = [] + + @view.connect_ + def on_close_widget(): + _close.append(0) + view.close() + assert _close == [0] + gui.close() From 78adc7fb6d33ef152402e59db2571da52d7fb690 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:32:57 +0200 Subject: [PATCH 0090/1059] Increase coverage in phy.gui --- phy/gui/tests/test_dock.py | 1 + 1 file changed, 1 insertion(+) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index f5420c59a..f85034bc6 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -51,6 +51,7 @@ def on_show_gui(): gui.add_action('test', lambda: None) # Adding an action twice has no effect. gui.add_action('test', lambda: None) + gui.remove_actions() view = gui.add_view(_create_canvas(), 'view1', floating=True) gui.add_view(_create_canvas(), 'view2') From a2e5a1d396439794c762a55fa40102f9fc0f4efd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Sep 2015 21:47:58 +0200 Subject: [PATCH 0091/1059] All tests pass --- .coveragerc | 1 + phy/__init__.py | 2 +- phy/cluster/manual/tests/test_clustering.py | 8 ++++---- phy/cluster/manual/wizard.py | 2 +- phy/gui/qt.py | 1 - phy/utils/settings.py | 2 ++ phy/utils/tests/test_testing.py | 2 +- 7 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.coveragerc b/.coveragerc index 22e6810ed..6ce8e8f8b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,4 @@ source = phy omit = */phy/ext/* */phy/utils/tempdir.py + */default_settings.py diff --git a/phy/__init__.py b/phy/__init__.py index 94e0a8fe0..5acfc8b14 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -72,7 +72,7 @@ def string_handler(level='INFO'): logger.info("Activate DEBUG level.") -def test(): +def test(): # pragma: no cover """Run the full testing suite of phy.""" import pytest pytest.main() diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 0b4be635a..602ac3fb1 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -85,16 +85,16 @@ def test_extend_assignment(): for to in (123, 0, 1, 2, 3): clusters_rel = [123] * len(spike_ids) new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, - spike_clusters, - clusters_rel) + spike_clusters, + clusters_rel) ae(new_spike_ids, [0, 2, 6]) ae(new_cluster_ids, [10, 10, 11]) # Second case: we assign the spikes to different clusters. clusters_rel = [0, 1] new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, - spike_clusters, - clusters_rel) + spike_clusters, + clusters_rel) ae(new_spike_ids, [0, 2, 6]) ae(new_cluster_ids, [10, 11, 12]) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 8b064d3fa..fc4894b5d 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -45,7 +45,7 @@ def _find_first(items, filter=None): def _previous(items, current, filter=None): if current not in items: - logger.debug("%d is not in %s.", current, items) + logger.debug("%s is not in %s.", current, items) return i = items.index(current) if i == 0: diff --git a/phy/gui/qt.py b/phy/gui/qt.py index a6ca3479b..6b342f46e 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -6,7 +6,6 @@ # Imports # ----------------------------------------------------------------------------- -import contextlib import logging import os import sys diff --git a/phy/utils/settings.py b/phy/utils/settings.py index 112ed397e..3aa24708c 100644 --- a/phy/utils/settings.py +++ b/phy/utils/settings.py @@ -103,8 +103,10 @@ def load(self, path): try: if op.splitext(path)[1] == '.py': self._update(_read_python(path)) + logger.debug("Read settings file %s.", path) elif op.splitext(path)[1] == '.json': self._update(_load_json(path)) + logger.debug("Read settings file %s.", path) else: logger.warn("The settings file %s must have the extension " "'.py' or '.json'.", path) diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index 00e642ad8..8381ea433 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -33,7 +33,7 @@ def test_captured_output(): def test_assert_equal(): d = {'a': {'b': np.random.rand(5), 3: 'c'}, 'b': 2.} d_bis = deepcopy(d) - d_bis['a']['b'] = d_bis['a']['b'] + 1e-8 + d_bis['a']['b'] = d_bis['a']['b'] + 1e-10 _assert_equal(d, d_bis) From 4d760b0a64fa146cab7b862d1379bc146ba0b254 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 10:32:39 +0200 Subject: [PATCH 0092/1059] Improve DockWindow.shortcut decorator --- phy/gui/dock.py | 7 ++++--- phy/gui/tests/test_dock.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index d922437dd..75a721baf 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -27,12 +27,13 @@ def _title(widget): class DockWindow(QtGui.QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. + `DockWindow` derives from `QMainWindow`. + Events ------ close_gui show_gui - keystroke Note ---- @@ -133,10 +134,10 @@ def remove_actions(self): for name in names: self.remove_action(name) - def shortcut(self, name, key=None): + def shortcut(self, key=None, name=None): """Decorator to add a global keyboard shortcut.""" def wrap(func): - self.add_action(name, shortcut=key, callback=func) + self.add_action(name or func.__name__, shortcut=key, callback=func) return wrap # Views diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index f85034bc6..211b033f2 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -89,7 +89,7 @@ def test_dock_state(qtbot): _press = [] - @gui.shortcut('press', 'ctrl+g') + @gui.shortcut('ctrl+g') def press(): _press.append(0) From d7d4f78481daf3f0fa64d0e5ac1a651e954916b9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 10:40:52 +0200 Subject: [PATCH 0093/1059] Improve dock --- phy/gui/dock.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 75a721baf..05bd2d609 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -24,6 +24,24 @@ def _title(widget): # Qt windows # ----------------------------------------------------------------------------- +class DockWidget(QtGui.QDockWidget): + """A QDockWidget that can emit events.""" + def __init__(self, *args, **kwargs): + super(DockWidget, self).__init__(*args, **kwargs) + self._event = EventEmitter() + + def emit(self, *args, **kwargs): + return self._event.emit(*args, **kwargs) + + def connect_(self, *args, **kwargs): + self._event.connect(*args, **kwargs) + + def closeEvent(self, e): + """Qt slot when the window is closed.""" + self.emit('close_widget') + super(DockWidget, self).closeEvent(e) + + class DockWindow(QtGui.QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. @@ -161,22 +179,6 @@ def add_view(self, except ImportError: # pragma: no cover pass - class DockWidget(QtGui.QDockWidget): - def __init__(self, *args, **kwargs): - super(DockWidget, self).__init__(*args, **kwargs) - self._event = EventEmitter() - - def emit(self, *args, **kwargs): - return self._event.emit(*args, **kwargs) - - def connect_(self, *args, **kwargs): - self._event.connect(*args, **kwargs) - - def closeEvent(self, e): - """Qt slot when the window is closed.""" - self.emit('close_widget') - super(DockWidget, self).closeEvent(e) - # Create the dock widget. dockwidget = DockWidget(self) dockwidget.setObjectName(title) @@ -221,7 +223,7 @@ def list_views(self, title='', is_visible=True): child.height() >= 10 ] - def view_count(self, is_visible=True): + def view_count(self): """Return the number of opened views.""" views = self.list_views() counts = defaultdict(lambda: 0) From 387039b66c50aac65216405ac9bedeb60dd8f376 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 10:46:12 +0200 Subject: [PATCH 0094/1059] Add show_shortcuts() --- phy/gui/dock.py | 19 +++++++++++++++++++ phy/gui/tests/test_dock.py | 14 ++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 05bd2d609..5e92fb0e6 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -8,6 +8,8 @@ from collections import defaultdict +from six import string_types + from .qt import QtCore, QtGui from ..utils.event import EventEmitter @@ -20,6 +22,23 @@ def _title(widget): return str(widget.windowTitle()).lower() +def _show_shortcut(shortcut): + if isinstance(shortcut, string_types): + return shortcut + elif isinstance(shortcut, tuple): + return ', '.join(shortcut) + + +def _show_shortcuts(shortcuts, name=''): + print() + if name: + name = ' for ' + name + print('Keyboard shortcuts' + name) + for name in sorted(shortcuts): + print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) + print() + + # ----------------------------------------------------------------------------- # Qt windows # ----------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 211b033f2..f88674bae 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -11,9 +11,9 @@ from vispy import app from ..qt import Qt -from ..dock import DockWindow +from ..dock import DockWindow, _show_shortcuts from phy.utils._color import _random_color - +from phy.utils.testing import captured_output # Skip these tests in "make test-quick". pytestmark = mark.long @@ -35,6 +35,16 @@ def on_draw(e): # pragma: no cover return c +def test_utils(): + shortcuts = { + 'test_1': 'ctrl+t', + 'test_2': ('ctrl+a', 'shift+b'), + } + with captured_output() as (stdout, stderr): + _show_shortcuts(shortcuts, 'test') + assert 'ctrl+a, shift+b' in stdout.getvalue() + + def test_dock_1(qtbot): gui = DockWindow(position=(200, 100), size=(100, 100)) From a9308657b7d78409b62822d0a19cb81d23c5a590 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 11:20:56 +0200 Subject: [PATCH 0095/1059] WIP: add Actions and Snippets companion classes to DockWindow --- phy/gui/dock.py | 247 +++++++++++++++++++++++++++++-------- phy/gui/tests/test_dock.py | 73 ++++++++--- 2 files changed, 251 insertions(+), 69 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 5e92fb0e6..140e8ba5c 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -7,12 +7,15 @@ # ----------------------------------------------------------------------------- from collections import defaultdict +import logging -from six import string_types +from six import string_types, PY3 from .qt import QtCore, QtGui from ..utils.event import EventEmitter +logger = logging.getLogger(__name__) + # ----------------------------------------------------------------------------- # Qt utilities @@ -39,6 +42,197 @@ def _show_shortcuts(shortcuts, name=''): print() +# ----------------------------------------------------------------------------- +# Companion class +# ----------------------------------------------------------------------------- + +class Actions(EventEmitter): + """Handle GUI actions.""" + def __init__(self): + super(Actions, self).__init__() + self._dock = None + self._actions = {} + + def reset(self): + """Reset the actions. + + All actions should be registered here, as follows: + + ```python + @actions.connect + def on_reset(): + actions.add(...) + actions.add(...) + ... + ``` + + """ + self.remove_all() + self.emit('reset') + + def attach(self, dock): + self._dock = dock + + # Default exit action. + @self.shortcut('ctrl+q') + def exit(): + dock.close() + + def add(self, name, callback=None, shortcut=None, + checkable=False, checked=False): + """Add an action with a keyboard shortcut.""" + if name in self._actions: + return + action = QtGui.QAction(name, self._dock) + action.triggered.connect(callback) + action.setCheckable(checkable) + action.setChecked(checked) + if shortcut: + if not isinstance(shortcut, (tuple, list)): + shortcut = [shortcut] + for key in shortcut: + action.setShortcut(key) + if self._dock: + self._dock.addAction(action) + self._actions[name] = action + if callback: + setattr(self, name, callback) + return action + + def remove(self, name): + """Remove an action.""" + if self._dock: + self._dock.removeAction(self._actions[name]) + del self._actions[name] + delattr(self, name) + + def remove_all(self): + """Remove all actions.""" + names = sorted(self._actions.keys()) + for name in names: + self.remove(name) + + def shortcut(self, key=None, name=None): + """Decorator to add a global keyboard shortcut.""" + def wrap(func): + self.add(name or func.__name__, shortcut=key, callback=func) + return wrap + + +class Snippets(object): + # HACK: Unicode characters do not appear to work on Python 2 + cursor = '\u200A\u258C' if PY3 else '' + _snippet_chars = 'abcdefghijklmnopqrstuvwxyz0123456789 ._,+*-=:()' + + def __init__(self): + self._dock = None + + def attach(self, dock, actions): + self._dock = dock + self._actions = actions + + # Register snippet mode shortcut. + @actions.connect + def on_reset(): + @actions.shortcut(':') + def enable_snippet_mode(): + self.mode_on() + + @property + def command(self): + """This is used to write a snippet message in the status bar. + + A cursor is appended at the end. + + """ + n = len(self._dock.status_message) + n_cur = len(self.cursor) + return self._dock.status_message[:n - n_cur] + + @command.setter + def command(self, value): + self._dock.status_message = value + self.cursor + + def run(self, snippet): + """Executes a snippet. + + May be overriden. + + """ + assert snippet[0] == ':' + snippet = snippet[1:] + split = snippet.split(' ') + cmd = split[0] + snippet = snippet[len(cmd):].strip() + func = self._snippets.get(cmd, None) + if func is None: + logger.info("The snippet `%s` could not be found.", cmd) + return + try: + logger.info("Processing snippet `%s`.", cmd) + func(self, snippet) + except Exception as e: + logger.warn("Error when executing snippet `%s`: %s.", + cmd, str(e)) + + def _create_snippet_actions(self): + """Delete all existing actions, and add mock ones for snippet + keystrokes. + + Used to enable snippet mode. + + """ + self._actions.remove_all() + + # One action per allowed character. + for i, char in enumerate(self._snippet_chars): + + def _make_func(char): + def callback(): + self.command += char + return callback + + self._actions.add('snippet_{}'.format(i), + shortcut=char, + callback=_make_func(char), + ) + + def backspace(): + if self.command == ':': + return + self.command = self.command[:-1] + + def enter(): + self.run(self.command) + self.disable_snippet_mode() + + self._actions.add('snippet_backspace', + shortcut='backspace', + callback=backspace, + ) + self._actions.add('snippet_activate', + shortcut=('enter', 'return'), + callback=enter, + ) + self._actions.add('snippet_disable', + shortcut='escape', + callback=self.disable_snippet_mode, + ) + + def mode_on(self): + logger.info("Snippet mode enabled, press `escape` to leave this mode.") + # Remove all existing actions, and replace them by snippet keystroke + # actions. + self._create_snippet_actions() + self.command = ':' + + def mode_off(self): + self._dock.status_message = '' + # Reestablishes the shortcuts. + self._actions.reset() + logger.info("Snippet mode disabled.") + + # ----------------------------------------------------------------------------- # Qt windows # ----------------------------------------------------------------------------- @@ -84,7 +278,6 @@ def __init__(self, title=None, ): super(DockWindow, self).__init__() - self._actions = {} if title is None: title = 'phy' self.setWindowTitle(title) @@ -131,52 +324,6 @@ def show(self): self.emit('show_gui') super(DockWindow, self).show() - # Actions - # ------------------------------------------------------------------------- - - def add_action(self, - name, - callback=None, - shortcut=None, - checkable=False, - checked=False, - ): - """Add an action with a keyboard shortcut.""" - if name in self._actions: - return - action = QtGui.QAction(name, self) - action.triggered.connect(callback) - action.setCheckable(checkable) - action.setChecked(checked) - if shortcut: - if not isinstance(shortcut, (tuple, list)): - shortcut = [shortcut] - for key in shortcut: - action.setShortcut(key) - self.addAction(action) - self._actions[name] = action - if callback: - setattr(self, name, callback) - return action - - def remove_action(self, name): - """Remove an action.""" - self.removeAction(self._actions[name]) - del self._actions[name] - delattr(self, name) - - def remove_actions(self): - """Remove all actions.""" - names = sorted(self._actions.keys()) - for name in names: - self.remove_action(name) - - def shortcut(self, key=None, name=None): - """Decorator to add a global keyboard shortcut.""" - def wrap(func): - self.add_action(name or func.__name__, shortcut=key, callback=func) - return wrap - # Views # ------------------------------------------------------------------------- @@ -256,11 +403,11 @@ def view_count(self): @property def status_message(self): """The message in the status bar.""" - return self._status_bar.currentMessage() + return str(self._status_bar.currentMessage()) @status_message.setter def status_message(self, value): - self._status_bar.showMessage(value) + self._status_bar.showMessage(str(value)) # State # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index f88674bae..8dc5e1096 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -6,12 +6,10 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - -from vispy import app +from pytest import mark, yield_fixture from ..qt import Qt -from ..dock import DockWindow, _show_shortcuts +from ..dock import DockWindow, _show_shortcuts, Actions, Snippets from phy.utils._color import _random_color from phy.utils.testing import captured_output @@ -20,11 +18,12 @@ #------------------------------------------------------------------------------ -# Tests +# Utilities and fixtures #------------------------------------------------------------------------------ def _create_canvas(): """Create a VisPy canvas with a color background.""" + from vispy import app c = app.Canvas() c.color = _random_color() @@ -35,6 +34,25 @@ def on_draw(e): # pragma: no cover return c +@yield_fixture +def gui(): + yield DockWindow(position=(200, 100), size=(100, 100)) + + +@yield_fixture +def actions(): + yield Actions() + + +@yield_fixture +def snippets(): + yield Snippets() + + +#------------------------------------------------------------------------------ +# Tests +#------------------------------------------------------------------------------ + def test_utils(): shortcuts = { 'test_1': 'ctrl+t', @@ -45,6 +63,37 @@ def test_utils(): assert 'ctrl+a, shift+b' in stdout.getvalue() +def test_actions(actions): + actions.add('test', lambda: None) + # Adding an action twice has no effect. + actions.add('test', lambda: None) + actions.remove_all() + + +def test_snippets(snippets): + pass + + +def test_actions_dock(qtbot, gui, actions): + actions.attach(gui) + qtbot.addWidget(gui) + gui.show() + qtbot.waitForWindowShown(gui) + + _press = [] + + @actions.shortcut('ctrl+g') + def press(): + _press.append(0) + + qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) + assert _press == [0] + + +def test_snippets_dock(): + pass + + def test_dock_1(qtbot): gui = DockWindow(position=(200, 100), size=(100, 100)) @@ -58,11 +107,6 @@ def on_show_gui(): qtbot.keyPress(gui, Qt.Key_Control) qtbot.keyRelease(gui, Qt.Key_Control) - gui.add_action('test', lambda: None) - # Adding an action twice has no effect. - gui.add_action('test', lambda: None) - gui.remove_actions() - view = gui.add_view(_create_canvas(), 'view1', floating=True) gui.add_view(_create_canvas(), 'view2') view.setFloating(False) @@ -97,12 +141,6 @@ def test_dock_state(qtbot): gui = DockWindow(size=(100, 100)) qtbot.addWidget(gui) - _press = [] - - @gui.shortcut('ctrl+g') - def press(): - _press.append(0) - gui.add_view(_create_canvas(), 'view1') gui.add_view(_create_canvas(), 'view2') gui.add_view(_create_canvas(), 'view2') @@ -114,9 +152,6 @@ def on_close_gui(): gui.show() qtbot.waitForWindowShown(gui) - qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) - assert _press == [0] - assert len(gui.list_views('view')) == 3 assert gui.view_count() == { 'view1': 1, From eda99972fbc9d8e9df707bc90e7e2b9328a4e3cb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 11:34:42 +0200 Subject: [PATCH 0096/1059] Test Actions --- phy/gui/dock.py | 23 +++++++++++++++++++++-- phy/gui/tests/test_dock.py | 15 +++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 140e8ba5c..cbbe64969 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -28,11 +28,12 @@ def _title(widget): def _show_shortcut(shortcut): if isinstance(shortcut, string_types): return shortcut - elif isinstance(shortcut, tuple): + elif isinstance(shortcut, (tuple, list)): return ', '.join(shortcut) -def _show_shortcuts(shortcuts, name=''): +def _show_shortcuts(shortcuts, name=None): + name = name or '' print() if name: name = ' for ' + name @@ -71,6 +72,7 @@ def on_reset(): self.emit('reset') def attach(self, dock): + """Attach a DockWindow.""" self._dock = dock # Default exit action. @@ -78,6 +80,17 @@ def attach(self, dock): def exit(): dock.close() + @property + def shortcuts(self): + """A dictionary of action shortcuts.""" + return {name: action._shortcut_string + for name, action in self._actions.items()} + + def show_shortcuts(self): + """Print all shortcuts.""" + _show_shortcuts(self.shortcuts, + self._dock.title() if self._dock else None) + def add(self, name, callback=None, shortcut=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" @@ -92,6 +105,10 @@ def add(self, name, callback=None, shortcut=None, shortcut = [shortcut] for key in shortcut: action.setShortcut(key) + # HACK: add the shortcut string to the QAction object so that + # it can be shown in show_shortcuts(). I don't manage to recover + # the key sequence string from a QAction using Qt. + action._shortcut_string = shortcut or '' if self._dock: self._dock.addAction(action) self._actions[name] = action @@ -122,6 +139,8 @@ def wrap(func): class Snippets(object): # HACK: Unicode characters do not appear to work on Python 2 cursor = '\u200A\u258C' if PY3 else '' + + # Allowed characters in snippet mode. _snippet_chars = 'abcdefghijklmnopqrstuvwxyz0123456789 ._,+*-=:()' def __init__(self): diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 8dc5e1096..b2c82b1a5 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -67,6 +67,21 @@ def test_actions(actions): actions.add('test', lambda: None) # Adding an action twice has no effect. actions.add('test', lambda: None) + + # Create a shortcut and display it. + _captured = [] + + @actions.shortcut('h') + def show_my_shortcuts(): + with captured_output() as (stdout, stderr): + actions.show_shortcuts() + _captured.append(stdout.getvalue()) + + actions.show_my_shortcuts() + assert 'show_my_shortcuts' in _captured[0] + assert ': h' in _captured[0] + print(_captured[0]) + actions.remove_all() From 2c9f26aaf70d0715a7d3f71a6d2dfb50c97cc354 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 11:45:15 +0200 Subject: [PATCH 0097/1059] WIP: refactor snippets --- phy/gui/dock.py | 107 +++++++++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 56 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index cbbe64969..68adfa033 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -80,17 +80,6 @@ def attach(self, dock): def exit(): dock.close() - @property - def shortcuts(self): - """A dictionary of action shortcuts.""" - return {name: action._shortcut_string - for name, action in self._actions.items()} - - def show_shortcuts(self): - """Print all shortcuts.""" - _show_shortcuts(self.shortcuts, - self._dock.title() if self._dock else None) - def add(self, name, callback=None, shortcut=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" @@ -129,6 +118,17 @@ def remove_all(self): for name in names: self.remove(name) + @property + def shortcuts(self): + """A dictionary of action shortcuts.""" + return {name: action._shortcut_string + for name, action in self._actions.items()} + + def show_shortcuts(self): + """Print all shortcuts.""" + _show_shortcuts(self.shortcuts, + self._dock.title() if self._dock else None) + def shortcut(self, key=None, name=None): """Decorator to add a global keyboard shortcut.""" def wrap(func): @@ -172,8 +172,47 @@ def command(self): def command(self, value): self._dock.status_message = value + self.cursor + def _backspace(self): + """Erase the last character in the snippet command.""" + if self.command == ':': + return + self.command = self.command[:-1] + + def _enter(self): + """Disable the snippet mode and execute the command.""" + command = self.command + self.disable_snippet_mode() + self.run(command) + + def _create_snippet_actions(self): + """Delete all existing actions, and add mock ones for snippet + keystrokes. + + Used to enable snippet mode. + + """ + self._actions.remove_all() + + # One action per allowed character. + for i, char in enumerate(self._snippet_chars): + + def _make_func(char): + def callback(): + self.command += char + return callback + + self._actions.add('snippet_{}'.format(i), shortcut=char, + callback=_make_func(char)) + + self._actions.add('snippet_backspace', shortcut='backspace', + callback=self._backspace) + self._actions.add('snippet_activate', shortcut=('enter', 'return'), + callback=self._enter) + self._actions.add('snippet_disable', shortcut='escape', + callback=self.disable_snippet_mode) + def run(self, snippet): - """Executes a snippet. + """Executes a snippet command. May be overriden. @@ -194,50 +233,6 @@ def run(self, snippet): logger.warn("Error when executing snippet `%s`: %s.", cmd, str(e)) - def _create_snippet_actions(self): - """Delete all existing actions, and add mock ones for snippet - keystrokes. - - Used to enable snippet mode. - - """ - self._actions.remove_all() - - # One action per allowed character. - for i, char in enumerate(self._snippet_chars): - - def _make_func(char): - def callback(): - self.command += char - return callback - - self._actions.add('snippet_{}'.format(i), - shortcut=char, - callback=_make_func(char), - ) - - def backspace(): - if self.command == ':': - return - self.command = self.command[:-1] - - def enter(): - self.run(self.command) - self.disable_snippet_mode() - - self._actions.add('snippet_backspace', - shortcut='backspace', - callback=backspace, - ) - self._actions.add('snippet_activate', - shortcut=('enter', 'return'), - callback=enter, - ) - self._actions.add('snippet_disable', - shortcut='escape', - callback=self.disable_snippet_mode, - ) - def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") # Remove all existing actions, and replace them by snippet keystroke From ba0dd63b9b2ef652e1eff0b59f8510abdbb8bfdd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 11:53:52 +0200 Subject: [PATCH 0098/1059] WIP --- phy/gui/dock.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 68adfa033..3a4e06cec 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -83,6 +83,7 @@ def exit(): def add(self, name, callback=None, shortcut=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" + # TODO: add menu_name option and create menu bar if name in self._actions: return action = QtGui.QAction(name, self._dock) @@ -141,7 +142,9 @@ class Snippets(object): cursor = '\u200A\u258C' if PY3 else '' # Allowed characters in snippet mode. - _snippet_chars = 'abcdefghijklmnopqrstuvwxyz0123456789 ._,+*-=:()' + # A Qt shortcut will be created for every character. + _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" + " ,.;:?!_-+~=*/\\(){}[]") def __init__(self): self._dock = None From dd4a0e02a793f4902f78ab64f96fd5857ef9e9d4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 12:20:17 +0200 Subject: [PATCH 0099/1059] Add snippet command parser --- phy/gui/dock.py | 28 ++++++++++++++++++++++++++++ phy/gui/tests/test_dock.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 3a4e06cec..5cd8e2e5b 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -43,6 +43,34 @@ def _show_shortcuts(shortcuts, name=None): print() +def _parse_arg(s): + try: + return int(s) + except ValueError: + pass + try: + return float(s) + except ValueError: + pass + return s + + +def _parse_list(s): + # Range: 'x-y' + if '-' in s: + m, M = map(_parse_arg, s.split('-')) + return tuple(range(m, M + 1)) + # List of ids: 'x,y,z' + elif ',' in s: + return tuple(map(_parse_arg, s.split(','))) + else: + return _parse_arg(s) + + +def _parse_snippet(s): + return list(map(_parse_list, s.split(' '))) + + # ----------------------------------------------------------------------------- # Companion class # ----------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index b2c82b1a5..07fdf5678 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -9,7 +9,8 @@ from pytest import mark, yield_fixture from ..qt import Qt -from ..dock import DockWindow, _show_shortcuts, Actions, Snippets +from ..dock import (DockWindow, _show_shortcuts, Actions, Snippets, + _parse_snippet) from phy.utils._color import _random_color from phy.utils.testing import captured_output @@ -85,6 +86,37 @@ def show_my_shortcuts(): actions.remove_all() +def test_snippets_parse(): + def _check(args, expected): + snippet = 'snip ' + args + assert _parse_snippet(snippet) == ['snip'] + expected + + _check('a', ['a']) + _check('abc', ['abc']) + _check('a,b,c', [('a', 'b', 'c')]) + _check('a b,c', ['a', ('b', 'c')]) + + _check('1', [1]) + _check('10', [10]) + + _check('1.', [1.]) + _check('10.', [10.]) + _check('10.0', [10.0]) + + _check('0 1', [0, 1]) + _check('0 1.', [0, 1.]) + _check('0 1.0', [0, 1.]) + + _check('0,1', [(0, 1)]) + _check('0,10.', [(0, 10.)]) + _check('0. 1,10.', [0., (1, 10.)]) + + _check('2-7', [(2, 3, 4, 5, 6, 7)]) + _check('2 3-5', [2, (3, 4, 5)]) + + _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) + + def test_snippets(snippets): pass From 8f7f2ddce6ecc57505a34aff2b6232d0952bc702 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 12:27:06 +0200 Subject: [PATCH 0100/1059] WIP: snippets --- phy/gui/dock.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 5cd8e2e5b..c4ba4295f 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -43,7 +43,12 @@ def _show_shortcuts(shortcuts, name=None): print() +# ----------------------------------------------------------------------------- +# Snippet parsing utilities +# ----------------------------------------------------------------------------- + def _parse_arg(s): + """Parse a number or string.""" try: return int(s) except ValueError: @@ -56,6 +61,7 @@ def _parse_arg(s): def _parse_list(s): + """Parse a comma-separated list of values (strings or numbers).""" # Range: 'x-y' if '-' in s: m, M = map(_parse_arg, s.split('-')) @@ -68,6 +74,7 @@ def _parse_list(s): def _parse_snippet(s): + """Parse an entire snippet command.""" return list(map(_parse_list, s.split(' '))) @@ -134,6 +141,18 @@ def add(self, name, callback=None, shortcut=None, setattr(self, name, callback) return action + def get_name_from_char(self, char): + """Return an action name from its defining character. + + The symbol `&` needs to appear before that character in the action + name. + + """ + to_find = '&' + char + for name, action in self._actions.items(): + if to_find in name: + return action + def remove(self, name): """Remove an action.""" if self._dock: @@ -245,24 +264,22 @@ def callback(): def run(self, snippet): """Executes a snippet command. - May be overriden. + May be overridden. """ assert snippet[0] == ':' snippet = snippet[1:] - split = snippet.split(' ') - cmd = split[0] - snippet = snippet[len(cmd):].strip() - func = self._snippets.get(cmd, None) - if func is None: - logger.info("The snippet `%s` could not be found.", cmd) - return + snippet_args = _parse_snippet(snippet) + character = snippet_args[0] + name = self._actions.get_name_from_char(character) + if name is None: + logger.info("The snippet `%s` could not be found.", character) + func = getattr(self._actions, name) try: - logger.info("Processing snippet `%s`.", cmd) - func(self, snippet) + logger.info("Processing snippet `%s`.", snippet) + func(*snippet_args[1:]) except Exception as e: - logger.warn("Error when executing snippet `%s`: %s.", - cmd, str(e)) + logger.warn("Error when executing snippet: %s.", str(e)) def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") From 0c8d6b3a98db5615b6e5b8fd1fb58c3049d369cf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 13:00:01 +0200 Subject: [PATCH 0101/1059] WIP: snippets --- phy/gui/dock.py | 25 +++++++++----------- phy/gui/tests/test_dock.py | 47 +++++++++++++++++++++++--------------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index c4ba4295f..fd70b5f44 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -115,7 +115,7 @@ def attach(self, dock): def exit(): dock.close() - def add(self, name, callback=None, shortcut=None, + def add(self, name, callback=None, shortcut=None, alias=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar @@ -134,6 +134,9 @@ def add(self, name, callback=None, shortcut=None, # it can be shown in show_shortcuts(). I don't manage to recover # the key sequence string from a QAction using Qt. action._shortcut_string = shortcut or '' + # The alias is used in snippets. By default it is the character after & + action._alias = alias or (name[name.index('&') + 1] + if '&' in name else None) if self._dock: self._dock.addAction(action) self._actions[name] = action @@ -141,17 +144,11 @@ def add(self, name, callback=None, shortcut=None, setattr(self, name, callback) return action - def get_name_from_char(self, char): - """Return an action name from its defining character. - - The symbol `&` needs to appear before that character in the action - name. - - """ - to_find = '&' + char + def from_alias(self, alias): + """Return an action name from its alias.""" for name, action in self._actions.items(): - if to_find in name: - return action + if action._alias == alias: + return name def remove(self, name): """Remove an action.""" @@ -270,10 +267,10 @@ def run(self, snippet): assert snippet[0] == ':' snippet = snippet[1:] snippet_args = _parse_snippet(snippet) - character = snippet_args[0] - name = self._actions.get_name_from_char(character) + alias = snippet_args[0] + name = self._actions.from_alias(alias) if name is None: - logger.info("The snippet `%s` could not be found.", character) + logger.info("The snippet `%s` could not be found.", alias) func = getattr(self._actions, name) try: logger.info("Processing snippet `%s`.", snippet) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 07fdf5678..00848a045 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -51,10 +51,10 @@ def snippets(): #------------------------------------------------------------------------------ -# Tests +# Test actions #------------------------------------------------------------------------------ -def test_utils(): +def test_shortcuts(): shortcuts = { 'test_1': 'ctrl+t', 'test_2': ('ctrl+a', 'shift+b'), @@ -86,6 +86,26 @@ def show_my_shortcuts(): actions.remove_all() +def test_actions_dock(qtbot, gui, actions): + actions.attach(gui) + qtbot.addWidget(gui) + gui.show() + qtbot.waitForWindowShown(gui) + + _press = [] + + @actions.shortcut('ctrl+g') + def press(): + _press.append(0) + + qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) + assert _press == [0] + + +#------------------------------------------------------------------------------ +# Test snippets +#------------------------------------------------------------------------------ + def test_snippets_parse(): def _check(args, expected): snippet = 'snip ' + args @@ -118,29 +138,18 @@ def _check(args, expected): def test_snippets(snippets): + # TODO pass -def test_actions_dock(qtbot, gui, actions): - actions.attach(gui) - qtbot.addWidget(gui) - gui.show() - qtbot.waitForWindowShown(gui) - - _press = [] - - @actions.shortcut('ctrl+g') - def press(): - _press.append(0) - - qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) - assert _press == [0] - - -def test_snippets_dock(): +def test_snippets_dock(qtbot, gui, snippets): pass +#------------------------------------------------------------------------------ +# Test dock +#------------------------------------------------------------------------------ + def test_dock_1(qtbot): gui = DockWindow(position=(200, 100), size=(100, 100)) From 4d1c094ec70b5d78109be19bf87b3abb06bcb57b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 13:23:16 +0200 Subject: [PATCH 0102/1059] WIP: test snippets --- phy/gui/dock.py | 40 ++++++++++++++++++++++++------------ phy/gui/tests/test_dock.py | 42 ++++++++++++++++++++++++++++++-------- 2 files changed, 61 insertions(+), 21 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index fd70b5f44..d6300cdf6 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -119,6 +119,10 @@ def add(self, name, callback=None, shortcut=None, alias=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar + # Get the alias from the character after & if it exists. + if alias is None: + alias = name[name.index('&') + 1] if '&' in name else name + name = name.replace('&', '') if name in self._actions: return action = QtGui.QAction(name, self._dock) @@ -134,20 +138,21 @@ def add(self, name, callback=None, shortcut=None, alias=None, # it can be shown in show_shortcuts(). I don't manage to recover # the key sequence string from a QAction using Qt. action._shortcut_string = shortcut or '' - # The alias is used in snippets. By default it is the character after & - action._alias = alias or (name[name.index('&') + 1] - if '&' in name else None) + # The alias is used in snippets. + action._alias = alias if self._dock: self._dock.addAction(action) self._actions[name] = action + logger.debug("Add action `%s`, alias `%s`, shortcut %s.", + name, alias, shortcut or '') if callback: setattr(self, name, callback) return action - def from_alias(self, alias): - """Return an action name from its alias.""" + def get_name(self, alias_or_name): + """Return an action name from its alias or name.""" for name, action in self._actions.items(): - if action._alias == alias: + if alias_or_name in (action._alias, name): return name def remove(self, name): @@ -174,10 +179,11 @@ def show_shortcuts(self): _show_shortcuts(self.shortcuts, self._dock.title() if self._dock else None) - def shortcut(self, key=None, name=None): + def shortcut(self, key=None, name=None, **kwargs): """Decorator to add a global keyboard shortcut.""" def wrap(func): - self.add(name or func.__name__, shortcut=key, callback=func) + self.add(name or func.__name__, shortcut=key, + callback=func, **kwargs) return wrap @@ -192,6 +198,7 @@ class Snippets(object): def __init__(self): self._dock = None + self._cmd = '' # only used when there is no dock attached def attach(self, dock, actions): self._dock = dock @@ -211,13 +218,18 @@ def command(self): A cursor is appended at the end. """ - n = len(self._dock.status_message) + msg = self._dock.status_message if self._dock else self._cmd + n = len(msg) n_cur = len(self.cursor) - return self._dock.status_message[:n - n_cur] + return msg[:n - n_cur] @command.setter def command(self, value): - self._dock.status_message = value + self.cursor + value += self.cursor + if not self._dock: + self._cmd = value + else: + self._dock.status_message = value def _backspace(self): """Erase the last character in the snippet command.""" @@ -268,9 +280,10 @@ def run(self, snippet): snippet = snippet[1:] snippet_args = _parse_snippet(snippet) alias = snippet_args[0] - name = self._actions.from_alias(alias) + name = self._actions.get_name(alias) if name is None: logger.info("The snippet `%s` could not be found.", alias) + return func = getattr(self._actions, name) try: logger.info("Processing snippet `%s`.", snippet) @@ -286,7 +299,8 @@ def mode_on(self): self.command = ':' def mode_off(self): - self._dock.status_message = '' + if self._dock: + self._dock.status_message = '' # Reestablishes the shortcuts. self._actions.reset() logger.info("Snippet mode disabled.") diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 00848a045..c4f925653 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark, yield_fixture +from pytest import mark, raises, yield_fixture from ..qt import Qt from ..dock import (DockWindow, _show_shortcuts, Actions, Snippets, @@ -64,8 +64,8 @@ def test_shortcuts(): assert 'ctrl+a, shift+b' in stdout.getvalue() -def test_actions(actions): - actions.add('test', lambda: None) +def test_actions_simple(actions): + actions.add('tes&t', lambda: None) # Adding an action twice has no effect. actions.add('test', lambda: None) @@ -81,7 +81,10 @@ def show_my_shortcuts(): actions.show_my_shortcuts() assert 'show_my_shortcuts' in _captured[0] assert ': h' in _captured[0] - print(_captured[0]) + + assert actions.get_name('e') is None + assert actions.get_name('t') == 'test' + assert actions.get_name('test') == 'test' actions.remove_all() @@ -137,12 +140,35 @@ def _check(args, expected): _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) -def test_snippets(snippets): - # TODO - pass +def test_snippets_actions(actions, snippets): + snippets.attach(None, actions) + + _actions = [] + + @actions.connect + def on_reset(): + @actions.shortcut(name='my_test_1') + def test_1(*args): + _actions.append((1, args)) + + @actions.shortcut(name='my_&test_2') + def test_2(*args): + _actions.append((2, args)) + + @actions.shortcut(name='my_test_3', alias='t3') + def test_3(*args): + _actions.append((3, args)) + + assert snippets.command == '' + + snippets.run(':my_test_1') + # assert _actions == [(1, ())] + + with raises(AttributeError): + actions.snippet_m() -def test_snippets_dock(qtbot, gui, snippets): +def test_snippets_dock(qtbot, gui, actions, snippets): pass From 782f3e3e5b9ae52a7b7f694bf0507effeb56b483 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 13:30:49 +0200 Subject: [PATCH 0103/1059] Add Actions.run() method --- phy/gui/dock.py | 8 ++++++++ phy/gui/tests/test_dock.py | 24 +++++++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index d6300cdf6..916f2b13a 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -140,6 +140,7 @@ def add(self, name, callback=None, shortcut=None, alias=None, action._shortcut_string = shortcut or '' # The alias is used in snippets. action._alias = alias + action._callback = callback if self._dock: self._dock.addAction(action) self._actions[name] = action @@ -149,6 +150,13 @@ def add(self, name, callback=None, shortcut=None, alias=None, setattr(self, name, callback) return action + def run(self, action, *args): + """Run an action, specified by its name or object.""" + if isinstance(action, string_types): + name = self.get_name(action) + action = self._actions[name] + return action._callback(*args) + def get_name(self, alias_or_name): """Return an action name from its alias or name.""" for name, action in self._actions.items(): diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index c4f925653..69dfa317d 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -65,9 +65,15 @@ def test_shortcuts(): def test_actions_simple(actions): - actions.add('tes&t', lambda: None) + + _res = [] + + def _action(*args): + _res.append(args) + + actions.add('tes&t', _action) # Adding an action twice has no effect. - actions.add('test', lambda: None) + actions.add('test', _action) # Create a shortcut and display it. _captured = [] @@ -86,6 +92,9 @@ def show_my_shortcuts(): assert actions.get_name('t') == 'test' assert actions.get_name('test') == 'test' + actions.run('t', 1) + assert _res == [(1,)] + actions.remove_all() @@ -159,10 +168,19 @@ def test_2(*args): def test_3(*args): _actions.append((3, args)) + actions.reset() + assert snippets.command == '' + # Action 1. snippets.run(':my_test_1') - # assert _actions == [(1, ())] + assert _actions == [(1, ())] + + # Action 2. + snippets.run(':t 1.5 a 2-4 5,7') + assert _actions[-1] == (2, (1.5, 'a', (2, 3, 4), (5, 7))) + + # snippets.run('snippet_:') with raises(AttributeError): actions.snippet_m() From 63f579de2d357048eb86c9c26b33c8bb350ec0ec Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 13:48:03 +0200 Subject: [PATCH 0104/1059] More snippet tests --- phy/gui/dock.py | 54 +++++++++++++++++++++++++------------- phy/gui/tests/test_dock.py | 25 +++++++++++++----- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 916f2b13a..55226e39a 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -134,35 +134,50 @@ def add(self, name, callback=None, shortcut=None, alias=None, shortcut = [shortcut] for key in shortcut: action.setShortcut(key) + + # Add some attributes to the QAction instance. + # The alias is used in snippets. + action._alias = alias + action._callback = callback + action._name = name # HACK: add the shortcut string to the QAction object so that # it can be shown in show_shortcuts(). I don't manage to recover # the key sequence string from a QAction using Qt. action._shortcut_string = shortcut or '' - # The alias is used in snippets. - action._alias = alias - action._callback = callback + + # Register the action. if self._dock: self._dock.addAction(action) self._actions[name] = action - logger.debug("Add action `%s`, alias `%s`, shortcut %s.", - name, alias, shortcut or '') + + # Log the creation of the action. + if not name.startswith('_'): + logger.debug("Add action `%s`, alias `%s`, shortcut %s.", + name, alias, shortcut or '') + if callback: setattr(self, name, callback) return action + def get_name(self, alias_or_name): + """Return an action name from its alias or name.""" + for name, action in self._actions.items(): + if alias_or_name in (action._alias, name): + return name + raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) + def run(self, action, *args): """Run an action, specified by its name or object.""" if isinstance(action, string_types): name = self.get_name(action) + assert name in self._actions action = self._actions[name] + else: + name = action.name + if not name.startswith('_'): + logger.debug("Execute action `%s`.", name) return action._callback(*args) - def get_name(self, alias_or_name): - """Return an action name from its alias or name.""" - for name, action in self._actions.items(): - if alias_or_name in (action._alias, name): - return name - def remove(self, name): """Remove an action.""" if self._dock: @@ -202,7 +217,7 @@ class Snippets(object): # Allowed characters in snippet mode. # A Qt shortcut will be created for every character. _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" - " ,.;:?!_-+~=*/\\(){}[]") + " ,.;?!_-+~=*/\(){}[]") def __init__(self): self._dock = None @@ -243,12 +258,14 @@ def _backspace(self): """Erase the last character in the snippet command.""" if self.command == ':': return + logger.debug("Snippet keystroke `Backspace`.") self.command = self.command[:-1] def _enter(self): """Disable the snippet mode and execute the command.""" command = self.command - self.disable_snippet_mode() + logger.debug("Snippet keystroke `Enter`.") + self.mode_off() self.run(command) def _create_snippet_actions(self): @@ -265,18 +282,19 @@ def _create_snippet_actions(self): def _make_func(char): def callback(): + logger.debug("Snippet keystroke `%s`.", char) self.command += char return callback - self._actions.add('snippet_{}'.format(i), shortcut=char, + self._actions.add('_snippet_{}'.format(i), shortcut=char, callback=_make_func(char)) - self._actions.add('snippet_backspace', shortcut='backspace', + self._actions.add('_snippet_backspace', shortcut='backspace', callback=self._backspace) - self._actions.add('snippet_activate', shortcut=('enter', 'return'), + self._actions.add('_snippet_activate', shortcut=('enter', 'return'), callback=self._enter) - self._actions.add('snippet_disable', shortcut='escape', - callback=self.disable_snippet_mode) + self._actions.add('_snippet_disable', shortcut='escape', + callback=self.mode_off) def run(self, snippet): """Executes a snippet command. diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 69dfa317d..171e5ca67 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -88,7 +88,8 @@ def show_my_shortcuts(): assert 'show_my_shortcuts' in _captured[0] assert ': h' in _captured[0] - assert actions.get_name('e') is None + with raises(ValueError): + assert actions.get_name('e') assert actions.get_name('t') == 'test' assert actions.get_name('test') == 'test' @@ -149,7 +150,7 @@ def _check(args, expected): _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) -def test_snippets_actions(actions, snippets): +def test_snippets_actions(qtbot, actions, snippets): snippets.attach(None, actions) _actions = [] @@ -180,10 +181,22 @@ def test_3(*args): snippets.run(':t 1.5 a 2-4 5,7') assert _actions[-1] == (2, (1.5, 'a', (2, 3, 4), (5, 7))) - # snippets.run('snippet_:') - - with raises(AttributeError): - actions.snippet_m() + def _run(cmd): + """Simulate keystrokes.""" + for char in cmd: + i = snippets._snippet_chars.index(char) + actions.run('_snippet_{}'.format(i)) + + # Need to activate the snippet mode first. + with raises(ValueError): + _run(':t3 hello') + + # Simulate keystrokes ':t3 hello' + snippets.mode_on() # ':' + _run('t3 hello') + actions._snippet_activate() # 'Enter' + assert _actions[-1] == (3, ('hello',)) + snippets.mode_off() def test_snippets_dock(qtbot, gui, actions, snippets): From fc8fecf58b21091d288b37f518c9ae60a2e86d4c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 14:14:19 +0200 Subject: [PATCH 0105/1059] Test snippets on DockWindow --- phy/gui/dock.py | 16 +++++++++++----- phy/gui/tests/test_dock.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 55226e39a..30eb99a09 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -110,10 +110,13 @@ def attach(self, dock): """Attach a DockWindow.""" self._dock = dock - # Default exit action. - @self.shortcut('ctrl+q') - def exit(): - dock.close() + # Register default actions. + @self.connect + def on_reset(): + # Default exit action. + @self.shortcut('ctrl+q') + def exit(): + dock.close() def add(self, name, callback=None, shortcut=None, alias=None, checkable=False, checked=False): @@ -317,6 +320,9 @@ def run(self, snippet): except Exception as e: logger.warn("Error when executing snippet: %s.", str(e)) + def is_mode_on(self): + return self.command.startswith(':') + def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") # Remove all existing actions, and replace them by snippet keystroke @@ -327,9 +333,9 @@ def mode_on(self): def mode_off(self): if self._dock: self._dock.status_message = '' + logger.info("Snippet mode disabled.") # Reestablishes the shortcuts. self._actions.reset() - logger.info("Snippet mode disabled.") # ----------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 171e5ca67..c92fd827a 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -151,7 +151,6 @@ def _check(args, expected): def test_snippets_actions(qtbot, actions, snippets): - snippets.attach(None, actions) _actions = [] @@ -169,6 +168,8 @@ def test_2(*args): def test_3(*args): _actions.append((3, args)) + # Attach the GUI and register the actions. + snippets.attach(None, actions) actions.reset() assert snippets.command == '' @@ -200,7 +201,34 @@ def _run(cmd): def test_snippets_dock(qtbot, gui, actions, snippets): - pass + + qtbot.addWidget(gui) + gui.show() + qtbot.waitForWindowShown(gui) + + _actions = [] + + @actions.connect + def on_reset(): + @actions.shortcut(name='my_test_1', alias='t1') + def test(*args): + _actions.append(args) + + # Attach the GUI and register the actions. + snippets.attach(gui, actions) + actions.attach(gui) + actions.reset() + + # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` + assert not snippets.is_mode_on() + qtbot.keyClicks(gui, ':t2 ') + assert snippets.is_mode_on() + qtbot.keyPress(gui, Qt.Key_Backspace) + qtbot.keyPress(gui, Qt.Key_Backspace) + qtbot.keyClicks(gui, '1 3-5 ab,c') + qtbot.keyPress(gui, Qt.Key_Return) + + assert _actions == [((3, 4, 5), ('ab', 'c'))] #------------------------------------------------------------------------------ From 9ece486df4b11726bd4bca1648eeb328465ed486 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 14:27:41 +0200 Subject: [PATCH 0106/1059] Add captured_logging() context manager --- phy/utils/testing.py | 14 ++++++++++++++ phy/utils/tests/test_testing.py | 10 +++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 5f2c41729..60d89c134 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -42,6 +42,20 @@ def captured_output(): sys.stdout, sys.stderr = old_out, old_err +@contextmanager +def captured_logging(logger): + buffer = StringIO() + handlers = logger.handlers + for handler in logger.handlers: + logger.removeHandler(handler) + handler = logging.StreamHandler(buffer) + logger.addHandler(handler) + yield buffer + logger.removeHandler(handler) + for handler in handlers: + logger.addHandler(handler) + + def _assert_equal(d_0, d_1): """Check that two objects are equal.""" # Compare arrays. diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index 8381ea433..fb13e3348 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -7,6 +7,7 @@ #------------------------------------------------------------------------------ from copy import deepcopy +import logging import os.path as op import time @@ -14,7 +15,7 @@ from pytest import mark from vispy.app import Canvas -from ..testing import (benchmark, captured_output, show_test, +from ..testing import (benchmark, captured_output, captured_logging, show_test, _assert_equal, _enable_profiler, _profile, show_colored_canvas, ) @@ -30,6 +31,13 @@ def test_captured_output(): assert out.getvalue().strip() == 'Hello world!' +def test_captured_logging(): + logger = logging.getLogger(__name__) + with captured_logging(logger) as buf: + logger.debug('Hello world!') + assert 'Hello world!' in buf.getvalue() + + def test_assert_equal(): d = {'a': {'b': np.random.rand(5), 3: 'c'}, 'b': 2.} d_bis = deepcopy(d) From bd492929be87151ff18af60fa921a30a4b158d77 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 14:28:24 +0200 Subject: [PATCH 0107/1059] Test captured_logging() --- phy/utils/tests/test_testing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index fb13e3348..78b6af9bb 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -33,9 +33,11 @@ def test_captured_output(): def test_captured_logging(): logger = logging.getLogger(__name__) + handlers = logger.handlers with captured_logging(logger) as buf: logger.debug('Hello world!') assert 'Hello world!' in buf.getvalue() + assert logger.handlers == handlers def test_assert_equal(): From 8193625b81ed2b87c6edce16999b3d228987956d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 14:34:11 +0200 Subject: [PATCH 0108/1059] Fix bugs in phy.utils --- phy/utils/testing.py | 5 ++++- phy/utils/tests/test_datasets.py | 28 ++++++++++++++-------------- phy/utils/tests/test_testing.py | 4 ++-- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 60d89c134..cb53064e1 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -43,14 +43,17 @@ def captured_output(): @contextmanager -def captured_logging(logger): +def captured_logging(name=None): buffer = StringIO() + logger = logging.getLogger(name) handlers = logger.handlers for handler in logger.handlers: logger.removeHandler(handler) handler = logging.StreamHandler(buffer) + handler.setLevel(logging.DEBUG) logger.addHandler(handler) yield buffer + buffer.flush() logger.removeHandler(handler) for handler in handlers: logger.addHandler(handler) diff --git a/phy/utils/tests/test_datasets.py b/phy/utils/tests/test_datasets.py index 7e86ab552..1c309f68e 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/utils/tests/test_datasets.py @@ -15,7 +15,6 @@ import responses from pytest import raises, yield_fixture -from phy import string_handler from ..datasets import (download_file, download_test_data, download_sample_data, @@ -23,6 +22,7 @@ _BASE_URL, _validate_output_dir, ) +from ..testing import captured_logging logger = logging.getLogger(__name__) @@ -123,23 +123,23 @@ def test_download_not_found(tempdir): @responses.activate def test_download_already_exists_invalid(tempdir, mock_url): - buffer = string_handler() - path = op.join(tempdir, 'test') - # Create empty file. - open(path, 'a').close() - _check(_dl(path)) - assert 'redownload' in buffer.getvalue() + with captured_logging() as buf: + path = op.join(tempdir, 'test') + # Create empty file. + open(path, 'a').close() + _check(_dl(path)) + assert 'redownload' in buf.getvalue() @responses.activate def test_download_already_exists_valid(tempdir, mock_url): - buffer = string_handler() - path = op.join(tempdir, 'test') - # Create valid file. - with open(path, 'ab') as f: - f.write(_DATA.tostring()) - _check(_dl(path)) - assert 'skip' in buffer.getvalue() + with captured_logging() as buf: + path = op.join(tempdir, 'test') + # Create valid file. + with open(path, 'ab') as f: + f.write(_DATA.tostring()) + _check(_dl(path)) + assert 'skip' in buf.getvalue() @responses.activate diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index 78b6af9bb..68a90faee 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -32,9 +32,9 @@ def test_captured_output(): def test_captured_logging(): - logger = logging.getLogger(__name__) + logger = logging.getLogger() handlers = logger.handlers - with captured_logging(logger) as buf: + with captured_logging() as buf: logger.debug('Hello world!') assert 'Hello world!' in buf.getvalue() assert logger.handlers == handlers From 6875c0b33ee9d9f4583a5707ac5288e9df751b42 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 14:38:54 +0200 Subject: [PATCH 0109/1059] WIP: increase coverage in phy.gui --- phy/__init__.py | 9 --------- phy/gui/tests/test_dock.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index 5acfc8b14..953d17734 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -58,15 +58,6 @@ def add_default_handler(level='INFO'): logger.addHandler(handler) -def string_handler(level='INFO'): - buffer = StringIO() - for handler in logger.handlers: - logger.removeHandler(handler) - handler = logging.StreamHandler(buffer) - logger.addHandler(handler) - return buffer - - if '--debug' in sys.argv: # pragma: no cover add_default_handler('DEBUG') logger.info("Activate DEBUG level.") diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index c92fd827a..2058fcd18 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -12,7 +12,7 @@ from ..dock import (DockWindow, _show_shortcuts, Actions, Snippets, _parse_snippet) from phy.utils._color import _random_color -from phy.utils.testing import captured_output +from phy.utils.testing import captured_output, captured_logging # Skip these tests in "make test-quick". pytestmark = mark.long @@ -150,6 +150,41 @@ def _check(args, expected): _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) +def test_snippets_errors(qtbot, actions, snippets): + + _actions = [] + + @actions.connect + def on_reset(): + @actions.shortcut(name='my_test', alias='t') + def test(arg): + # Enforce single-character argument. + assert len(str(arg)) == 1 + _actions.append(arg) + + # Attach the GUI and register the actions. + snippets.attach(None, actions) + actions.reset() + + with raises(ValueError): + snippets.run(':t1') + + with captured_logging() as buf: + snippets.run(':t') + assert 'missing 1 required positional argument' in buf.getvalue() + + with captured_logging() as buf: + snippets.run(':t 1 2') + assert 'takes 1 positional argument but 2 were given' in buf.getvalue() + + with captured_logging() as buf: + snippets.run(':t aa') + assert 'assert 2 == 1' in buf.getvalue() + + snippets.run(':t a') + assert _actions == ['a'] + + def test_snippets_actions(qtbot, actions, snippets): _actions = [] From 973175020f4086d099c5ef65cb85226603eec5a8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 14:43:12 +0200 Subject: [PATCH 0110/1059] Increase coverage in phy.gui --- phy/gui/dock.py | 6 ++---- phy/gui/tests/test_dock.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/phy/gui/dock.py b/phy/gui/dock.py index 30eb99a09..0c17ea665 100644 --- a/phy/gui/dock.py +++ b/phy/gui/dock.py @@ -176,7 +176,7 @@ def run(self, action, *args): assert name in self._actions action = self._actions[name] else: - name = action.name + name = action._name if not name.startswith('_'): logger.debug("Execute action `%s`.", name) return action._callback(*args) @@ -310,9 +310,7 @@ def run(self, snippet): snippet_args = _parse_snippet(snippet) alias = snippet_args[0] name = self._actions.get_name(alias) - if name is None: - logger.info("The snippet `%s` could not be found.", alias) - return + assert name func = getattr(self._actions, name) try: logger.info("Processing snippet `%s`.", snippet) diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py index 2058fcd18..137d41755 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_dock.py @@ -96,11 +96,18 @@ def show_my_shortcuts(): actions.run('t', 1) assert _res == [(1,)] + # Run an action instance. + actions.run(actions._actions['test'], 1) + actions.remove_all() def test_actions_dock(qtbot, gui, actions): actions.attach(gui) + + # Set the default actions. + actions.reset() + qtbot.addWidget(gui) gui.show() qtbot.waitForWindowShown(gui) @@ -114,6 +121,9 @@ def press(): qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) assert _press == [0] + # Quit the GUI. + qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) + #------------------------------------------------------------------------------ # Test snippets @@ -229,6 +239,7 @@ def _run(cmd): # Simulate keystrokes ':t3 hello' snippets.mode_on() # ':' + actions._snippet_backspace() _run('t3 hello') actions._snippet_activate() # 'Enter' assert _actions[-1] == (3, ('hello',)) From 300d240cfea61d0de3d62634d87561bd4ffd1af5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 15:32:46 +0200 Subject: [PATCH 0111/1059] Rename dock to gui --- phy/gui/__init__.py | 2 +- phy/gui/{dock.py => gui.py} | 44 ++++++++++----------- phy/gui/tests/{test_dock.py => test_gui.py} | 4 +- 3 files changed, 25 insertions(+), 25 deletions(-) rename phy/gui/{dock.py => gui.py} (94%) rename phy/gui/tests/{test_dock.py => test_gui.py} (98%) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index 5dbd6c7b7..e0bcdaf31 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -4,4 +4,4 @@ """GUI routines.""" from .qt import start_qt_app, run_qt_app, enable_qt -from .dock import DockWindow +from .gui import DockWindow diff --git a/phy/gui/dock.py b/phy/gui/gui.py similarity index 94% rename from phy/gui/dock.py rename to phy/gui/gui.py index 0c17ea665..3fd92ebff 100644 --- a/phy/gui/dock.py +++ b/phy/gui/gui.py @@ -86,7 +86,7 @@ class Actions(EventEmitter): """Handle GUI actions.""" def __init__(self): super(Actions, self).__init__() - self._dock = None + self._gui = None self._actions = {} def reset(self): @@ -106,9 +106,9 @@ def on_reset(): self.remove_all() self.emit('reset') - def attach(self, dock): + def attach(self, gui): """Attach a DockWindow.""" - self._dock = dock + self._gui = gui # Register default actions. @self.connect @@ -116,7 +116,7 @@ def on_reset(): # Default exit action. @self.shortcut('ctrl+q') def exit(): - dock.close() + gui.close() def add(self, name, callback=None, shortcut=None, alias=None, checkable=False, checked=False): @@ -128,7 +128,7 @@ def add(self, name, callback=None, shortcut=None, alias=None, name = name.replace('&', '') if name in self._actions: return - action = QtGui.QAction(name, self._dock) + action = QtGui.QAction(name, self._gui) action.triggered.connect(callback) action.setCheckable(checkable) action.setChecked(checked) @@ -149,8 +149,8 @@ def add(self, name, callback=None, shortcut=None, alias=None, action._shortcut_string = shortcut or '' # Register the action. - if self._dock: - self._dock.addAction(action) + if self._gui: + self._gui.addAction(action) self._actions[name] = action # Log the creation of the action. @@ -183,8 +183,8 @@ def run(self, action, *args): def remove(self, name): """Remove an action.""" - if self._dock: - self._dock.removeAction(self._actions[name]) + if self._gui: + self._gui.removeAction(self._actions[name]) del self._actions[name] delattr(self, name) @@ -203,7 +203,7 @@ def shortcuts(self): def show_shortcuts(self): """Print all shortcuts.""" _show_shortcuts(self.shortcuts, - self._dock.title() if self._dock else None) + self._gui.title() if self._gui else None) def shortcut(self, key=None, name=None, **kwargs): """Decorator to add a global keyboard shortcut.""" @@ -223,11 +223,11 @@ class Snippets(object): " ,.;?!_-+~=*/\(){}[]") def __init__(self): - self._dock = None - self._cmd = '' # only used when there is no dock attached + self._gui = None + self._cmd = '' # only used when there is no gui attached - def attach(self, dock, actions): - self._dock = dock + def attach(self, gui, actions): + self._gui = gui self._actions = actions # Register snippet mode shortcut. @@ -244,7 +244,7 @@ def command(self): A cursor is appended at the end. """ - msg = self._dock.status_message if self._dock else self._cmd + msg = self._gui.status_message if self._gui else self._cmd n = len(msg) n_cur = len(self.cursor) return msg[:n - n_cur] @@ -252,10 +252,10 @@ def command(self): @command.setter def command(self, value): value += self.cursor - if not self._dock: + if not self._gui: self._cmd = value else: - self._dock.status_message = value + self._gui.status_message = value def _backspace(self): """Erase the last character in the snippet command.""" @@ -329,8 +329,8 @@ def mode_on(self): self.command = ':' def mode_off(self): - if self._dock: - self._dock.status_message = '' + if self._gui: + self._gui.status_message = '' logger.info("Snippet mode disabled.") # Reestablishes the shortcuts. self._actions.reset() @@ -448,13 +448,13 @@ def add_view(self, except ImportError: # pragma: no cover pass - # Create the dock widget. + # Create the gui widget. dockwidget = DockWidget(self) dockwidget.setObjectName(title) dockwidget.setWindowTitle(title) dockwidget.setWidget(view) - # Set dock widget options. + # Set gui widget options. options = QtGui.QDockWidget.DockWidgetMovable if closable: options = options | QtGui.QDockWidget.DockWidgetClosable @@ -530,7 +530,7 @@ def save_geometry_state(self): def restore_geometry_state(self, gs): """Restore the position of the main window and the docks. - The dock widgets need to be recreated first. + The gui widgets need to be recreated first. This function can be called in `on_show()`. diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_gui.py similarity index 98% rename from phy/gui/tests/test_dock.py rename to phy/gui/tests/test_gui.py index 137d41755..97d581b1b 100644 --- a/phy/gui/tests/test_dock.py +++ b/phy/gui/tests/test_gui.py @@ -9,8 +9,8 @@ from pytest import mark, raises, yield_fixture from ..qt import Qt -from ..dock import (DockWindow, _show_shortcuts, Actions, Snippets, - _parse_snippet) +from ..gui import (DockWindow, _show_shortcuts, Actions, Snippets, + _parse_snippet) from phy.utils._color import _random_color from phy.utils.testing import captured_output, captured_logging From 1480e53d5f0fbd307c5e0ec58a1cef115e62c2bb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 24 Sep 2015 15:51:10 +0200 Subject: [PATCH 0112/1059] Rename DockWindow to GUI --- phy/gui/__init__.py | 2 +- phy/gui/gui.py | 12 ++++++------ phy/gui/tests/test_gui.py | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index e0bcdaf31..5f4908f9e 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -4,4 +4,4 @@ """GUI routines.""" from .qt import start_qt_app, run_qt_app, enable_qt -from .gui import DockWindow +from .gui import GUI diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 3fd92ebff..fd12ec6d0 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -107,7 +107,7 @@ def on_reset(): self.emit('reset') def attach(self, gui): - """Attach a DockWindow.""" + """Attach a GUI.""" self._gui = gui # Register default actions. @@ -358,10 +358,10 @@ def closeEvent(self, e): super(DockWidget, self).closeEvent(e) -class DockWindow(QtGui.QMainWindow): +class GUI(QtGui.QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. - `DockWindow` derives from `QMainWindow`. + `GUI` derives from `QMainWindow`. Events ------ @@ -380,7 +380,7 @@ def __init__(self, size=None, title=None, ): - super(DockWindow, self).__init__() + super(GUI, self).__init__() if title is None: title = 'phy' self.setWindowTitle(title) @@ -420,12 +420,12 @@ def closeEvent(self, e): if False in res: # pragma: no cover e.ignore() return - super(DockWindow, self).closeEvent(e) + super(GUI, self).closeEvent(e) def show(self): """Show the window.""" self.emit('show_gui') - super(DockWindow, self).show() + super(GUI, self).show() # Views # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 97d581b1b..7e7ca3262 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -9,7 +9,7 @@ from pytest import mark, raises, yield_fixture from ..qt import Qt -from ..gui import (DockWindow, _show_shortcuts, Actions, Snippets, +from ..gui import (GUI, _show_shortcuts, Actions, Snippets, _parse_snippet) from phy.utils._color import _random_color from phy.utils.testing import captured_output, captured_logging @@ -37,7 +37,7 @@ def on_draw(e): # pragma: no cover @yield_fixture def gui(): - yield DockWindow(position=(200, 100), size=(100, 100)) + yield GUI(position=(200, 100), size=(100, 100)) @yield_fixture @@ -283,7 +283,7 @@ def test(*args): def test_dock_1(qtbot): - gui = DockWindow(position=(200, 100), size=(100, 100)) + gui = GUI(position=(200, 100), size=(100, 100)) qtbot.addWidget(gui) # Increase coverage. @@ -316,7 +316,7 @@ def on_close_widget(): def test_dock_status_message(qtbot): - gui = DockWindow() + gui = GUI() qtbot.addWidget(gui) assert gui.status_message == '' gui.status_message = ':hello world!' @@ -325,7 +325,7 @@ def test_dock_status_message(qtbot): def test_dock_state(qtbot): _gs = [] - gui = DockWindow(size=(100, 100)) + gui = GUI(size=(100, 100)) qtbot.addWidget(gui) gui.add_view(_create_canvas(), 'view1') @@ -348,7 +348,7 @@ def on_close_gui(): gui.close() # Recreate the GUI with the saved state. - gui = DockWindow() + gui = GUI() gui.add_view(_create_canvas(), 'view1') gui.add_view(_create_canvas(), 'view2') From 794db14a34344ce93de0f2c52a97e07850c3af5e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 13:25:00 +0200 Subject: [PATCH 0113/1059] Remove Selector --- phy/utils/selector.py | 214 ------------------------------- phy/utils/tests/test_selector.py | 112 ---------------- 2 files changed, 326 deletions(-) delete mode 100644 phy/utils/selector.py delete mode 100644 phy/utils/tests/test_selector.py diff --git a/phy/utils/selector.py b/phy/utils/selector.py deleted file mode 100644 index 891fd3d21..000000000 --- a/phy/utils/selector.py +++ /dev/null @@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Selector structure.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ._types import _as_array, _as_list -from .array import (regular_subset, - get_excerpts, - _unique, - _ensure_unique, - _spikes_in_clusters, - _spikes_per_cluster, - ) - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _concat(l): - if not len(l): - return np.array([], dtype=np.int64) - return np.sort(np.hstack(l)) - - -#------------------------------------------------------------------------------ -# Selector class -#------------------------------------------------------------------------------ - -class Selector(object): - """Object representing a selection of spikes or clusters.""" - def __init__(self, spike_clusters, - n_spikes_max=None, - excerpt_size=None, - ): - self._spike_clusters = spike_clusters - self._n_spikes_max = n_spikes_max - self._n_spikes = (len(spike_clusters) - if spike_clusters is not None else 0) - self._excerpt_size = excerpt_size - self._selected_spikes = np.array([], dtype=np.int64) - self._selected_clusters = None - - @property - def n_spikes_max(self): - """Maximum number of spikes allowed in the selection.""" - return self._n_spikes_max - - @n_spikes_max.setter - def n_spikes_max(self, value): - """Change the maximum number of spikes allowed.""" - self._n_spikes_max = value - # Update the selected spikes accordingly. - self.selected_spikes = self.subset_spikes() - if self._n_spikes_max is not None: - assert len(self._selected_spikes) <= self._n_spikes_max - - @property - def excerpt_size(self): - """Maximum number of spikes allowed in the selection.""" - return self._excerpt_size - - @excerpt_size.setter - def excerpt_size(self, value): - """Change the excerpt size.""" - self._excerpt_size = value - # Update the selected spikes accordingly. - self.selected_spikes = self.subset_spikes() - - @_ensure_unique - def subset_spikes(self, - spikes=None, - n_spikes_max=None, - excerpt_size=None, - ): - """Prune the current selection to get at most `n_spikes_max` spikes. - - Parameters - ---------- - - spikes : array-like - Array of spike ids to subset from. By default, this is - `selector.selected_spikes`. - n_spikes_max : int or None - Maximum number of spikes allowed in the selection. - excerpt_size : int or None - If None, the method returns a regular strided selection. - Otherwise, returns a regular selection of contiguous chunks - with the specified chunk size. - - """ - # Default arguments. - if spikes is None: - spikes = self._selected_spikes - if spikes is None or len(spikes) == 0: - return spikes - if n_spikes_max is None: - n_spikes_max = self._n_spikes_max or len(spikes) - if excerpt_size is None: - excerpt_size = self._excerpt_size - # Nothing to do if there are less spikes than the maximum number. - if len(spikes) <= n_spikes_max: - return spikes - # Take a regular or chunked subset of the spikes. - if excerpt_size is None: - return regular_subset(spikes, n_spikes_max) - else: - n_excerpts = n_spikes_max // excerpt_size - return get_excerpts(spikes, - n_excerpts=n_excerpts, - excerpt_size=excerpt_size, - ) - - def subset_spikes_clusters(self, clusters, - n_spikes_max=None, - excerpt_size=None, - ): - """Take a subselection of spikes belonging to a set of clusters. - - This method ensures that the same number of spikes is chosen - for every spike. - - `n_spikes_max` is the maximum number of spikers *per cluster*. - - """ - if not len(clusters): - return {} - # Get the selection parameters. - if n_spikes_max is None: - n_spikes_max = self._n_spikes_max or self._n_spikes - if excerpt_size is None: - excerpt_size = self._excerpt_size - # Take all spikes from the selected clusters. - spikes = _spikes_in_clusters(self._spike_clusters, clusters) - if not len(spikes): - return {} - # Group the spikes per cluster. - spc = _spikes_per_cluster(spikes, self._spike_clusters[spikes]) - # Do nothing if there are less spikes than the maximum number. - if len(spikes) <= n_spikes_max: - return spc - # Take a regular or chunked subset of the spikes. - if excerpt_size is None: - return {cluster: regular_subset(spc[cluster], n_spikes_max) - for cluster in clusters} - else: - n_excerpts = n_spikes_max // excerpt_size - return {cluster: get_excerpts(spc[cluster], - n_excerpts=n_excerpts, - excerpt_size=excerpt_size, - ) - for cluster in clusters} - - @property - def selected_spikes(self): - """Ids of the selected spikes.""" - return self._selected_spikes - - @selected_spikes.setter - def selected_spikes(self, value): - """Explicitely select some spikes. - - The selection is automatically pruned to ensure that less than - `n_spikes_max` spikes are selected. - - """ - value = _as_array(value) - # Make sure there are less spikes than n_spikes_max. - self._selected_spikes = self.subset_spikes(value) - - @property - def selected_clusters(self): - """Cluster ids appearing in the current spike selection.""" - if self._selected_clusters is not None: - return self._selected_clusters - clusters = _unique(self._spike_clusters[self._selected_spikes]) - return clusters - - @selected_clusters.setter - def selected_clusters(self, value): - """Select some clusters. - - This will select less than `n_spikes_max` spikes belonging to - those clusters. - - """ - self._selected_clusters = _as_list(value) - value = _as_array(value) - # Make sure there are less spikes than n_spikes_max. - spk = self.subset_spikes_clusters(value) - self._selected_spikes = _concat(spk.values()) - - @property - def n_spikes(self): - return len(self._selected_spikes) - - @property - def n_clusters(self): - return len(self.selected_clusters) - - def on_cluster(self, up=None): - """Callback method called when the clustering has changed. - - This currently does nothing, i.e. the spike selection remains - unchanged when merges and splits occur. - - """ - # TODO - pass diff --git a/phy/utils/tests/test_selector.py b/phy/utils/tests/test_selector.py deleted file mode 100644 index 5b4af745e..000000000 --- a/phy/utils/tests/test_selector.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test selector.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae - -from ...io.mock import artificial_spike_clusters -from ..array import _spikes_in_clusters -from ..selector import Selector - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_selector_spikes(): - """Test selecting spikes.""" - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - selector = Selector(spike_clusters) - selector.on_cluster() - assert selector.n_spikes_max is None - selector.n_spikes_max = None - ae(selector.selected_spikes, []) - - # Select a few spikes. - my_spikes = [10, 20, 30] - selector.selected_spikes = my_spikes - ae(selector.selected_spikes, my_spikes) - - # Check selected clusters. - ae(selector.selected_clusters, np.unique(spike_clusters[my_spikes])) - - # Specify a maximum number of spikes. - selector.n_spikes_max = 3 - assert selector.n_spikes_max is 3 - my_spikes = [10, 20, 30, 40] - selector.selected_spikes = my_spikes[:3] - ae(selector.selected_spikes, my_spikes[:3]) - selector.selected_spikes = my_spikes - assert len(selector.selected_spikes) <= 3 - assert selector.n_spikes == len(selector.selected_spikes) - assert np.all(np.in1d(selector.selected_spikes, my_spikes)) - - # Check that this doesn't raise any error. - selector.selected_clusters = [100] - selector.selected_spikes = [] - - -def test_selector_clusters(): - """Test selecting clusters.""" - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - selector = Selector(spike_clusters) - selector.selected_clusters = [] - ae(selector.selected_spikes, []) - - # Select 1 cluster. - selector.selected_clusters = [0] - ae(selector.selected_spikes, _spikes_in_clusters(spike_clusters, [0])) - assert np.all(spike_clusters[selector.selected_spikes] == 0) - - # Select 2 clusters. - selector.selected_clusters = [1, 3] - ae(selector.selected_spikes, _spikes_in_clusters(spike_clusters, [1, 3])) - assert np.all(np.in1d(spike_clusters[selector.selected_spikes], (1, 3))) - assert selector.n_clusters == 2 - - # Specify a maximum number of spikes. - selector.n_spikes_max = 10 - selector.selected_clusters = [4, 2] - assert len(selector.selected_spikes) <= (10 * 2) - assert selector.selected_clusters == [4, 2] - assert np.all(np.in1d(spike_clusters[selector.selected_spikes], (2, 4))) - - # Reduce the number of maximum spikes: the selection should update - # accordingly. - selector.n_spikes_max = 5 - assert len(selector.selected_spikes) <= 5 - assert np.all(np.in1d(spike_clusters[selector.selected_spikes], (2, 4))) - - -def test_selector_subset(): - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - selector = Selector(spike_clusters) - selector.subset_spikes(excerpt_size=10) - selector.subset_spikes(np.arange(n_spikes), excerpt_size=10) - - -def test_selector_subset_clusters(): - n_spikes = 100 - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - spike_clusters[10:15] = 1 - spike_clusters[85:90] = 1 - - selector = Selector(spike_clusters) - spc = selector.subset_spikes_clusters([0, 1], excerpt_size=10) - counts = {_: len(spc[_]) for _ in sorted(spc.keys())} - # TODO - assert counts From 305acadfab97dc9792ab9b3c1b22ac23329e9f06 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 13:25:10 +0200 Subject: [PATCH 0114/1059] Clean up array utility module --- phy/cluster/manual/clustering.py | 8 +- phy/cluster/manual/tests/test_clustering.py | 18 +- phy/utils/array.py | 237 +------------------- phy/utils/tests/test_array.py | 113 ++-------- 4 files changed, 38 insertions(+), 338 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index c3fc4fb60..406f13afd 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -228,8 +228,8 @@ def spikes_in_clusters(self, clusters): #-------------------------------------------------------------------------- def _update_all_spikes_per_cluster(self): - self._spikes_per_cluster = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) + self._spikes_per_cluster = _spikes_per_cluster(self._spike_clusters, + self._spike_ids) def _do_assign(self, spike_ids, new_spike_clusters): """Make spike-cluster assignments after the spike selection has @@ -258,8 +258,8 @@ def _do_assign(self, spike_ids, new_spike_clusters): old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] for cluster in old_clusters} - new_spikes_per_cluster = _spikes_per_cluster(spike_ids, - new_spike_clusters) + new_spikes_per_cluster = _spikes_per_cluster(new_spike_clusters, + spike_ids) self._spikes_per_cluster.update(new_spikes_per_cluster) # All old clusters are deleted. for cluster in old_clusters: diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 602ac3fb1..0502a8de6 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -12,9 +12,7 @@ from six import itervalues from ....io.mock import artificial_spike_clusters -from ....utils.array import (_spikes_in_clusters, - _flatten_spikes_per_cluster, - ) +from ....utils.array import (_spikes_in_clusters,) from ..clustering import (_extend_spikes, _concatenate_spike_clusters, _extend_assignment, @@ -103,6 +101,20 @@ def test_extend_assignment(): # Test clustering #------------------------------------------------------------------------------ +def _flatten_spikes_per_cluster(spikes_per_cluster): + """Convert a dictionary {cluster: list_of_spikes} to a + spike_clusters array.""" + clusters = sorted(spikes_per_cluster) + clusters_arr = np.concatenate([(cluster * + np.ones(len(spikes_per_cluster[cluster]))) + for cluster in clusters]).astype(np.int64) + spikes_arr = np.concatenate([spikes_per_cluster[cluster] + for cluster in clusters]) + spike_clusters = np.vstack((spikes_arr, clusters_arr)) + ind = np.argsort(spike_clusters[0, :]) + return spike_clusters[1, ind] + + def _check_spikes_per_cluster(clustering): ae(_flatten_spikes_per_cluster(clustering.spikes_per_cluster), clustering.spike_clusters) diff --git a/phy/utils/array.py b/phy/utils/array.py index 02fc921a3..498840895 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -69,14 +69,6 @@ def _unique(x): return np.nonzero(bc)[0] -def _ensure_unique(func): - """Apply unique() to the output of a function.""" - def wrapped(*args, **kwargs): - out = func(*args, **kwargs) - return _unique(out) - return wrapped - - def _normalize(arr, keep_ratio=False): """Normalize an array into [0, 1].""" (x_min, y_min), (x_max, y_max) = arr.min(axis=0), arr.max(axis=0) @@ -302,20 +294,17 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): return out -def regular_subset(spikes=None, n_spikes_max=None): +def regular_subset(spikes=None, n_spikes_max=None, offset=0): """Prune the current selection to get at most n_spikes_max spikes.""" assert spikes is not None # Nothing to do if the selection already satisfies n_spikes_max. - if n_spikes_max is None or len(spikes) <= n_spikes_max: + if n_spikes_max is None or len(spikes) <= n_spikes_max: # pragma: no cover return spikes step = math.ceil(np.clip(1. / n_spikes_max * len(spikes), 1, len(spikes))) step = int(step) - # Random shift. - # start = np.random.randint(low=0, high=step) # Note: randomly-changing selections are confusing... - start = 0 - my_spikes = spikes[start::step][:n_spikes_max] + my_spikes = spikes[offset::step][:n_spikes_max] assert len(my_spikes) <= len(spikes) assert len(my_spikes) <= n_spikes_max return my_spikes @@ -329,17 +318,13 @@ def _spikes_in_clusters(spike_clusters, clusters): """Return the ids of all spikes belonging to the specified clusters.""" if len(spike_clusters) == 0 or len(clusters) == 0: return np.array([], dtype=np.int) - # spikes_per_cluster case. - if isinstance(spike_clusters, dict): - return np.sort(np.concatenate([spike_clusters[cluster] - for cluster in clusters])) return np.nonzero(np.in1d(spike_clusters, clusters))[0] -def _spikes_per_cluster(spike_ids, spike_clusters): +def _spikes_per_cluster(spike_clusters, spike_ids=None): """Return a dictionary {cluster: list_of_spikes}.""" - if not len(spike_ids): - return {} + if spike_ids is None: + spike_ids = np.arange(len(spike_clusters)) rel_spikes = np.argsort(spike_clusters) abs_spikes = spike_ids[rel_spikes] spike_clusters = spike_clusters[rel_spikes] @@ -356,213 +341,3 @@ def _spikes_per_cluster(spike_ids, spike_clusters): spikes_in_clusters[clusters[-1]] = np.sort(abs_spikes[idx[-1]:]) return spikes_in_clusters - - -def _flatten_spikes_per_cluster(spikes_per_cluster): - """Convert a dictionary {cluster: list_of_spikes} to a - spike_clusters array.""" - clusters = sorted(spikes_per_cluster) - clusters_arr = np.concatenate([(cluster * - np.ones(len(spikes_per_cluster[cluster]))) - for cluster in clusters]).astype(np.int64) - spikes_arr = np.concatenate([spikes_per_cluster[cluster] - for cluster in clusters]) - spike_clusters = np.vstack((spikes_arr, clusters_arr)) - ind = np.argsort(spike_clusters[0, :]) - return spike_clusters[1, ind] - - -def _concatenate_per_cluster_arrays(spikes_per_cluster, arrays): - """Concatenate arrays from a {cluster: array} dictionary.""" - assert set(arrays) <= set(spikes_per_cluster) - clusters = sorted(arrays) - # Check the sizes of the spikes per cluster and the arrays. - n_0 = [len(spikes_per_cluster[cluster]) for cluster in clusters] - n_1 = [len(arrays[cluster]) for cluster in clusters] - assert n_0 == n_1 - - # Concatenate all spikes to find the right insertion order. - if not len(clusters): - return np.array([]) - - spikes = np.concatenate([spikes_per_cluster[cluster] - for cluster in clusters]) - idx = np.argsort(spikes) - # NOTE: concatenate all arrays along the first axis, because we assume - # that the first axis represents the spikes. - arrays = np.concatenate([_as_array(arrays[cluster]) - for cluster in clusters]) - return arrays[idx, ...] - - -def _subset_spc(spc, clusters): - return {c: s for c, s in spc.items() - if c in clusters} - - -class PerClusterData(object): - """Store data associated to every spike. - - This class provides several data structures, with per-spike data and - per-cluster data. It also defines a `subset()` method that allows to - make a subset of the data using either spikes or clusters. - - """ - def __init__(self, - spike_ids=None, array=None, spike_clusters=None, - spc=None, arrays=None): - if (array is not None and spike_ids is not None): - # From array to per-cluster arrays. - self._spike_ids = _as_array(spike_ids) - self._array = _as_array(array) - self._spike_clusters = _as_array(spike_clusters) - self._check_array() - self._split() - self._check_dict() - elif (arrays is not None and spc is not None): - # From per-cluster arrays to array. - self._spc = spc - self._arrays = arrays - self._check_dict() - self._concatenate() - self._check_array() - else: - raise ValueError() - - @property - def spike_ids(self): - """Sorted array of all spike ids.""" - return self._spike_ids - - @property - def spike_clusters(self): - """Array with the cluster id of every spike.""" - return self._spike_clusters - - @property - def array(self): - """Data array. - - The first dimension of the array corresponds to the spikes in the - cluster. - - """ - return self._array - - @property - def arrays(self): - """Dictionary of arrays `{cluster: array}`. - - The first dimension of the arrays correspond to the spikes in the - cluster. - - """ - return self._arrays - - @property - def spc(self): - """Spikes per cluster dictionary.""" - return self._spc - - @property - def cluster_ids(self): - """Sorted list of clusters.""" - return self._cluster_ids - - @property - def n_clusters(self): - return len(self._cluster_ids) - - def _check_dict(self): - assert set(self._arrays) == set(self._spc) - clusters = sorted(self._arrays) - n_0 = [len(self._spc[cluster]) for cluster in clusters] - n_1 = [len(self._arrays[cluster]) for cluster in clusters] - assert n_0 == n_1 - - def _check_array(self): - assert len(self._array) == len(self._spike_ids) - assert len(self._spike_clusters) == len(self._spike_ids) - - def _concatenate(self): - self._cluster_ids = sorted(self._spc) - n = len(self.cluster_ids) - if n == 0: - self._array = np.array([]) - self._spike_clusters = np.array([], dtype=np.int32) - self._spike_ids = np.array([], dtype=np.int64) - elif n == 1: - c = self.cluster_ids[0] - self._array = _as_array(self._arrays[c]) - self._spike_ids = self._spc[c] - self._spike_clusters = c * np.ones(len(self._spike_ids), - dtype=np.int32) - else: - # Concatenate all spikes to find the right insertion order. - spikes = np.concatenate([self._spc[cluster] - for cluster in self.cluster_ids]) - idx = np.argsort(spikes) - self._spike_ids = np.sort(spikes) - # NOTE: concatenate all arrays along the first axis, because we - # assume that the first axis represents the spikes. - # TODO OPTIM: use ConcatenatedArray and implement custom indices. - # array = ConcatenatedArrays([_as_array(self._arrays[cluster]) - # for cluster in self.cluster_ids]) - array = np.concatenate([_as_array(self._arrays[cluster]) - for cluster in self.cluster_ids]) - self._array = array[idx] - self._spike_clusters = _flatten_spikes_per_cluster(self._spc) - - def _split(self): - self._spc = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) - self._cluster_ids = sorted(self._spc) - n = len(self.cluster_ids) - # Optimization for single cluster. - if n == 0: - self._arrays = {} - elif n == 1: - c = self._cluster_ids[0] - self._arrays = {c: self._array} - else: - self._arrays = {} - for cluster in sorted(self._cluster_ids): - spk = _as_array(self._spc[cluster]) - spk_rel = _index_of(spk, self._spike_ids) - self._arrays[cluster] = self._array[spk_rel] - - def subset(self, spike_ids=None, clusters=None, spc=None): - """Return a new PerClusterData instance with a subset of the data. - - There are three ways to specify the subset: - - * With a list of spikes - * With a list of clusters - * With a dictionary of `{cluster: some_spikes}` - - """ - if spike_ids is not None: - if np.array_equal(spike_ids, self._spike_ids): - return self - assert np.all(np.in1d(spike_ids, self._spike_ids)) - spike_ids_s_rel = _index_of(spike_ids, self._spike_ids) - array_s = self._array[spike_ids_s_rel] - spike_clusters_s = self._spike_clusters[spike_ids_s_rel] - return PerClusterData(spike_ids=spike_ids, - array=array_s, - spike_clusters=spike_clusters_s, - ) - elif clusters is not None: - assert set(clusters) <= set(self._cluster_ids) - spc_s = {clu: self._spc[clu] for clu in clusters} - arrays_s = {clu: self._arrays[clu] for clu in clusters} - return PerClusterData(spc=spc_s, arrays=arrays_s) - elif spc is not None: - clusters = sorted(spc) - assert set(clusters) <= set(self._cluster_ids) - arrays_s = {} - for cluster in clusters: - spk_rel = _index_of(_as_array(spc[cluster]), - _as_array(self._spc[cluster])) - arrays_s[cluster] = _as_array(self._arrays[cluster])[spk_rel] - return PerClusterData(spc=spc, arrays=arrays_s) diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index af6f790ef..2194ce6f8 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -18,13 +18,11 @@ _in_polygon, _spikes_in_clusters, _spikes_per_cluster, - _flatten_spikes_per_cluster, - _concatenate_per_cluster_arrays, chunk_bounds, + regular_subset, excerpts, data_chunk, get_excerpts, - PerClusterData, _range_from_slice, _pad, _load_arrays, @@ -273,6 +271,17 @@ def test_get_excerpts(): subdata = get_excerpts(data, n_excerpts=1, excerpt_size=10) ae(subdata, data) + assert len(get_excerpts(data, n_excerpts=0, excerpt_size=10)) == 0 + + +def test_regular_subset(): + spikes = [2, 3, 5, 7, 11, 13, 17] + ae(regular_subset(spikes), spikes) + ae(regular_subset(spikes, 100), spikes) + ae(regular_subset(spikes, 100, offset=2), spikes) + ae(regular_subset(spikes, 3), [2, 7, 17]) + ae(regular_subset(spikes, 3, offset=1), [3, 11]) + #------------------------------------------------------------------------------ # Test spike clusters functions @@ -301,108 +310,12 @@ def test_spikes_per_cluster(): """Test _spikes_per_cluster().""" n_spikes = 1000 - spike_ids = np.arange(n_spikes).astype(np.int64) n_clusters = 10 spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) + spikes_per_cluster = _spikes_per_cluster(spike_clusters) assert list(spikes_per_cluster.keys()) == list(range(n_clusters)) for i in range(10): ae(spikes_per_cluster[i], np.sort(spikes_per_cluster[i])) assert np.all(spike_clusters[spikes_per_cluster[i]] == i) - - sc = _flatten_spikes_per_cluster(spikes_per_cluster) - ae(spike_clusters, sc) - - -def test_concatenate_per_cluster_arrays(): - """Test _spikes_per_cluster().""" - - def _column(arr): - out = np.zeros((len(arr), 10)) - out[:, 0] = arr - return out - - # 8, 11, 12, 13, 17, 18, 20 - spikes_per_cluster = {2: [11, 13, 17], 3: [8, 12], 5: [18, 20]} - - arrays_1d = {2: [1, 3, 7], 3: [8, 2], 5: [8, 0]} - - arrays_2d = {2: _column([1, 3, 7]), - 3: _column([8, 2]), - 5: _column([8, 0])} - - concat = _concatenate_per_cluster_arrays(spikes_per_cluster, arrays_1d) - ae(concat, [8, 1, 2, 3, 7, 8, 0]) - - concat = _concatenate_per_cluster_arrays(spikes_per_cluster, arrays_2d) - ae(concat[:, 0], [8, 1, 2, 3, 7, 8, 0]) - ae(concat[:, 1:], np.zeros((7, 9))) - - -def test_per_cluster_data(): - - spike_ids = [8, 11, 12, 13, 17, 18, 20] - spc = { - 2: [11, 13, 17], - 3: [8, 12], - 5: [18, 20], - } - spike_clusters = [3, 2, 3, 2, 2, 5, 5] - arrays = { - 2: [1, 3, 7], - 3: [8, 2], - 5: [8, 0], - } - array = [8, 1, 2, 3, 7, 8, 0] - - def _check(pcd): - ae(pcd.spike_ids, spike_ids) - ae(pcd.spike_clusters, spike_clusters) - ae(pcd.array, array) - ae(pcd.spc, spc) - ae(pcd.arrays, arrays) - - # Check subset on 1 cluster. - pcd_s = pcd.subset(clusters=[2]) - ae(pcd_s.spike_ids, [11, 13, 17]) - ae(pcd_s.spike_clusters, [2, 2, 2]) - ae(pcd_s.array, [1, 3, 7]) - ae(pcd_s.spc, {2: [11, 13, 17]}) - ae(pcd_s.arrays, {2: [1, 3, 7]}) - - # Check subset on some spikes. - pcd_s = pcd.subset(spike_ids=[11, 12, 13, 17]) - ae(pcd_s.spike_ids, [11, 12, 13, 17]) - ae(pcd_s.spike_clusters, [2, 3, 2, 2]) - ae(pcd_s.array, [1, 2, 3, 7]) - ae(pcd_s.spc, {2: [11, 13, 17], 3: [12]}) - ae(pcd_s.arrays, {2: [1, 3, 7], 3: [2]}) - - # Check subset on 2 complete clusters. - pcd_s = pcd.subset(clusters=[3, 5]) - ae(pcd_s.spike_ids, [8, 12, 18, 20]) - ae(pcd_s.spike_clusters, [3, 3, 5, 5]) - ae(pcd_s.array, [8, 2, 8, 0]) - ae(pcd_s.spc, {3: [8, 12], 5: [18, 20]}) - ae(pcd_s.arrays, {3: [8, 2], 5: [8, 0]}) - - # Check subset on 2 incomplete clusters. - pcd_s = pcd.subset(spc={3: [8, 12], 5: [20]}) - ae(pcd_s.spike_ids, [8, 12, 20]) - ae(pcd_s.spike_clusters, [3, 3, 5]) - ae(pcd_s.array, [8, 2, 0]) - ae(pcd_s.spc, {3: [8, 12], 5: [20]}) - ae(pcd_s.arrays, {3: [8, 2], 5: [0]}) - - # From flat arrays. - pcd = PerClusterData(spike_ids=spike_ids, - array=array, - spike_clusters=spike_clusters, - ) - _check(pcd) - - # From dicts. - pcd = PerClusterData(spc=spc, arrays=arrays) - _check(pcd) From 3e1839b40565af8540c188a6a1d9a3b3d94f829e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 14:01:37 +0200 Subject: [PATCH 0115/1059] WIP: select_spikes() function --- phy/utils/array.py | 64 ++++++++++++++++++++++++----------- phy/utils/tests/test_array.py | 18 ++++++++++ 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/phy/utils/array.py b/phy/utils/array.py index 498840895..1b3e35bf8 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -8,12 +8,12 @@ import logging import math -from math import floor +from math import floor, exp import os.path as op import numpy as np -from ._types import _as_array +from ._types import _as_array, _is_array_like logger = logging.getLogger(__name__) @@ -294,22 +294,6 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): return out -def regular_subset(spikes=None, n_spikes_max=None, offset=0): - """Prune the current selection to get at most n_spikes_max spikes.""" - assert spikes is not None - # Nothing to do if the selection already satisfies n_spikes_max. - if n_spikes_max is None or len(spikes) <= n_spikes_max: # pragma: no cover - return spikes - step = math.ceil(np.clip(1. / n_spikes_max * len(spikes), - 1, len(spikes))) - step = int(step) - # Note: randomly-changing selections are confusing... - my_spikes = spikes[offset::step][:n_spikes_max] - assert len(my_spikes) <= len(spikes) - assert len(my_spikes) <= n_spikes_max - return my_spikes - - # ----------------------------------------------------------------------------- # Spike clusters utility functions # ----------------------------------------------------------------------------- @@ -324,7 +308,7 @@ def _spikes_in_clusters(spike_clusters, clusters): def _spikes_per_cluster(spike_clusters, spike_ids=None): """Return a dictionary {cluster: list_of_spikes}.""" if spike_ids is None: - spike_ids = np.arange(len(spike_clusters)) + spike_ids = np.arange(len(spike_clusters)).astype(np.int64) rel_spikes = np.argsort(spike_clusters) abs_spikes = spike_ids[rel_spikes] spike_clusters = spike_clusters[rel_spikes] @@ -341,3 +325,45 @@ def _spikes_per_cluster(spike_clusters, spike_ids=None): spikes_in_clusters[clusters[-1]] = np.sort(abs_spikes[idx[-1]:]) return spikes_in_clusters + + +def _flatten_per_cluster(per_cluster): + """Convert a dictionary {cluster: spikes} to a spikes array.""" + return np.sort(np.concatenate(list(per_cluster.values()))).astype(np.int64) + + +def regular_subset(spikes, n_spikes_max=None, offset=0): + """Prune the current selection to get at most n_spikes_max spikes.""" + assert spikes is not None + # Nothing to do if the selection already satisfies n_spikes_max. + if n_spikes_max is None or len(spikes) <= n_spikes_max: # pragma: no cover + return spikes + step = math.ceil(np.clip(1. / n_spikes_max * len(spikes), + 1, len(spikes))) + step = int(step) + # Note: randomly-changing selections are confusing... + my_spikes = spikes[offset::step][:n_spikes_max] + assert len(my_spikes) <= len(spikes) + assert len(my_spikes) <= n_spikes_max + return my_spikes + + +def select_spikes(cluster_ids=None, + max_n_spikes_per_cluster=None, + spikes_per_cluster=None): + """Return a selection of spikes belonging to the specified clusters.""" + assert _is_array_like(cluster_ids) + if not len(cluster_ids): + return np.array([], dtype=np.int64) + if max_n_spikes_per_cluster in (None, 0): + selection = {c: spikes_per_cluster[c] for c in cluster_ids} + else: + assert max_n_spikes_per_cluster > 0 + selection = {} + n_clusters = len(cluster_ids) + for cluster in cluster_ids: + n = max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1)) + n = int(np.clip(n, 1, n_clusters)) + spikes = spikes_per_cluster[cluster] + selection[cluster] = regular_subset(spikes, n_spikes_max=n) + return _flatten_per_cluster(selection) diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index 2194ce6f8..4c596760e 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -18,6 +18,8 @@ _in_polygon, _spikes_in_clusters, _spikes_per_cluster, + _flatten_per_cluster, + select_spikes, chunk_bounds, regular_subset, excerpts, @@ -319,3 +321,19 @@ def test_spikes_per_cluster(): for i in range(10): ae(spikes_per_cluster[i], np.sort(spikes_per_cluster[i])) assert np.all(spike_clusters[spikes_per_cluster[i]] == i) + + +def test_flatten_per_cluster(): + spc = {2: [2, 7, 11], 3: [3, 5], 5: []} + arr = _flatten_per_cluster(spc) + ae(arr, [2, 3, 5, 7, 11]) + + +def test_select_spikes(): + with raises(AssertionError): + select_spikes() + spikes = [2, 3, 5, 7, 11] + sc = [2, 3, 3, 2, 2] + spc = {2: [2, 7, 11], 3: [3, 5], 5: []} + ae(select_spikes([], spikes_per_cluster=spc), []) + ae(select_spikes([2, 3, 5], spikes_per_cluster=spc), spikes) From 745f8712ab78cf00527b0637827bd7339fd9994a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 14:04:17 +0200 Subject: [PATCH 0116/1059] More tests --- phy/utils/tests/test_array.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index 4c596760e..73ceddf97 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -333,7 +333,12 @@ def test_select_spikes(): with raises(AssertionError): select_spikes() spikes = [2, 3, 5, 7, 11] - sc = [2, 3, 3, 2, 2] spc = {2: [2, 7, 11], 3: [3, 5], 5: []} ae(select_spikes([], spikes_per_cluster=spc), []) ae(select_spikes([2, 3, 5], spikes_per_cluster=spc), spikes) + ae(select_spikes([2, 5], spikes_per_cluster=spc), spc[2]) + + ae(select_spikes([2, 3, 5], 0, spikes_per_cluster=spc), spikes) + ae(select_spikes([2, 3, 5], None, spikes_per_cluster=spc), spikes) + ae(select_spikes([2, 3, 5], 1, spikes_per_cluster=spc), [2, 3]) + ae(select_spikes([2, 5], 2, spikes_per_cluster=spc), [2]) From 2394bcd0d73500a680d4f69f819e108eb58d3d95 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 15:09:55 +0200 Subject: [PATCH 0117/1059] Comment --- phy/utils/array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/utils/array.py b/phy/utils/array.py index 1b3e35bf8..653a49bb4 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -362,6 +362,8 @@ def select_spikes(cluster_ids=None, selection = {} n_clusters = len(cluster_ids) for cluster in cluster_ids: + # Decrease the number of spikes per cluster when there + # are more clusters. n = max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1)) n = int(np.clip(n, 1, n_clusters)) spikes = spikes_per_cluster[cluster] From d047e93fd135d912bbfacd50cf05d63fbd783e01 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 15:29:50 +0200 Subject: [PATCH 0118/1059] WIP: ChunkedArray --- phy/utils/array.py | 28 ++++++++++++++++++++++++++++ phy/utils/tests/test_array.py | 17 +++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/phy/utils/array.py b/phy/utils/array.py index 653a49bb4..b2cb1674b 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -294,6 +294,34 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): return out +# ----------------------------------------------------------------------------- +# Chunked array +# ----------------------------------------------------------------------------- + +class ChunkedArray(object): + def __init__(self, getters=None): + self._getters = getters + + @property + def getters(self): + return self._getters + + def __repr__(self): + return ''.format(self._getters) + + +def from_dask_array(da): + from dask.core import flatten + getters = [da.dask[k] for k in flatten(da._keys())] + return ChunkedArray(getters) + + +def to_dask_array(ca, chunks, dtype=None, shape=None): + from dask.array import Array + dask = {(i, 0): ca.getters[i] for i in range(len(ca.getters))} + return Array(dask, 'arr', chunks, dtype=dtype, shape=shape) + + # ----------------------------------------------------------------------------- # Spike clusters utility functions # ----------------------------------------------------------------------------- diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index 73ceddf97..d773e8a86 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -25,6 +25,9 @@ excerpts, data_chunk, get_excerpts, + ChunkedArray, + from_dask_array, + to_dask_array, _range_from_slice, _pad, _load_arrays, @@ -285,6 +288,20 @@ def test_regular_subset(): ae(regular_subset(spikes, 3, offset=1), [3, 11]) +#------------------------------------------------------------------------------ +# Test chunked array +#------------------------------------------------------------------------------ + +def test_chunked_array(): + from dask.array import from_array + arr = np.arange(10) + chunks = ((2, 3, 5),) + da = from_array(arr, chunks) + ca = from_dask_array(da) + da_bis = to_dask_array(ca, chunks) + assert da.chunks == da_bis.chunks + + #------------------------------------------------------------------------------ # Test spike clusters functions #------------------------------------------------------------------------------ From 995a8d7d2b5a9940c2641c1ca1d09e142003ecd6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 15:46:46 +0200 Subject: [PATCH 0119/1059] WIP: ChunkedArray and dask.array --- phy/utils/array.py | 30 ++++++++++++++++++++++++++---- phy/utils/tests/test_array.py | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/phy/utils/array.py b/phy/utils/array.py index b2cb1674b..44c14823e 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -299,13 +299,34 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): # ----------------------------------------------------------------------------- class ChunkedArray(object): - def __init__(self, getters=None): + def __init__(self, getters=None, sizes=None, dtype=None, shape=None): self._getters = getters + self._sizes = sizes + self._dtype = dtype + self._shape = shape @property def getters(self): return self._getters + @property + def dtype(self): + return self._dtype + + @property + def shape(self): + return self._shape + + @property + def chunks(self): + assert self._sizes is not None + return (self._sizes,) + + def rechunk(self): + da = to_dask_array(self) + rechunked = da.rechunk() + return from_dask_array(rechunked) + def __repr__(self): return ''.format(self._getters) @@ -313,13 +334,14 @@ def __repr__(self): def from_dask_array(da): from dask.core import flatten getters = [da.dask[k] for k in flatten(da._keys())] - return ChunkedArray(getters) + return ChunkedArray(getters, sizes=da.chunks[0], + dtype=da.dtype, shape=da.shape) -def to_dask_array(ca, chunks, dtype=None, shape=None): +def to_dask_array(ca): from dask.array import Array dask = {(i, 0): ca.getters[i] for i in range(len(ca.getters))} - return Array(dask, 'arr', chunks, dtype=dtype, shape=shape) + return Array(dask, 'arr', ca.chunks, dtype=ca.dtype, shape=ca.shape) # ----------------------------------------------------------------------------- diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index d773e8a86..d8a6e1b60 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -298,7 +298,7 @@ def test_chunked_array(): chunks = ((2, 3, 5),) da = from_array(arr, chunks) ca = from_dask_array(da) - da_bis = to_dask_array(ca, chunks) + da_bis = to_dask_array(ca) assert da.chunks == da_bis.chunks From 4dba2a9a66fa6ff0ffe83852d3dcb9b66e263668 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 15:53:30 +0200 Subject: [PATCH 0120/1059] Fix --- phy/utils/tests/test_array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index d8a6e1b60..b1302d0f6 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -292,7 +292,7 @@ def test_regular_subset(): # Test chunked array #------------------------------------------------------------------------------ -def test_chunked_array(): +def test_chunked_array_dask(): from dask.array import from_array arr = np.arange(10) chunks = ((2, 3, 5),) @@ -300,6 +300,8 @@ def test_chunked_array(): ca = from_dask_array(da) da_bis = to_dask_array(ca) assert da.chunks == da_bis.chunks + assert da.dtype == da_bis.dtype + assert da.shape == da_bis.shape #------------------------------------------------------------------------------ From 5d1a2ffa357652bd69deb481b1ee246c1050496c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 16:19:06 +0200 Subject: [PATCH 0121/1059] Update array tests --- phy/utils/array.py | 13 +++++++++++++ phy/utils/tests/test_array.py | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/phy/utils/array.py b/phy/utils/array.py index 44c14823e..a20d334bb 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -298,8 +298,21 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): # Chunked array # ----------------------------------------------------------------------------- +def _id(x): + return x + + +def _getter(g): + if isinstance(g, tuple): + assert hasattr(g[0], '__call__') + elif _is_array_like(g): + return (_id, _as_array(g)) + return g + + class ChunkedArray(object): def __init__(self, getters=None, sizes=None, dtype=None, shape=None): + getters = [_getter(g) for g in getters] self._getters = getters self._sizes = sizes self._dtype = dtype diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index b1302d0f6..4c7632d6a 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -121,8 +121,8 @@ def test_unique(): """Test _unique() function""" _unique([]) - n_spikes = 1000 - n_clusters = 10 + n_spikes = 300 + n_clusters = 3 spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) ae(_unique(spike_clusters), np.arange(n_clusters)) @@ -292,6 +292,11 @@ def test_regular_subset(): # Test chunked array #------------------------------------------------------------------------------ +def test_chunked_array_simple(): + arr = ChunkedArray([[2, 3, 5]], sizes=(3,)) + assert arr.chunks == ((3,),) + + def test_chunked_array_dask(): from dask.array import from_array arr = np.arange(10) @@ -311,8 +316,8 @@ def test_chunked_array_dask(): def test_spikes_in_clusters(): """Test _spikes_in_clusters().""" - n_spikes = 1000 - n_clusters = 10 + n_spikes = 100 + n_clusters = 5 spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) ae(_spikes_in_clusters(spike_clusters, []), []) @@ -321,7 +326,7 @@ def test_spikes_in_clusters(): assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, [i])] == i) - clusters = [1, 5, 9] + clusters = [1, 2, 3] assert np.all(np.in1d(spike_clusters[_spikes_in_clusters(spike_clusters, clusters)], clusters)) @@ -330,14 +335,14 @@ def test_spikes_in_clusters(): def test_spikes_per_cluster(): """Test _spikes_per_cluster().""" - n_spikes = 1000 - n_clusters = 10 + n_spikes = 100 + n_clusters = 3 spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) spikes_per_cluster = _spikes_per_cluster(spike_clusters) assert list(spikes_per_cluster.keys()) == list(range(n_clusters)) - for i in range(10): + for i in range(n_clusters): ae(spikes_per_cluster[i], np.sort(spikes_per_cluster[i])) assert np.all(spike_clusters[spikes_per_cluster[i]] == i) From 656337d39225aac21e394290b7b82abd139bf65d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 16:56:32 +0200 Subject: [PATCH 0122/1059] Remove chunked array --- phy/utils/array.py | 63 ----------------------------------- phy/utils/tests/test_array.py | 24 ------------- 2 files changed, 87 deletions(-) diff --git a/phy/utils/array.py b/phy/utils/array.py index a20d334bb..653a49bb4 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -294,69 +294,6 @@ def get_excerpts(data, n_excerpts=None, excerpt_size=None): return out -# ----------------------------------------------------------------------------- -# Chunked array -# ----------------------------------------------------------------------------- - -def _id(x): - return x - - -def _getter(g): - if isinstance(g, tuple): - assert hasattr(g[0], '__call__') - elif _is_array_like(g): - return (_id, _as_array(g)) - return g - - -class ChunkedArray(object): - def __init__(self, getters=None, sizes=None, dtype=None, shape=None): - getters = [_getter(g) for g in getters] - self._getters = getters - self._sizes = sizes - self._dtype = dtype - self._shape = shape - - @property - def getters(self): - return self._getters - - @property - def dtype(self): - return self._dtype - - @property - def shape(self): - return self._shape - - @property - def chunks(self): - assert self._sizes is not None - return (self._sizes,) - - def rechunk(self): - da = to_dask_array(self) - rechunked = da.rechunk() - return from_dask_array(rechunked) - - def __repr__(self): - return ''.format(self._getters) - - -def from_dask_array(da): - from dask.core import flatten - getters = [da.dask[k] for k in flatten(da._keys())] - return ChunkedArray(getters, sizes=da.chunks[0], - dtype=da.dtype, shape=da.shape) - - -def to_dask_array(ca): - from dask.array import Array - dask = {(i, 0): ca.getters[i] for i in range(len(ca.getters))} - return Array(dask, 'arr', ca.chunks, dtype=ca.dtype, shape=ca.shape) - - # ----------------------------------------------------------------------------- # Spike clusters utility functions # ----------------------------------------------------------------------------- diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index 4c7632d6a..03aae16fe 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -25,9 +25,6 @@ excerpts, data_chunk, get_excerpts, - ChunkedArray, - from_dask_array, - to_dask_array, _range_from_slice, _pad, _load_arrays, @@ -288,27 +285,6 @@ def test_regular_subset(): ae(regular_subset(spikes, 3, offset=1), [3, 11]) -#------------------------------------------------------------------------------ -# Test chunked array -#------------------------------------------------------------------------------ - -def test_chunked_array_simple(): - arr = ChunkedArray([[2, 3, 5]], sizes=(3,)) - assert arr.chunks == ((3,),) - - -def test_chunked_array_dask(): - from dask.array import from_array - arr = np.arange(10) - chunks = ((2, 3, 5),) - da = from_array(arr, chunks) - ca = from_dask_array(da) - da_bis = to_dask_array(ca) - assert da.chunks == da_bis.chunks - assert da.dtype == da_bis.dtype - assert da.shape == da_bis.shape - - #------------------------------------------------------------------------------ # Test spike clusters functions #------------------------------------------------------------------------------ From 302686b5161398c28dce834d65cf199c2a5eeece Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 16:56:55 +0200 Subject: [PATCH 0123/1059] Update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c6acc7356..fc5dcadfe 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ wiki __pycache__ _old *.py[cod] +*~ .coverage* *credentials From 04f9df8cbbe0852e559491f1edef433dead83e2d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 19:42:32 +0200 Subject: [PATCH 0124/1059] WIP: update phy.electrode --- phy/electrode/mea.py | 6 +++++- phy/electrode/tests/test_mea.py | 18 +++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/phy/electrode/mea.py b/phy/electrode/mea.py index 1837bac5e..44dbbc80d 100644 --- a/phy/electrode/mea.py +++ b/phy/electrode/mea.py @@ -87,7 +87,7 @@ def load_probe(name): path = op.join(curdir, 'probes/{}.prb'.format(name)) if not op.exists(path): raise IOError("The probe `{}` cannot be found.".format(name)) - return _read_python(path) + return MEA(probe=_read_python(path)) def list_probes(): @@ -127,6 +127,10 @@ def __init__(self, adjacency = _probe_adjacency_list(probe) self.channels_per_group = _channels_per_group(probe) self._adjacency = adjacency + if probe: + # Select the first channel group. + cg = sorted(self._probe['channel_groups'].keys())[0] + self.change_channel_group(cg) def _check_positions(self, positions): if positions is None: diff --git a/phy/electrode/tests/test_mea.py b/phy/electrode/tests/test_mea.py index ab04246c8..a798a8f93 100644 --- a/phy/electrode/tests/test_mea.py +++ b/phy/electrode/tests/test_mea.py @@ -42,16 +42,12 @@ def test_probe(): assert _probe_adjacency_list(probe) == adjacency mea = MEA(probe=probe) - assert mea.positions is None - assert mea.channels is None - assert mea.n_channels == 0 + assert mea.adjacency == adjacency assert mea.channels_per_group == {0: [0, 3, 1], 1: [7]} - - mea.change_channel_group(0) - ae(mea.positions, [(10, 10), (20, 30), (10, 20)]) assert mea.channels == [0, 3, 1] assert mea.n_channels == 3 + ae(mea.positions, [(10, 10), (20, 30), (10, 20)]) def test_mea(): @@ -87,12 +83,11 @@ def test_positions(): def test_library(tempdir): + assert '1x32_buzsaki' in list_probes() + probe = load_probe('1x32_buzsaki') assert probe - assert probe['channel_groups'][0]['channels'] == list(range(32)) - assert _probe_all_channels(probe) == list(range(32)) - - assert '1x32_buzsaki' in list_probes() + assert probe.channels == list(range(32)) path = op.join(tempdir, 'test.prb') with raises(IOError): @@ -100,4 +95,5 @@ def test_library(tempdir): with open(path, 'w') as f: f.write('') - load_probe(path) + with raises(KeyError): + load_probe(path) From 767c67b9413fd627f2a1096254be7870ae138158 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 28 Sep 2015 19:46:04 +0200 Subject: [PATCH 0125/1059] WIP: SpikeDetector class --- phy/traces/spike_detect.py | 67 +++++++++++++++++++++++++++ phy/traces/tests/test_spike_detect.py | 34 ++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 phy/traces/spike_detect.py create mode 100644 phy/traces/tests/test_spike_detect.py diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py new file mode 100644 index 000000000..951dc3c49 --- /dev/null +++ b/phy/traces/spike_detect.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- + +"""Spike detection.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from traitlets.config.configurable import Configurable +from traitlets import Int, Float, Unicode, Bool + + +#------------------------------------------------------------------------------ +# SpikeDetector +#------------------------------------------------------------------------------ + +class SpikeDetector(Configurable): + filter_low = Float(500.) + filter_butter_order = Int(3) + chunk_size_seconds = Float(1) + chunk_overlap_seconds = Float(.015) + n_excerpts = Int(50) + excerpt_size_seconds = Float(1.) + use_single_threshold = Bool(True) + threshold_strong_std_factor = Float(4.5) + threshold_weak_std_factor = Float(2) + detect_spikes = Unicode('negative') + connected_component_join_size = Int(1) + extract_s_before = Int(10) + extract_s_after = Int(10) + weight_power = Float(2) + + def __init__(self, ctx=None): + super(SpikeDetector, self).__init__() + if not ctx or not hasattr(ctx, 'cache'): + return + self.filter = ctx.cache(self.filter) + self.extract_components = ctx.cache(self.extract_components) + self.extract_spikes = ctx.cache(self.extract_spikes) + + def set_metadata(self, probe, channel_mapping=None): + self.probe = probe + if channel_mapping is None: + channel_mapping = {c: c for c in probe.channels} + self.channel_mapping = channel_mapping + self.channels = probe.channels + self.n_channels = probe.n_channels + + def filter(self, raw_data): + pass + + def extract_components(self, filtered): + pass + + def extract_spikes(self, components): + return None, None, None + + def detect(self, raw_data, sample_rate=None): + assert sample_rate > 0 + assert raw_data.ndim == 2 + assert raw_data.shape[1] == self.n_channels + + filtered = self.filter(raw_data) + components = self.extract_components(filtered) + spike_samples, masks, waveforms = self.extract_spikes(components) + return spike_samples, masks diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py new file mode 100644 index 000000000..bc5f36d78 --- /dev/null +++ b/phy/traces/tests/test_spike_detect.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +"""Tests of spike detection routines.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_array_equal as ae +from numpy.testing import assert_allclose as ac + +from phy.utils.datasets import download_test_data +from phy.electrode import load_probe +from ..spike_detect import (SpikeDetector, + ) + + +#------------------------------------------------------------------------------ +# Test spike detection +#------------------------------------------------------------------------------ + +def test_detect(): + + path = download_test_data('test-32ch-10s.dat') + traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) + traces = traces[:45000] + n_samples, n_channels = traces.shape + sample_rate = 20000 + probe = load_probe('1x32_buzsaki') + + sd = SpikeDetector() + sd.set_metadata(probe) + spike_samples, masks = sd.detect(traces, sample_rate=sample_rate) From ec08c16223cb1e2e63b5943c078161e4b91229fd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 12:01:18 +0200 Subject: [PATCH 0126/1059] WIP: remove channel group support from waveform extractor --- phy/traces/tests/test_waveform.py | 18 +++++------------- phy/traces/waveform.py | 29 ++++------------------------- 2 files changed, 9 insertions(+), 38 deletions(-) diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index 09819aad2..5b4859374 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -27,9 +27,6 @@ def test_extract_simple(): strong = 2. nc = 4 ns = 20 - channels = list(range(nc)) - cpg = {0: channels} - # graph = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: []} data = np.random.uniform(size=(ns, nc), low=0., high=1.) @@ -51,7 +48,6 @@ def test_extract_simple(): we = WaveformExtractor(extract_before=3, extract_after=5, - channels_per_group=cpg, ) we.set_thresholds(weak=weak, strong=strong) @@ -60,7 +56,6 @@ def test_extract_simple(): ae(comp.comp_s, [10, 10, 11, 11, 12, 12]) ae(comp.comp_ch, [0, 1, 0, 1, 0, 1]) assert (comp.s_min, comp.s_max) == (10 - 3, 12 + 4) - ae(comp.channels, range(nc)) # _normalize() assert we._normalize(weak) == 0 @@ -83,7 +78,7 @@ def test_extract_simple(): assert 11 <= s < 12 # extract() - wave_e = we.extract(data, s, channels=channels) + wave_e = we.extract(data, s) assert wave_e.shape[1] == wave.shape[1] ae(wave[3:6, :2], wave_e[3:6, :2]) @@ -92,9 +87,8 @@ def test_extract_simple(): assert wave_a.shape == (3 + 5, nc) # Test final call. - groups, s_f, wave_f, masks_f = we(component, data=data, data_t=data) + s_f, wave_f, masks_f = we(component, data=data, data_t=data) assert s_f == s - assert np.all(groups == 0) ae(masks_f, masks) ae(wave_f, wave_a) @@ -103,13 +97,11 @@ def test_extract_simple(): extract_after=5, thresholds={'weak': weak, 'strong': strong}, - channels_per_group={0: [1, 0, 3]}, ) - groups, s_f_o, wave_f_o, masks_f_o = we(component, data=data, data_t=data) - assert np.all(groups == 0) + s_f_o, wave_f_o, masks_f_o = we(component, data=data, data_t=data) assert s_f == s_f_o - assert np.allclose(wave_f[:, [1, 0, 3]], wave_f_o) - ae(masks_f_o, [1., 0.5, 0.]) + assert np.allclose(wave_f, wave_f_o) + ae(masks_f_o, [0.5, 1., 0., 0.]) def test_get_padded(): diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index d94c0d52b..1cfa52b3c 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -48,31 +48,15 @@ def __init__(self, extract_after=None, weight_power=None, thresholds=None, - channels_per_group=None, ): self._extract_before = extract_before self._extract_after = extract_after self._weight_power = weight_power if weight_power is not None else 1. self._thresholds = thresholds or {} - self._channels_per_group = channels_per_group - # mapping channel => channels in the shank - self._dep_channels = {i: channels - for channels in channels_per_group.values() - for i in channels} - self._channel_groups = {i: g - for g, channels in channels_per_group.items() - for i in channels} def _component(self, component, data=None, n_samples=None): comp_s = component[:, 0] # shape: (component_size,) comp_ch = component[:, 1] # shape: (component_size,) - channel = comp_ch[0] - if channel not in self._dep_channels: # pragma: no cover - raise RuntimeError("Channel `{}` appears to be dead and should " - "have been excluded from the threshold " - "crossings.".format(channel)) - channels = self._dep_channels[channel] - group = self._channel_groups[comp_ch[0]] # Get the temporal window around the waveform. s_min, s_max = (comp_s.min() - 3), (comp_s.max() + 4) @@ -84,8 +68,6 @@ def _component(self, component, data=None, n_samples=None): comp_ch=comp_ch, s_min=s_min, s_max=s_max, - channels=channels, - group=group, ) def _normalize(self, x): @@ -106,7 +88,6 @@ def _comp_wave(self, data_t, comp): def masks(self, data_t, wave, comp): nc = data_t.shape[1] - channels = comp.channels comp_ch = comp.comp_ch s_min = comp.s_min @@ -122,7 +103,6 @@ def masks(self, data_t, wave, comp): # Compute the float masks. masks_float = self._normalize(peaks_values) # Keep shank channels. - masks_float = masks_float[channels] return masks_float def spike_sample_aligned(self, wave, comp): @@ -135,13 +115,13 @@ def spike_sample_aligned(self, wave, comp): s_aligned = np.sum(wave_n_p * u) / np.sum(wave_n_p) + s_min return s_aligned - def extract(self, data, s_aligned, channels=None): + def extract(self, data, s_aligned): s = int(s_aligned) # Get block of given size around peak sample. waveform = _get_padded(data, s - self._extract_before - 1, s + self._extract_after + 2) - return waveform[:, channels] # Keep shank channels. + return waveform def align(self, waveform, s_aligned): s = int(s_aligned) @@ -166,20 +146,19 @@ def __call__(self, component=None, data=None, data_t=None): data=data, n_samples=data_t.shape[0], ) - channels = comp.channels wave = self._comp_wave(data_t, comp) masks = self.masks(data_t, wave, comp) s_aligned = self.spike_sample_aligned(wave, comp) - waveform_unaligned = self.extract(data, s_aligned, channels=channels) + waveform_unaligned = self.extract(data, s_aligned) waveform_aligned = self.align(waveform_unaligned, s_aligned) assert waveform_aligned.ndim == 2 assert masks.ndim == 1 assert waveform_aligned.shape[1] == masks.shape[0] - return comp.group, s_aligned, waveform_aligned, masks + return s_aligned, waveform_aligned, masks #------------------------------------------------------------------------------ From 9f5b7839236a801241075a171196d8c6b77d955f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 12:10:34 +0200 Subject: [PATCH 0127/1059] WIP: traces tests pass --- phy/traces/spike_detect.py | 111 ++++++++++++++++++++++---- phy/traces/tests/test_spike_detect.py | 9 ++- phy/traces/tests/test_waveform.py | 4 +- phy/traces/waveform.py | 2 +- 4 files changed, 106 insertions(+), 20 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 951dc3c49..670068715 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -6,10 +6,20 @@ # Imports #------------------------------------------------------------------------------ +import logging + import numpy as np from traitlets.config.configurable import Configurable from traitlets import Int, Float, Unicode, Bool +from phy.electrode import MEA +from phy.utils.array import get_excerpts +from .detect import FloodFillDetector, Thresholder, compute_threshold +from .filter import Filter +from .waveform import WaveformExtractor + +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # SpikeDetector @@ -35,33 +45,106 @@ def __init__(self, ctx=None): super(SpikeDetector, self).__init__() if not ctx or not hasattr(ctx, 'cache'): return + self.find_thresholds = ctx.cache(self.find_thresholds) self.filter = ctx.cache(self.filter) self.extract_components = ctx.cache(self.extract_components) self.extract_spikes = ctx.cache(self.extract_spikes) - def set_metadata(self, probe, channel_mapping=None): + def set_metadata(self, probe, channel_mapping=None, sample_rate=None): + assert isinstance(probe, MEA) self.probe = probe + + assert sample_rate > 0 + self.sample_rate = sample_rate + if channel_mapping is None: channel_mapping = {c: c for c in probe.channels} self.channel_mapping = channel_mapping - self.channels = probe.channels - self.n_channels = probe.n_channels - def filter(self, raw_data): - pass + # Array of channel idx to consider. + self.channels = sorted(channel_mapping.keys()) + self.n_channels = len(self.channels) + self.n_samples_waveforms = self.extract_s_before + self.extract_s_after + + def _select_channels(self, traces): + return traces[:, self.channels] + + def find_thresholds(self, traces): + """Find weak and strong thresholds in filtered traces.""" + excerpt_size = int(self.excerpt_size_seconds * self.sample_rate) + single_threshold = self.use_single_threshold + std_factor = (self.threshold_weak_std_factor, + self.threshold_strong_std_factor) + + logger.info("Extracting some data for finding the thresholds...") + excerpt = get_excerpts(traces, n_excerpts=self.n_excerpts, + excerpt_size=excerpt_size) + + logger.info("Filtering the excerpts...") + excerpt_f = self.filter(excerpt) + + logger.info("Computing the thresholds...") + thresholds = compute_threshold(excerpt_f, + single_threshold=single_threshold, + std_factor=std_factor) + + thresholds = {'weak': thresholds[0], 'strong': thresholds[1]} + logger.info("Thresholds found: {}.".format(thresholds)) + self._thresholder = Thresholder(mode=self.detect_spikes, + thresholds=thresholds) + return thresholds + + def filter(self, traces): + f = Filter(rate=self.sample_rate, + low=self.filter_low, + high=0.95 * .5 * self.sample_rate, + order=self.filter_butter_order, + ) + return f(traces).astype(np.float32) def extract_components(self, filtered): - pass + # Transform the filtered data according to the detection mode. + traces_t = self._thresholder.transform(filtered) - def extract_spikes(self, components): - return None, None, None + # Compute the threshold crossings. + weak = self._thresholder.detect(traces_t, 'weak') + strong = self._thresholder.detect(traces_t, 'strong') - def detect(self, raw_data, sample_rate=None): - assert sample_rate > 0 - assert raw_data.ndim == 2 - assert raw_data.shape[1] == self.n_channels + # Run the detection. + join_size = self.connected_component_join_size + detector = FloodFillDetector(probe_adjacency_list=self.probe.adjacency, + join_size=join_size) + return detector(weak_crossings=weak, + strong_crossings=strong) + + def extract_spikes(self, filtered, components): + # Transform the filtered data according to the detection mode. + traces_t = self._thresholder.transform(filtered) + + # Extract all waveforms. + extractor = WaveformExtractor(extract_before=self.extract_s_before, + extract_after=self.extract_s_after, + weight_power=self.weight_power, + thresholds=self._thresholds, + ) + + s, m, w = zip(*(extractor(component, data=filtered, data_t=traces_t) + for component in components)) + s = np.array(s, dtype=np.int64) + m = np.array(m, dtype=np.float32) + w = np.array(w, dtype=np.float32) + return s, m, w + + def detect(self, traces): + assert traces.ndim == 2 + assert traces.shape[1] == self.n_channels + + # Only keep the selected channels (given shank, no dead channels, etc.) + traces = self._select_channels(traces) + self._thresholds = self.find_thresholds(traces) - filtered = self.filter(raw_data) + filtered = self.filter(traces) components = self.extract_components(filtered) - spike_samples, masks, waveforms = self.extract_spikes(components) + spike_samples, masks, waveforms = self.extract_spikes(filtered, + components) return spike_samples, masks diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index bc5f36d78..00e1ba15c 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -24,11 +24,14 @@ def test_detect(): path = download_test_data('test-32ch-10s.dat') traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) - traces = traces[:45000] + traces = traces[:20000] n_samples, n_channels = traces.shape sample_rate = 20000 probe = load_probe('1x32_buzsaki') sd = SpikeDetector() - sd.set_metadata(probe) - spike_samples, masks = sd.detect(traces, sample_rate=sample_rate) + sd.use_single_threshold = False + sd.set_metadata(probe, sample_rate=sample_rate) + spike_samples, masks = sd.detect(traces) + print(spike_samples) + print(masks) diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index 5b4859374..2af203632 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -87,7 +87,7 @@ def test_extract_simple(): assert wave_a.shape == (3 + 5, nc) # Test final call. - s_f, wave_f, masks_f = we(component, data=data, data_t=data) + s_f, masks_f, wave_f = we(component, data=data, data_t=data) assert s_f == s ae(masks_f, masks) ae(wave_f, wave_a) @@ -98,7 +98,7 @@ def test_extract_simple(): thresholds={'weak': weak, 'strong': strong}, ) - s_f_o, wave_f_o, masks_f_o = we(component, data=data, data_t=data) + s_f_o, masks_f_o, wave_f_o = we(component, data=data, data_t=data) assert s_f == s_f_o assert np.allclose(wave_f, wave_f_o) ae(masks_f_o, [0.5, 1., 0., 0.]) diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 1cfa52b3c..537ad4e4e 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -158,7 +158,7 @@ def __call__(self, component=None, data=None, data_t=None): assert masks.ndim == 1 assert waveform_aligned.shape[1] == masks.shape[0] - return s_aligned, waveform_aligned, masks + return s_aligned, masks, waveform_aligned #------------------------------------------------------------------------------ From 5ded62dd34c7425849c7d7bd80795355cc3c6673 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 12:45:42 +0200 Subject: [PATCH 0128/1059] WIP: fix bug with channel mapping in component extractor --- phy/traces/spike_detect.py | 12 ++++++++++-- phy/traces/tests/test_spike_detect.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 670068715..22604b990 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -60,6 +60,7 @@ def set_metadata(self, probe, channel_mapping=None, sample_rate=None): if channel_mapping is None: channel_mapping = {c: c for c in probe.channels} self.channel_mapping = channel_mapping + self.adjacency = self.probe.adjacency # Array of channel idx to consider. self.channels = sorted(channel_mapping.keys()) @@ -112,7 +113,7 @@ def extract_components(self, filtered): # Run the detection. join_size = self.connected_component_join_size - detector = FloodFillDetector(probe_adjacency_list=self.probe.adjacency, + detector = FloodFillDetector(probe_adjacency_list=self.adjacency, join_size=join_size) return detector(weak_crossings=weak, strong_crossings=strong) @@ -137,14 +138,21 @@ def extract_spikes(self, filtered, components): def detect(self, traces): assert traces.ndim == 2 - assert traces.shape[1] == self.n_channels # Only keep the selected channels (given shank, no dead channels, etc.) traces = self._select_channels(traces) + assert traces.shape[1] == self.n_channels + + # Find the thresholds. self._thresholds = self.find_thresholds(traces) + # Apply the filter. filtered = self.filter(traces) + + # Extract the spike components. components = self.extract_components(filtered) + + # Extract the spikes, masks, waveforms. spike_samples, masks, waveforms = self.extract_spikes(filtered, components) return spike_samples, masks diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 00e1ba15c..affda6187 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -28,10 +28,17 @@ def test_detect(): n_samples, n_channels = traces.shape sample_rate = 20000 probe = load_probe('1x32_buzsaki') + # channel_mapping = {i: i for i in range(1, 21, 2)} + channel_mapping = None sd = SpikeDetector() sd.use_single_threshold = False - sd.set_metadata(probe, sample_rate=sample_rate) + sd.set_metadata(probe, channel_mapping=channel_mapping, + sample_rate=sample_rate) spike_samples, masks = sd.detect(traces) - print(spike_samples) - print(masks) + + # from vispy.app import run + # from phy.plot import plot_traces + # plot_traces(traces, spike_samples=spike_samples, masks=masks, + # n_samples_per_spike=40) + # run() From 68ecc0ef96395ee48f5aefec155c4258690b7a83 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 12:51:14 +0200 Subject: [PATCH 0129/1059] WIP: channel remapping --- phy/electrode/mea.py | 14 +++++++------- phy/electrode/tests/test_mea.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/phy/electrode/mea.py b/phy/electrode/mea.py index 44dbbc80d..8257673a1 100644 --- a/phy/electrode/mea.py +++ b/phy/electrode/mea.py @@ -38,6 +38,13 @@ def _edges_to_adjacency_list(edges): return adj +def _remap_adjacency(adjacency, mapping): + remapped = {} + for key, vals in adjacency.items(): + remapped[mapping[key]] = [mapping[i] for i in vals] + return remapped + + def _probe_positions(probe, group): """Return the positions of a probe channel group.""" positions = probe['channel_groups'][group]['geometry'] @@ -54,13 +61,6 @@ def _probe_channels(probe, group): return probe['channel_groups'][group]['channels'] -def _probe_all_channels(probe): - """Return the list of channels in the probe.""" - cgs = probe['channel_groups'].values() - cg_channels = [cg['channels'] for cg in cgs] - return sorted(set(itertools.chain(*cg_channels))) - - def _probe_adjacency_list(probe): """Return an adjacency list of a whole probe.""" cgs = probe['channel_groups'].values() diff --git a/phy/electrode/tests/test_mea.py b/phy/electrode/tests/test_mea.py index a798a8f93..45c56f160 100644 --- a/phy/electrode/tests/test_mea.py +++ b/phy/electrode/tests/test_mea.py @@ -12,7 +12,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from ..mea import (_probe_channels, _probe_all_channels, +from ..mea import (_probe_channels, _remap_adjacency, _probe_positions, _probe_adjacency_list, MEA, linear_positions, staggered_positions, load_probe, list_probes @@ -23,6 +23,15 @@ # Tests #------------------------------------------------------------------------------ +def test_remap(): + adjacency = {1: [2, 3, 7], 3: [5, 11]} + mapping = {1: 3, 2: 20, 3: 30, 5: 50, 7: 70, 11: 1} + remapped = _remap_adjacency(adjacency, mapping) + assert sorted(remapped.keys()) == [3, 30] + assert remapped[3] == [20, 30, 70] + assert remapped[30] == [50, 1] + + def test_probe(): probe = {'channel_groups': { 0: {'channels': [0, 3, 1], From 25668d7a942766d0999e67734ca697b8b097bcc1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 12:58:17 +0200 Subject: [PATCH 0130/1059] WIP: subset adjacency --- phy/electrode/mea.py | 5 +++++ phy/electrode/tests/test_mea.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/phy/electrode/mea.py b/phy/electrode/mea.py index 8257673a1..18909b18b 100644 --- a/phy/electrode/mea.py +++ b/phy/electrode/mea.py @@ -38,6 +38,11 @@ def _edges_to_adjacency_list(edges): return adj +def _adjacency_subset(adjacency, subset): + return {c: [v for v in vals if v in subset] + for (c, vals) in adjacency.items() if c in subset} + + def _remap_adjacency(adjacency, mapping): remapped = {} for key, vals in adjacency.items(): diff --git a/phy/electrode/tests/test_mea.py b/phy/electrode/tests/test_mea.py index 45c56f160..4bd580a0e 100644 --- a/phy/electrode/tests/test_mea.py +++ b/phy/electrode/tests/test_mea.py @@ -12,7 +12,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from ..mea import (_probe_channels, _remap_adjacency, +from ..mea import (_probe_channels, _remap_adjacency, _adjacency_subset, _probe_positions, _probe_adjacency_list, MEA, linear_positions, staggered_positions, load_probe, list_probes @@ -32,6 +32,15 @@ def test_remap(): assert remapped[30] == [50, 1] +def test_adjacency_subset(): + adjacency = {1: [2, 3, 7], 3: [5, 11], 5: [1, 2, 11]} + subset = [1, 5, 32] + adjsub = _adjacency_subset(adjacency, subset) + assert sorted(adjsub.keys()) == [1, 5] + assert adjsub[1] == [] + assert adjsub[5] == [1] + + def test_probe(): probe = {'channel_groups': { 0: {'channels': [0, 3, 1], From 508b5858ffbd25401e79919dbf345630ea3be2ac Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 13:35:23 +0200 Subject: [PATCH 0131/1059] Fix channel mapping --- phy/traces/spike_detect.py | 34 ++++++++++++++++++++------- phy/traces/tests/test_spike_detect.py | 8 ++++--- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 22604b990..d7fe4c736 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -12,7 +12,7 @@ from traitlets.config.configurable import Configurable from traitlets import Int, Float, Unicode, Bool -from phy.electrode import MEA +from phy.electrode.mea import MEA, _adjacency_subset, _remap_adjacency from phy.utils.array import get_excerpts from .detect import FloodFillDetector, Thresholder, compute_threshold from .filter import Filter @@ -57,18 +57,34 @@ def set_metadata(self, probe, channel_mapping=None, sample_rate=None): assert sample_rate > 0 self.sample_rate = sample_rate + # Channel mapping. if channel_mapping is None: channel_mapping = {c: c for c in probe.channels} - self.channel_mapping = channel_mapping - self.adjacency = self.probe.adjacency + # channel mappings is {trace_col: channel_id}. + # Trace columns and channel ids to keep. + self.trace_cols = sorted(channel_mapping.keys()) + self.channel_ids = sorted(channel_mapping.values()) + # The key is the col in traces, the val is the channel id. + adj = self.probe.adjacency # Numbers are all channel ids. + # First, we subset the adjacency list with the kept channel ids. + adj = _adjacency_subset(adj, self.channel_ids) + # Then, we remap to convert from channel ids to trace columns. + # We need to inverse the mapping. + channel_mapping_inv = {v: c for (c, v) in channel_mapping.items()} + # Now, the adjacency list contains trace column numbers. + adj = _remap_adjacency(adj, channel_mapping_inv) + assert set(adj) <= set(self.trace_cols) + # Finally, we need to remap with relative column indices. + rel_mapping = {c: i for (i, c) in enumerate(self.trace_cols)} + adj = _remap_adjacency(adj, rel_mapping) + self._adjacency = adj # Array of channel idx to consider. - self.channels = sorted(channel_mapping.keys()) - self.n_channels = len(self.channels) + self.n_channels = len(self.channel_ids) self.n_samples_waveforms = self.extract_s_before + self.extract_s_after - def _select_channels(self, traces): - return traces[:, self.channels] + def subset_traces(self, traces): + return traces[:, self.trace_cols] def find_thresholds(self, traces): """Find weak and strong thresholds in filtered traces.""" @@ -113,7 +129,7 @@ def extract_components(self, filtered): # Run the detection. join_size = self.connected_component_join_size - detector = FloodFillDetector(probe_adjacency_list=self.adjacency, + detector = FloodFillDetector(probe_adjacency_list=self._adjacency, join_size=join_size) return detector(weak_crossings=weak, strong_crossings=strong) @@ -140,7 +156,7 @@ def detect(self, traces): assert traces.ndim == 2 # Only keep the selected channels (given shank, no dead channels, etc.) - traces = self._select_channels(traces) + traces = self.subset_traces(traces) assert traces.shape[1] == self.n_channels # Find the thresholds. diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index affda6187..dd8d955f7 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -28,8 +28,8 @@ def test_detect(): n_samples, n_channels = traces.shape sample_rate = 20000 probe = load_probe('1x32_buzsaki') - # channel_mapping = {i: i for i in range(1, 21, 2)} - channel_mapping = None + channel_mapping = {i: i for i in range(1, 21, 2)} + # channel_mapping = None sd = SpikeDetector() sd.use_single_threshold = False @@ -39,6 +39,8 @@ def test_detect(): # from vispy.app import run # from phy.plot import plot_traces - # plot_traces(traces, spike_samples=spike_samples, masks=masks, + # plot_traces(sd.subset_traces(traces), + # spike_samples=spike_samples, + # masks=masks, # n_samples_per_spike=40) # run() From f82c8d0daadda7d948e46d0409838289628da8ad Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 13:47:28 +0200 Subject: [PATCH 0132/1059] Update spike detection --- phy/traces/spike_detect.py | 31 ++++++++++++++------------- phy/traces/tests/test_spike_detect.py | 3 ++- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index d7fe4c736..8cf6aecda 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -42,6 +42,7 @@ class SpikeDetector(Configurable): weight_power = Float(2) def __init__(self, ctx=None): + self._thresholds = None super(SpikeDetector, self).__init__() if not ctx or not hasattr(ctx, 'cache'): return @@ -50,13 +51,16 @@ def __init__(self, ctx=None): self.extract_components = ctx.cache(self.extract_components) self.extract_spikes = ctx.cache(self.extract_spikes) - def set_metadata(self, probe, channel_mapping=None, sample_rate=None): + def set_metadata(self, probe, channel_mapping=None, + sample_rate=None, thresholds=None): assert isinstance(probe, MEA) self.probe = probe assert sample_rate > 0 self.sample_rate = sample_rate + self._thresholds = thresholds + # Channel mapping. if channel_mapping is None: channel_mapping = {c: c for c in probe.channels} @@ -107,8 +111,6 @@ def find_thresholds(self, traces): thresholds = {'weak': thresholds[0], 'strong': thresholds[1]} logger.info("Thresholds found: {}.".format(thresholds)) - self._thresholder = Thresholder(mode=self.detect_spikes, - thresholds=thresholds) return thresholds def filter(self, traces): @@ -119,7 +121,13 @@ def filter(self, traces): ) return f(traces).astype(np.float32) - def extract_components(self, filtered): + def extract_spikes(self, filtered, thresholds=None): + if thresholds is None: + thresholds = self._thresholds + assert thresholds is not None + self._thresholder = Thresholder(mode=self.detect_spikes, + thresholds=thresholds) + # Transform the filtered data according to the detection mode. traces_t = self._thresholder.transform(filtered) @@ -131,12 +139,8 @@ def extract_components(self, filtered): join_size = self.connected_component_join_size detector = FloodFillDetector(probe_adjacency_list=self._adjacency, join_size=join_size) - return detector(weak_crossings=weak, - strong_crossings=strong) - - def extract_spikes(self, filtered, components): - # Transform the filtered data according to the detection mode. - traces_t = self._thresholder.transform(filtered) + components = detector(weak_crossings=weak, + strong_crossings=strong) # Extract all waveforms. extractor = WaveformExtractor(extract_before=self.extract_s_before, @@ -165,10 +169,7 @@ def detect(self, traces): # Apply the filter. filtered = self.filter(traces) - # Extract the spike components. - components = self.extract_components(filtered) - # Extract the spikes, masks, waveforms. - spike_samples, masks, waveforms = self.extract_spikes(filtered, - components) + spike_samples, masks, waveforms = self.extract_spikes(filtered + ) return spike_samples, masks diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index dd8d955f7..36d2c4ece 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -33,7 +33,8 @@ def test_detect(): sd = SpikeDetector() sd.use_single_threshold = False - sd.set_metadata(probe, channel_mapping=channel_mapping, + sd.set_metadata(probe, + channel_mapping=channel_mapping, sample_rate=sample_rate) spike_samples, masks = sd.detect(traces) From 26bcb4014a68b828daecd50f026d4d999324f169 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 14:00:38 +0200 Subject: [PATCH 0133/1059] Minor updates in spike detect --- phy/traces/spike_detect.py | 20 +++++++++----------- phy/traces/tests/test_spike_detect.py | 2 +- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 8cf6aecda..abc8525bb 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -48,8 +48,8 @@ def __init__(self, ctx=None): return self.find_thresholds = ctx.cache(self.find_thresholds) self.filter = ctx.cache(self.filter) - self.extract_components = ctx.cache(self.extract_components) self.extract_spikes = ctx.cache(self.extract_spikes) + self.detect = ctx.cache(self.detect) def set_metadata(self, probe, channel_mapping=None, sample_rate=None, thresholds=None): @@ -121,15 +121,18 @@ def filter(self, traces): ) return f(traces).astype(np.float32) - def extract_spikes(self, filtered, thresholds=None): + def extract_spikes(self, traces_subset, thresholds=None): if thresholds is None: thresholds = self._thresholds assert thresholds is not None self._thresholder = Thresholder(mode=self.detect_spikes, thresholds=thresholds) + # Filter the traces. + traces_f = self.filter(traces_subset) + # Transform the filtered data according to the detection mode. - traces_t = self._thresholder.transform(filtered) + traces_t = self._thresholder.transform(traces_f) # Compute the threshold crossings. weak = self._thresholder.detect(traces_t, 'weak') @@ -149,7 +152,7 @@ def extract_spikes(self, filtered, thresholds=None): thresholds=self._thresholds, ) - s, m, w = zip(*(extractor(component, data=filtered, data_t=traces_t) + s, m, w = zip(*(extractor(component, data=traces_f, data_t=traces_t) for component in components)) s = np.array(s, dtype=np.int64) m = np.array(m, dtype=np.float32) @@ -157,19 +160,14 @@ def extract_spikes(self, filtered, thresholds=None): return s, m, w def detect(self, traces): - assert traces.ndim == 2 # Only keep the selected channels (given shank, no dead channels, etc.) traces = self.subset_traces(traces) + assert traces.ndim == 2 assert traces.shape[1] == self.n_channels # Find the thresholds. self._thresholds = self.find_thresholds(traces) - # Apply the filter. - filtered = self.filter(traces) - # Extract the spikes, masks, waveforms. - spike_samples, masks, waveforms = self.extract_spikes(filtered - ) - return spike_samples, masks + return self.extract_spikes(traces) diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 36d2c4ece..1400990c2 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -36,7 +36,7 @@ def test_detect(): sd.set_metadata(probe, channel_mapping=channel_mapping, sample_rate=sample_rate) - spike_samples, masks = sd.detect(traces) + spike_samples, masks, _ = sd.detect(traces) # from vispy.app import run # from phy.plot import plot_traces From 1603f8367bb2fe77c0513b542f6e916500d92c0b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 14:03:22 +0200 Subject: [PATCH 0134/1059] Minor updates in spike detect --- phy/traces/spike_detect.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index abc8525bb..948fd19a3 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -42,7 +42,6 @@ class SpikeDetector(Configurable): weight_power = Float(2) def __init__(self, ctx=None): - self._thresholds = None super(SpikeDetector, self).__init__() if not ctx or not hasattr(ctx, 'cache'): return @@ -52,15 +51,13 @@ def __init__(self, ctx=None): self.detect = ctx.cache(self.detect) def set_metadata(self, probe, channel_mapping=None, - sample_rate=None, thresholds=None): + sample_rate=None): assert isinstance(probe, MEA) self.probe = probe assert sample_rate > 0 self.sample_rate = sample_rate - self._thresholds = thresholds - # Channel mapping. if channel_mapping is None: channel_mapping = {c: c for c in probe.channels} @@ -122,8 +119,6 @@ def filter(self, traces): return f(traces).astype(np.float32) def extract_spikes(self, traces_subset, thresholds=None): - if thresholds is None: - thresholds = self._thresholds assert thresholds is not None self._thresholder = Thresholder(mode=self.detect_spikes, thresholds=thresholds) @@ -149,7 +144,7 @@ def extract_spikes(self, traces_subset, thresholds=None): extractor = WaveformExtractor(extract_before=self.extract_s_before, extract_after=self.extract_s_after, weight_power=self.weight_power, - thresholds=self._thresholds, + thresholds=thresholds, ) s, m, w = zip(*(extractor(component, data=traces_f, data_t=traces_t) @@ -159,7 +154,7 @@ def extract_spikes(self, traces_subset, thresholds=None): w = np.array(w, dtype=np.float32) return s, m, w - def detect(self, traces): + def detect(self, traces, thresholds=None): # Only keep the selected channels (given shank, no dead channels, etc.) traces = self.subset_traces(traces) @@ -167,7 +162,8 @@ def detect(self, traces): assert traces.shape[1] == self.n_channels # Find the thresholds. - self._thresholds = self.find_thresholds(traces) + if thresholds is None: + thresholds = self.find_thresholds(traces) # Extract the spikes, masks, waveforms. - return self.extract_spikes(traces) + return self.extract_spikes(traces, thresholds=thresholds) From c157baac9c3dfa19c8a046083c25f26eb587955e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 14:22:15 +0200 Subject: [PATCH 0135/1059] Clean spike detection tests --- phy/traces/spike_detect.py | 5 +++- phy/traces/tests/test_spike_detect.py | 43 +++++++++++++++++++++------ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 948fd19a3..42974416c 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -26,6 +26,7 @@ #------------------------------------------------------------------------------ class SpikeDetector(Configurable): + do_filter = Bool(True) filter_low = Float(500.) filter_butter_order = Int(3) chunk_size_seconds = Float(1) @@ -107,10 +108,12 @@ def find_thresholds(self, traces): std_factor=std_factor) thresholds = {'weak': thresholds[0], 'strong': thresholds[1]} - logger.info("Thresholds found: {}.".format(thresholds)) + # logger.info("Thresholds found: {}.".format(thresholds)) return thresholds def filter(self, traces): + if not self.do_filter: # pragma: no cover + return traces f = Filter(rate=self.sample_rate, low=self.filter_low, high=0.95 * .5 * self.sample_rate, diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 1400990c2..34f2c61eb 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -7,8 +7,7 @@ #------------------------------------------------------------------------------ import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac +from pytest import yield_fixture from phy.utils.datasets import download_test_data from phy.electrode import load_probe @@ -17,27 +16,53 @@ #------------------------------------------------------------------------------ -# Test spike detection +# Fixtures #------------------------------------------------------------------------------ -def test_detect(): - +@yield_fixture +def traces(): path = download_test_data('test-32ch-10s.dat') traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) traces = traces[:20000] - n_samples, n_channels = traces.shape - sample_rate = 20000 + + yield traces + + +@yield_fixture(params=[(True,), (False,)]) +def spike_detector(request): + remap = request.param[0] + probe = load_probe('1x32_buzsaki') - channel_mapping = {i: i for i in range(1, 21, 2)} - # channel_mapping = None + channel_mapping = {i: i for i in range(1, 21, 2)} if remap else None sd = SpikeDetector() sd.use_single_threshold = False + sample_rate = 20000 sd.set_metadata(probe, channel_mapping=channel_mapping, sample_rate=sample_rate) + + yield sd + + +#------------------------------------------------------------------------------ +# Test spike detection +#------------------------------------------------------------------------------ + +def test_detect(spike_detector, traces): + sd = spike_detector spike_samples, masks, _ = sd.detect(traces) + n_channels = sd.n_channels + n_spikes = len(spike_samples) + + assert spike_samples.dtype == np.int64 + assert spike_samples.ndim == 1 + + assert masks.dtype == np.float32 + assert masks.ndim == 2 + assert masks.shape == (n_spikes, n_channels) + # from vispy.app import run # from phy.plot import plot_traces # plot_traces(sd.subset_traces(traces), From 31ded0b24ab84979bd727ec4700bfd2117cdda82 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 17:19:25 +0200 Subject: [PATCH 0136/1059] WIP: add Context --- phy/utils/context.py | 106 ++++++++++++++++++++++++++++++++ phy/utils/tests/test_context.py | 51 +++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 phy/utils/context.py create mode 100644 phy/utils/tests/test_context.py diff --git a/phy/utils/context.py b/phy/utils/context.py new file mode 100644 index 000000000..d077b6501 --- /dev/null +++ b/phy/utils/context.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- + +"""Execution context that handles parallel processing and cacheing.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging +import os +import os.path as op + +import numpy as np + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Utility functions +#------------------------------------------------------------------------------ + +def _iter_chunks_dask(da): + from dask.core import flatten + for i, chunk in enumerate(flatten(da._keys())): + yield i, chunk + + +def read_array(path): + return np.load(path) + + +def write_array(path, arr): + np.save(path, arr) + + +#------------------------------------------------------------------------------ +# Context +#------------------------------------------------------------------------------ + +class Context(object): + def __init__(self, cache_dir): + self.cache_dir = op.realpath(cache_dir) + if not op.exists(self.cache_dir): + os.makedirs(self.cache_dir) + try: + from joblib import Memory + joblib_cachedir = self._path('joblib') + self._memory = Memory(cachedir=joblib_cachedir, verbose=0) + except ImportError: # pragma: no cover + logger.warn("Joblib is not installed. " + "Install it with `conda install joblib`.") + self._memory = None + + def _path(self, rel_path, *args, **kwargs): + return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) + + def cache(self, f): + if self._memory is None: # pragma: no cover + logger.debug("Joblib is not installed, so skipping cacheing.") + return + return self._memory.cache(f) + + def map_dask_array(self, f, da, chunks=None, name=None, + dtype=None, shape=None): + try: + from dask.array import Array + from dask.async import get_sync as get + except ImportError: # pragma: no cover + logger.warn("dask is not installed. " + "Install it with `conda install dask`.") + return + + assert isinstance(da, Array) + + name = name or f.__name__ + assert name != da.name + dtype = dtype or da.dtype + shape = shape or da.shape + chunks = chunks or da.chunks + dask = da.dask + + def wrapped(chk): + (i, chunk) = chk + # # Load the array's chunk. + # arg = chunk[0](*chunk[1:]) + arr = get(dask, chunk) + # Execute the function on the chunk. + res = f(arr) + # Save the output in the cache. + if not op.exists(self._path(name)): + os.makedirs(self._path(name)) + path = self._path('{name:s}/{i:d}.npy', name=name, i=i) + write_array(path, res) + + # Return a dask pair to load the result. + return (read_array, path) + + # Map the wrapped function normally. + mapped = self.map(wrapped, _iter_chunks_dask(da)) + + # Return the result as a dask array. + dask = {(name, i): chunk for i, chunk in enumerate(mapped)} + return Array(dask, name, chunks, dtype=dtype, shape=shape) + + def map(self, f, args): + return [f(arg) for arg in args] diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py new file mode 100644 index 000000000..a401a0fe4 --- /dev/null +++ b/phy/utils/tests/test_context.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +"""Test context.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_array_equal as ae +from pytest import yield_fixture + +from ..context import Context, _iter_chunks_dask + + +#------------------------------------------------------------------------------ +# Test context +#------------------------------------------------------------------------------ + +@yield_fixture +def context(tempdir): + ctx = Context('{}/cache/'.format(tempdir)) + yield ctx + + +def test_iter_chunks_dask(): + from dask.array import from_array + + x = np.arange(10) + da = from_array(x, chunks=(3,)) + assert len(list(_iter_chunks_dask(da))) == 4 + + +def test_context_map(context): + def f(x): + return x * x + + args = range(10) + assert context.map(f, args) == [x * x for x in range(10)] + + +def test_context_dask(context): + from dask.array import from_array + + def square(x): + return x * x + + x = np.arange(10) + da = from_array(x, chunks=(3,)) + res = context.map_dask_array(square, da) + ae(res.compute(), x ** 2) From 9f191c6446d4da2b80bc5950fad698a77621774f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 17:20:34 +0200 Subject: [PATCH 0137/1059] WIP --- phy/utils/context.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index d077b6501..1e973d27c 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -56,7 +56,7 @@ def _path(self, rel_path, *args, **kwargs): def cache(self, f): if self._memory is None: # pragma: no cover - logger.debug("Joblib is not installed, so skipping cacheing.") + logger.debug("Joblib is not installed: skipping cacheing.") return return self._memory.cache(f) @@ -66,9 +66,8 @@ def map_dask_array(self, f, da, chunks=None, name=None, from dask.array import Array from dask.async import get_sync as get except ImportError: # pragma: no cover - logger.warn("dask is not installed. " - "Install it with `conda install dask`.") - return + raise Exception("dask is not installed. " + "Install it with `conda install dask`.") assert isinstance(da, Array) @@ -81,11 +80,12 @@ def map_dask_array(self, f, da, chunks=None, name=None, def wrapped(chk): (i, chunk) = chk - # # Load the array's chunk. - # arg = chunk[0](*chunk[1:]) + # Load the array's chunk. arr = get(dask, chunk) + # Execute the function on the chunk. res = f(arr) + # Save the output in the cache. if not op.exists(self._path(name)): os.makedirs(self._path(name)) From 1812471c236830ded2a40a806fe7e3e9c12ec026 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 17:30:33 +0200 Subject: [PATCH 0138/1059] WIP: ipyparallel client in context --- phy/utils/context.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 1e973d27c..47de4c05b 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -38,7 +38,7 @@ def write_array(path, arr): #------------------------------------------------------------------------------ class Context(object): - def __init__(self, cache_dir): + def __init__(self, cache_dir, ipy_client=None): self.cache_dir = op.realpath(cache_dir) if not op.exists(self.cache_dir): os.makedirs(self.cache_dir) @@ -50,6 +50,7 @@ def __init__(self, cache_dir): logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") self._memory = None + self._ipy_client = ipy_client def _path(self, rel_path, *args, **kwargs): return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) @@ -102,5 +103,14 @@ def wrapped(chk): dask = {(name, i): chunk for i, chunk in enumerate(mapped)} return Array(dask, name, chunks, dtype=dtype, shape=shape) - def map(self, f, args): + def _map_serial(self, f, args): return [f(arg) for arg in args] + + def _map_ipy(self, f, args): + return self._ipy_client.map(f, args) + + def map(self, f, args): + if self._ipy_client: + return self._map_ipy(f, args) + else: + return self._map_serial(f, args) From 1d263cb307c741d44b433a70f5ac7bcb669ae054 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 21:03:19 +0200 Subject: [PATCH 0139/1059] WIP: ipyparallel fixtures --- phy/utils/tests/test_context.py | 39 +++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index a401a0fe4..26f4aa8d7 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -6,15 +6,18 @@ # Imports #------------------------------------------------------------------------------ +import os + import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import yield_fixture +from pytest import fixture, yield_fixture, mark +from ipyparallel.tests.clienttest import ClusterTestCase, add_engines from ..context import Context, _iter_chunks_dask #------------------------------------------------------------------------------ -# Test context +# Fixtures #------------------------------------------------------------------------------ @yield_fixture @@ -23,6 +26,10 @@ def context(tempdir): yield ctx +#------------------------------------------------------------------------------ +# Test context +#------------------------------------------------------------------------------ + def test_iter_chunks_dask(): from dask.array import from_array @@ -49,3 +56,31 @@ def square(x): da = from_array(x, chunks=(3,)) res = context.map_dask_array(square, da) ae(res.compute(), x ** 2) + + +#------------------------------------------------------------------------------ +# ipyparallel tests +#------------------------------------------------------------------------------ + +@yield_fixture(scope='module') +def ipy_client(): + + def iptest_stdstreams_fileno(): + return os.open(os.devnull, os.O_WRONLY) + + # OMG-THIS-IS-UGLY-HACK: monkey-patch this global object to avoid + # using the nose iptest extension (we're using pytest). + # See https://github.com/ipython/ipython/blob/master/IPython/testing/iptest.py#L317-L319 # noqa + from ipyparallel import Client + import ipyparallel.tests + ipyparallel.tests.nose.iptest_stdstreams_fileno = iptest_stdstreams_fileno + + # Start two engines engine (one is launched by setup()). + ipyparallel.tests.setup() + ipyparallel.tests.add_engines(1) + yield Client(profile='iptest') + ipyparallel.tests.teardown() + + +def test_client(ipy_client): + print(ipy_client.ids) From ad8cb84e510dd3c56dd57a7e6d6100274b5c4b3e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 21:09:26 +0200 Subject: [PATCH 0140/1059] Done ipy_client fixtures --- phy/utils/tests/test_context.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index 26f4aa8d7..e10ee7503 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -10,8 +10,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import fixture, yield_fixture, mark -from ipyparallel.tests.clienttest import ClusterTestCase, add_engines +from pytest import yield_fixture from ..context import Context, _iter_chunks_dask @@ -82,5 +81,9 @@ def iptest_stdstreams_fileno(): ipyparallel.tests.teardown() -def test_client(ipy_client): - print(ipy_client.ids) +def test_client_1(ipy_client): + assert ipy_client.ids == [0, 1] + + +def test_client_2(ipy_client): + assert ipy_client[:].map_sync(lambda x: x * x, [1, 2, 3]) == [1, 4, 9] From d65e93d75ee1755f49e7b2bfe768181ac426ebf2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 21:43:32 +0200 Subject: [PATCH 0141/1059] Test context parallel map --- phy/utils/context.py | 29 ++++++++------ phy/utils/tests/test_context.py | 68 +++++++++++++++++++-------------- 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 47de4c05b..988f7678c 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -38,7 +38,7 @@ def write_array(path, arr): #------------------------------------------------------------------------------ class Context(object): - def __init__(self, cache_dir, ipy_client=None): + def __init__(self, cache_dir, ipy_view=None): self.cache_dir = op.realpath(cache_dir) if not op.exists(self.cache_dir): os.makedirs(self.cache_dir) @@ -50,7 +50,15 @@ def __init__(self, cache_dir, ipy_client=None): logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") self._memory = None - self._ipy_client = ipy_client + self._ipy_view = ipy_view + + @property + def ipy_view(self): + return self._ipy_view + + @ipy_view.setter + def ipy_view(self, value): + self._ipy_view = value def _path(self, rel_path, *args, **kwargs): return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) @@ -103,14 +111,11 @@ def wrapped(chk): dask = {(name, i): chunk for i, chunk in enumerate(mapped)} return Array(dask, name, chunks, dtype=dtype, shape=shape) - def _map_serial(self, f, args): - return [f(arg) for arg in args] - - def _map_ipy(self, f, args): - return self._ipy_client.map(f, args) - - def map(self, f, args): - if self._ipy_client: - return self._map_ipy(f, args) + def map(self, f, args, sync=True): + if self._ipy_view: + if sync: + return self._ipy_view.map_sync(f, args) + else: + return self._ipy_view.map_async(f, args) else: - return self._map_serial(f, args) + return [f(arg) for arg in args] diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index e10ee7503..b35630ebb 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -25,6 +25,38 @@ def context(tempdir): yield ctx +@yield_fixture(scope='module') +def ipy_client(): + + def iptest_stdstreams_fileno(): + return os.open(os.devnull, os.O_WRONLY) + + # OMG-THIS-IS-UGLY-HACK: monkey-patch this global object to avoid + # using the nose iptest extension (we're using pytest). + # See https://github.com/ipython/ipython/blob/master/IPython/testing/iptest.py#L317-L319 # noqa + from ipyparallel import Client + import ipyparallel.tests + ipyparallel.tests.nose.iptest_stdstreams_fileno = iptest_stdstreams_fileno + + # Start two engines engine (one is launched by setup()). + ipyparallel.tests.setup() + ipyparallel.tests.add_engines(1) + yield Client(profile='iptest') + ipyparallel.tests.teardown() + + +#------------------------------------------------------------------------------ +# ipyparallel tests +#------------------------------------------------------------------------------ + +def test_client_1(ipy_client): + assert ipy_client.ids == [0, 1] + + +def test_client_2(ipy_client): + assert ipy_client[:].map_sync(lambda x: x * x, [1, 2, 3]) == [1, 4, 9] + + #------------------------------------------------------------------------------ # Test context #------------------------------------------------------------------------------ @@ -57,33 +89,13 @@ def square(x): ae(res.compute(), x ** 2) -#------------------------------------------------------------------------------ -# ipyparallel tests -#------------------------------------------------------------------------------ - -@yield_fixture(scope='module') -def ipy_client(): - - def iptest_stdstreams_fileno(): - return os.open(os.devnull, os.O_WRONLY) - - # OMG-THIS-IS-UGLY-HACK: monkey-patch this global object to avoid - # using the nose iptest extension (we're using pytest). - # See https://github.com/ipython/ipython/blob/master/IPython/testing/iptest.py#L317-L319 # noqa - from ipyparallel import Client - import ipyparallel.tests - ipyparallel.tests.nose.iptest_stdstreams_fileno = iptest_stdstreams_fileno - - # Start two engines engine (one is launched by setup()). - ipyparallel.tests.setup() - ipyparallel.tests.add_engines(1) - yield Client(profile='iptest') - ipyparallel.tests.teardown() - - -def test_client_1(ipy_client): - assert ipy_client.ids == [0, 1] +def test_context_parallel_map(context, ipy_client): + view = ipy_client[:] + context.ipy_view = view + assert context.ipy_view == view + def square(x): + return x * x -def test_client_2(ipy_client): - assert ipy_client[:].map_sync(lambda x: x * x, [1, 2, 3]) == [1, 4, 9] + assert context.map(square, [1, 2, 3]) == [1, 4, 9] + assert context.map(square, [1, 2, 3], sync=False).get() == [1, 4, 9] From 60cc31f0f99e507f1b54c0d716e931b020af13d6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 29 Sep 2015 22:05:44 +0200 Subject: [PATCH 0142/1059] WIP: test map parallel on dask array --- phy/utils/context.py | 6 +++++- phy/utils/tests/test_context.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 988f7678c..b7abdf8e7 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -50,7 +50,9 @@ def __init__(self, cache_dir, ipy_view=None): logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") self._memory = None - self._ipy_view = ipy_view + self._ipy_view = None + if ipy_view: + self.ipy_view = ipy_view @property def ipy_view(self): @@ -58,6 +60,8 @@ def ipy_view(self): @ipy_view.setter def ipy_view(self, value): + # Dill is necessary because we need to serialize closures. + value.use_dill() self._ipy_view = value def _path(self, rel_path, *args, **kwargs): diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index b35630ebb..0ebbaa6df 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -99,3 +99,19 @@ def square(x): assert context.map(square, [1, 2, 3]) == [1, 4, 9] assert context.map(square, [1, 2, 3], sync=False).get() == [1, 4, 9] + + +def test_context_parallel_dask(context, ipy_client): + from dask.array import from_array + + context.ipy_view = ipy_client[:] + + def square(x): + import os + print(os.getpid()) + return x * x + + x = np.arange(10) + da = from_array(x, chunks=(3,)) + res = context.map_dask_array(square, da) + ae(res.compute(), x ** 2) From d36d7d86077409e016127ef851a1b8e0a489770c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 10:50:15 +0200 Subject: [PATCH 0143/1059] WIP: fixing bugs in parallel context --- phy/utils/context.py | 81 +++++++++++++++++++++------------ phy/utils/tests/test_context.py | 2 +- 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index b7abdf8e7..7a2669694 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -11,6 +11,11 @@ import os.path as op import numpy as np +try: + from dask.async import get_sync as get +except ImportError: # pragma: no cover + raise Exception("dask is not installed. " + "Install it with `conda install dask`.") logger = logging.getLogger(__name__) @@ -21,8 +26,8 @@ def _iter_chunks_dask(da): from dask.core import flatten - for i, chunk in enumerate(flatten(da._keys())): - yield i, chunk + for chunk in flatten(da._keys()): + yield chunk def read_array(path): @@ -37,6 +42,24 @@ def write_array(path, arr): # Context #------------------------------------------------------------------------------ +def _mapped(i, chunk, dask, func, cachedir, name): + # Load the array's chunk. + arr = get(dask, chunk) + + # Execute the function on the chunk. + res = func(arr) + + # Save the output in the cache. + dirpath = op.join(cachedir, name) + if not op.exists(dirpath): + os.makedirs(dirpath) + path = op.join(dirpath, '{}.npy'.format(i)) + write_array(path, res) + + # Return a dask pair to load the result. + return (read_array, path) + + class Context(object): def __init__(self, cache_dir, ipy_view=None): self.cache_dir = op.realpath(cache_dir) @@ -73,53 +96,51 @@ def cache(self, f): return return self._memory.cache(f) - def map_dask_array(self, f, da, chunks=None, name=None, + def map_dask_array(self, func, da, chunks=None, name=None, dtype=None, shape=None): try: from dask.array import Array - from dask.async import get_sync as get except ImportError: # pragma: no cover raise Exception("dask is not installed. " "Install it with `conda install dask`.") assert isinstance(da, Array) - name = name or f.__name__ + name = name or func.__name__ assert name != da.name dtype = dtype or da.dtype shape = shape or da.shape chunks = chunks or da.chunks dask = da.dask - def wrapped(chk): - (i, chunk) = chk - # Load the array's chunk. - arr = get(dask, chunk) - - # Execute the function on the chunk. - res = f(arr) - - # Save the output in the cache. - if not op.exists(self._path(name)): - os.makedirs(self._path(name)) - path = self._path('{name:s}/{i:d}.npy', name=name, i=i) - write_array(path, res) - - # Return a dask pair to load the result. - return (read_array, path) - - # Map the wrapped function normally. - mapped = self.map(wrapped, _iter_chunks_dask(da)) + cachedir = self.cache_dir + args_0 = list(_iter_chunks_dask(da)) + n = len(args_0) + mapped = self.map(_mapped, range(n), args_0, [dask] * n, + [func] * n, [cachedir] * n, [name] * n) # Return the result as a dask array. dask = {(name, i): chunk for i, chunk in enumerate(mapped)} return Array(dask, name, chunks, dtype=dtype, shape=shape) - def map(self, f, args, sync=True): + def _map_serial(self, f, *args): + return [f(*arg) for arg in zip(*args)] + + def _map_ipy(self, f, *args, **kwargs): + if kwargs.get('sync', True): + name = 'map_sync' + else: + name = 'map_async' + return getattr(self._ipy_view, name)(f, *args) + + def map_async(self, f, *args): + if self._ipy_view: + return self._map_ipy(f, *args, sync=False) + else: + return self._map_serial(f, *args) + + def map(self, f, *args): if self._ipy_view: - if sync: - return self._ipy_view.map_sync(f, args) - else: - return self._ipy_view.map_async(f, args) + return self._map_ipy(f, *args, sync=True) else: - return [f(arg) for arg in args] + return self._map_serial(f, *args) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index 0ebbaa6df..cf5d15a5f 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -98,7 +98,7 @@ def square(x): return x * x assert context.map(square, [1, 2, 3]) == [1, 4, 9] - assert context.map(square, [1, 2, 3], sync=False).get() == [1, 4, 9] + assert context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] def test_context_parallel_dask(context, ipy_client): From d753f524a861558745176ac9a39260505528968c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 10:53:26 +0200 Subject: [PATCH 0144/1059] Increase coverage in phy.utils.context --- phy/utils/context.py | 9 ++++----- phy/utils/tests/test_context.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 7a2669694..f7ccc40f5 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -73,9 +73,7 @@ def __init__(self, cache_dir, ipy_view=None): logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") self._memory = None - self._ipy_view = None - if ipy_view: - self.ipy_view = ipy_view + self.ipy_view = ipy_view if ipy_view else None @property def ipy_view(self): @@ -83,9 +81,10 @@ def ipy_view(self): @ipy_view.setter def ipy_view(self, value): - # Dill is necessary because we need to serialize closures. - value.use_dill() self._ipy_view = value + if hasattr(value, 'use_dill'): + # Dill is necessary because we need to serialize closures. + value.use_dill() def _path(self, rel_path, *args, **kwargs): return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index cf5d15a5f..f5a095150 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -75,6 +75,7 @@ def f(x): args = range(10) assert context.map(f, args) == [x * x for x in range(10)] + assert context.map_async(f, args) == [x * x for x in range(10)] def test_context_dask(context): From e0e21f72b84ce5051b1baf21fb935b76248bb81b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 11:01:55 +0200 Subject: [PATCH 0145/1059] Add context cache test --- phy/utils/tests/test_context.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index f5a095150..d483bf6e3 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -19,7 +19,7 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture +@yield_fixture(scope='function') def context(tempdir): ctx = Context('{}/cache/'.format(tempdir)) yield ctx @@ -69,6 +69,31 @@ def test_iter_chunks_dask(): assert len(list(_iter_chunks_dask(da))) == 4 +def test_context_cache(context): + + _res = [] + + def f(x): + _res.append(x) + return x ** 2 + + x = np.arange(5) + x2 = x * x + + ae(f(x), x2) + assert len(_res) == 1 + + f = context.cache(f) + + # Run it a first time. + ae(f(x), x2) + assert len(_res) == 2 + + # The second time, the cache is used. + ae(f(x), x2) + assert len(_res) == 2 + + def test_context_map(context): def f(x): return x * x From a32f33d48b2ffd6e367c8465b1048eb2585bfe65 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 11:10:53 +0200 Subject: [PATCH 0146/1059] More context tests --- phy/utils/context.py | 11 +++++++---- phy/utils/tests/test_context.py | 32 ++++++++++++++------------------ 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index f7ccc40f5..577a10212 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -11,6 +11,7 @@ import os.path as op import numpy as np +from six.moves.cPickle import dump try: from dask.async import get_sync as get except ImportError: # pragma: no cover @@ -42,7 +43,7 @@ def write_array(path, arr): # Context #------------------------------------------------------------------------------ -def _mapped(i, chunk, dask, func, cachedir, name): +def _mapped(i, chunk, dask, func, dirpath): # Load the array's chunk. arr = get(dask, chunk) @@ -50,7 +51,6 @@ def _mapped(i, chunk, dask, func, cachedir, name): res = func(arr) # Save the output in the cache. - dirpath = op.join(cachedir, name) if not op.exists(dirpath): os.makedirs(dirpath) path = op.join(dirpath, '{}.npy'.format(i)) @@ -112,11 +112,14 @@ def map_dask_array(self, func, da, chunks=None, name=None, chunks = chunks or da.chunks dask = da.dask - cachedir = self.cache_dir args_0 = list(_iter_chunks_dask(da)) n = len(args_0) + dirpath = op.join(self.cache_dir, name) mapped = self.map(_mapped, range(n), args_0, [dask] * n, - [func] * n, [cachedir] * n, [name] * n) + [func] * n, [dirpath] * n) + + with open(op.join(dirpath, 'info'), 'wb') as f: + dump({'chunks': chunks, 'dtype': dtype, 'axis': 0}, f) # Return the result as a dask array. dask = {(name, i): chunk for i, chunk in enumerate(mapped)} diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index d483bf6e3..ea43a6167 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -7,10 +7,11 @@ #------------------------------------------------------------------------------ import os +import os.path as op import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import yield_fixture +from pytest import yield_fixture, mark from ..context import Context, _iter_chunks_dask @@ -103,18 +104,6 @@ def f(x): assert context.map_async(f, args) == [x * x for x in range(10)] -def test_context_dask(context): - from dask.array import from_array - - def square(x): - return x * x - - x = np.arange(10) - da = from_array(x, chunks=(3,)) - res = context.map_dask_array(square, da) - ae(res.compute(), x ** 2) - - def test_context_parallel_map(context, ipy_client): view = ipy_client[:] context.ipy_view = view @@ -127,17 +116,24 @@ def square(x): assert context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] -def test_context_parallel_dask(context, ipy_client): - from dask.array import from_array +@mark.parametrize('is_parallel', [True, False]) +def test_context_dask(context, ipy_client, is_parallel): + from dask.array import from_array, from_npy_stack - context.ipy_view = ipy_client[:] + if is_parallel: + context.ipy_view = ipy_client[:] def square(x): - import os - print(os.getpid()) return x * x x = np.arange(10) da = from_array(x, chunks=(3,)) res = context.map_dask_array(square, da) ae(res.compute(), x ** 2) + + # Check that we can load the dumped dask array from disk. + # The location is in the context cache dir, in a subdirectory with the + # name of the function by default. + path = op.join(context.cache_dir, 'square') + y = from_npy_stack(path) + ae(y.compute(), x ** 2) From 28941be9c78bdf716591a21af1d39477a2ddde31 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 11:25:05 +0200 Subject: [PATCH 0147/1059] Add documentation in py.utils.context --- phy/utils/context.py | 39 +++++++++++++++++++++++++++++++++ phy/utils/tests/test_context.py | 8 ++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 577a10212..f7a45f88b 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -32,10 +32,12 @@ def _iter_chunks_dask(da): def read_array(path): + """Read a .npy array.""" return np.load(path) def write_array(path, arr): + """Write an array to a .npy file.""" np.save(path, arr) @@ -44,6 +46,11 @@ def write_array(path, arr): #------------------------------------------------------------------------------ def _mapped(i, chunk, dask, func, dirpath): + """Top-level function to map. + + This function needs to be a top-level function for ipyparallel to work. + + """ # Load the array's chunk. arr = get(dask, chunk) @@ -61,10 +68,15 @@ def _mapped(i, chunk, dask, func, dirpath): class Context(object): + """Handle function cacheing and parallel map with ipyparallel.""" def __init__(self, cache_dir, ipy_view=None): + + # Make sure the cache directory exists. self.cache_dir = op.realpath(cache_dir) if not op.exists(self.cache_dir): os.makedirs(self.cache_dir) + + # Try importing joblib. try: from joblib import Memory joblib_cachedir = self._path('joblib') @@ -73,10 +85,12 @@ def __init__(self, cache_dir, ipy_view=None): logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") self._memory = None + self.ipy_view = ipy_view if ipy_view else None @property def ipy_view(self): + """ipyparallel view to parallel computing resources.""" return self._ipy_view @ipy_view.setter @@ -87,9 +101,11 @@ def ipy_view(self, value): value.use_dill() def _path(self, rel_path, *args, **kwargs): + """Get the full path to a local cache resource.""" return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) def cache(self, f): + """Cache a function using the context's cache directory.""" if self._memory is None: # pragma: no cover logger.debug("Joblib is not installed: skipping cacheing.") return @@ -97,6 +113,19 @@ def cache(self, f): def map_dask_array(self, func, da, chunks=None, name=None, dtype=None, shape=None): + """Map a function on the chunks of a dask array, and return a + new dask array. + + This function works in parallel if an `ipy_view` has been set. + + Every task loads one chunk, applies the function, and saves the + result into a `.npy` file in a cache subdirectory with the specified + name (the function's name by default). The result is a new dask array + that reads data from the npy stack in the cache subdirectory. + + The metadata of the output dask array need to be specified. + + """ try: from dask.array import Array except ImportError: # pragma: no cover @@ -136,12 +165,22 @@ def _map_ipy(self, f, *args, **kwargs): return getattr(self._ipy_view, name)(f, *args) def map_async(self, f, *args): + """Map a function asynchronously. + + Use the ipyparallel resources if available. + + """ if self._ipy_view: return self._map_ipy(f, *args, sync=False) else: return self._map_serial(f, *args) def map(self, f, *args): + """Map a function synchronously. + + Use the ipyparallel resources if available. + + """ if self._ipy_view: return self._map_ipy(f, *args, sync=True) else: diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index ea43a6167..c334a1fb1 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -13,7 +13,7 @@ from numpy.testing import assert_array_equal as ae from pytest import yield_fixture, mark -from ..context import Context, _iter_chunks_dask +from ..context import Context, _iter_chunks_dask, write_array, read_array #------------------------------------------------------------------------------ @@ -70,6 +70,12 @@ def test_iter_chunks_dask(): assert len(list(_iter_chunks_dask(da))) == 4 +def test_read_write(tempdir): + x = np.arange(10) + write_array(op.join(tempdir, 'test.npy'), x) + ae(read_array(op.join(tempdir, 'test.npy')), x) + + def test_context_cache(context): _res = [] From 5f076af0d1fdb2011814511c90a61b7d04c72d60 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 11:33:05 +0200 Subject: [PATCH 0148/1059] WIP --- phy/traces/tests/test_spike_detect.py | 18 +++++++++++------- phy/utils/context.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 34f2c61eb..6152b9abd 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -49,6 +49,16 @@ def spike_detector(request): # Test spike detection #------------------------------------------------------------------------------ +def _plot(sd, traces, spike_samples, masks): + from vispy.app import run + from phy.plot import plot_traces + plot_traces(sd.subset_traces(traces), + spike_samples=spike_samples, + masks=masks, + n_samples_per_spike=40) + run() + + def test_detect(spike_detector, traces): sd = spike_detector spike_samples, masks, _ = sd.detect(traces) @@ -63,10 +73,4 @@ def test_detect(spike_detector, traces): assert masks.ndim == 2 assert masks.shape == (n_spikes, n_channels) - # from vispy.app import run - # from phy.plot import plot_traces - # plot_traces(sd.subset_traces(traces), - # spike_samples=spike_samples, - # masks=masks, - # n_samples_per_spike=40) - # run() + # _plot(sd, traces, spike_samples, masks) diff --git a/phy/utils/context.py b/phy/utils/context.py index f7a45f88b..3d8408f73 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -80,7 +80,7 @@ def __init__(self, cache_dir, ipy_view=None): try: from joblib import Memory joblib_cachedir = self._path('joblib') - self._memory = Memory(cachedir=joblib_cachedir, verbose=0) + self._memory = Memory(cachedir=joblib_cachedir) except ImportError: # pragma: no cover logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") From 5a3febd646d430f0317da4dd3ed2af674f67f0cd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 12:16:26 +0200 Subject: [PATCH 0149/1059] WIP: refactor context dask --- phy/utils/context.py | 125 ++++++++++++++++++++++++-------- phy/utils/tests/test_context.py | 20 ++--- 2 files changed, 105 insertions(+), 40 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 3d8408f73..f9fd31e72 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -12,12 +12,17 @@ import numpy as np from six.moves.cPickle import dump +from six import string_types try: + from dask.array import Array from dask.async import get_sync as get + from dask.core import flatten except ImportError: # pragma: no cover raise Exception("dask is not installed. " "Install it with `conda install dask`.") +from phy.utils import Bunch + logger = logging.getLogger(__name__) @@ -26,7 +31,6 @@ #------------------------------------------------------------------------------ def _iter_chunks_dask(da): - from dask.core import flatten for chunk in flatten(da._keys()): yield chunk @@ -38,6 +42,7 @@ def read_array(path): def write_array(path, arr): """Write an array to a .npy file.""" + logger.debug("Write array to %s.", path) np.save(path, arr) @@ -45,7 +50,7 @@ def write_array(path, arr): # Context #------------------------------------------------------------------------------ -def _mapped(i, chunk, dask, func, dirpath): +def _mapped(i, chunk, dask, func, cache_dir, name): """Top-level function to map. This function needs to be a top-level function for ipyparallel to work. @@ -55,16 +60,89 @@ def _mapped(i, chunk, dask, func, dirpath): arr = get(dask, chunk) # Execute the function on the chunk. + logger.debug("Run %s on chunk %d", name, i) res = func(arr) - # Save the output in the cache. - if not op.exists(dirpath): - os.makedirs(dirpath) + # Save the result, and return the information about what we saved. + return _save_stack_chunk(i, res, cache_dir, name) + + +def _save_stack_chunk(i, arr, cache_dir, name): + """Save an output chunk array to a npy file, and return information about + it.""" + # Handle the case where several output arrays are returned. + if isinstance(arr, tuple): + # The name is a tuple of names for the different arrays returned. + assert isinstance(name, tuple) + assert len(arr) == len(name) + + return tuple(_save_stack_chunk(i, arr_, cache_dir, name_) + for arr_, name_ in zip(arr, name)) + + assert isinstance(name, string_types) + assert isinstance(arr, np.ndarray) + + dirpath = op.join(cache_dir, name) path = op.join(dirpath, '{}.npy'.format(i)) - write_array(path, res) + write_array(path, arr) + + # Return information about what we just saved. + return Bunch(dask_tuple=(read_array, path), + shape=arr.shape, + dtype=arr.dtype, + name=name, + dirpath=dirpath, + ) + - # Return a dask pair to load the result. - return (read_array, path) +def _save_stack_info(outputs): + """Save the npy stack info, and return one or several dask arrays from + saved npy stacks. + + The argument is a list of objects returned by `_save_stack_chunk()`. + + """ + # Handle the case where several arrays are returned, i.e. outputs is a list + # of tuples of Bunch objects. + assert len(outputs) + if isinstance(outputs[0], tuple): + return tuple(_save_stack_info(output) for output in zip(*outputs)) + + # Get metadata fields common to all chunks. + assert len(outputs) + assert isinstance(outputs[0], Bunch) + name = outputs[0].name + dirpath = outputs[0].dirpath + dtype = outputs[0].dtype + trail_shape = outputs[0].shape[1:] + + # Ensure the consistency of all chunks metadata. + assert all(output.name == name for output in outputs) + assert all(output.dirpath == dirpath for output in outputs) + assert all(output.dtype == dtype for output in outputs) + assert all(output.shape[1:] == trail_shape for output in outputs) + + # Compute the output dask array chunks and shape. + chunks = (tuple(output.shape[0] for output in outputs),) + n = sum(output.shape[0] for output in outputs) + shape = (n,) + trail_shape + + # Save the info object for dask npy stack. + with open(op.join(dirpath, 'info'), 'wb') as f: + dump({'chunks': chunks, 'dtype': dtype, 'axis': 0}, f) + + # Return the result as a dask array. + dask_tuples = tuple(output.dask_tuple for output in outputs) + dask = {(name, i): chunk for i, chunk in enumerate(dask_tuples)} + return Array(dask, name, chunks, dtype=dtype, shape=shape) + + +def _ensure_cache_dirs_exist(cache_dir, name): + if isinstance(name, tuple): + return [_ensure_cache_dirs_exist(cache_dir, name_) for name_ in name] + dirpath = op.join(cache_dir, name) + if not op.exists(dirpath): + os.makedirs(dirpath) class Context(object): @@ -111,8 +189,7 @@ def cache(self, f): return return self._memory.cache(f) - def map_dask_array(self, func, da, chunks=None, name=None, - dtype=None, shape=None): + def map_dask_array(self, func, da, name=None): """Map a function on the chunks of a dask array, and return a new dask array. @@ -123,36 +200,24 @@ def map_dask_array(self, func, da, chunks=None, name=None, name (the function's name by default). The result is a new dask array that reads data from the npy stack in the cache subdirectory. - The metadata of the output dask array need to be specified. - """ - try: - from dask.array import Array - except ImportError: # pragma: no cover - raise Exception("dask is not installed. " - "Install it with `conda install dask`.") - assert isinstance(da, Array) name = name or func.__name__ assert name != da.name - dtype = dtype or da.dtype - shape = shape or da.shape - chunks = chunks or da.chunks dask = da.dask + # Ensure the directories exist. + _ensure_cache_dirs_exist(self.cache_dir, name) + args_0 = list(_iter_chunks_dask(da)) n = len(args_0) - dirpath = op.join(self.cache_dir, name) - mapped = self.map(_mapped, range(n), args_0, [dask] * n, - [func] * n, [dirpath] * n) - - with open(op.join(dirpath, 'info'), 'wb') as f: - dump({'chunks': chunks, 'dtype': dtype, 'axis': 0}, f) + output = self.map(_mapped, range(n), args_0, [dask] * n, + [func] * n, [self.cache_dir] * n, [name] * n) - # Return the result as a dask array. - dask = {(name, i): chunk for i, chunk in enumerate(mapped)} - return Array(dask, name, chunks, dtype=dtype, shape=shape) + # output contains information about the output arrays. We use this + # information to reconstruct the final dask array. + return _save_stack_info(output) def _map_serial(self, f, *args): return [f(*arg) for arg in zip(*args)] diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index c334a1fb1..09934ef5e 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -102,12 +102,12 @@ def f(x): def test_context_map(context): - def f(x): - return x * x + def f3(x): + return x * x * x args = range(10) - assert context.map(f, args) == [x * x for x in range(10)] - assert context.map_async(f, args) == [x * x for x in range(10)] + assert context.map(f3, args) == [x ** 3 for x in range(10)] + assert context.map_async(f3, args) == [x ** 3 for x in range(10)] def test_context_parallel_map(context, ipy_client): @@ -129,17 +129,17 @@ def test_context_dask(context, ipy_client, is_parallel): if is_parallel: context.ipy_view = ipy_client[:] - def square(x): - return x * x + def f4(x): + return x * x * x * x x = np.arange(10) da = from_array(x, chunks=(3,)) - res = context.map_dask_array(square, da) - ae(res.compute(), x ** 2) + res = context.map_dask_array(f4, da) + ae(res.compute(), x ** 4) # Check that we can load the dumped dask array from disk. # The location is in the context cache dir, in a subdirectory with the # name of the function by default. - path = op.join(context.cache_dir, 'square') + path = op.join(context.cache_dir, 'f4') y = from_npy_stack(path) - ae(y.compute(), x ** 2) + ae(y.compute(), x ** 4) From d706c6e9843513ceeb16caaa287fcb06ea6da5f4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 12:24:52 +0200 Subject: [PATCH 0150/1059] Add tests for dask map returning several arrays --- phy/utils/tests/test_context.py | 38 +++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index 09934ef5e..081de6f52 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +from itertools import product import os import os.path as op @@ -122,24 +123,43 @@ def square(x): assert context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] -@mark.parametrize('is_parallel', [True, False]) -def test_context_dask(context, ipy_client, is_parallel): +@mark.parametrize('is_parallel,multiple_outputs', + product([True, False], repeat=2)) +def test_context_dask(context, ipy_client, is_parallel, multiple_outputs): + from dask.array import from_array, from_npy_stack if is_parallel: context.ipy_view = ipy_client[:] - def f4(x): - return x * x * x * x + if not multiple_outputs: + def f4(x): + return x * x * x * x + name = None + else: + def f4(x): + return x * x * x * x, x + 1 + name = ('f4', 'plus_one') x = np.arange(10) da = from_array(x, chunks=(3,)) - res = context.map_dask_array(f4, da) - ae(res.compute(), x ** 4) + res = context.map_dask_array(f4, da, name=name) + + if not multiple_outputs: + ae(res.compute(), x ** 4) + else: + ae(res[0].compute(), x ** 4) + ae(res[1].compute(), x + 1) # Check that we can load the dumped dask array from disk. # The location is in the context cache dir, in a subdirectory with the # name of the function by default. - path = op.join(context.cache_dir, 'f4') - y = from_npy_stack(path) - ae(y.compute(), x ** 4) + if not multiple_outputs: + y = from_npy_stack(op.join(context.cache_dir, 'f4')) + ae(y.compute(), x ** 4) + else: + y = from_npy_stack(op.join(context.cache_dir, 'f4')) + ae(y.compute(), x ** 4) + + y = from_npy_stack(op.join(context.cache_dir, 'plus_one')) + ae(y.compute(), x + 1) From e43dbaf7e36b31241a50463c0a437023dc1cf462 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 12:26:58 +0200 Subject: [PATCH 0151/1059] Updates in tests --- phy/utils/tests/test_context.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index 081de6f52..1d37ae1ba 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -139,26 +139,24 @@ def f4(x): else: def f4(x): return x * x * x * x, x + 1 - name = ('f4', 'plus_one') + name = ('power_four', 'plus_one') x = np.arange(10) da = from_array(x, chunks=(3,)) res = context.map_dask_array(f4, da, name=name) - if not multiple_outputs: - ae(res.compute(), x ** 4) - else: - ae(res[0].compute(), x ** 4) - ae(res[1].compute(), x + 1) - # Check that we can load the dumped dask array from disk. # The location is in the context cache dir, in a subdirectory with the # name of the function by default. if not multiple_outputs: + ae(res.compute(), x ** 4) + y = from_npy_stack(op.join(context.cache_dir, 'f4')) ae(y.compute(), x ** 4) else: - y = from_npy_stack(op.join(context.cache_dir, 'f4')) + ae(res[0].compute(), x ** 4) + ae(res[1].compute(), x + 1) + y = from_npy_stack(op.join(context.cache_dir, 'power_four')) ae(y.compute(), x ** 4) y = from_npy_stack(op.join(context.cache_dir, 'plus_one')) From 6bf2614dbe0112985a79111dd0aa4f2a50e08d7f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 13:08:33 +0200 Subject: [PATCH 0152/1059] WIP: parallel spike detection --- phy/traces/spike_detect.py | 24 +++++++++++++++++++----- phy/traces/tests/test_spike_detect.py | 26 ++++++++++++++++++++++++-- phy/utils/context.py | 4 ++-- 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 42974416c..f31b13397 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -44,12 +44,16 @@ class SpikeDetector(Configurable): def __init__(self, ctx=None): super(SpikeDetector, self).__init__() + self.set_context(ctx) + + def set_context(self, ctx): + self.ctx = ctx if not ctx or not hasattr(ctx, 'cache'): return - self.find_thresholds = ctx.cache(self.find_thresholds) - self.filter = ctx.cache(self.filter) - self.extract_spikes = ctx.cache(self.extract_spikes) - self.detect = ctx.cache(self.detect) + # self.find_thresholds = ctx.cache(self.find_thresholds) + # self.filter = ctx.cache(self.filter) + # self.extract_spikes = ctx.cache(self.extract_spikes) + # self.detect = ctx.cache(self.detect) def set_metadata(self, probe, channel_mapping=None, sample_rate=None): @@ -119,9 +123,11 @@ def filter(self, traces): high=0.95 * .5 * self.sample_rate, order=self.filter_butter_order, ) + logger.info("Filtering %d samples...", traces.shape[0]) return f(traces).astype(np.float32) def extract_spikes(self, traces_subset, thresholds=None): + thresholds = thresholds or self._thresholds assert thresholds is not None self._thresholder = Thresholder(mode=self.detect_spikes, thresholds=thresholds) @@ -137,6 +143,7 @@ def extract_spikes(self, traces_subset, thresholds=None): strong = self._thresholder.detect(traces_t, 'strong') # Run the detection. + logger.info("Detecting connected components...") join_size = self.connected_component_join_size detector = FloodFillDetector(probe_adjacency_list=self._adjacency, join_size=join_size) @@ -150,6 +157,7 @@ def extract_spikes(self, traces_subset, thresholds=None): thresholds=thresholds, ) + logger.info("Extracting %d spikes...", len(components)) s, m, w = zip(*(extractor(component, data=traces_f, data_t=traces_t) for component in components)) s = np.array(s, dtype=np.int64) @@ -169,4 +177,10 @@ def detect(self, traces, thresholds=None): thresholds = self.find_thresholds(traces) # Extract the spikes, masks, waveforms. - return self.extract_spikes(traces, thresholds=thresholds) + if not self.ctx: + return self.extract_spikes(traces, thresholds=thresholds) + else: + names = ('spike_samples', 'masks', 'waveforms') + self._thresholds = thresholds + return self.ctx.map_dask_array(self.extract_spikes, + traces, name=names) diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 6152b9abd..fbf778c39 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -10,6 +10,7 @@ from pytest import yield_fixture from phy.utils.datasets import download_test_data +from phy.utils.tests.test_context import context, ipy_client from phy.electrode import load_probe from ..spike_detect import (SpikeDetector, ) @@ -49,7 +50,7 @@ def spike_detector(request): # Test spike detection #------------------------------------------------------------------------------ -def _plot(sd, traces, spike_samples, masks): +def _plot(sd, traces, spike_samples, masks): # pragma: no cover from vispy.app import run from phy.plot import plot_traces plot_traces(sd.subset_traces(traces), @@ -59,8 +60,9 @@ def _plot(sd, traces, spike_samples, masks): run() -def test_detect(spike_detector, traces): +def test_detect_simple(spike_detector, traces): sd = spike_detector + spike_samples, masks, _ = sd.detect(traces) n_channels = sd.n_channels @@ -74,3 +76,23 @@ def test_detect(spike_detector, traces): assert masks.shape == (n_spikes, n_channels) # _plot(sd, traces, spike_samples, masks) + + +def test_detect_context(spike_detector, traces, context): + sd = spike_detector + sd.set_context(context) + # context.ipy_view = ipy_client[:] + + from dask.array import from_array + traces_da = from_array(traces, chunks=(5000, traces.shape[1])) + spike_samples, masks, _ = sd.detect(traces_da) + + n_channels = sd.n_channels + n_spikes = len(spike_samples) + + assert spike_samples.dtype == np.int64 + assert spike_samples.ndim == 1 + + assert masks.dtype == np.float32 + assert masks.ndim == 2 + assert masks.shape == (n_spikes, n_channels) diff --git a/phy/utils/context.py b/phy/utils/context.py index f9fd31e72..148619c42 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -60,7 +60,7 @@ def _mapped(i, chunk, dask, func, cache_dir, name): arr = get(dask, chunk) # Execute the function on the chunk. - logger.debug("Run %s on chunk %d", name, i) + # logger.debug("Run %s on chunk %d", name, i) res = func(arr) # Save the result, and return the information about what we saved. @@ -123,7 +123,7 @@ def _save_stack_info(outputs): assert all(output.shape[1:] == trail_shape for output in outputs) # Compute the output dask array chunks and shape. - chunks = (tuple(output.shape[0] for output in outputs),) + chunks = (tuple(output.shape[0] for output in outputs),) + trail_shape n = sum(output.shape[0] for output in outputs) shape = (n,) + trail_shape From f2220c5eab0ab02fdb374afef0efcf1e44092b26 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 13:14:42 +0200 Subject: [PATCH 0153/1059] Fix bug in context dask with ndarrays --- phy/utils/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 148619c42..59a9bd140 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -115,6 +115,7 @@ def _save_stack_info(outputs): dirpath = outputs[0].dirpath dtype = outputs[0].dtype trail_shape = outputs[0].shape[1:] + trail_ndim = len(trail_shape) # Ensure the consistency of all chunks metadata. assert all(output.name == name for output in outputs) @@ -133,7 +134,8 @@ def _save_stack_info(outputs): # Return the result as a dask array. dask_tuples = tuple(output.dask_tuple for output in outputs) - dask = {(name, i): chunk for i, chunk in enumerate(dask_tuples)} + dask = {((name, i) + (0,) * trail_ndim): chunk + for i, chunk in enumerate(dask_tuples)} return Array(dask, name, chunks, dtype=dtype, shape=shape) From da28d179ed8ce1d12bd0240787084e476227e034 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 13:46:29 +0200 Subject: [PATCH 0154/1059] Fix bug: concatenate spike samples according to the chunks --- phy/traces/spike_detect.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index f31b13397..df6e7d035 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -25,6 +25,16 @@ # SpikeDetector #------------------------------------------------------------------------------ +def _concat_spikes(s, m, w, chunks=None): + # TODO: overlap + def add_offset(x, block_id=None): + i = block_id[0] + return x + sum(chunks[0][:i]) + + s = s.map_blocks(add_offset) + return s, m, w + + class SpikeDetector(Configurable): do_filter = Bool(True) filter_low = Float(500.) @@ -182,5 +192,6 @@ def detect(self, traces, thresholds=None): else: names = ('spike_samples', 'masks', 'waveforms') self._thresholds = thresholds - return self.ctx.map_dask_array(self.extract_spikes, - traces, name=names) + s, m, w = self.ctx.map_dask_array(self.extract_spikes, + traces, name=names) + return _concat_spikes(s, m, w, chunks=traces.chunks) From 0a6abe605d8ca095e5c50927ee69648a740d3c65 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 15:26:00 +0200 Subject: [PATCH 0155/1059] WIP: concat spikes --- phy/traces/spike_detect.py | 34 ++++++++++++++++++++--- phy/traces/tests/test_spike_detect.py | 40 ++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index df6e7d035..65f121deb 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -25,11 +25,34 @@ # SpikeDetector #------------------------------------------------------------------------------ -def _concat_spikes(s, m, w, chunks=None): - # TODO: overlap +def _concat_spikes(s, m, w, chunks=None, depth=None): + + # Find where to trim the spikes in the overlapping bands. + def find_bounds(x, block_id=None): + n = chunks[0][block_id[0]] + i = np.searchsorted(x, depth) + j = np.searchsorted(x, n + depth) + return np.array([i, j]) + + # Trim the arrays. + ij = s.map_blocks(find_bounds, chunks=(2,)).compute() + onsets = ij[::2] + offsets = ij[1::2] + + def trim(x, block_id=None): + i = block_id[0] + on = onsets[i] + off = offsets[i] + return x[on:off, ...] + + s = s.map_blocks(trim) + m = m.map_blocks(trim) + w = w.map_blocks(trim) + + # Add the spike sample offsets. def add_offset(x, block_id=None): i = block_id[0] - return x + sum(chunks[0][:i]) + return x + sum(chunks[0][:i]) - depth s = s.map_blocks(add_offset) return s, m, w @@ -190,8 +213,11 @@ def detect(self, traces, thresholds=None): if not self.ctx: return self.extract_spikes(traces, thresholds=thresholds) else: + depth = int(self.chunk_overlap_seconds * self.sample_rate) + chunks = traces.chunks + traces = da.ghost.ghost(traces, depth={0: depth}, boundary={0: 0}) names = ('spike_samples', 'masks', 'waveforms') self._thresholds = thresholds s, m, w = self.ctx.map_dask_array(self.extract_spikes, traces, name=names) - return _concat_spikes(s, m, w, chunks=traces.chunks) + return _concat_spikes(s, m, w, chunks=chunks, depth=depth) diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index fbf778c39..35c419706 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -12,7 +12,7 @@ from phy.utils.datasets import download_test_data from phy.utils.tests.test_context import context, ipy_client from phy.electrode import load_probe -from ..spike_detect import (SpikeDetector, +from ..spike_detect import (SpikeDetector, _concat_spikes, ) @@ -60,6 +60,43 @@ def _plot(sd, traces, spike_samples, masks): # pragma: no cover run() +def test_detect_concat(): + import dask.async + from dask import set_options + from dask.array import Array, from_array + set_options(get=dask.async.get_sync) + + chunks = ((5, 5, 2), (3,)) + depth = 2 + # [ 0 1 2 3 4 | 5 6 7 8 9 | 10 11 ] + # [ 0 2 3 8 9 ] + + # Traces + # [ * * 0 1 2 3 4 * * | * * 5 6 7 8 9 * * | * * 10 11 ] + # [ ! ! ! ! ! ] + # Spikes + + dask = {('s', 0): np.array([0, 3, 6]), + ('s', 1): np.array([2, 7]), + ('s', 2): np.array([]), + } + chunks_spikes = ((3, 2, 0),) + s = Array(dask, 's', chunks_spikes, shape=(5,), dtype=np.int32) + m = from_array(np.arange(5 * 3).reshape((5, 3)), + chunks_spikes + (3,)) + w = from_array(np.arange(5 * 3 * 2).reshape((5, 3, 2)), + chunks_spikes + (3, 2)) + + sc, mc, wc = _concat_spikes(s, m, w, chunks=chunks, depth=depth) + sc = sc.compute() + mc = mc.compute() + wc = wc.compute() + + print(sc) + print(mc) + print(wc) + + def test_detect_simple(spike_detector, traces): sd = spike_detector @@ -96,3 +133,4 @@ def test_detect_context(spike_detector, traces, context): assert masks.dtype == np.float32 assert masks.ndim == 2 assert masks.shape == (n_spikes, n_channels) + # _plot(sd, traces, spike_samples.compute(), masks.compute()) From 01f3e0a341e9ae64c17b4c38c0c6931ab40ebd3b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 15:38:08 +0200 Subject: [PATCH 0156/1059] Use synchronous dask scheduler --- phy/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/phy/__init__.py b/phy/__init__.py index 953d17734..ed0e1b1b7 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -63,6 +63,16 @@ def add_default_handler(level='INFO'): logger.info("Activate DEBUG level.") +# Force dask to use the synchronous scheduler: we'll use ipyparallel +# manually for parallel processing. +try: + import dask.async + from dask import set_options + set_options(get=dask.async.get_sync) +except ImportError: + logger.debug("dask is not available.") + + def test(): # pragma: no cover """Run the full testing suite of phy.""" import pytest From 99db1225bc4b2701bdde262b3c931e8f2747a175 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 15:55:58 +0200 Subject: [PATCH 0157/1059] WIP: spike concat tests --- phy/traces/spike_detect.py | 42 +++++++--- phy/traces/tests/test_spike_detect.py | 111 ++++++++++++++++++-------- 2 files changed, 105 insertions(+), 48 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 65f121deb..fa59b4f1c 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -25,36 +25,52 @@ # SpikeDetector #------------------------------------------------------------------------------ -def _concat_spikes(s, m, w, chunks=None, depth=None): +def _spikes_to_keep(spikes, trace_chunks, depth): + """Find the indices of the spikes to keep given a chunked trace array.""" # Find where to trim the spikes in the overlapping bands. - def find_bounds(x, block_id=None): - n = chunks[0][block_id[0]] + def _find_bounds(x, block_id=None): + n = trace_chunks[0][block_id[0]] i = np.searchsorted(x, depth) j = np.searchsorted(x, n + depth) return np.array([i, j]) # Trim the arrays. - ij = s.map_blocks(find_bounds, chunks=(2,)).compute() - onsets = ij[::2] - offsets = ij[1::2] + ij = spikes.map_blocks(_find_bounds, chunks=(2,)).compute() + return ij[::2], ij[1::2] - def trim(x, block_id=None): + +def _trim_spikes(arr, indices): + onsets, offsets = indices + + def _trim(x, block_id=None): i = block_id[0] on = onsets[i] off = offsets[i] return x[on:off, ...] - s = s.map_blocks(trim) - m = m.map_blocks(trim) - w = w.map_blocks(trim) + # Compute the trimmed chunks. + chunks = (tuple(offsets - onsets),) + arr.chunks[1:] + return arr.map_blocks(_trim, chunks=chunks) + + +def _add_chunk_offset(arr, trace_chunks, depth): # Add the spike sample offsets. - def add_offset(x, block_id=None): + def _add_offset(x, block_id=None): i = block_id[0] - return x + sum(chunks[0][:i]) - depth + return x + sum(trace_chunks[0][:i]) - depth + + return arr.map_blocks(_add_offset) + + +def _concat_spikes(s, m, w, trace_chunks=None, depth=None): + indices = _spikes_to_keep(s, trace_chunks, depth) + s = _trim_spikes(s, indices) + m = _trim_spikes(m, indices) + w = _trim_spikes(w, indices) - s = s.map_blocks(add_offset) + s = _add_chunk_offset(s, trace_chunks, depth) return s, m, w diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 35c419706..a0d09a45e 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -7,12 +7,17 @@ #------------------------------------------------------------------------------ import numpy as np +from numpy.testing import assert_array_equal as ae from pytest import yield_fixture from phy.utils.datasets import download_test_data from phy.utils.tests.test_context import context, ipy_client from phy.electrode import load_probe -from ..spike_detect import (SpikeDetector, _concat_spikes, +from ..spike_detect import (SpikeDetector, + _spikes_to_keep, + _trim_spikes, + _add_chunk_offset, + _concat_spikes, ) @@ -60,41 +65,77 @@ def _plot(sd, traces, spike_samples, masks): # pragma: no cover run() -def test_detect_concat(): - import dask.async - from dask import set_options - from dask.array import Array, from_array - set_options(get=dask.async.get_sync) - - chunks = ((5, 5, 2), (3,)) - depth = 2 - # [ 0 1 2 3 4 | 5 6 7 8 9 | 10 11 ] - # [ 0 2 3 8 9 ] - - # Traces +class TestConcat(object): # [ * * 0 1 2 3 4 * * | * * 5 6 7 8 9 * * | * * 10 11 ] # [ ! ! ! ! ! ] - # Spikes - - dask = {('s', 0): np.array([0, 3, 6]), - ('s', 1): np.array([2, 7]), - ('s', 2): np.array([]), - } - chunks_spikes = ((3, 2, 0),) - s = Array(dask, 's', chunks_spikes, shape=(5,), dtype=np.int32) - m = from_array(np.arange(5 * 3).reshape((5, 3)), - chunks_spikes + (3,)) - w = from_array(np.arange(5 * 3 * 2).reshape((5, 3, 2)), - chunks_spikes + (3, 2)) - - sc, mc, wc = _concat_spikes(s, m, w, chunks=chunks, depth=depth) - sc = sc.compute() - mc = mc.compute() - wc = wc.compute() - - print(sc) - print(mc) - print(wc) + # spike_samples: 1, 4, 5 + + def setup(self): + from dask.array import Array, from_array + + self.trace_chunks = ((5, 5, 2), (3,)) + self.depth = 2 + + # Create the chunked spike_samples array. + dask = {('spike_samples', 0): np.array([0, 3, 6]), + ('spike_samples', 1): np.array([2, 7]), + ('spike_samples', 2): np.array([]), + } + spikes_chunks = ((3, 2, 0),) + s = Array(dask, 'spike_samples', spikes_chunks, + shape=(5,), dtype=np.int32) + self.spike_samples = s + # Indices of the spikes that are kept (outside of overlapping bands). + self.spike_indices = np.array([1, 2, 3]) + + assert len(self.spike_samples.compute()) == 5 + + self.masks = from_array(np.arange(5 * 3).reshape((5, 3)), + spikes_chunks + (3,)) + self.waveforms = from_array(np.arange(5 * 3 * 2).reshape((5, 3, 2)), + spikes_chunks + (3, 2)) + + def test_spikes_to_keep(self): + indices = _spikes_to_keep(self.spike_samples, + self.trace_chunks, + self.depth) + onsets, offsets = indices + assert list(zip(onsets, offsets)) == [(1, 3), (0, 1), (0, 0)] + + def test_trim_spikes(self): + indices = _spikes_to_keep(self.spike_samples, + self.trace_chunks, + self.depth) + + # Trim the spikes. + spikes_trimmed = _trim_spikes(self.spike_samples, indices) + ae(spikes_trimmed.compute(), [3, 6, 2]) + + def test_add_chunk_offset(self): + indices = _spikes_to_keep(self.spike_samples, + self.trace_chunks, + self.depth) + spikes_trimmed = _trim_spikes(self.spike_samples, indices) + + # Add the chunk offsets to the spike samples. + self.spikes = _add_chunk_offset(spikes_trimmed, + self.trace_chunks, self.depth) + ae(self.spikes, [1, 4, 5]) + + def test_concat(self): + sc, mc, wc = _concat_spikes(self.spike_samples, + self.masks, + self.waveforms, + trace_chunks=self.trace_chunks, + depth=self.depth, + ) + sc = sc.compute() + mc = mc.compute() + wc = wc.compute() + + ae(sc, [1, 4, 5]) + ae(mc, self.masks.compute()[self.spike_indices]) + ae(wc, self.waveforms.compute()[self.spike_indices]) def test_detect_simple(spike_detector, traces): @@ -115,7 +156,7 @@ def test_detect_simple(spike_detector, traces): # _plot(sd, traces, spike_samples, masks) -def test_detect_context(spike_detector, traces, context): +def atest_detect_context(spike_detector, traces, context): sd = spike_detector sd.set_context(context) # context.ipy_view = ipy_client[:] From 0c10c330d658264200090c44caf6add1930edf10 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 16:08:27 +0200 Subject: [PATCH 0158/1059] WIP: parallel spike detection --- phy/traces/spike_detect.py | 25 ++++++++++++++++++++++--- phy/traces/tests/test_spike_detect.py | 6 ++---- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index fa59b4f1c..3e7516081 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -229,11 +229,30 @@ def detect(self, traces, thresholds=None): if not self.ctx: return self.extract_spikes(traces, thresholds=thresholds) else: + import dask.array as da + + # Chunking parameters. + chunk_size = int(self.chunk_size_seconds * self.sample_rate) depth = int(self.chunk_overlap_seconds * self.sample_rate) - chunks = traces.chunks - traces = da.ghost.ghost(traces, depth={0: depth}, boundary={0: 0}) + trace_chunks = (chunk_size, traces.shape[1]) + + # Chunk the data. traces is now a dask Array. + traces = da.from_array(traces, chunks=trace_chunks) + trace_chunks = traces.chunks + + # Add overlapping band in traces. + traces = da.ghost.ghost(traces, + depth={0: depth}, boundary={0: 0}) + names = ('spike_samples', 'masks', 'waveforms') self._thresholds = thresholds + + # Run the spike extraction procedure in parallel. s, m, w = self.ctx.map_dask_array(self.extract_spikes, traces, name=names) - return _concat_spikes(s, m, w, chunks=chunks, depth=depth) + + # Return the concatenated spike samples, masks, waveforms, as + # dask arrays reading from the cached .npy files. + return _concat_spikes(s, m, w, + trace_chunks=trace_chunks, + depth=depth) diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index a0d09a45e..78fe1f504 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -156,14 +156,12 @@ def test_detect_simple(spike_detector, traces): # _plot(sd, traces, spike_samples, masks) -def atest_detect_context(spike_detector, traces, context): +def test_detect_context(spike_detector, traces, context): sd = spike_detector sd.set_context(context) # context.ipy_view = ipy_client[:] - from dask.array import from_array - traces_da = from_array(traces, chunks=(5000, traces.shape[1])) - spike_samples, masks, _ = sd.detect(traces_da) + spike_samples, masks, _ = sd.detect(traces) n_channels = sd.n_channels n_spikes = len(spike_samples) From dffbefdf4c2fdc4c1995b58ba99a2ba611a17273 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 16:28:09 +0200 Subject: [PATCH 0159/1059] Add support for multiple arguments in map_dask_array() --- phy/utils/context.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 59a9bd140..fc9aff950 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -50,7 +50,7 @@ def write_array(path, arr): # Context #------------------------------------------------------------------------------ -def _mapped(i, chunk, dask, func, cache_dir, name): +def _mapped(i, chunk, dask, func, args, cache_dir, name): """Top-level function to map. This function needs to be a top-level function for ipyparallel to work. @@ -61,7 +61,7 @@ def _mapped(i, chunk, dask, func, cache_dir, name): # Execute the function on the chunk. # logger.debug("Run %s on chunk %d", name, i) - res = func(arr) + res = func(arr, *args) # Save the result, and return the information about what we saved. return _save_stack_chunk(i, res, cache_dir, name) @@ -191,7 +191,7 @@ def cache(self, f): return return self._memory.cache(f) - def map_dask_array(self, func, da, name=None): + def map_dask_array(self, func, da, *args, **kwargs): """Map a function on the chunks of a dask array, and return a new dask array. @@ -202,10 +202,14 @@ def map_dask_array(self, func, da, name=None): name (the function's name by default). The result is a new dask array that reads data from the npy stack in the cache subdirectory. + The mapped function can return several arrays as a tuple. In this case, + `name` must also be a tuple, and the output of this function is a + tuple of dask arrays. + """ assert isinstance(da, Array) - name = name or func.__name__ + name = kwargs.get('name', None) or func.__name__ assert name != da.name dask = da.dask @@ -215,7 +219,8 @@ def map_dask_array(self, func, da, name=None): args_0 = list(_iter_chunks_dask(da)) n = len(args_0) output = self.map(_mapped, range(n), args_0, [dask] * n, - [func] * n, [self.cache_dir] * n, [name] * n) + [func] * n, [args] * n, + [self.cache_dir] * n, [name] * n) # output contains information about the output arrays. We use this # information to reconstruct the final dask array. From 8ec0c54e40a05836ddadaf52cbb1f427c0ef0321 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 16:34:27 +0200 Subject: [PATCH 0160/1059] Update context test --- phy/utils/tests/test_context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index 1d37ae1ba..ea385b5cb 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -133,17 +133,17 @@ def test_context_dask(context, ipy_client, is_parallel, multiple_outputs): context.ipy_view = ipy_client[:] if not multiple_outputs: - def f4(x): + def f4(x, onset): return x * x * x * x name = None else: - def f4(x): - return x * x * x * x, x + 1 + def f4(x, onset): + return x * x * x * x + onset, x + 1 name = ('power_four', 'plus_one') x = np.arange(10) da = from_array(x, chunks=(3,)) - res = context.map_dask_array(f4, da, name=name) + res = context.map_dask_array(f4, da, 0, name=name) # Check that we can load the dumped dask array from disk. # The location is in the context cache dir, in a subdirectory with the From 3eb06869e19982903e621efa0da1b7e7c6bd46c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 16:42:52 +0200 Subject: [PATCH 0161/1059] WIP: parallel spike detection --- phy/traces/spike_detect.py | 5 +++++ phy/traces/tests/test_spike_detect.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 3e7516081..b16e0eba5 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -256,3 +256,8 @@ def detect(self, traces, thresholds=None): return _concat_spikes(s, m, w, trace_chunks=trace_chunks, depth=depth) + + def __getstate__(self): + state = self.__dict__.copy() + state['ctx'] = None + return state diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 78fe1f504..6ddf45416 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -156,10 +156,10 @@ def test_detect_simple(spike_detector, traces): # _plot(sd, traces, spike_samples, masks) -def test_detect_context(spike_detector, traces, context): +def test_detect_context(spike_detector, traces, context, ipy_client): sd = spike_detector sd.set_context(context) - # context.ipy_view = ipy_client[:] + context.ipy_view = ipy_client[:] spike_samples, masks, _ = sd.detect(traces) From d432e2f015c5e05e97b9f669d2cf41bc1a09682c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 16:48:11 +0200 Subject: [PATCH 0162/1059] Refactored context tests --- phy/utils/tests/test_context.py | 70 ++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index ea385b5cb..7cba80bbd 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -21,12 +21,6 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture(scope='function') -def context(tempdir): - ctx = Context('{}/cache/'.format(tempdir)) - yield ctx - - @yield_fixture(scope='module') def ipy_client(): @@ -47,6 +41,19 @@ def iptest_stdstreams_fileno(): ipyparallel.tests.teardown() +@yield_fixture(scope='function') +def context(tempdir): + ctx = Context('{}/cache/'.format(tempdir)) + yield ctx + + +@yield_fixture(scope='function') +def parallel_context(tempdir, ipy_client): + ctx = Context('{}/cache/'.format(tempdir)) + ctx.ipy_view = ipy_client[:] + yield ctx + + #------------------------------------------------------------------------------ # ipyparallel tests #------------------------------------------------------------------------------ @@ -60,17 +67,9 @@ def test_client_2(ipy_client): #------------------------------------------------------------------------------ -# Test context +# Test utils and cache #------------------------------------------------------------------------------ -def test_iter_chunks_dask(): - from dask.array import from_array - - x = np.arange(10) - da = from_array(x, chunks=(3,)) - assert len(list(_iter_chunks_dask(da))) == 4 - - def test_read_write(tempdir): x = np.arange(10) write_array(op.join(tempdir, 'test.npy'), x) @@ -102,6 +101,10 @@ def f(x): assert len(_res) == 2 +#------------------------------------------------------------------------------ +# Test map +#------------------------------------------------------------------------------ + def test_context_map(context): def f3(x): return x * x * x @@ -111,26 +114,29 @@ def f3(x): assert context.map_async(f3, args) == [x ** 3 for x in range(10)] -def test_context_parallel_map(context, ipy_client): - view = ipy_client[:] - context.ipy_view = view - assert context.ipy_view == view +def test_context_parallel_map(parallel_context): def square(x): return x * x - assert context.map(square, [1, 2, 3]) == [1, 4, 9] - assert context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] + assert parallel_context.map(square, [1, 2, 3]) == [1, 4, 9] + assert parallel_context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] -@mark.parametrize('is_parallel,multiple_outputs', - product([True, False], repeat=2)) -def test_context_dask(context, ipy_client, is_parallel, multiple_outputs): +#------------------------------------------------------------------------------ +# Test context dask +#------------------------------------------------------------------------------ - from dask.array import from_array, from_npy_stack +def test_iter_chunks_dask(): + from dask.array import from_array - if is_parallel: - context.ipy_view = ipy_client[:] + x = np.arange(10) + da = from_array(x, chunks=(3,)) + assert len(list(_iter_chunks_dask(da))) == 4 + + +def _test_context_dask(context, multiple_outputs): + from dask.array import from_array, from_npy_stack if not multiple_outputs: def f4(x, onset): @@ -161,3 +167,13 @@ def f4(x, onset): y = from_npy_stack(op.join(context.cache_dir, 'plus_one')) ae(y.compute(), x + 1) + + +@mark.parametrize('multiple_outputs', [True, False]) +def test_context_dask_simple(context, multiple_outputs): + _test_context_dask(context, multiple_outputs) + + +@mark.parametrize('multiple_outputs', [True, False]) +def test_context_dask_parallel(parallel_context, multiple_outputs): + _test_context_dask(parallel_context, multiple_outputs) From 23cbc949e85aa77d23b99d597d0f9856e2cd77f0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 16:54:25 +0200 Subject: [PATCH 0163/1059] Minor refactoring in context --- phy/utils/context.py | 3 ++- phy/utils/tests/test_context.py | 36 ++++++++++----------------------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index fc9aff950..34e3b97da 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -245,7 +245,8 @@ def map_async(self, f, *args): if self._ipy_view: return self._map_ipy(f, *args, sync=False) else: - return self._map_serial(f, *args) + raise RuntimeError("Asynchronous execution requires an " + "ipyparallel context.") def map(self, f, *args): """Map a function synchronously. diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index 7cba80bbd..c8b80914a 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -47,10 +47,12 @@ def context(tempdir): yield ctx -@yield_fixture(scope='function') -def parallel_context(tempdir, ipy_client): +@yield_fixture(scope='function', params=[False, True]) +def parallel_context(tempdir, ipy_client, request): + """Parallel and non-parallel context.""" ctx = Context('{}/cache/'.format(tempdir)) - ctx.ipy_view = ipy_client[:] + if request.param: + ctx.ipy_view = ipy_client[:] yield ctx @@ -105,22 +107,14 @@ def f(x): # Test map #------------------------------------------------------------------------------ -def test_context_map(context): - def f3(x): - return x * x * x - - args = range(10) - assert context.map(f3, args) == [x ** 3 for x in range(10)] - assert context.map_async(f3, args) == [x ** 3 for x in range(10)] - - -def test_context_parallel_map(parallel_context): +def test_context_map(parallel_context): def square(x): return x * x assert parallel_context.map(square, [1, 2, 3]) == [1, 4, 9] - assert parallel_context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] + if parallel_context.ipy_view: + assert parallel_context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] #------------------------------------------------------------------------------ @@ -135,8 +129,10 @@ def test_iter_chunks_dask(): assert len(list(_iter_chunks_dask(da))) == 4 -def _test_context_dask(context, multiple_outputs): +@mark.parametrize('multiple_outputs', [True, False]) +def test_context_dask(parallel_context, multiple_outputs): from dask.array import from_array, from_npy_stack + context = parallel_context if not multiple_outputs: def f4(x, onset): @@ -167,13 +163,3 @@ def f4(x, onset): y = from_npy_stack(op.join(context.cache_dir, 'plus_one')) ae(y.compute(), x + 1) - - -@mark.parametrize('multiple_outputs', [True, False]) -def test_context_dask_simple(context, multiple_outputs): - _test_context_dask(context, multiple_outputs) - - -@mark.parametrize('multiple_outputs', [True, False]) -def test_context_dask_parallel(parallel_context, multiple_outputs): - _test_context_dask(parallel_context, multiple_outputs) From 1353e389f06b7e28cb07bddb66bf09a88cc1bb79 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 17:06:41 +0200 Subject: [PATCH 0164/1059] Picklable Context --- phy/utils/context.py | 19 +++++++++++++++++-- phy/utils/tests/test_context.py | 11 +++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/phy/utils/context.py b/phy/utils/context.py index 34e3b97da..77936bfe1 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -156,6 +156,10 @@ def __init__(self, cache_dir, ipy_view=None): if not op.exists(self.cache_dir): os.makedirs(self.cache_dir) + self._set_memory(cache_dir) + self.ipy_view = ipy_view if ipy_view else None + + def _set_memory(self, cache_dir): # Try importing joblib. try: from joblib import Memory @@ -166,8 +170,6 @@ def __init__(self, cache_dir, ipy_view=None): "Install it with `conda install joblib`.") self._memory = None - self.ipy_view = ipy_view if ipy_view else None - @property def ipy_view(self): """ipyparallel view to parallel computing resources.""" @@ -258,3 +260,16 @@ def map(self, f, *args): return self._map_ipy(f, *args, sync=True) else: return self._map_serial(f, *args) + + def __getstate__(self): + """Make sure that this class is picklable.""" + state = self.__dict__.copy() + state['_memory'] = None + state['_ipy_view'] = None + return state + + def __setstate__(self, state): + """Make sure that this class is picklable.""" + self.__dict__ = state + # Recreate the joblib Memory instance. + self._set_memory(state['cache_dir']) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index c8b80914a..ed1808553 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -13,6 +13,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae from pytest import yield_fixture, mark +from six.moves import cPickle from ..context import Context, _iter_chunks_dask, write_array, read_array @@ -103,6 +104,16 @@ def f(x): assert len(_res) == 2 +def test_pickle_cache(tempdir, parallel_context): + """Make sure the Context is picklable.""" + with open(op.join(tempdir, 'test.pkl'), 'wb') as f: + cPickle.dump(parallel_context, f) + with open(op.join(tempdir, 'test.pkl'), 'rb') as f: + ctx = cPickle.load(f) + assert isinstance(ctx, Context) + assert ctx.cache_dir == parallel_context.cache_dir + + #------------------------------------------------------------------------------ # Test map #------------------------------------------------------------------------------ From 8e788e6b237ef44b8f2d89198910c0aa18bdbde5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 17:08:13 +0200 Subject: [PATCH 0165/1059] Increase coverage --- phy/utils/tests/test_context.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/phy/utils/tests/test_context.py b/phy/utils/tests/test_context.py index ed1808553..b4f326a1c 100644 --- a/phy/utils/tests/test_context.py +++ b/phy/utils/tests/test_context.py @@ -6,13 +6,12 @@ # Imports #------------------------------------------------------------------------------ -from itertools import product import os import os.path as op import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import yield_fixture, mark +from pytest import yield_fixture, mark, raises from six.moves import cPickle from ..context import Context, _iter_chunks_dask, write_array, read_array @@ -124,7 +123,10 @@ def square(x): return x * x assert parallel_context.map(square, [1, 2, 3]) == [1, 4, 9] - if parallel_context.ipy_view: + if not parallel_context.ipy_view: + with raises(RuntimeError): + parallel_context.map_async(square, [1, 2, 3]) + else: assert parallel_context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] From 3d89be1d15f1ef66346b7d11e6b948cbb1f43ed4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 17:09:08 +0200 Subject: [PATCH 0166/1059] WIP --- phy/traces/spike_detect.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index b16e0eba5..3e7516081 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -256,8 +256,3 @@ def detect(self, traces, thresholds=None): return _concat_spikes(s, m, w, trace_chunks=trace_chunks, depth=depth) - - def __getstate__(self): - state = self.__dict__.copy() - state['ctx'] = None - return state From a35114822a72534e059b950aed2e481ff9cce6bb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 17:15:58 +0200 Subject: [PATCH 0167/1059] Tests pass --- phy/__init__.py | 2 +- phy/traces/tests/test_spike_detect.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index ed0e1b1b7..bc4bb4b5b 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -69,7 +69,7 @@ def add_default_handler(level='INFO'): import dask.async from dask import set_options set_options(get=dask.async.get_sync) -except ImportError: +except ImportError: # pragma: no cover logger.debug("dask is not available.") diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 6ddf45416..b0d42661a 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -11,7 +11,8 @@ from pytest import yield_fixture from phy.utils.datasets import download_test_data -from phy.utils.tests.test_context import context, ipy_client +from phy.utils.tests.test_context import (ipy_client, context, # noqa + parallel_context) from phy.electrode import load_probe from ..spike_detect import (SpikeDetector, _spikes_to_keep, @@ -156,10 +157,9 @@ def test_detect_simple(spike_detector, traces): # _plot(sd, traces, spike_samples, masks) -def test_detect_context(spike_detector, traces, context, ipy_client): +def test_detect_context(spike_detector, traces, parallel_context): # noqa sd = spike_detector - sd.set_context(context) - context.ipy_view = ipy_client[:] + sd.set_context(parallel_context) spike_samples, masks, _ = sd.detect(traces) From ef8b11eef9b8f5dfe8fb9eb39370449970c2be10 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 30 Sep 2015 20:45:21 +0200 Subject: [PATCH 0168/1059] Add travis dependencies --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index a473fb4b8..32b798762 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,7 +21,7 @@ install: # Create the environment. - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION - source activate testenv - - conda install pip numpy vispy matplotlib scipy h5py pyqt pyzmq ipython requests six + - conda install pip numpy vispy matplotlib scipy h5py pyqt ipython requests six dill ipyparallel joblib dask # NOTE: cython is only temporarily needed for building KK2. # Once we have KK2 binary builds on binstar we can remove this dependency. - conda install cython && pip install klustakwik2 From 1cbeb060003eb7ef19fa8901b36fe92ee0918e1d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 12:24:31 +0200 Subject: [PATCH 0169/1059] Move read/write array functions --- phy/traces/spike_detect.py | 6 ++++- phy/utils/array.py | 41 ++++++++++++++--------------------- phy/utils/context.py | 12 +--------- phy/utils/tests/test_array.py | 30 +++++++------------------ 4 files changed, 30 insertions(+), 59 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 3e7516081..361416c1a 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -22,7 +22,7 @@ #------------------------------------------------------------------------------ -# SpikeDetector +# Chunking-related utility functions #------------------------------------------------------------------------------ def _spikes_to_keep(spikes, trace_chunks, depth): @@ -74,6 +74,10 @@ def _concat_spikes(s, m, w, trace_chunks=None, depth=None): return s, m, w +#------------------------------------------------------------------------------ +# SpikeDetector +#------------------------------------------------------------------------------ + class SpikeDetector(Configurable): do_filter = Bool(True) filter_low = Float(500.) diff --git a/phy/utils/array.py b/phy/utils/array.py index 653a49bb4..3f968a510 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -171,31 +171,22 @@ def _in_polygon(points, polygon): # I/O functions # ----------------------------------------------------------------------------- -def _save_arrays(path, arrays): - """Save multiple arrays in a single file by concatenating them along - the first axis. - - A second array is stored with the offsets. - - """ - assert path.endswith('.npy') - path = op.splitext(path)[0] - offsets = np.cumsum([arr.shape[0] for arr in arrays]) - if not len(arrays): - return - concat = np.concatenate(arrays, axis=0) - np.save(path + '.npy', concat) - np.save(path + '.offsets.npy', offsets) - - -def _load_arrays(path): - assert path.endswith('.npy') - if not op.exists(path): - return [] - path = op.splitext(path)[0] - concat = np.load(path + '.npy') - offsets = np.load(path + '.offsets.npy') - return np.split(concat, offsets[:-1], axis=0) +def read_array(path): + """Read a .npy array.""" + file_ext = op.splitext(path)[1] + if file_ext == '.npy': + return np.load(path) + raise NotImplementedError("The file extension `{}` ".format(file_ext) + + "is not currently supported.") + + +def write_array(path, arr): + """Write an array to a .npy file.""" + file_ext = op.splitext(path)[1] + if file_ext == '.npy': + return np.save(path, arr) + raise NotImplementedError("The file extension `{}` ".format(file_ext) + + "is not currently supported.") # ----------------------------------------------------------------------------- diff --git a/phy/utils/context.py b/phy/utils/context.py index 77936bfe1..074bdc550 100644 --- a/phy/utils/context.py +++ b/phy/utils/context.py @@ -22,6 +22,7 @@ "Install it with `conda install dask`.") from phy.utils import Bunch +from phy.utils.array import read_array, write_array logger = logging.getLogger(__name__) @@ -35,17 +36,6 @@ def _iter_chunks_dask(da): yield chunk -def read_array(path): - """Read a .npy array.""" - return np.load(path) - - -def write_array(path, arr): - """Write an array to a .npy file.""" - logger.debug("Write array to %s.", path) - np.save(path, arr) - - #------------------------------------------------------------------------------ # Context #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py index 03aae16fe..1756a854e 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/utils/tests/test_array.py @@ -9,7 +9,7 @@ import os.path as op import numpy as np -from pytest import raises, mark +from pytest import raises from .._types import _as_array from ..array import (_unique, @@ -27,8 +27,8 @@ get_excerpts, _range_from_slice, _pad, - _load_arrays, - _save_arrays, + read_array, + write_array, ) from ..testing import _assert_equal as ae from ...io.mock import artificial_spike_clusters @@ -182,28 +182,14 @@ def test_in_polygon(): #------------------------------------------------------------------------------ -# Test I/O functions +# Test read/save #------------------------------------------------------------------------------ -@mark.parametrize('n', [20, 0]) -def test_load_save_arrays(tempdir, n): +def test_read_write(tempdir): + arr = np.arange(10).astype(np.float32) path = op.join(tempdir, 'test.npy') - # Random arrays. - arrays = [] - for i in range(n): - size = np.random.randint(low=3, high=50) - assert size > 0 - arr = np.random.rand(size, 3).astype(np.float32) - arrays.append(arr) - _save_arrays(path, arrays) - - arrays_loaded = _load_arrays(path) - - assert len(arrays) == len(arrays_loaded) - for arr, arr_loaded in zip(arrays, arrays_loaded): - assert arr.shape == arr_loaded.shape - assert arr.dtype == arr_loaded.dtype - ae(arr, arr_loaded) + write_array(path, arr) + ae(read_array(path), arr) #------------------------------------------------------------------------------ From 7ef1ec37067a2aa2da99484857f1f654796871ba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 13:09:08 +0200 Subject: [PATCH 0170/1059] Updated read_kwd() function --- phy/io/__init__.py | 1 - phy/io/h5.py | 241 --------------------------------- phy/io/tests/test_h5.py | 256 ------------------------------------ phy/io/tests/test_traces.py | 23 ++-- phy/io/traces.py | 42 +++--- 5 files changed, 37 insertions(+), 526 deletions(-) delete mode 100644 phy/io/h5.py delete mode 100644 phy/io/tests/test_h5.py diff --git a/phy/io/__init__.py b/phy/io/__init__.py index 17b52558c..d068cad04 100644 --- a/phy/io/__init__.py +++ b/phy/io/__init__.py @@ -3,5 +3,4 @@ """Input/output.""" -from .h5 import File, open_h5 from .traces import read_dat, read_kwd diff --git a/phy/io/h5.py b/phy/io/h5.py deleted file mode 100644 index 24835d0c6..000000000 --- a/phy/io/h5.py +++ /dev/null @@ -1,241 +0,0 @@ -# -*- coding: utf-8 -*- - -"""HDF5 input and output.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import logging - -import h5py - -logger = logging.getLogger(__name__) - - -#------------------------------------------------------------------------------ -# HDF5 utility functions -#------------------------------------------------------------------------------ - -def _split_hdf5_path(path): - """Return the group and dataset of the path.""" - # Make sure the path starts with a leading slash. - if not path.startswith('/'): - raise ValueError(("The HDF5 path '{0:s}' should start with a " - "leading slash '/'.").format(path)) - if '//' in path: - raise ValueError(("There should be no double slash in the HDF5 path " - "'{0:s}'.").format(path)) - # Handle the special case '/'. - if path == '/': - return '/', '' - # Temporarily remove the leading '/', we'll add it later (otherwise split - # and join will mess it up). - path = path[1:] - # We split the path by slash and we get the head and tail. - _split = path.split('/') - group_path = '/'.join(_split[:-1]) - name = _split[-1] - # Make some consistency checks. - assert not group_path.endswith('/') - assert '/' not in name - # Finally, we add the leading slash at the beginning of the group path. - return '/' + group_path, name - - -def _check_hdf5_path(h5_file, path): - """Check that an HDF5 path exists in a file.""" - if path not in h5_file: - raise ValueError("{path} doesn't exist.".format(path=path)) - - -#------------------------------------------------------------------------------ -# File class -#------------------------------------------------------------------------------ - -class File(object): - def __init__(self, filename, mode=None): - if mode is None: - mode = 'r' - self.filename = filename - self.mode = mode - self._h5py_file = None - - # Open and close - #-------------------------------------------------------------------------- - - @property - def h5py_file(self): - """Native h5py file handle.""" - return self._h5py_file - - def is_open(self): - return self._h5py_file is not None - - def open(self, mode=None): - self.mode = mode if mode is not None else None - if not self.is_open(): - self._h5py_file = h5py.File(self.filename, self.mode) - - def close(self): - if self.is_open(): - self._h5py_file.close() - self._h5py_file = None - - # Datasets - #-------------------------------------------------------------------------- - - def read(self, path): - """Read an HDF5 dataset, given its HDF5 path in the file.""" - _check_hdf5_path(self._h5py_file, path) - return self._h5py_file[path] - - def write(self, path, array=None, dtype=None, shape=None, overwrite=False): - """Write a NumPy array in the file. - - Parameters - ---------- - path : str - Full HDF5 path to the dataset to create. - array : ndarray - Array to write in the file. - dtype : dtype - If `array` is None, the dtype of the array. - shape : tuple - If `array` is None, the shape of the array. - overwrite : bool - If False, raise an error if the dataset already exists. Defaults - to False. - - """ - # Get the group path and the dataset name. - group_path, dset_name = _split_hdf5_path(path) - - # If the parent group doesn't already exist, create it. - if group_path not in self._h5py_file: - self._h5py_file.create_group(group_path) - - group = self._h5py_file[group_path] - - # Check that the dataset does not already exist. - if path in self._h5py_file: - if overwrite: - # Force rewriting the dataset if 'overwrite' is True. - del self._h5py_file[path] - else: - # Otherwise, raise an error. - raise ValueError(("The dataset '{0:s}' already exists." - ).format(path)) - - if array is not None: - return group.create_dataset(dset_name, data=array) - else: - assert dtype - assert shape - return group.create_dataset(dset_name, dtype=dtype, shape=shape) - - # Copy and rename - #-------------------------------------------------------------------------- - - def _check_move_copy(self, path, new_path): # pragma: no cover - if not self.exists(path): - raise ValueError("'{0}' doesn't exist.".format(path)) - if self.exists(new_path): - raise ValueError("'{0}' already exist.".format(new_path)) - - def move(self, path, new_path): - """Move a group or dataset to another location.""" - self._check_move_copy(path, new_path) - self._h5py_file.move(path, new_path) - - def copy(self, path, new_path): - """Copy a group or dataset to another location.""" - self._check_move_copy(path, new_path) - self._h5py_file.copy(path, new_path) - - def delete(self, path): - """Delete a group or dataset.""" - if not path.startswith('/'): - path = '/' + path - if not self.exists(path): - raise ValueError("'{0}' doesn't exist.".format(path)) - - path, name = _split_hdf5_path(path) - parent = self.read(path) - del parent[name] - - def attrs(self, path='/'): - """Return the list of attributes at the given path.""" - if path in self._h5py_file: - return sorted(self._h5py_file[path].attrs) - else: - return [] - - def has_attr(self, path, attr_name): - """Return whether an attribute exists at a given path.""" - if path not in self._h5py_file: - return False - else: - return attr_name in self._h5py_file[path].attrs - - # Children - #-------------------------------------------------------------------------- - - def children(self, path='/'): - """Return the list of children of a given node.""" - return sorted(self._h5py_file[path].keys()) - - def groups(self, path='/'): - """Return the list of groups under a given node.""" - return [key for key in self.children(path) - if isinstance(self._h5py_file[path + '/' + key], - h5py.Group)] - - def datasets(self, path='/'): - """Return the list of datasets under a given node.""" - return [key for key in self.children(path) - if isinstance(self._h5py_file[path + '/' + key], - h5py.Dataset)] - - # Miscellaneous properties - #-------------------------------------------------------------------------- - - def exists(self, path): - return path in self._h5py_file - - def _print_node_info(self, name, node): - """Print node information.""" - info = ('/' + name).ljust(50) - if isinstance(node, h5py.Group): - pass - elif isinstance(node, h5py.Dataset): - info += str(node.shape).ljust(20) - info += str(node.dtype).ljust(8) - print(info) - - def describe(self): - """Display the list of all groups and datasets in the file.""" - if not self.is_open(): - raise IOError("Cannot display file information because the file" - " '{0:s}' is not open.".format(self.filename)) - self._h5py_file['/'].visititems(self._print_node_info) - - # Context manager - #-------------------------------------------------------------------------- - - def __contains__(self, path): - return path in self._h5py_file - - def __enter__(self): - self.open() - return self - - def __exit__(self, type, value, tb): - self.close() - - -def open_h5(filename, mode=None): - """Open an HDF5 file and return a File instance.""" - file = File(filename, mode=mode) - file.open() - return file diff --git a/phy/io/tests/test_h5.py b/phy/io/tests/test_h5.py deleted file mode 100644 index 7696790dd..000000000 --- a/phy/io/tests/test_h5.py +++ /dev/null @@ -1,256 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of HDF5 routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ...utils.testing import captured_output -from ..h5 import open_h5, _split_hdf5_path - - -#------------------------------------------------------------------------------ -# Utility test routines -#------------------------------------------------------------------------------ - -def _create_test_file(dirpath): - filename = op.join(dirpath, '_test.h5') - with open_h5(filename, 'w') as tempfile: - # Create a random dataset using h5py directly. - h5file = tempfile.h5py_file - h5file.create_dataset('ds1', (10,), dtype=np.float32) - group = h5file.create_group('/mygroup') - h5file.create_dataset('/mygroup/ds2', (10,), dtype=np.int8) - group.attrs['myattr'] = 123 - return tempfile.filename - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_split_hdf5_path(): - # The path should always start with a leading '/'. - with raises(ValueError): - _split_hdf5_path('') - with raises(ValueError): - _split_hdf5_path('path') - - h, t = _split_hdf5_path('/') - assert (h == '/') and (t == '') - - h, t = _split_hdf5_path('/path') - assert (h == '/') and (t == 'path') - - h, t = _split_hdf5_path('/path/') - assert (h == '/path') and (t == '') - - h, t = _split_hdf5_path('/path/to') - assert (h == '/path') and (t == 'to') - - h, t = _split_hdf5_path('/path/to/') - assert (h == '/path/to') and (t == '') - - # Check that invalid paths raise errors. - with raises(ValueError): - _split_hdf5_path('path/') - with raises(ValueError): - _split_hdf5_path('/path//') - with raises(ValueError): - _split_hdf5_path('/path//to') - - -def test_h5_read(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - # Test close() method. - f = open_h5(filename) - assert f.is_open() - f.close() - assert not f.is_open() - with raises(IOError): - f.describe() - - # Open the test HDF5 file. - with open_h5(filename) as f: - assert f.is_open() - - assert f.children() == ['ds1', 'mygroup'] - assert f.groups() == ['mygroup'] - assert f.datasets() == ['ds1'] - assert f.attrs('/mygroup') == ['myattr'] - assert f.attrs('/mygroup_nonexisting') == [] - - # Check dataset ds1. - ds1 = f.read('/ds1')[:] - assert isinstance(ds1, np.ndarray) - assert ds1.shape == (10,) - assert ds1.dtype == np.float32 - - # Check dataset ds2. - ds2 = f.read('/mygroup/ds2')[:] - assert isinstance(ds2, np.ndarray) - assert ds2.shape == (10,) - assert ds2.dtype == np.int8 - - # Check HDF5 group attribute. - assert f.has_attr('/mygroup', 'myattr') - assert not f.has_attr('/mygroup', 'myattr_bis') - assert not f.has_attr('/mygroup_bis', 'myattr_bis') - - # Check that errors are raised when the paths are invalid. - with raises(Exception): - f.read('//path') - with raises(Exception): - f.read('/path//') - with raises(ValueError): - f.read('/nonexistinggroup') - with raises(ValueError): - f.read('/nonexistinggroup/ds34') - - assert not f.is_open() - - -def test_h5_append(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - f.write('/ds_empty', dtype=np.float32, shape=(10, 2)) - arr = f.read('/ds_empty') - arr[:5, 0] = 1 - - with open_h5(filename, 'r') as f: - arr = f.read('/ds_empty')[...] - assert np.all(arr[:5, 0] == 1) - - -def test_h5_write(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - # Create some array. - temp_array = np.zeros(10, dtype=np.float32) - - # Open the test HDF5 file in read-only mode (the default) and - # try to write in it. This should raise an exception. - with open_h5(filename) as f: - with raises(Exception): - f.write('/ds1', temp_array) - - # Open the test HDF5 file in read/write mode and - # try to write in an existing dataset. - with open_h5(filename, 'a') as f: - # This raises an exception because the file already exists, - # and by default this is forbidden. - with raises(ValueError): - f.write('/ds1', temp_array) - - # This works, though, because we force overwriting the dataset. - f.write('/ds1', temp_array, overwrite=True) - ae(f.read('/ds1'), temp_array) - - # Write a new array. - f.write('/ds2', temp_array) - ae(f.read('/ds2'), temp_array) - - # Write a new array in a nonexistent group. - f.write('/ds3/ds4/ds5', temp_array) - ae(f.read('/ds3/ds4/ds5'), temp_array) - - -def test_h5_describe(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - # Open the test HDF5 file. - with open_h5(filename) as f: - with captured_output() as (out, err): - f.describe() - output = out.getvalue().strip() - output_lines = output.split('\n') - assert len(output_lines) == 3 - - -def test_h5_move(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - - # Test dataset move. - assert f.exists('ds1') - arr = f.read('ds1')[:] - assert len(arr) == 10 - f.move('ds1', 'ds1_new') - assert not f.exists('ds1') - assert f.exists('ds1_new') - arr_new = f.read('ds1_new')[:] - assert len(arr_new) == 10 - ae(arr, arr_new) - - # Test group move. - assert f.exists('mygroup/ds2') - arr = f.read('mygroup/ds2') - f.move('mygroup', 'g/mynewgroup') - assert not f.exists('mygroup') - assert f.exists('g/mynewgroup') - assert f.exists('g/mynewgroup/ds2') - arr_new = f.read('g/mynewgroup/ds2') - ae(arr, arr_new) - - -def test_h5_copy(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - - # Test dataset copy. - assert f.exists('ds1') - arr = f.read('ds1')[:] - assert len(arr) == 10 - f.copy('ds1', 'ds1_new') - assert f.exists('ds1') - assert f.exists('ds1_new') - arr_new = f.read('ds1_new')[:] - assert len(arr_new) == 10 - ae(arr, arr_new) - - # Test group copy. - assert f.exists('mygroup/ds2') - arr = f.read('mygroup/ds2') - f.copy('mygroup', 'g/mynewgroup') - assert f.exists('mygroup') - assert f.exists('g/mynewgroup') - assert f.exists('g/mynewgroup/ds2') - arr_new = f.read('g/mynewgroup/ds2') - ae(arr, arr_new) - - -def test_h5_delete(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - - # Test dataset delete. - assert f.exists('ds1') - with raises(ValueError): - f.delete('a') - f.delete('ds1') - assert not f.exists('ds1') - - # Test group delete. - assert f.exists('mygroup/ds2') - f.delete('mygroup') - assert not f.exists('mygroup') - assert not f.exists('mygroup/ds2') diff --git a/phy/io/tests/test_traces.py b/phy/io/tests/test_traces.py index b20772863..beef959fa 100644 --- a/phy/io/tests/test_traces.py +++ b/phy/io/tests/test_traces.py @@ -13,7 +13,6 @@ from numpy.testing import assert_allclose as ac from pytest import raises -from ..h5 import open_h5 from ..traces import read_dat, _dat_n_samples, read_kwd, read_ns5 from ..mock import artificial_traces @@ -39,22 +38,24 @@ def test_read_dat(tempdir): def test_read_kwd(tempdir): + from h5py import File + n_samples = 100 n_channels = 10 - arr = artificial_traces(n_samples, n_channels) + path = op.join(tempdir, 'test.kwd') - path = op.join(tempdir, 'test') + with File(path, 'w') as f: + g0 = f.create_group('/recordings/0') + g1 = f.create_group('/recordings/1') + + arr0 = arr[:n_samples // 2, ...] + arr1 = arr[n_samples // 2:, ...] - with open_h5(path, 'w') as f: - f.write('/recordings/0/data', - arr[:n_samples // 2, ...].astype(np.float32)) - f.write('/recordings/1/data', - arr[n_samples // 2:, ...].astype(np.float32)) + g0.create_dataset('data', data=arr0) + g1.create_dataset('data', data=arr1) - with open_h5(path, 'r') as f: - data = read_kwd(f)[...] - ac(arr, data) + ae(read_kwd(path)[...], arr) def test_read_ns5(): diff --git a/phy/io/traces.py b/phy/io/traces.py index 7db9a239b..875e5f083 100644 --- a/phy/io/traces.py +++ b/phy/io/traces.py @@ -15,31 +15,39 @@ # Raw data readers #------------------------------------------------------------------------------ -def read_kwd(kwd_handle): - """Read all traces in a  `.kwd` file. +def _read_recording(filename, rec_name): + """Open a file and return a recording dataset. - The output is a memory-mapped file. + WARNING: the file is not closed when the function returns, so that the + memory-mapped array can still be accessed from disk. """ - import dask.array + from h5py import File + f = File(filename, mode='r') + return f['/recordings/{}/data'.format(rec_name)] + - f = kwd_handle +def read_kwd(filename): + """Read all traces in a `.kwd` file.""" + from h5py import File + from dask.array import Array - if '/recordings' not in f: # pragma: no cover - return - recordings = f.children('/recordings') + with File(filename, mode='r') as f: + rec_names = sorted([name for name in f['/recordings']]) + shapes = [f['/recordings/{}/data'.format(name)].shape + for name in rec_names] - def _read(idx): - # The file needs to be open. - assert f.is_open() - return f.read('/recordings/{}/data'.format(recordings[idx])) + # Create the dask graph for all recordings from the .kwdd file. + dask = {('data', idx, 0): (_read_recording, filename, rec_name) + for (idx, rec_name) in enumerate(rec_names)} - dsk = {('data', idx, 0): (_read, idx) for idx in range(len(recordings))} + # Make sure all recordings have the same number of channels. + n_cols = shapes[0][1] + assert all(shape[1] == n_cols for shape in shapes) - chunks = (tuple(_read(idx).shape[0] for idx in range(len(recordings))), - (_read(0).shape[1],) - ) - return dask.array.Array(dsk, 'data', chunks) + # Create the dask Array. + chunks = (tuple(shape[0] for shape in shapes), (n_cols,)) + return Array(dask, 'data', chunks) def _dat_n_samples(filename, dtype=None, n_channels=None): From 038b95857285350c0ead9100fadaf4f06c1bf90b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 13:19:46 +0200 Subject: [PATCH 0171/1059] WIP: moving files between utils and io --- phy/io/default_settings.py | 24 ------------------------ phy/io/tests/test_traces.py | 1 - phy/utils/array.py | 2 +- phy/utils/datasets.py | 4 ++-- phy/utils/tests/test_settings.py | 7 +------ 5 files changed, 4 insertions(+), 34 deletions(-) delete mode 100644 phy/io/default_settings.py diff --git a/phy/io/default_settings.py b/phy/io/default_settings.py deleted file mode 100644 index 8dd5bf997..000000000 --- a/phy/io/default_settings.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for I/O.""" - - -# ----------------------------------------------------------------------------- -# Traces -# ----------------------------------------------------------------------------- - -traces = { - 'raw_data_files': [], - 'n_channels': None, - 'dtype': None, - 'sample_rate': None, -} - - -# ----------------------------------------------------------------------------- -# Store settings -# ----------------------------------------------------------------------------- - -# Number of spikes to load at once from the features_masks array -# during the cluster store generation. -features_masks_chunk_size = 100000 diff --git a/phy/io/tests/test_traces.py b/phy/io/tests/test_traces.py index beef959fa..5ad397534 100644 --- a/phy/io/tests/test_traces.py +++ b/phy/io/tests/test_traces.py @@ -10,7 +10,6 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac from pytest import raises from ..traces import read_dat, _dat_n_samples, read_kwd, read_ns5 diff --git a/phy/utils/array.py b/phy/utils/array.py index 3f968a510..7eb0a11f4 100644 --- a/phy/utils/array.py +++ b/phy/utils/array.py @@ -13,7 +13,7 @@ import numpy as np -from ._types import _as_array, _is_array_like +from phy.utils._types import _as_array, _is_array_like logger = logging.getLogger(__name__) diff --git a/phy/utils/datasets.py b/phy/utils/datasets.py index d5c577b97..04fd9b0fd 100644 --- a/phy/utils/datasets.py +++ b/phy/utils/datasets.py @@ -11,8 +11,8 @@ import os import os.path as op -from .settings import _phy_user_dir, _ensure_dir_exists -from .event import ProgressReporter +from phy.utils.event import ProgressReporter +from phy.utils.settings import _phy_user_dir, _ensure_dir_exists logger = logging.getLogger(__name__) diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 38b13cb2e..846a0df36 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -51,12 +51,7 @@ def test_recursive_dirs(): def test_load_default_settings(): settings = _load_default_settings() keys = settings.keys() - # assert 'log_file_level' in keys - # assert 'on_open' in keys - # assert 'spikedetekt' in keys - # assert 'klustakwik2' in keys - assert 'traces' in keys - # assert 'cluster_manual_config' in keys + assert keys def test_base_settings(): From a751e3b860f66472cd57f96875d0e7df653d16ed Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 13:37:44 +0200 Subject: [PATCH 0172/1059] Move files and tweak coverage --- .coveragerc | 7 +++++++ phy/__init__.py | 2 +- phy/cluster/algorithms/klustakwik.py | 4 ++-- phy/cluster/manual/clustering.py | 10 +++++----- phy/cluster/manual/tests/test_clustering.py | 4 ++-- phy/{utils => io}/array.py | 0 phy/{utils => io}/context.py | 2 +- phy/{utils => io}/datasets.py | 0 phy/{utils => io}/tests/test_array.py | 6 +++--- phy/{utils => io}/tests/test_context.py | 0 phy/{utils => io}/tests/test_datasets.py | 2 +- phy/plot/_vispy_utils.py | 4 ++-- phy/plot/features.py | 6 +++--- phy/plot/traces.py | 6 +++--- phy/plot/waveforms.py | 8 ++++---- phy/stats/ccg.py | 4 ++-- phy/traces/detect.py | 2 +- phy/traces/spike_detect.py | 2 +- phy/traces/tests/test_spike_detect.py | 6 +++--- phy/traces/waveform.py | 2 +- phy/utils/__init__.py | 1 - pytest.ini | 2 -- setup.cfg | 1 + 23 files changed, 43 insertions(+), 38 deletions(-) rename phy/{utils => io}/array.py (100%) rename phy/{utils => io}/context.py (99%) rename phy/{utils => io}/datasets.py (100%) rename phy/{utils => io}/tests/test_array.py (98%) rename phy/{utils => io}/tests/test_context.py (100%) rename phy/{utils => io}/tests/test_datasets.py (99%) delete mode 100644 pytest.ini diff --git a/.coveragerc b/.coveragerc index 6ce8e8f8b..3934caa38 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,5 +3,12 @@ branch = True source = phy omit = */phy/ext/* + */phy/plot/* */phy/utils/tempdir.py */default_settings.py + +[report] +exclude_lines = + pragma: no cover + raise AssertionError + raise NotImplementedError diff --git a/phy/__init__.py b/phy/__init__.py index bc4bb4b5b..b7324d977 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -14,7 +14,7 @@ from six import StringIO -from .utils.datasets import download_sample_data +from .io.datasets import download_file, download_sample_data from .utils._misc import _git_version diff --git a/phy/cluster/algorithms/klustakwik.py b/phy/cluster/algorithms/klustakwik.py index c61a13897..b640f0c23 100644 --- a/phy/cluster/algorithms/klustakwik.py +++ b/phy/cluster/algorithms/klustakwik.py @@ -9,8 +9,8 @@ import numpy as np import six -from ...utils.array import chunk_bounds -from ...utils.event import EventEmitter +from phy.io.array import chunk_bounds +from phy.utils.event import EventEmitter #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 406f13afd..851901546 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -8,11 +8,11 @@ import numpy as np -from ...utils._types import _as_array, _is_array_like -from ...utils.array import (_unique, - _spikes_in_clusters, - _spikes_per_cluster, - ) +from phy.utils._types import _as_array, _is_array_like +from phy.io.array import (_unique, + _spikes_in_clusters, + _spikes_per_cluster, + ) from ._utils import UpdateInfo from ._history import History from phy.utils.event import EventEmitter diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 0502a8de6..091a3b8b7 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -11,8 +11,8 @@ from pytest import raises from six import itervalues -from ....io.mock import artificial_spike_clusters -from ....utils.array import (_spikes_in_clusters,) +from phy.io.mock import artificial_spike_clusters +from phy.io.array import (_spikes_in_clusters,) from ..clustering import (_extend_spikes, _concatenate_spike_clusters, _extend_assignment, diff --git a/phy/utils/array.py b/phy/io/array.py similarity index 100% rename from phy/utils/array.py rename to phy/io/array.py diff --git a/phy/utils/context.py b/phy/io/context.py similarity index 99% rename from phy/utils/context.py rename to phy/io/context.py index 074bdc550..9fe371f87 100644 --- a/phy/utils/context.py +++ b/phy/io/context.py @@ -21,8 +21,8 @@ raise Exception("dask is not installed. " "Install it with `conda install dask`.") +from .array import read_array, write_array from phy.utils import Bunch -from phy.utils.array import read_array, write_array logger = logging.getLogger(__name__) diff --git a/phy/utils/datasets.py b/phy/io/datasets.py similarity index 100% rename from phy/utils/datasets.py rename to phy/io/datasets.py diff --git a/phy/utils/tests/test_array.py b/phy/io/tests/test_array.py similarity index 98% rename from phy/utils/tests/test_array.py rename to phy/io/tests/test_array.py index 1756a854e..52debe21a 100644 --- a/phy/utils/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -11,7 +11,6 @@ import numpy as np from pytest import raises -from .._types import _as_array from ..array import (_unique, _normalize, _index_of, @@ -30,8 +29,9 @@ read_array, write_array, ) -from ..testing import _assert_equal as ae -from ...io.mock import artificial_spike_clusters +from phy.utils._types import _as_array +from phy.utils.testing import _assert_equal as ae +from ..mock import artificial_spike_clusters #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_context.py b/phy/io/tests/test_context.py similarity index 100% rename from phy/utils/tests/test_context.py rename to phy/io/tests/test_context.py diff --git a/phy/utils/tests/test_datasets.py b/phy/io/tests/test_datasets.py similarity index 99% rename from phy/utils/tests/test_datasets.py rename to phy/io/tests/test_datasets.py index 1c309f68e..aafc4d102 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/io/tests/test_datasets.py @@ -22,7 +22,7 @@ _BASE_URL, _validate_output_dir, ) -from ..testing import captured_logging +from phy.utils.testing import captured_logging logger = logging.getLogger(__name__) diff --git a/phy/plot/_vispy_utils.py b/phy/plot/_vispy_utils.py index fab6230a0..4e193189c 100644 --- a/phy/plot/_vispy_utils.py +++ b/phy/plot/_vispy_utils.py @@ -16,8 +16,8 @@ from vispy import app, gloo, config from vispy.util.event import Event -from ..utils._types import _as_array, _as_list -from ..utils.array import _unique, _in_polygon +from phy.utils._types import _as_array, _as_list +from phy.io.array import _unique, _in_polygon from ._panzoom import PanZoom logger = logging.getLogger(__name__) diff --git a/phy/plot/features.py b/phy/plot/features.py index 34c76b1a7..03709cb33 100644 --- a/phy/plot/features.py +++ b/phy/plot/features.py @@ -20,9 +20,9 @@ _wrap_vispy, ) from ._panzoom import PanZoomGrid -from ..utils._types import _as_array, _is_integer -from ..utils.array import _index_of, _unique -from ..utils._color import _selected_clusters_colors +from phy.utils._types import _as_array, _is_integer +from phy.io.array import _index_of, _unique +from phy.utils._color import _selected_clusters_colors #------------------------------------------------------------------------------ diff --git a/phy/plot/traces.py b/phy/plot/traces.py index 889d141a4..80b689529 100644 --- a/phy/plot/traces.py +++ b/phy/plot/traces.py @@ -15,9 +15,9 @@ BaseSpikeCanvas, _wrap_vispy, ) -from ..utils._color import _selected_clusters_colors -from ..utils._types import _as_array -from ..utils.array import _index_of, _unique +from phy.utils._color import _selected_clusters_colors +from phy.utils._types import _as_array +from phy.io.array import _index_of, _unique #------------------------------------------------------------------------------ diff --git a/phy/plot/waveforms.py b/phy/plot/waveforms.py index de9b5665a..5e72eec78 100644 --- a/phy/plot/waveforms.py +++ b/phy/plot/waveforms.py @@ -18,10 +18,10 @@ _enable_depth_mask, _wrap_vispy, ) -from ..utils._types import _as_array -from ..utils._color import _selected_clusters_colors -from ..utils.array import _index_of, _normalize, _unique -from ..electrode.mea import linear_positions +from phy.utils._types import _as_array +from phy.utils._color import _selected_clusters_colors +from phy.io.array import _index_of, _normalize, _unique +from phy.electrode.mea import linear_positions #------------------------------------------------------------------------------ diff --git a/phy/stats/ccg.py b/phy/stats/ccg.py index 28baed6df..448f862af 100644 --- a/phy/stats/ccg.py +++ b/phy/stats/ccg.py @@ -8,8 +8,8 @@ import numpy as np -from ..utils._types import _as_array -from ..utils.array import _index_of, _unique +from phy.utils._types import _as_array +from phy.io.array import _index_of, _unique #------------------------------------------------------------------------------ diff --git a/phy/traces/detect.py b/phy/traces/detect.py index 5f0970abd..00483316f 100644 --- a/phy/traces/detect.py +++ b/phy/traces/detect.py @@ -10,7 +10,7 @@ from six import string_types from six.moves import range, zip -from ..utils.array import _as_array +from phy.io.array import _as_array #------------------------------------------------------------------------------ diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 361416c1a..a7f6fdca2 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -13,7 +13,7 @@ from traitlets import Int, Float, Unicode, Bool from phy.electrode.mea import MEA, _adjacency_subset, _remap_adjacency -from phy.utils.array import get_excerpts +from phy.io.array import get_excerpts from .detect import FloodFillDetector, Thresholder, compute_threshold from .filter import Filter from .waveform import WaveformExtractor diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index b0d42661a..1b9ec92fe 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -10,9 +10,9 @@ from numpy.testing import assert_array_equal as ae from pytest import yield_fixture -from phy.utils.datasets import download_test_data -from phy.utils.tests.test_context import (ipy_client, context, # noqa - parallel_context) +from phy.io.datasets import download_test_data +from phy.io.tests.test_context import (ipy_client, context, # noqa + parallel_context) from phy.electrode import load_probe from ..spike_detect import (SpikeDetector, _spikes_to_keep, diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 537ad4e4e..6125fb82c 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -12,7 +12,7 @@ from scipy.interpolate import interp1d from ..utils._types import _as_array, Bunch -from ..utils.array import _pad +from phy.io.array import _pad logger = logging.getLogger(__name__) diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 6ea22eb9d..f68479984 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -5,6 +5,5 @@ from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, Bunch, _is_list) -from .datasets import download_file, download_sample_data from .event import EventEmitter, ProgressReporter from .settings import Settings, _ensure_dir_exists diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index cf6b0c3e4..000000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -norecursedirs = experimental diff --git a/setup.cfg b/setup.cfg index 42eef79ea..755cdb951 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,7 @@ universal = 1 [pytest] addopts = --cov-report term-missing --cov phy -s +norecursedirs = plot experimental [flake8] ignore=E265 From 690ff0a5dac0fe2f30bf5ebaaa6b4457ff2f7b66 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 13:56:51 +0200 Subject: [PATCH 0173/1059] Save dask array chunk-by-chunk in a .npy file --- phy/io/array.py | 16 ++++++++++++++-- phy/io/tests/test_array.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/phy/io/array.py b/phy/io/array.py index 7eb0a11f4..e404cfcc0 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -171,11 +171,11 @@ def _in_polygon(points, polygon): # I/O functions # ----------------------------------------------------------------------------- -def read_array(path): +def read_array(path, mmap_mode=None): """Read a .npy array.""" file_ext = op.splitext(path)[1] if file_ext == '.npy': - return np.load(path) + return np.load(path, mmap_mode=mmap_mode) raise NotImplementedError("The file extension `{}` ".format(file_ext) + "is not currently supported.") @@ -184,6 +184,18 @@ def write_array(path, arr): """Write an array to a .npy file.""" file_ext = op.splitext(path)[1] if file_ext == '.npy': + try: + # Save a dask array into a .npy file chunk-by-chunk. + from dask.array import Array, store + if isinstance(arr, Array): + f = np.memmap(path, mode='w+', + dtype=arr.dtype, shape=arr.shape) + store(arr, f) + del f + except ImportError: # pragma: no cover + # We'll save the dask array normally: it works but it is less + # efficient since we need to load everything in memory. + pass return np.save(path, arr) raise NotImplementedError("The file extension `{}` ".format(file_ext) + "is not currently supported.") diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 52debe21a..634c31720 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -187,9 +187,25 @@ def test_in_polygon(): def test_read_write(tempdir): arr = np.arange(10).astype(np.float32) + path = op.join(tempdir, 'test.npy') + write_array(path, arr) ae(read_array(path), arr) + ae(read_array(path, mmap_mode='r'), arr) + + +def test_read_write_dask(tempdir): + from dask.array import from_array + arr = np.arange(10).astype(np.float32) + + arr_da = from_array(arr, ((5, 5),)) + + path = op.join(tempdir, 'test.npy') + + write_array(path, arr_da) + ae(read_array(path), arr) + ae(read_array(path, mmap_mode='r'), arr) #------------------------------------------------------------------------------ From 44f718249fe0ab751e8a20fb231080d45e8721d1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 13:56:59 +0200 Subject: [PATCH 0174/1059] Tweak setup.cfg --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 755cdb951..a0583bf00 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,6 @@ universal = 1 [pytest] -addopts = --cov-report term-missing --cov phy -s norecursedirs = plot experimental [flake8] From b2d49fb5b501edec46d222c9f0ee213d26c56486 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 14:16:52 +0200 Subject: [PATCH 0175/1059] Add PCA test --- phy/traces/pca.py | 2 +- phy/traces/tests/test_pca.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/phy/traces/pca.py b/phy/traces/pca.py index f7e962dd5..103c89550 100644 --- a/phy/traces/pca.py +++ b/phy/traces/pca.py @@ -87,8 +87,8 @@ def _project_pcs(x, pcs): pcs : array The PCs returned by `_compute_pcs()`. """ - # pcs: (nf, ns, nc) # x: (n, ns, nc) + # pcs: (nf, ns, nc) # out: (n, nc, nf) assert pcs.ndim == 3 assert x.ndim == 3 diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py index 54788c91d..1a5f10f13 100644 --- a/phy/traces/tests/test_pca.py +++ b/phy/traces/tests/test_pca.py @@ -9,7 +9,7 @@ import numpy as np from ...io.mock import artificial_waveforms, artificial_masks -from ..pca import PCA, _compute_pcs +from ..pca import PCA, _compute_pcs, _project_pcs #------------------------------------------------------------------------------ @@ -56,3 +56,13 @@ def test_compute_pcs_3d(): assert np.linalg.norm(pcs[1, :, 0] - np.array([0., -1.])) < 1e-2 assert np.linalg.norm(pcs[0, :, 1] - np.array([0, 1.])) < 1e-2 assert np.linalg.norm(pcs[1, :, 1] - np.array([-1., 0.])) < 1e-2 + + +def test_project_pcs(): + n, ns, nc = 1000, 50, 100 + nf = 3 + arr = np.random.randn(n, ns, nc) + pcs = np.random.randn(nf, ns, nc) + + y1 = _project_pcs(arr, pcs) + assert y1.shape == (n, nc, nf) From 782cdf12d3f92605cc57cdaaba7fe84cc82d27b5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 14:22:59 +0200 Subject: [PATCH 0176/1059] WIP: parallelize PCA --- phy/traces/pca.py | 24 +++++++++++++++++------- phy/traces/spike_detect.py | 6 ------ phy/traces/tests/test_pca.py | 2 +- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/phy/traces/pca.py b/phy/traces/pca.py index 103c89550..126009652 100644 --- a/phy/traces/pca.py +++ b/phy/traces/pca.py @@ -7,6 +7,8 @@ #------------------------------------------------------------------------------ import numpy as np +from traitlets.config.configurable import Configurable +from traitlets import Int from ..utils._types import _as_array @@ -102,12 +104,18 @@ def _project_pcs(x, pcs): return x_proj -class PCA(object): +class PCA(Configurable): """Apply PCA to waveforms.""" - def __init__(self, n_pcs=None): - self._n_pcs = n_pcs + n_features_per_channel = Int(3) + + def __init__(self, ctx=None): + super(PCA, self).__init__() + self.set_context(ctx) self._pcs = None + def set_context(self, ctx): + self.ctx = ctx + def fit(self, waveforms, masks=None): """Compute the PCs of waveforms. @@ -120,7 +128,10 @@ def fit(self, waveforms, masks=None): Shape: `(n_spikes, n_channels)` """ - self._pcs = _compute_pcs(waveforms, n_pcs=self._n_pcs, masks=masks) + self._pcs = _compute_pcs(waveforms, + n_pcs=self.n_features_per_channel, + masks=masks, + ) return self._pcs def transform(self, waveforms, pcs=None): @@ -135,6 +146,5 @@ def transform(self, waveforms, pcs=None): """ if pcs is None: pcs = self._pcs - # Need to call fit() if the pcs are None here. - if pcs is not None: - return _project_pcs(waveforms, pcs) + assert pcs is not None + return _project_pcs(waveforms, pcs) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index a7f6fdca2..64d33ec6f 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -101,12 +101,6 @@ def __init__(self, ctx=None): def set_context(self, ctx): self.ctx = ctx - if not ctx or not hasattr(ctx, 'cache'): - return - # self.find_thresholds = ctx.cache(self.find_thresholds) - # self.filter = ctx.cache(self.filter) - # self.extract_spikes = ctx.cache(self.extract_spikes) - # self.detect = ctx.cache(self.detect) def set_metadata(self, probe, channel_mapping=None, sample_rate=None): diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py index 1a5f10f13..237a63208 100644 --- a/phy/traces/tests/test_pca.py +++ b/phy/traces/tests/test_pca.py @@ -23,7 +23,7 @@ def test_pca(): waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) masks = artificial_masks(n_spikes, n_channels) - pca = PCA(n_pcs=3) + pca = PCA() pcs = pca.fit(waveforms, masks) assert pcs.shape == (3, n_samples, n_channels) fet = pca.transform(waveforms) From 8d327dcd3f9ab59b2149f53585a6d90c804a6a03 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 14:40:27 +0200 Subject: [PATCH 0177/1059] Remove default settings --- phy/cluster/algorithms/default_settings.py | 12 -- phy/cluster/manual/default_settings.py | 132 --------------------- phy/utils/settings.py | 26 +--- phy/utils/tests/test_settings.py | 8 -- 4 files changed, 2 insertions(+), 176 deletions(-) delete mode 100644 phy/cluster/algorithms/default_settings.py delete mode 100644 phy/cluster/manual/default_settings.py diff --git a/phy/cluster/algorithms/default_settings.py b/phy/cluster/algorithms/default_settings.py deleted file mode 100644 index c65f06ee0..000000000 --- a/phy/cluster/algorithms/default_settings.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for clustering.""" - - -# ----------------------------------------------------------------------------- -# Clustering -# ----------------------------------------------------------------------------- - -# NOTE: the default parameters are in klustakwik2's repository. -klustakwik2 = { -} diff --git a/phy/cluster/manual/default_settings.py b/phy/cluster/manual/default_settings.py deleted file mode 100644 index 42cd55add..000000000 --- a/phy/cluster/manual/default_settings.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for manual sorting.""" - - -# ----------------------------------------------------------------------------- -# Correlograms -# ----------------------------------------------------------------------------- - -# Number of time samples in a bin. -correlograms_binsize = 20 - -# Number of bins (odd number). -correlograms_winsize_bins = 2 * 25 + 1 - -# Maximum number of spikes for the correlograms. -# Use `None` to specify an infinite value. -correlograms_n_spikes_max = 1000000 - -# Contiguous chunks of spikes for computing the CCGs. -# Use `None` to have a regular (strided) subselection instead of a chunked -# subselection. -correlograms_excerpt_size = 100000 - - -# ----------------------------------------------------------------------------- -# Views -# ----------------------------------------------------------------------------- - -# Maximum number of spikes to display in the waveform view. -waveforms_n_spikes_max = 100 - -# Load regularly-spaced waveforms. -waveforms_excerpt_size = None - -# Maximum number of spikes to display in the feature view. -features_n_spikes_max = 2500 - -# Load a regular subselection of spikes from the cluster store. -features_excerpt_size = None - -# Maximum number of background spikes to display in the feature view. -features_n_spikes_max_bg = features_n_spikes_max - -features_grid_n_spikes_max = features_n_spikes_max -features_grid_excerpt_size = features_excerpt_size -features_grid_n_spikes_max_bg = features_n_spikes_max_bg - - -# ----------------------------------------------------------------------------- -# Clustering GUI -# ----------------------------------------------------------------------------- - -cluster_manual_shortcuts = { - 'reset_gui': 'alt+r', - 'show_shortcuts': 'ctrl+h', - 'save': 'ctrl+s', - 'exit': 'ctrl+q', - # Wizard actions. - 'reset_wizard': 'ctrl+w', - 'next': 'space', - 'previous': 'shift+space', - 'reset_wizard': 'ctrl+alt+space', - 'first': 'home', - 'last': 'end', - 'pin': 'return', - 'unpin': 'backspace', - # Clustering actions. - 'merge': 'g', - 'split': 'k', - 'undo': 'ctrl+z', - 'redo': ('ctrl+shift+z', 'ctrl+y'), - 'move_best_to_noise': 'alt+n', - 'move_best_to_mua': 'alt+m', - 'move_best_to_good': 'alt+g', - 'move_match_to_noise': 'ctrl+n', - 'move_match_to_mua': 'ctrl+m', - 'move_match_to_good': 'ctrl+g', - 'move_both_to_noise': 'ctrl+alt+n', - 'move_both_to_mua': 'ctrl+alt+m', - 'move_both_to_good': 'ctrl+alt+g', - # Views. - 'show_view_shortcuts': 'h', - 'toggle_correlogram_normalization': 'n', - 'toggle_waveforms_overlap': 'o', - 'toggle_waveforms_mean': 'm', - 'show_features_time': 't', -} - - -cluster_manual_config = [ - # The wizard panel is less useful now that there's the stats panel. - # ('wizard', {'position': 'right'}), - ('stats', {'position': 'right'}), - ('features_grid', {'position': 'left'}), - ('features', {'position': 'left'}), - ('correlograms', {'position': 'left'}), - ('waveforms', {'position': 'right'}), - ('traces', {'position': 'right'}), -] - - -def _select_clusters(gui, args): - # Range: '5-12' - if '-' in args: - m, M = map(int, args.split('-')) - # The second one should be included. - M += 1 - clusters = list(range(m, M)) - # List of ids: '5 6 9 12' - else: - clusters = list(map(int, args.split(' '))) - gui.select(clusters) - - -cluster_manual_snippets = { - 'c': _select_clusters, -} - - -# Whether to ask the user if they want to save when the GUI is closed. -prompt_save_on_exit = True - - -# ----------------------------------------------------------------------------- -# Internal settings -# ----------------------------------------------------------------------------- - -waveforms_scale_factor = .01 -features_scale_factor = .01 -features_grid_scale_factor = features_scale_factor -traces_scale_factor = .01 diff --git a/phy/utils/settings.py b/phy/utils/settings.py index 3aa24708c..c7e743cd6 100644 --- a/phy/utils/settings.py +++ b/phy/utils/settings.py @@ -44,22 +44,6 @@ def _recursive_dirs(): yield op.realpath(op.join(phy_root, root)) -def _default_settings_paths(): - return [op.join(dir, 'default_settings.py') - for dir in _recursive_dirs()] - - -def _load_default_settings(paths=None): - """Load all default settings in phy's package.""" - if paths is None: - paths = _default_settings_paths() - settings = BaseSettings() - for path in paths: - if op.exists(path): - settings.load(path) - return settings - - class BaseSettings(object): """Store key-value pairs.""" def __init__(self): @@ -127,22 +111,16 @@ def save(self, path): class Settings(object): - """Manage default, user-wide, and experiment-wide settings.""" + """Manage user-wide, and experiment-wide settings.""" - def __init__(self, phy_user_dir=None, default_paths=None): + def __init__(self, phy_user_dir=None): self.phy_user_dir = phy_user_dir if self.phy_user_dir: _ensure_dir_exists(self.phy_user_dir) - self._default_paths = default_paths or _default_settings_paths() self._bs = BaseSettings() self._load_user_settings() def _load_user_settings(self): - # Load phy's defaults. - if self._default_paths: - for path in self._default_paths: - if op.exists(path): - self._bs.load(path) if not self.phy_user_dir: return diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 846a0df36..69b517fb3 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -12,7 +12,6 @@ from ..settings import (BaseSettings, Settings, - _load_default_settings, _recursive_dirs, _phy_user_dir, ) @@ -48,12 +47,6 @@ def test_recursive_dirs(): assert '_' not in dir -def test_load_default_settings(): - settings = _load_default_settings() - keys = settings.keys() - assert keys - - def test_base_settings(): s = BaseSettings() @@ -191,5 +184,4 @@ def test_settings_manager(tempdir, tempdir_bis): assert sm['c'] == 50 assert 'a' not in sm - assert len(sm.keys()) >= 10 assert str(sm).startswith(' Date: Thu, 1 Oct 2015 14:43:16 +0200 Subject: [PATCH 0178/1059] Create Task class --- phy/io/context.py | 14 ++++++++++++++ phy/traces/pca.py | 15 +++------------ phy/traces/spike_detect.py | 11 ++--------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index 9fe371f87..e0403baa7 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -10,6 +10,7 @@ import os import os.path as op +from traitlets.config.configurable import Configurable import numpy as np from six.moves.cPickle import dump from six import string_types @@ -263,3 +264,16 @@ def __setstate__(self, state): self.__dict__ = state # Recreate the joblib Memory instance. self._set_memory(state['cache_dir']) + + +#------------------------------------------------------------------------------ +# Task +#------------------------------------------------------------------------------ + +class Task(Configurable): + def __init__(self, ctx=None): + super(Task, self).__init__() + self.set_context(ctx) + + def set_context(self, ctx): + self.ctx = ctx diff --git a/phy/traces/pca.py b/phy/traces/pca.py index 126009652..e2531dfc0 100644 --- a/phy/traces/pca.py +++ b/phy/traces/pca.py @@ -7,9 +7,9 @@ #------------------------------------------------------------------------------ import numpy as np -from traitlets.config.configurable import Configurable from traitlets import Int +from phy.io.context import Task from ..utils._types import _as_array @@ -104,18 +104,10 @@ def _project_pcs(x, pcs): return x_proj -class PCA(Configurable): +class PCA(Task): """Apply PCA to waveforms.""" n_features_per_channel = Int(3) - def __init__(self, ctx=None): - super(PCA, self).__init__() - self.set_context(ctx) - self._pcs = None - - def set_context(self, ctx): - self.ctx = ctx - def fit(self, waveforms, masks=None): """Compute the PCs of waveforms. @@ -144,7 +136,6 @@ def transform(self, waveforms, pcs=None): Shape: `(n_spikes, n_samples, n_channels)` """ - if pcs is None: - pcs = self._pcs + pcs = self._pcs if pcs is None else pcs assert pcs is not None return _project_pcs(waveforms, pcs) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 64d33ec6f..56f90b527 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -9,11 +9,11 @@ import logging import numpy as np -from traitlets.config.configurable import Configurable from traitlets import Int, Float, Unicode, Bool from phy.electrode.mea import MEA, _adjacency_subset, _remap_adjacency from phy.io.array import get_excerpts +from phy.io.context import Task from .detect import FloodFillDetector, Thresholder, compute_threshold from .filter import Filter from .waveform import WaveformExtractor @@ -78,7 +78,7 @@ def _concat_spikes(s, m, w, trace_chunks=None, depth=None): # SpikeDetector #------------------------------------------------------------------------------ -class SpikeDetector(Configurable): +class SpikeDetector(Task): do_filter = Bool(True) filter_low = Float(500.) filter_butter_order = Int(3) @@ -95,13 +95,6 @@ class SpikeDetector(Configurable): extract_s_after = Int(10) weight_power = Float(2) - def __init__(self, ctx=None): - super(SpikeDetector, self).__init__() - self.set_context(ctx) - - def set_context(self, ctx): - self.ctx = ctx - def set_metadata(self, probe, channel_mapping=None, sample_rate=None): assert isinstance(probe, MEA) From 3668bee79f730dd703e9890114c21bb4351b191c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 14:54:32 +0200 Subject: [PATCH 0179/1059] WIP: parallelize PCA --- phy/traces/pca.py | 16 +++++++++++++--- phy/traces/tests/test_pca.py | 32 ++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/phy/traces/pca.py b/phy/traces/pca.py index e2531dfc0..0d4e10960 100644 --- a/phy/traces/pca.py +++ b/phy/traces/pca.py @@ -126,6 +126,9 @@ def fit(self, waveforms, masks=None): ) return self._pcs + def _project(self, waveforms): + return _project_pcs(waveforms, self._pcs) + def transform(self, waveforms, pcs=None): """Project waveforms on the PCs. @@ -136,6 +139,13 @@ def transform(self, waveforms, pcs=None): Shape: `(n_spikes, n_samples, n_channels)` """ - pcs = self._pcs if pcs is None else pcs - assert pcs is not None - return _project_pcs(waveforms, pcs) + self._pcs = self._pcs if pcs is None else pcs + assert self._pcs is not None + if not self.ctx: + return self._project(waveforms) + else: + import dask.array as da + assert isinstance(waveforms, da.Array) + + return self.ctx.map_dask_array(self._project, waveforms, + name='features') diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py index 237a63208..5918b6916 100644 --- a/phy/traces/tests/test_pca.py +++ b/phy/traces/tests/test_pca.py @@ -16,20 +16,6 @@ # Test PCA #------------------------------------------------------------------------------ -def test_pca(): - n_spikes = 100 - n_samples = 40 - n_channels = 12 - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - - pca = PCA() - pcs = pca.fit(waveforms, masks) - assert pcs.shape == (3, n_samples, n_channels) - fet = pca.transform(waveforms) - assert fet.shape == (n_spikes, n_channels, 3) - - def test_compute_pcs(): """Test PCA on a 2D array.""" # Horizontal ellipsoid. @@ -66,3 +52,21 @@ def test_project_pcs(): y1 = _project_pcs(arr, pcs) assert y1.shape == (n, nc, nf) + + +class PCATest(object): + def setup(self): + self.n_spikes = 100 + self.n_samples = 40 + self.n_channels = 12 + self.waveforms = artificial_waveforms(self.n_spikes, + self.n_samples, + self.n_channels) + self.masks = artificial_masks(self.n_spikes, self.n_channels) + + def test_serial(self): + pca = PCA() + pcs = pca.fit(self.waveforms, self.masks) + assert pcs.shape == (3, self.n_samples, self.n_channels) + fet = pca.transform(self.waveforms) + assert fet.shape == (self.n_spikes, self.n_channels, 3) From 308ae66df84fc4d1226cf66e1b062eb14197e2b1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 15:00:27 +0200 Subject: [PATCH 0180/1059] WIP: parallelize PCA --- phy/traces/tests/test_pca.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py index 5918b6916..041efaa54 100644 --- a/phy/traces/tests/test_pca.py +++ b/phy/traces/tests/test_pca.py @@ -7,7 +7,10 @@ #------------------------------------------------------------------------------ import numpy as np +from numpy.testing import assert_array_equal as ae +from phy.io.tests.test_context import (ipy_client, context, # noqa + parallel_context) from ...io.mock import artificial_waveforms, artificial_masks from ..pca import PCA, _compute_pcs, _project_pcs @@ -54,7 +57,7 @@ def test_project_pcs(): assert y1.shape == (n, nc, nf) -class PCATest(object): +class TestPCA(object): def setup(self): self.n_spikes = 100 self.n_samples = 40 @@ -64,9 +67,29 @@ def setup(self): self.n_channels) self.masks = artificial_masks(self.n_spikes, self.n_channels) - def test_serial(self): + def _get_features(self): pca = PCA() pcs = pca.fit(self.waveforms, self.masks) assert pcs.shape == (3, self.n_samples, self.n_channels) - fet = pca.transform(self.waveforms) + return pca.transform(self.waveforms) + + def test_serial(self): + fet = self._get_features() assert fet.shape == (self.n_spikes, self.n_channels, 3) + + def test_parallel(self, parallel_context): + + # Chunk the waveforms array. + from dask.array import from_array + chunks = (10, self.n_samples, self.n_channels) + waveforms = from_array(self.waveforms, chunks) + + # Compute the PCs in parallel. + pca = PCA(parallel_context) + pcs = pca.fit(waveforms, self.masks) + assert pcs.shape == (3, self.n_samples, self.n_channels) + fet = pca.transform(waveforms) + assert fet.shape == (self.n_spikes, self.n_channels, 3) + + # Check that the computed features are identical. + ae(fet, self._get_features()) From 5a09122b7cd67c01e433a26e46c90f34890aaa73 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 15:03:05 +0200 Subject: [PATCH 0181/1059] Flakify and increase coverage --- phy/traces/tests/test_pca.py | 2 +- phy/utils/tests/test_settings.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py index 041efaa54..ea9c71ce4 100644 --- a/phy/traces/tests/test_pca.py +++ b/phy/traces/tests/test_pca.py @@ -77,7 +77,7 @@ def test_serial(self): fet = self._get_features() assert fet.shape == (self.n_spikes, self.n_channels, 3) - def test_parallel(self, parallel_context): + def test_parallel(self, parallel_context): # noqa # Chunk the waveforms array. from dask.array import from_array diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 69b517fb3..37b49df7f 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -185,3 +185,4 @@ def test_settings_manager(tempdir, tempdir_bis): assert 'a' not in sm assert str(sm).startswith(' Date: Thu, 1 Oct 2015 16:38:02 +0200 Subject: [PATCH 0182/1059] Remove dead channels --- phy/traces/spike_detect.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 56f90b527..0ce1f20db 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -106,6 +106,9 @@ def set_metadata(self, probe, channel_mapping=None, # Channel mapping. if channel_mapping is None: channel_mapping = {c: c for c in probe.channels} + # Remove channels mapped to None or a negative value: they are dead. + channel_mapping = {k: v for (k, v) in channel_mapping.items() + if v is not None and v >= 0} # channel mappings is {trace_col: channel_id}. # Trace columns and channel ids to keep. self.trace_cols = sorted(channel_mapping.keys()) From 8ee0c3ba8d3bb120a77d7532afc64c30b1762e35 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 20:56:41 +0200 Subject: [PATCH 0183/1059] Add exception in GUI snippet --- phy/gui/gui.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index fd12ec6d0..94fd02527 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -309,7 +309,11 @@ def run(self, snippet): snippet = snippet[1:] snippet_args = _parse_snippet(snippet) alias = snippet_args[0] - name = self._actions.get_name(alias) + try: + name = self._actions.get_name(alias) + except ValueError: + logger.warn("The action %s could not be found.", alias) + return assert name func = getattr(self._actions, name) try: From 8bfbf5fcdb12340d2682614e1e40c80f5dbedb26 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 21:50:06 +0200 Subject: [PATCH 0184/1059] Skip some Qt tests on OS X --- phy/gui/tests/test_gui.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 7e7ca3262..f2afd788e 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +from sys import platform + from pytest import mark, raises, yield_fixture from ..qt import Qt @@ -17,6 +19,12 @@ # Skip these tests in "make test-quick". pytestmark = mark.long +# Skip some tests on OS X. +skip_mac = mark.skipif(platform == "darwin", + reason="Some tests don't work on OS X because of a bug " + "with QTest (qtbot) keyboard events that don't " + "trigger QAction shortcuts.") + #------------------------------------------------------------------------------ # Utilities and fixtures @@ -102,6 +110,7 @@ def show_my_shortcuts(): actions.remove_all() +@skip_mac def test_actions_dock(qtbot, gui, actions): actions.attach(gui) @@ -176,8 +185,7 @@ def test(arg): snippets.attach(None, actions) actions.reset() - with raises(ValueError): - snippets.run(':t1') + snippets.run(':t1') with captured_logging() as buf: snippets.run(':t') @@ -246,6 +254,7 @@ def _run(cmd): snippets.mode_off() +@skip_mac def test_snippets_dock(qtbot, gui, actions, snippets): qtbot.addWidget(gui) From 67fafe61c1f7a5824fd96148dbe13c623ca87321 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 21:58:59 +0200 Subject: [PATCH 0185/1059] Add nose dependency for ipyparallel tests --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index d616703b5..36ea1da9f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,4 @@ coveralls responses pytest-cov pytest-qt +nose From 5ab762c065bf80430c2cb2a6ddd1e578cea667fa Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 22:07:45 +0200 Subject: [PATCH 0186/1059] Tweak git version test --- phy/utils/tests/test_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 72cdec267..fdde63296 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -87,6 +87,6 @@ def test_git_version(): subprocess.check_output(['git', '-C', filedir, 'status'], stderr=fnull) assert v is not "", "git_version failed to return" - assert v[:6] == "-git-v", "Git version does not begin in -git-v" + assert v[:5] == "-git-", "Git version does not begin in -git-" except (OSError, subprocess.CalledProcessError): # pragma: no cover assert v == "" From bc4e382ba485f09b498dd18602fd2147fee440ed Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 22:14:24 +0200 Subject: [PATCH 0187/1059] WIP: try some Qt tests on Travis --- phy/gui/tests/test_gui.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index f2afd788e..be780bc44 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +import os from sys import platform from pytest import mark, raises, yield_fixture @@ -19,11 +20,12 @@ # Skip these tests in "make test-quick". pytestmark = mark.long -# Skip some tests on OS X. -skip_mac = mark.skipif(platform == "darwin", - reason="Some tests don't work on OS X because of a bug " - "with QTest (qtbot) keyboard events that don't " - "trigger QAction shortcuts.") +# Skip some tests on OS X or on CI systems (Travis). +skip = mark.skipif((platform == "darwin") or not(os.environ.get('CI', None)), + reason="Some tests don't work on OS X because of a bug " + "with QTest (qtbot) keyboard events that don't " + "trigger QAction shortcuts. On CI these tests " + "fail because the GUI is not displayed.") #------------------------------------------------------------------------------ @@ -110,7 +112,7 @@ def show_my_shortcuts(): actions.remove_all() -@skip_mac +@skip def test_actions_dock(qtbot, gui, actions): actions.attach(gui) @@ -128,10 +130,12 @@ def press(): _press.append(0) qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) + qtbot.waitForWindowShown(gui) assert _press == [0] # Quit the GUI. qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) + qtbot.waitForWindowShown(gui) #------------------------------------------------------------------------------ @@ -254,7 +258,7 @@ def _run(cmd): snippets.mode_off() -@skip_mac +@skip def test_snippets_dock(qtbot, gui, actions, snippets): qtbot.addWidget(gui) @@ -277,11 +281,14 @@ def test(*args): # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` assert not snippets.is_mode_on() qtbot.keyClicks(gui, ':t2 ') + qtbot.waitForWindowShown(gui) + assert snippets.is_mode_on() qtbot.keyPress(gui, Qt.Key_Backspace) qtbot.keyPress(gui, Qt.Key_Backspace) qtbot.keyClicks(gui, '1 3-5 ab,c') qtbot.keyPress(gui, Qt.Key_Return) + qtbot.waitForWindowShown(gui) assert _actions == [((3, 4, 5), ('ab', 'c'))] From 3ab672d476ca24a56a75707838c37a506ecaa3e7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 22:19:07 +0200 Subject: [PATCH 0188/1059] Disable the two failing tests on Travis --- phy/gui/tests/test_gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index be780bc44..45b701582 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -21,7 +21,7 @@ pytestmark = mark.long # Skip some tests on OS X or on CI systems (Travis). -skip = mark.skipif((platform == "darwin") or not(os.environ.get('CI', None)), +skip = mark.skipif((platform == "darwin") or os.environ.get('CI', None), reason="Some tests don't work on OS X because of a bug " "with QTest (qtbot) keyboard events that don't " "trigger QAction shortcuts. On CI these tests " From 04d701eb22de6622636d9d38ea2016323e93f865 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 22:30:57 +0200 Subject: [PATCH 0189/1059] Try to downgrade coverage --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 36ea1da9f..7b981de1e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ git+https://github.com/pytest-dev/pytest.git flake8 -coverage +coverage=3.7.1 coveralls responses pytest-cov From bac545079ef3676ad38c73dacaae61c15f47229a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 22:34:46 +0200 Subject: [PATCH 0190/1059] Fix --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 7b981de1e..06845e29a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ git+https://github.com/pytest-dev/pytest.git flake8 -coverage=3.7.1 +coverage==3.7.1 coveralls responses pytest-cov From bad43453173b17ab00b3a880a29ad3777d531ac8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 22:47:32 +0200 Subject: [PATCH 0191/1059] WIP: fixing Travis --- phy/gui/tests/test_gui.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 45b701582..620bfd26d 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -21,11 +21,14 @@ pytestmark = mark.long # Skip some tests on OS X or on CI systems (Travis). -skip = mark.skipif((platform == "darwin") or os.environ.get('CI', None), - reason="Some tests don't work on OS X because of a bug " - "with QTest (qtbot) keyboard events that don't " - "trigger QAction shortcuts. On CI these tests " - "fail because the GUI is not displayed.") +skip_mac = mark.skipif(platform == "darwin", + reason="Some tests don't work on OS X because of a bug " + "with QTest (qtbot) keyboard events that don't " + "trigger QAction shortcuts. On CI these tests " + "fail because the GUI is not displayed.") + +skip_ci = mark.skipif(os.environ.get('CI', None) is not None, + reason="Some shortcut-related Qt tests fail on CI.") #------------------------------------------------------------------------------ @@ -112,7 +115,8 @@ def show_my_shortcuts(): actions.remove_all() -@skip +@skip_mac +@skip_ci def test_actions_dock(qtbot, gui, actions): actions.attach(gui) @@ -258,7 +262,8 @@ def _run(cmd): snippets.mode_off() -@skip +@skip_mac +@skip_ci def test_snippets_dock(qtbot, gui, actions, snippets): qtbot.addWidget(gui) From 2542ee8318b2d6f0d6b2c4c92d5446265c5b99b7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 1 Oct 2015 23:04:26 +0200 Subject: [PATCH 0192/1059] WIP: fixing tests on Python 2 --- phy/gui/tests/test_gui.py | 4 ++-- phy/gui/tests/test_qt.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 620bfd26d..8e8e3ac3a 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -197,11 +197,11 @@ def test(arg): with captured_logging() as buf: snippets.run(':t') - assert 'missing 1 required positional argument' in buf.getvalue() + assert 'error' in buf.getvalue().lower() with captured_logging() as buf: snippets.run(':t 1 2') - assert 'takes 1 positional argument but 2 were given' in buf.getvalue() + assert 'error' in buf.getvalue().lower() with captured_logging() as buf: snippets.run(':t aa') diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index 9c76ecf8f..ae8deac3e 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -55,4 +55,4 @@ def test_prompt(qtbot): buttons=['save', 'cancel', 'close'], ) qtbot.mouseClick(box.buttons()[0], QtCore.Qt.LeftButton) - assert 'save' in box.clickedButton().text().lower() + assert 'save' in str(box.clickedButton().text()).lower() From 8a6d32747de64e75d9db916d2d67c8be29b8d285 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 15:03:16 +0200 Subject: [PATCH 0193/1059] WIP: add plugin system --- phy/utils/plugin.py | 125 +++++++++++++++++++++++++++++++++ phy/utils/tests/test_plugin.py | 63 +++++++++++++++++ 2 files changed, 188 insertions(+) create mode 100644 phy/utils/plugin.py create mode 100644 phy/utils/tests/test_plugin.py diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py new file mode 100644 index 000000000..494c6ceac --- /dev/null +++ b/phy/utils/plugin.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +"""Plugin system. + +Code from http://eli.thegreenplace.net/2012/08/07/fundamental-concepts-of-plugin-infrastructures # noqa + +""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import imp +import logging +import os +import os.path as op + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# IPlugin interface +#------------------------------------------------------------------------------ + +class IPluginRegistry(type): + plugins = [] + + def __init__(cls, name, bases, attrs): + if name != 'IPlugin': + logger.debug("Register plugin %s.", name) + plugin_tuple = (cls, cls.file_extensions) + if plugin_tuple not in IPluginRegistry.plugins: + IPluginRegistry.plugins.append(plugin_tuple) + + +class IPlugin(object, metaclass=IPluginRegistry): + format_name = None + file_extensions = () + + def register(self, podoc): + """Called when the plugin is activated with `--plugins`.""" + raise NotImplementedError() + + def register_from(self, podoc): + """Called when the plugin is activated with `--from`.""" + raise NotImplementedError() + + def register_to(self, podoc): + """Called when the plugin is activated with `--to`.""" + raise NotImplementedError() + + +def get_plugin(name_or_ext): + """Get a plugin class from its name or file extension.""" + name_or_ext = name_or_ext.lower() + for (plugin, file_extension) in IPluginRegistry.plugins: + if (name_or_ext in plugin.__name__.lower() or + name_or_ext in file_extension): + return plugin + raise ValueError("The plugin %s cannot be found." % name_or_ext) + + +#------------------------------------------------------------------------------ +# Plugins discovery +#------------------------------------------------------------------------------ + +def discover_plugins(dirs): + """Discover the plugin classes contained in Python files. + + Parameters + ---------- + + dirs : list + List of directory names to scan. + + Returns + ------- + + plugins : list + List of plugin classes. + + """ + # Scan all subdirectories recursively. + for plugin_dir in dirs: + # logger.debug("Scanning %s", plugin_dir) + plugin_dir = op.realpath(plugin_dir) + for subdir, dirs, files in os.walk(plugin_dir): + # Skip test folders. + base = op.basename(subdir) + if 'test' in base or '__' in base: + continue + logger.debug("Scanning %s.", subdir) + for filename in files: + if (filename.startswith('__') or + not filename.endswith('.py')): + continue # pragma: no cover + logger.debug(" Found %s.", filename) + path = os.path.join(subdir, filename) + modname, ext = op.splitext(filename) + file, path, descr = imp.find_module(modname, [subdir]) + if file: + # Loading the module registers the plugin in + # IPluginRegistry + mod = imp.load_module(modname, file, path, descr) # noqa + return IPluginRegistry.plugins + + +def iter_plugins_dirs(): + """Iterate over all plugin directories.""" + curdir = op.dirname(op.realpath(__file__)) + plugins_dir = op.join(curdir, 'plugins') + # TODO: add other plugin directories (user plugins etc.) + names = [name for name in sorted(os.listdir(plugins_dir)) + if not name.startswith(('.', '_')) and + op.isdir(op.join(plugins_dir, name))] + for name in names: + yield op.join(plugins_dir, name) + + +def _load_all_native_plugins(): + """Load all native plugins when importing the library.""" + curdir = op.dirname(op.realpath(__file__)) + plugins_dir = op.join(curdir, 'plugins') + discover_plugins([plugins_dir]) diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py new file mode 100644 index 000000000..271e90748 --- /dev/null +++ b/phy/utils/tests/test_plugin.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +"""Test plugin system.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +from ..plugin import (IPluginRegistry, IPlugin, get_plugin, + iter_plugins_dirs, _load_all_native_plugins) + +from pytest import yield_fixture, raises + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def no_native_plugins(): + # Save the plugins. + plugins = IPluginRegistry.plugins + IPluginRegistry.plugins = [] + yield + IPluginRegistry.plugins = plugins + + +#------------------------------------------------------------------------------ +# Tests +#------------------------------------------------------------------------------ + +def test_plugin_registration(no_native_plugins): + class MyPlugin(IPlugin): + pass + + assert IPluginRegistry.plugins == [(MyPlugin, ())] + + +def test_get_plugin(): + # assert get_plugin('jso').__name__ == 'JSON' + # assert get_plugin('JSO').__name__ == 'JSON' + # assert get_plugin('JSON').__name__ == 'JSON' + # assert get_plugin('json').__name__ == 'JSON' + # assert get_plugin('.json').__name__ == 'JSON' + + # with raises(ValueError): + # assert get_plugin('.jso') is None + # with raises(ValueError): + # assert get_plugin('jsonn') is None + pass + + +def test_iter_plugins_dirs(): + # assert 'json' in [op.basename(plugin_dir) + # for plugin_dir in iter_plugins_dirs()] + pass + + +def test_load_all_native_plugins(no_native_plugins): + _load_all_native_plugins() From d19b76ab0b1cedaf228dc083fcb400e89257aa6c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 15:16:42 +0200 Subject: [PATCH 0194/1059] Clean up plugin system --- phy/utils/plugin.py | 38 ++++++++-------------------------- phy/utils/tests/test_plugin.py | 34 ++++-------------------------- 2 files changed, 13 insertions(+), 59 deletions(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 494c6ceac..6cc0266a6 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -29,36 +29,23 @@ class IPluginRegistry(type): def __init__(cls, name, bases, attrs): if name != 'IPlugin': logger.debug("Register plugin %s.", name) - plugin_tuple = (cls, cls.file_extensions) + plugin_tuple = (cls,) if plugin_tuple not in IPluginRegistry.plugins: IPluginRegistry.plugins.append(plugin_tuple) class IPlugin(object, metaclass=IPluginRegistry): - format_name = None - file_extensions = () + def attach_gui(self, gui): + pass - def register(self, podoc): - """Called when the plugin is activated with `--plugins`.""" - raise NotImplementedError() - def register_from(self, podoc): - """Called when the plugin is activated with `--from`.""" - raise NotImplementedError() - - def register_to(self, podoc): - """Called when the plugin is activated with `--to`.""" - raise NotImplementedError() - - -def get_plugin(name_or_ext): - """Get a plugin class from its name or file extension.""" - name_or_ext = name_or_ext.lower() - for (plugin, file_extension) in IPluginRegistry.plugins: - if (name_or_ext in plugin.__name__.lower() or - name_or_ext in file_extension): +def get_plugin(name): + """Get a plugin class from its name.""" + name = name.lower() + for (plugin,) in IPluginRegistry.plugins: + if name in plugin.__name__.lower(): return plugin - raise ValueError("The plugin %s cannot be found." % name_or_ext) + raise ValueError("The plugin %s cannot be found." % name) #------------------------------------------------------------------------------ @@ -116,10 +103,3 @@ def iter_plugins_dirs(): op.isdir(op.join(plugins_dir, name))] for name in names: yield op.join(plugins_dir, name) - - -def _load_all_native_plugins(): - """Load all native plugins when importing the library.""" - curdir = op.dirname(op.realpath(__file__)) - plugins_dir = op.join(curdir, 'plugins') - discover_plugins([plugins_dir]) diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 271e90748..c54f933f2 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -7,12 +7,9 @@ # Imports #------------------------------------------------------------------------------ -import os.path as op +from ..plugin import (IPluginRegistry, IPlugin, get_plugin) -from ..plugin import (IPluginRegistry, IPlugin, get_plugin, - iter_plugins_dirs, _load_all_native_plugins) - -from pytest import yield_fixture, raises +from pytest import yield_fixture #------------------------------------------------------------------------------ @@ -36,28 +33,5 @@ def test_plugin_registration(no_native_plugins): class MyPlugin(IPlugin): pass - assert IPluginRegistry.plugins == [(MyPlugin, ())] - - -def test_get_plugin(): - # assert get_plugin('jso').__name__ == 'JSON' - # assert get_plugin('JSO').__name__ == 'JSON' - # assert get_plugin('JSON').__name__ == 'JSON' - # assert get_plugin('json').__name__ == 'JSON' - # assert get_plugin('.json').__name__ == 'JSON' - - # with raises(ValueError): - # assert get_plugin('.jso') is None - # with raises(ValueError): - # assert get_plugin('jsonn') is None - pass - - -def test_iter_plugins_dirs(): - # assert 'json' in [op.basename(plugin_dir) - # for plugin_dir in iter_plugins_dirs()] - pass - - -def test_load_all_native_plugins(no_native_plugins): - _load_all_native_plugins() + assert IPluginRegistry.plugins == [(MyPlugin,)] + assert get_plugin('myplugin').__name__ == 'MyPlugin' From 2451aec9c6f755807a2f63abfe3d8e8e5ce74a58 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 15:24:16 +0200 Subject: [PATCH 0195/1059] Update plugin system --- phy/__init__.py | 1 + phy/utils/plugin.py | 17 ++--------------- phy/utils/tests/test_plugin.py | 24 +++++++++++++++++++++--- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index b7324d977..e3813bb9b 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -16,6 +16,7 @@ from .io.datasets import download_file, download_sample_data from .utils._misc import _git_version +from .utils.plugin import IPlugin #------------------------------------------------------------------------------ diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 6cc0266a6..8e631317a 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -35,7 +35,7 @@ def __init__(cls, name, bases, attrs): class IPlugin(object, metaclass=IPluginRegistry): - def attach_gui(self, gui): + def attach_gui(self, gui): # pragma: no cover pass @@ -70,12 +70,11 @@ def discover_plugins(dirs): """ # Scan all subdirectories recursively. for plugin_dir in dirs: - # logger.debug("Scanning %s", plugin_dir) plugin_dir = op.realpath(plugin_dir) for subdir, dirs, files in os.walk(plugin_dir): # Skip test folders. base = op.basename(subdir) - if 'test' in base or '__' in base: + if 'test' in base or '__' in base: # pragma: no cover continue logger.debug("Scanning %s.", subdir) for filename in files: @@ -91,15 +90,3 @@ def discover_plugins(dirs): # IPluginRegistry mod = imp.load_module(modname, file, path, descr) # noqa return IPluginRegistry.plugins - - -def iter_plugins_dirs(): - """Iterate over all plugin directories.""" - curdir = op.dirname(op.realpath(__file__)) - plugins_dir = op.join(curdir, 'plugins') - # TODO: add other plugin directories (user plugins etc.) - names = [name for name in sorted(os.listdir(plugins_dir)) - if not name.startswith(('.', '_')) and - op.isdir(op.join(plugins_dir, name))] - for name in names: - yield op.join(plugins_dir, name) diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index c54f933f2..15efb397f 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -7,9 +7,13 @@ # Imports #------------------------------------------------------------------------------ -from ..plugin import (IPluginRegistry, IPlugin, get_plugin) +import os.path as op -from pytest import yield_fixture +from ..plugin import (IPluginRegistry, IPlugin, get_plugin, + discover_plugins, + ) + +from pytest import yield_fixture, raises #------------------------------------------------------------------------------ @@ -29,9 +33,23 @@ def no_native_plugins(): # Tests #------------------------------------------------------------------------------ -def test_plugin_registration(no_native_plugins): +def test_plugin_1(no_native_plugins): class MyPlugin(IPlugin): pass assert IPluginRegistry.plugins == [(MyPlugin,)] assert get_plugin('myplugin').__name__ == 'MyPlugin' + + with raises(ValueError): + get_plugin('unknown') + + +def test_discover_plugins(tempdir, no_native_plugins): + path = op.join(tempdir, 'my_plugin.py') + contents = '''from phy import IPlugin\nclass MyPlugin(IPlugin): pass''' + with open(path, 'w') as f: + f.write(contents) + + plugins = discover_plugins([tempdir]) + assert plugins + assert plugins[0][0].__name__ == 'MyPlugin' From e3eb842c8677154999657d4443e56eb667b599db Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 15:51:10 +0200 Subject: [PATCH 0196/1059] WIP: split gui module in gui and actions modules --- phy/gui/actions.py | 368 ++++++++++++++++++++++++++++++++++ phy/gui/gui.py | 326 +----------------------------- phy/gui/tests/test_actions.py | 216 ++++++++++++++++++++ phy/gui/tests/test_gui.py | 194 +----------------- 4 files changed, 591 insertions(+), 513 deletions(-) create mode 100644 phy/gui/actions.py create mode 100644 phy/gui/tests/test_actions.py diff --git a/phy/gui/actions.py b/phy/gui/actions.py new file mode 100644 index 000000000..d5a120cde --- /dev/null +++ b/phy/gui/actions.py @@ -0,0 +1,368 @@ +# -*- coding: utf-8 -*- + +"""Actions and snippets.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +import logging + +from six import string_types, PY3 + +from .qt import QtGui +from phy.utils.event import EventEmitter + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Snippet parsing utilities +# ----------------------------------------------------------------------------- + +def _parse_arg(s): + """Parse a number or string.""" + try: + return int(s) + except ValueError: + pass + try: + return float(s) + except ValueError: + pass + return s + + +def _parse_list(s): + """Parse a comma-separated list of values (strings or numbers).""" + # Range: 'x-y' + if '-' in s: + m, M = map(_parse_arg, s.split('-')) + return tuple(range(m, M + 1)) + # List of ids: 'x,y,z' + elif ',' in s: + return tuple(map(_parse_arg, s.split(','))) + else: + return _parse_arg(s) + + +def _parse_snippet(s): + """Parse an entire snippet command.""" + return list(map(_parse_list, s.split(' '))) + + +# ----------------------------------------------------------------------------- +# Show shortcut utility functions +# ----------------------------------------------------------------------------- + +def _show_shortcut(shortcut): + if isinstance(shortcut, string_types): + return shortcut + elif isinstance(shortcut, (tuple, list)): + return ', '.join(shortcut) + + +def _show_shortcuts(shortcuts, name=None): + name = name or '' + print() + if name: + name = ' for ' + name + print('Keyboard shortcuts' + name) + for name in sorted(shortcuts): + print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) + print() + + +# ----------------------------------------------------------------------------- +# Actions +# ----------------------------------------------------------------------------- + +class Actions(EventEmitter): + """Handle GUI actions. + + This class attaches to a GUI and implements the following features: + + * Add and remove actions + * Keyboard shortcuts for the actions + * Display all shortcuts + + """ + def __init__(self): + super(Actions, self).__init__() + self._gui = None + self._actions = {} + + def reset(self): + """Reset the actions. + + All actions should be registered here, as follows: + + ```python + @actions.connect + def on_reset(): + actions.add(...) + actions.add(...) + ... + ``` + + """ + self.remove_all() + self.emit('reset') + + def attach(self, gui): + """Attach a GUI.""" + self._gui = gui + + # Register default actions. + @self.connect + def on_reset(): + # Default exit action. + @self.shortcut('ctrl+q') + def exit(): + gui.close() + + def add(self, name, callback=None, shortcut=None, alias=None, + checkable=False, checked=False): + """Add an action with a keyboard shortcut.""" + # TODO: add menu_name option and create menu bar + # Get the alias from the character after & if it exists. + if alias is None: + alias = name[name.index('&') + 1] if '&' in name else name + name = name.replace('&', '') + if name in self._actions: + return + action = QtGui.QAction(name, self._gui) + action.triggered.connect(callback) + action.setCheckable(checkable) + action.setChecked(checked) + if shortcut: + if not isinstance(shortcut, (tuple, list)): + shortcut = [shortcut] + for key in shortcut: + action.setShortcut(key) + + # Add some attributes to the QAction instance. + # The alias is used in snippets. + action._alias = alias + action._callback = callback + action._name = name + # HACK: add the shortcut string to the QAction object so that + # it can be shown in show_shortcuts(). I don't manage to recover + # the key sequence string from a QAction using Qt. + action._shortcut_string = shortcut or '' + + # Register the action. + if self._gui: + self._gui.addAction(action) + self._actions[name] = action + + # Log the creation of the action. + if not name.startswith('_'): + logger.debug("Add action `%s`, alias `%s`, shortcut %s.", + name, alias, shortcut or '') + + if callback: + setattr(self, name, callback) + return action + + def get_name(self, alias_or_name): + """Return an action name from its alias or name.""" + for name, action in self._actions.items(): + if alias_or_name in (action._alias, name): + return name + raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) + + def run(self, action, *args): + """Run an action, specified by its name or object.""" + if isinstance(action, string_types): + name = self.get_name(action) + assert name in self._actions + action = self._actions[name] + else: + name = action._name + if not name.startswith('_'): + logger.debug("Execute action `%s`.", name) + return action._callback(*args) + + def remove(self, name): + """Remove an action.""" + if self._gui: + self._gui.removeAction(self._actions[name]) + del self._actions[name] + delattr(self, name) + + def remove_all(self): + """Remove all actions.""" + names = sorted(self._actions.keys()) + for name in names: + self.remove(name) + + @property + def shortcuts(self): + """A dictionary of action shortcuts.""" + return {name: action._shortcut_string + for name, action in self._actions.items()} + + def show_shortcuts(self): + """Print all shortcuts.""" + _show_shortcuts(self.shortcuts, + self._gui.title() if self._gui else None) + + def shortcut(self, key=None, name=None, **kwargs): + """Decorator to add a global keyboard shortcut.""" + def wrap(func): + self.add(name or func.__name__, shortcut=key, + callback=func, **kwargs) + return wrap + + +# ----------------------------------------------------------------------------- +# Snippets +# ----------------------------------------------------------------------------- + +class Snippets(object): + """Provide keyboard snippets to quickly execute actions from a GUI. + + This class attaches to a GUI and an `Actions` instance. To every command + is associated a snippet with the same name, or with an alias as indicated + in the action. The arguments of the action's callback functions can be + provided in the snippet's command with a simple syntax. For example, the + following command: + + ``` + :my_action string 3-6 + ``` + + corresponds to: + + ```python + my_action('string', (3, 4, 5, 6)) + ``` + + The snippet mode is activated with the `:` keyboard shortcut. A snippet + command is activated with `Enter`, and one can leave the snippet mode + with `Escape`. + + """ + + # HACK: Unicode characters do not appear to work on Python 2 + cursor = '\u200A\u258C' if PY3 else '' + + # Allowed characters in snippet mode. + # A Qt shortcut will be created for every character. + _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" + " ,.;?!_-+~=*/\(){}[]") + + def __init__(self): + self._gui = None + self._cmd = '' # only used when there is no gui attached + + def attach(self, gui, actions): + self._gui = gui + self._actions = actions + + # Register snippet mode shortcut. + @actions.connect + def on_reset(): + @actions.shortcut(':') + def enable_snippet_mode(): + self.mode_on() + + @property + def command(self): + """This is used to write a snippet message in the status bar. + + A cursor is appended at the end. + + """ + msg = self._gui.status_message if self._gui else self._cmd + n = len(msg) + n_cur = len(self.cursor) + return msg[:n - n_cur] + + @command.setter + def command(self, value): + value += self.cursor + if not self._gui: + self._cmd = value + else: + self._gui.status_message = value + + def _backspace(self): + """Erase the last character in the snippet command.""" + if self.command == ':': + return + logger.debug("Snippet keystroke `Backspace`.") + self.command = self.command[:-1] + + def _enter(self): + """Disable the snippet mode and execute the command.""" + command = self.command + logger.debug("Snippet keystroke `Enter`.") + self.mode_off() + self.run(command) + + def _create_snippet_actions(self): + """Delete all existing actions, and add mock ones for snippet + keystrokes. + + Used to enable snippet mode. + + """ + self._actions.remove_all() + + # One action per allowed character. + for i, char in enumerate(self._snippet_chars): + + def _make_func(char): + def callback(): + logger.debug("Snippet keystroke `%s`.", char) + self.command += char + return callback + + self._actions.add('_snippet_{}'.format(i), shortcut=char, + callback=_make_func(char)) + + self._actions.add('_snippet_backspace', shortcut='backspace', + callback=self._backspace) + self._actions.add('_snippet_activate', shortcut=('enter', 'return'), + callback=self._enter) + self._actions.add('_snippet_disable', shortcut='escape', + callback=self.mode_off) + + def run(self, snippet): + """Executes a snippet command. + + May be overridden. + + """ + assert snippet[0] == ':' + snippet = snippet[1:] + snippet_args = _parse_snippet(snippet) + alias = snippet_args[0] + name = self._actions.get_name(alias) + assert name + func = getattr(self._actions, name) + try: + logger.info("Processing snippet `%s`.", snippet) + func(*snippet_args[1:]) + except Exception as e: + logger.warn("Error when executing snippet: %s.", str(e)) + + def is_mode_on(self): + return self.command.startswith(':') + + def mode_on(self): + logger.info("Snippet mode enabled, press `escape` to leave this mode.") + # Remove all existing actions, and replace them by snippet keystroke + # actions. + self._create_snippet_actions() + self.command = ':' + + def mode_off(self): + if self._gui: + self._gui.status_message = '' + logger.info("Snippet mode disabled.") + # Reestablishes the shortcuts. + self._actions.reset() diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 94fd02527..03ce1b0c2 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -2,6 +2,7 @@ """Qt dock window.""" + # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- @@ -9,341 +10,20 @@ from collections import defaultdict import logging -from six import string_types, PY3 - from .qt import QtCore, QtGui -from ..utils.event import EventEmitter +from phy.utils.event import EventEmitter logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -# Qt utilities +# GUI main window # ----------------------------------------------------------------------------- def _title(widget): return str(widget.windowTitle()).lower() -def _show_shortcut(shortcut): - if isinstance(shortcut, string_types): - return shortcut - elif isinstance(shortcut, (tuple, list)): - return ', '.join(shortcut) - - -def _show_shortcuts(shortcuts, name=None): - name = name or '' - print() - if name: - name = ' for ' + name - print('Keyboard shortcuts' + name) - for name in sorted(shortcuts): - print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) - print() - - -# ----------------------------------------------------------------------------- -# Snippet parsing utilities -# ----------------------------------------------------------------------------- - -def _parse_arg(s): - """Parse a number or string.""" - try: - return int(s) - except ValueError: - pass - try: - return float(s) - except ValueError: - pass - return s - - -def _parse_list(s): - """Parse a comma-separated list of values (strings or numbers).""" - # Range: 'x-y' - if '-' in s: - m, M = map(_parse_arg, s.split('-')) - return tuple(range(m, M + 1)) - # List of ids: 'x,y,z' - elif ',' in s: - return tuple(map(_parse_arg, s.split(','))) - else: - return _parse_arg(s) - - -def _parse_snippet(s): - """Parse an entire snippet command.""" - return list(map(_parse_list, s.split(' '))) - - -# ----------------------------------------------------------------------------- -# Companion class -# ----------------------------------------------------------------------------- - -class Actions(EventEmitter): - """Handle GUI actions.""" - def __init__(self): - super(Actions, self).__init__() - self._gui = None - self._actions = {} - - def reset(self): - """Reset the actions. - - All actions should be registered here, as follows: - - ```python - @actions.connect - def on_reset(): - actions.add(...) - actions.add(...) - ... - ``` - - """ - self.remove_all() - self.emit('reset') - - def attach(self, gui): - """Attach a GUI.""" - self._gui = gui - - # Register default actions. - @self.connect - def on_reset(): - # Default exit action. - @self.shortcut('ctrl+q') - def exit(): - gui.close() - - def add(self, name, callback=None, shortcut=None, alias=None, - checkable=False, checked=False): - """Add an action with a keyboard shortcut.""" - # TODO: add menu_name option and create menu bar - # Get the alias from the character after & if it exists. - if alias is None: - alias = name[name.index('&') + 1] if '&' in name else name - name = name.replace('&', '') - if name in self._actions: - return - action = QtGui.QAction(name, self._gui) - action.triggered.connect(callback) - action.setCheckable(checkable) - action.setChecked(checked) - if shortcut: - if not isinstance(shortcut, (tuple, list)): - shortcut = [shortcut] - for key in shortcut: - action.setShortcut(key) - - # Add some attributes to the QAction instance. - # The alias is used in snippets. - action._alias = alias - action._callback = callback - action._name = name - # HACK: add the shortcut string to the QAction object so that - # it can be shown in show_shortcuts(). I don't manage to recover - # the key sequence string from a QAction using Qt. - action._shortcut_string = shortcut or '' - - # Register the action. - if self._gui: - self._gui.addAction(action) - self._actions[name] = action - - # Log the creation of the action. - if not name.startswith('_'): - logger.debug("Add action `%s`, alias `%s`, shortcut %s.", - name, alias, shortcut or '') - - if callback: - setattr(self, name, callback) - return action - - def get_name(self, alias_or_name): - """Return an action name from its alias or name.""" - for name, action in self._actions.items(): - if alias_or_name in (action._alias, name): - return name - raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) - - def run(self, action, *args): - """Run an action, specified by its name or object.""" - if isinstance(action, string_types): - name = self.get_name(action) - assert name in self._actions - action = self._actions[name] - else: - name = action._name - if not name.startswith('_'): - logger.debug("Execute action `%s`.", name) - return action._callback(*args) - - def remove(self, name): - """Remove an action.""" - if self._gui: - self._gui.removeAction(self._actions[name]) - del self._actions[name] - delattr(self, name) - - def remove_all(self): - """Remove all actions.""" - names = sorted(self._actions.keys()) - for name in names: - self.remove(name) - - @property - def shortcuts(self): - """A dictionary of action shortcuts.""" - return {name: action._shortcut_string - for name, action in self._actions.items()} - - def show_shortcuts(self): - """Print all shortcuts.""" - _show_shortcuts(self.shortcuts, - self._gui.title() if self._gui else None) - - def shortcut(self, key=None, name=None, **kwargs): - """Decorator to add a global keyboard shortcut.""" - def wrap(func): - self.add(name or func.__name__, shortcut=key, - callback=func, **kwargs) - return wrap - - -class Snippets(object): - # HACK: Unicode characters do not appear to work on Python 2 - cursor = '\u200A\u258C' if PY3 else '' - - # Allowed characters in snippet mode. - # A Qt shortcut will be created for every character. - _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" - " ,.;?!_-+~=*/\(){}[]") - - def __init__(self): - self._gui = None - self._cmd = '' # only used when there is no gui attached - - def attach(self, gui, actions): - self._gui = gui - self._actions = actions - - # Register snippet mode shortcut. - @actions.connect - def on_reset(): - @actions.shortcut(':') - def enable_snippet_mode(): - self.mode_on() - - @property - def command(self): - """This is used to write a snippet message in the status bar. - - A cursor is appended at the end. - - """ - msg = self._gui.status_message if self._gui else self._cmd - n = len(msg) - n_cur = len(self.cursor) - return msg[:n - n_cur] - - @command.setter - def command(self, value): - value += self.cursor - if not self._gui: - self._cmd = value - else: - self._gui.status_message = value - - def _backspace(self): - """Erase the last character in the snippet command.""" - if self.command == ':': - return - logger.debug("Snippet keystroke `Backspace`.") - self.command = self.command[:-1] - - def _enter(self): - """Disable the snippet mode and execute the command.""" - command = self.command - logger.debug("Snippet keystroke `Enter`.") - self.mode_off() - self.run(command) - - def _create_snippet_actions(self): - """Delete all existing actions, and add mock ones for snippet - keystrokes. - - Used to enable snippet mode. - - """ - self._actions.remove_all() - - # One action per allowed character. - for i, char in enumerate(self._snippet_chars): - - def _make_func(char): - def callback(): - logger.debug("Snippet keystroke `%s`.", char) - self.command += char - return callback - - self._actions.add('_snippet_{}'.format(i), shortcut=char, - callback=_make_func(char)) - - self._actions.add('_snippet_backspace', shortcut='backspace', - callback=self._backspace) - self._actions.add('_snippet_activate', shortcut=('enter', 'return'), - callback=self._enter) - self._actions.add('_snippet_disable', shortcut='escape', - callback=self.mode_off) - - def run(self, snippet): - """Executes a snippet command. - - May be overridden. - - """ - assert snippet[0] == ':' - snippet = snippet[1:] - snippet_args = _parse_snippet(snippet) - alias = snippet_args[0] - try: - name = self._actions.get_name(alias) - except ValueError: - logger.warn("The action %s could not be found.", alias) - return - assert name - func = getattr(self._actions, name) - try: - logger.info("Processing snippet `%s`.", snippet) - func(*snippet_args[1:]) - except Exception as e: - logger.warn("Error when executing snippet: %s.", str(e)) - - def is_mode_on(self): - return self.command.startswith(':') - - def mode_on(self): - logger.info("Snippet mode enabled, press `escape` to leave this mode.") - # Remove all existing actions, and replace them by snippet keystroke - # actions. - self._create_snippet_actions() - self.command = ':' - - def mode_off(self): - if self._gui: - self._gui.status_message = '' - logger.info("Snippet mode disabled.") - # Reestablishes the shortcuts. - self._actions.reset() - - -# ----------------------------------------------------------------------------- -# Qt windows -# ----------------------------------------------------------------------------- - class DockWidget(QtGui.QDockWidget): """A QDockWidget that can emit events.""" def __init__(self, *args, **kwargs): diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py new file mode 100644 index 000000000..4c97d5f5d --- /dev/null +++ b/phy/gui/tests/test_actions.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- + +"""Test dock.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import mark, raises, yield_fixture + +from ..actions import _show_shortcuts, Actions, Snippets, _parse_snippet +from phy.utils._color import _random_color +from phy.utils.testing import captured_output, captured_logging + +# Skip these tests in "make test-quick". +pytestmark = mark.long + + +#------------------------------------------------------------------------------ +# Utilities and fixtures +#------------------------------------------------------------------------------ + +def _create_canvas(): + """Create a VisPy canvas with a color background.""" + from vispy import app + c = app.Canvas() + c.color = _random_color() + + @c.connect + def on_draw(e): # pragma: no cover + c.context.clear(c.color) + + return c + + +@yield_fixture +def actions(): + yield Actions() + + +@yield_fixture +def snippets(): + yield Snippets() + + +#------------------------------------------------------------------------------ +# Test actions +#------------------------------------------------------------------------------ + +def test_shortcuts(): + shortcuts = { + 'test_1': 'ctrl+t', + 'test_2': ('ctrl+a', 'shift+b'), + } + with captured_output() as (stdout, stderr): + _show_shortcuts(shortcuts, 'test') + assert 'ctrl+a, shift+b' in stdout.getvalue() + + +def test_actions_simple(actions): + + _res = [] + + def _action(*args): + _res.append(args) + + actions.add('tes&t', _action) + # Adding an action twice has no effect. + actions.add('test', _action) + + # Create a shortcut and display it. + _captured = [] + + @actions.shortcut('h') + def show_my_shortcuts(): + with captured_output() as (stdout, stderr): + actions.show_shortcuts() + _captured.append(stdout.getvalue()) + + actions.show_my_shortcuts() + assert 'show_my_shortcuts' in _captured[0] + assert ': h' in _captured[0] + + with raises(ValueError): + assert actions.get_name('e') + assert actions.get_name('t') == 'test' + assert actions.get_name('test') == 'test' + + actions.run('t', 1) + assert _res == [(1,)] + + # Run an action instance. + actions.run(actions._actions['test'], 1) + + actions.remove_all() + + +#------------------------------------------------------------------------------ +# Test snippets +#------------------------------------------------------------------------------ + +def test_snippets_parse(): + def _check(args, expected): + snippet = 'snip ' + args + assert _parse_snippet(snippet) == ['snip'] + expected + + _check('a', ['a']) + _check('abc', ['abc']) + _check('a,b,c', [('a', 'b', 'c')]) + _check('a b,c', ['a', ('b', 'c')]) + + _check('1', [1]) + _check('10', [10]) + + _check('1.', [1.]) + _check('10.', [10.]) + _check('10.0', [10.0]) + + _check('0 1', [0, 1]) + _check('0 1.', [0, 1.]) + _check('0 1.0', [0, 1.]) + + _check('0,1', [(0, 1)]) + _check('0,10.', [(0, 10.)]) + _check('0. 1,10.', [0., (1, 10.)]) + + _check('2-7', [(2, 3, 4, 5, 6, 7)]) + _check('2 3-5', [2, (3, 4, 5)]) + + _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) + + +def test_snippets_errors(actions, snippets): + + _actions = [] + + @actions.connect + def on_reset(): + @actions.shortcut(name='my_test', alias='t') + def test(arg): + # Enforce single-character argument. + assert len(str(arg)) == 1 + _actions.append(arg) + + # Attach the GUI and register the actions. + snippets.attach(None, actions) + actions.reset() + + with raises(ValueError): + snippets.run(':t1') + + with captured_logging() as buf: + snippets.run(':t') + assert 'missing 1 required positional argument' in buf.getvalue() + + with captured_logging() as buf: + snippets.run(':t 1 2') + assert 'takes 1 positional argument but 2 were given' in buf.getvalue() + + with captured_logging() as buf: + snippets.run(':t aa') + assert 'assert 2 == 1' in buf.getvalue() + + snippets.run(':t a') + assert _actions == ['a'] + + +def test_snippets_actions(actions, snippets): + + _actions = [] + + @actions.connect + def on_reset(): + @actions.shortcut(name='my_test_1') + def test_1(*args): + _actions.append((1, args)) + + @actions.shortcut(name='my_&test_2') + def test_2(*args): + _actions.append((2, args)) + + @actions.shortcut(name='my_test_3', alias='t3') + def test_3(*args): + _actions.append((3, args)) + + # Attach the GUI and register the actions. + snippets.attach(None, actions) + actions.reset() + + assert snippets.command == '' + + # Action 1. + snippets.run(':my_test_1') + assert _actions == [(1, ())] + + # Action 2. + snippets.run(':t 1.5 a 2-4 5,7') + assert _actions[-1] == (2, (1.5, 'a', (2, 3, 4), (5, 7))) + + def _run(cmd): + """Simulate keystrokes.""" + for char in cmd: + i = snippets._snippet_chars.index(char) + actions.run('_snippet_{}'.format(i)) + + # Need to activate the snippet mode first. + with raises(ValueError): + _run(':t3 hello') + + # Simulate keystrokes ':t3 hello' + snippets.mode_on() # ':' + actions._snippet_backspace() + _run('t3 hello') + actions._snippet_activate() # 'Enter' + assert _actions[-1] == (3, ('hello',)) + snippets.mode_off() diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 8e8e3ac3a..9853aa3c5 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,16 +6,12 @@ # Imports #------------------------------------------------------------------------------ -import os -from sys import platform - -from pytest import mark, raises, yield_fixture +from pytest import mark, yield_fixture from ..qt import Qt -from ..gui import (GUI, _show_shortcuts, Actions, Snippets, - _parse_snippet) +from ..gui import GUI from phy.utils._color import _random_color -from phy.utils.testing import captured_output, captured_logging +from .test_actions import actions, snippets # Skip these tests in "make test-quick". pytestmark = mark.long @@ -53,70 +49,10 @@ def gui(): yield GUI(position=(200, 100), size=(100, 100)) -@yield_fixture -def actions(): - yield Actions() - - -@yield_fixture -def snippets(): - yield Snippets() - - #------------------------------------------------------------------------------ -# Test actions +# Test actions and snippet #------------------------------------------------------------------------------ -def test_shortcuts(): - shortcuts = { - 'test_1': 'ctrl+t', - 'test_2': ('ctrl+a', 'shift+b'), - } - with captured_output() as (stdout, stderr): - _show_shortcuts(shortcuts, 'test') - assert 'ctrl+a, shift+b' in stdout.getvalue() - - -def test_actions_simple(actions): - - _res = [] - - def _action(*args): - _res.append(args) - - actions.add('tes&t', _action) - # Adding an action twice has no effect. - actions.add('test', _action) - - # Create a shortcut and display it. - _captured = [] - - @actions.shortcut('h') - def show_my_shortcuts(): - with captured_output() as (stdout, stderr): - actions.show_shortcuts() - _captured.append(stdout.getvalue()) - - actions.show_my_shortcuts() - assert 'show_my_shortcuts' in _captured[0] - assert ': h' in _captured[0] - - with raises(ValueError): - assert actions.get_name('e') - assert actions.get_name('t') == 'test' - assert actions.get_name('test') == 'test' - - actions.run('t', 1) - assert _res == [(1,)] - - # Run an action instance. - actions.run(actions._actions['test'], 1) - - actions.remove_all() - - -@skip_mac -@skip_ci def test_actions_dock(qtbot, gui, actions): actions.attach(gui) @@ -142,128 +78,6 @@ def press(): qtbot.waitForWindowShown(gui) -#------------------------------------------------------------------------------ -# Test snippets -#------------------------------------------------------------------------------ - -def test_snippets_parse(): - def _check(args, expected): - snippet = 'snip ' + args - assert _parse_snippet(snippet) == ['snip'] + expected - - _check('a', ['a']) - _check('abc', ['abc']) - _check('a,b,c', [('a', 'b', 'c')]) - _check('a b,c', ['a', ('b', 'c')]) - - _check('1', [1]) - _check('10', [10]) - - _check('1.', [1.]) - _check('10.', [10.]) - _check('10.0', [10.0]) - - _check('0 1', [0, 1]) - _check('0 1.', [0, 1.]) - _check('0 1.0', [0, 1.]) - - _check('0,1', [(0, 1)]) - _check('0,10.', [(0, 10.)]) - _check('0. 1,10.', [0., (1, 10.)]) - - _check('2-7', [(2, 3, 4, 5, 6, 7)]) - _check('2 3-5', [2, (3, 4, 5)]) - - _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) - - -def test_snippets_errors(qtbot, actions, snippets): - - _actions = [] - - @actions.connect - def on_reset(): - @actions.shortcut(name='my_test', alias='t') - def test(arg): - # Enforce single-character argument. - assert len(str(arg)) == 1 - _actions.append(arg) - - # Attach the GUI and register the actions. - snippets.attach(None, actions) - actions.reset() - - snippets.run(':t1') - - with captured_logging() as buf: - snippets.run(':t') - assert 'error' in buf.getvalue().lower() - - with captured_logging() as buf: - snippets.run(':t 1 2') - assert 'error' in buf.getvalue().lower() - - with captured_logging() as buf: - snippets.run(':t aa') - assert 'assert 2 == 1' in buf.getvalue() - - snippets.run(':t a') - assert _actions == ['a'] - - -def test_snippets_actions(qtbot, actions, snippets): - - _actions = [] - - @actions.connect - def on_reset(): - @actions.shortcut(name='my_test_1') - def test_1(*args): - _actions.append((1, args)) - - @actions.shortcut(name='my_&test_2') - def test_2(*args): - _actions.append((2, args)) - - @actions.shortcut(name='my_test_3', alias='t3') - def test_3(*args): - _actions.append((3, args)) - - # Attach the GUI and register the actions. - snippets.attach(None, actions) - actions.reset() - - assert snippets.command == '' - - # Action 1. - snippets.run(':my_test_1') - assert _actions == [(1, ())] - - # Action 2. - snippets.run(':t 1.5 a 2-4 5,7') - assert _actions[-1] == (2, (1.5, 'a', (2, 3, 4), (5, 7))) - - def _run(cmd): - """Simulate keystrokes.""" - for char in cmd: - i = snippets._snippet_chars.index(char) - actions.run('_snippet_{}'.format(i)) - - # Need to activate the snippet mode first. - with raises(ValueError): - _run(':t3 hello') - - # Simulate keystrokes ':t3 hello' - snippets.mode_on() # ':' - actions._snippet_backspace() - _run('t3 hello') - actions._snippet_activate() # 'Enter' - assert _actions[-1] == (3, ('hello',)) - snippets.mode_off() - - -@skip_mac -@skip_ci def test_snippets_dock(qtbot, gui, actions, snippets): qtbot.addWidget(gui) From b1cecbfc330f4b1648918ea49960a53c96fcb0ef Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 15:51:46 +0200 Subject: [PATCH 0197/1059] WIP: split gui module in gui and actions modules --- phy/gui/tests/test_gui.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 9853aa3c5..4af871994 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -11,7 +11,7 @@ from ..qt import Qt from ..gui import GUI from phy.utils._color import _random_color -from .test_actions import actions, snippets +from .test_actions import actions, snippets # noqa # Skip these tests in "make test-quick". pytestmark = mark.long @@ -74,8 +74,7 @@ def press(): assert _press == [0] # Quit the GUI. - qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) - qtbot.waitForWindowShown(gui) + qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) # noqa def test_snippets_dock(qtbot, gui, actions, snippets): @@ -109,7 +108,7 @@ def test(*args): qtbot.keyPress(gui, Qt.Key_Return) qtbot.waitForWindowShown(gui) - assert _actions == [((3, 4, 5), ('ab', 'c'))] + assert _actions == [((3, 4, 5), ('ab', 'c'))] # noqa #------------------------------------------------------------------------------ From 267d3c648a36a23e5e3a9dad51b257891ce5b715 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 15:53:11 +0200 Subject: [PATCH 0198/1059] Tests pass --- phy/gui/tests/test_actions.py | 13 ------------- phy/gui/tests/test_gui.py | 8 ++++---- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 4c97d5f5d..feedb8ea2 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -20,19 +20,6 @@ # Utilities and fixtures #------------------------------------------------------------------------------ -def _create_canvas(): - """Create a VisPy canvas with a color background.""" - from vispy import app - c = app.Canvas() - c.color = _random_color() - - @c.connect - def on_draw(e): # pragma: no cover - c.context.clear(c.color) - - return c - - @yield_fixture def actions(): yield Actions() diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 4af871994..5d9d94338 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -53,7 +53,7 @@ def gui(): # Test actions and snippet #------------------------------------------------------------------------------ -def test_actions_dock(qtbot, gui, actions): +def test_actions_dock(qtbot, gui, actions): # noqa actions.attach(gui) # Set the default actions. @@ -74,10 +74,10 @@ def press(): assert _press == [0] # Quit the GUI. - qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) # noqa + qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) -def test_snippets_dock(qtbot, gui, actions, snippets): +def test_snippets_dock(qtbot, gui, actions, snippets): # noqa qtbot.addWidget(gui) gui.show() @@ -108,7 +108,7 @@ def test(*args): qtbot.keyPress(gui, Qt.Key_Return) qtbot.waitForWindowShown(gui) - assert _actions == [((3, 4, 5), ('ab', 'c'))] # noqa + assert _actions == [((3, 4, 5), ('ab', 'c'))] #------------------------------------------------------------------------------ From 44ae5354ee4ac488f96ec89d8c4cb2640da9853e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 16:08:29 +0200 Subject: [PATCH 0199/1059] Add GUI.attach(plugin) --- phy/gui/gui.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 03ce1b0c2..7693365bf 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -12,6 +12,7 @@ from .qt import QtCore, QtGui from phy.utils.event import EventEmitter +from phy.utils.plugin import get_plugin logger = logging.getLogger(__name__) @@ -84,6 +85,10 @@ def __init__(self, self._status_bar = QtGui.QStatusBar() self.setStatusBar(self._status_bar) + def attach(self, plugin_name): + plugin = get_plugin(name) + plugin.attach_gui(self) + # Events # ------------------------------------------------------------------------- From 54123ebe8ec287ffcde874aab8e6fd2330ba744f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 16:59:07 +0200 Subject: [PATCH 0200/1059] Test GUI plugin --- phy/gui/gui.py | 3 ++- phy/gui/tests/test_gui.py | 27 +++++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 7693365bf..42c47ab39 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -86,7 +86,8 @@ def __init__(self, self.setStatusBar(self._status_bar) def attach(self, plugin_name): - plugin = get_plugin(name) + """Attach a plugin to the GUI.""" + plugin = get_plugin(plugin_name)() plugin.attach_gui(self) # Events diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 5d9d94338..6551c281a 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Test dock.""" +"""Test gui.""" #------------------------------------------------------------------------------ # Imports @@ -11,6 +11,7 @@ from ..qt import Qt from ..gui import GUI from phy.utils._color import _random_color +from phy.utils.plugin import IPlugin from .test_actions import actions, snippets # noqa # Skip these tests in "make test-quick". @@ -53,7 +54,7 @@ def gui(): # Test actions and snippet #------------------------------------------------------------------------------ -def test_actions_dock(qtbot, gui, actions): # noqa +def test_actions_gui(qtbot, gui, actions): # noqa actions.attach(gui) # Set the default actions. @@ -77,7 +78,7 @@ def press(): qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) -def test_snippets_dock(qtbot, gui, actions, snippets): # noqa +def test_snippets_gui(qtbot, gui, actions, snippets): # noqa qtbot.addWidget(gui) gui.show() @@ -112,10 +113,10 @@ def test(*args): #------------------------------------------------------------------------------ -# Test dock +# Test gui #------------------------------------------------------------------------------ -def test_dock_1(qtbot): +def test_gui_1(qtbot): gui = GUI(position=(200, 100), size=(100, 100)) qtbot.addWidget(gui) @@ -136,7 +137,7 @@ def on_show_gui(): assert len(gui.list_views('view')) == 2 - # Check that the close_widget event is fired when the dock widget is + # Check that the close_widget event is fired when the gui widget is # closed. _close = [] @@ -149,7 +150,17 @@ def on_close_widget(): gui.close() -def test_dock_status_message(qtbot): +def test_gui_plugin(qtbot, gui): + + class TestPlugin(IPlugin): + def attach_gui(self, gui): + gui._attached = True + + gui.attach('testplugin') + assert gui._attached + + +def test_gui_status_message(qtbot): gui = GUI() qtbot.addWidget(gui) assert gui.status_message == '' @@ -157,7 +168,7 @@ def test_dock_status_message(qtbot): assert gui.status_message == ':hello world!' -def test_dock_state(qtbot): +def test_gui_state(qtbot): _gs = [] gui = GUI(size=(100, 100)) qtbot.addWidget(gui) From 38134d0bcac237664721f6ddda221c2222f8d44f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 17:10:31 +0200 Subject: [PATCH 0201/1059] WIP: GUI plugins --- phy/gui/gui.py | 4 ++-- phy/gui/tests/test_gui.py | 5 +++-- phy/utils/plugin.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 42c47ab39..0dad8cd53 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -85,10 +85,10 @@ def __init__(self, self._status_bar = QtGui.QStatusBar() self.setStatusBar(self._status_bar) - def attach(self, plugin_name): + def attach(self, plugin_name, *args, **kwargs): """Attach a plugin to the GUI.""" plugin = get_plugin(plugin_name)() - plugin.attach_gui(self) + return plugin.attach_to_gui(self, *args, **kwargs) # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 6551c281a..6666b1c32 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -153,10 +153,11 @@ def on_close_widget(): def test_gui_plugin(qtbot, gui): class TestPlugin(IPlugin): - def attach_gui(self, gui): + def attach_to_gui(self, gui): gui._attached = True + return 'attached' - gui.attach('testplugin') + assert gui.attach('testplugin') == 'attached' assert gui._attached diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 8e631317a..f73f1aa9b 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -35,7 +35,7 @@ def __init__(cls, name, bases, attrs): class IPlugin(object, metaclass=IPluginRegistry): - def attach_gui(self, gui): # pragma: no cover + def attach_to_gui(self, gui, *args, **kwargs): # pragma: no cover pass From ebdcd3e2c96bf0efcc346c36931a7c13e2536499 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 17:46:31 +0200 Subject: [PATCH 0202/1059] Flakify --- phy/gui/tests/test_actions.py | 1 - phy/gui/tests/test_gui.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index feedb8ea2..61328cdd1 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -9,7 +9,6 @@ from pytest import mark, raises, yield_fixture from ..actions import _show_shortcuts, Actions, Snippets, _parse_snippet -from phy.utils._color import _random_color from phy.utils.testing import captured_output, captured_logging # Skip these tests in "make test-quick". diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 6666b1c32..ec6043028 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,6 +6,9 @@ # Imports #------------------------------------------------------------------------------ +import os +from sys import platform + from pytest import mark, yield_fixture from ..qt import Qt @@ -54,7 +57,9 @@ def gui(): # Test actions and snippet #------------------------------------------------------------------------------ -def test_actions_gui(qtbot, gui, actions): # noqa +@skip_mac # noqa +@skip_ci +def test_actions_gui(qtbot, gui, actions): actions.attach(gui) # Set the default actions. @@ -78,7 +83,9 @@ def press(): qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) -def test_snippets_gui(qtbot, gui, actions, snippets): # noqa +@skip_mac # noqa +@skip_ci +def test_snippets_gui(qtbot, gui, actions, snippets): qtbot.addWidget(gui) gui.show() From 116bc8a7059c40cf6f425598861746a04ed55803 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 18:07:48 +0200 Subject: [PATCH 0203/1059] WIP: fixing tests --- phy/gui/tests/test_actions.py | 4 ++-- phy/utils/plugin.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 61328cdd1..f078e34fe 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -137,11 +137,11 @@ def test(arg): with captured_logging() as buf: snippets.run(':t') - assert 'missing 1 required positional argument' in buf.getvalue() + assert 'error' in buf.getvalue().lower() with captured_logging() as buf: snippets.run(':t 1 2') - assert 'takes 1 positional argument but 2 were given' in buf.getvalue() + assert 'error' in buf.getvalue().lower() with captured_logging() as buf: snippets.run(':t aa') diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index f73f1aa9b..f77ce1fa6 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -16,6 +16,8 @@ import os import os.path as op +from six import with_metaclass + logger = logging.getLogger(__name__) @@ -34,7 +36,7 @@ def __init__(cls, name, bases, attrs): IPluginRegistry.plugins.append(plugin_tuple) -class IPlugin(object, metaclass=IPluginRegistry): +class IPlugin(with_metaclass(IPluginRegistry)): def attach_to_gui(self, gui, *args, **kwargs): # pragma: no cover pass From 891cef04dcfc58d6b83bb0710b4ac7979fc9ece7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 18:50:28 +0200 Subject: [PATCH 0204/1059] WIP: update setup --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 44a65d897..28b2b3bd9 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def finalize_options(self): def run_tests(self): #import here, cause outside the eggs aren't loaded import pytest - pytest_string = '-s ' + self.pytest_args + pytest_string = self.pytest_args print("Running: py.test " + pytest_string) errno = pytest.main(pytest_string) sys.exit(errno) @@ -78,7 +78,6 @@ def _package_tree(pkgroot): }, entry_points={ 'console_scripts': [ - 'phy=phy.scripts.phy_script:main', ], }, include_package_data=True, From 49870defd8ce79eca94851ead9d3a545fe26bf41 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 2 Oct 2015 18:52:38 +0200 Subject: [PATCH 0205/1059] Remove test-quick --- Makefile | 22 ++-------------------- phy/gui/tests/test_actions.py | 5 +---- phy/gui/tests/test_gui.py | 3 --- phy/plot/tests/test_ccg.py | 6 ------ phy/plot/tests/test_features.py | 6 ------ phy/plot/tests/test_traces.py | 6 ------ phy/plot/tests/test_utils.py | 6 ------ phy/plot/tests/test_waveforms.py | 6 ------ phy/utils/tests/test_color.py | 6 ------ 9 files changed, 3 insertions(+), 63 deletions(-) diff --git a/Makefile b/Makefile index 2a1349eba..5059b040c 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,3 @@ -help: - @echo "clean - remove all build, test, coverage and Python artifacts" - @echo "clean-build - remove build artifacts" - @echo "clean-pyc - remove Python file artifacts" - @echo "lint - check style with flake8" - @echo "test - run tests quickly with the default Python" - @echo "release - package and upload a release" - @echo "apidoc - build API doc" - -clean: clean-build clean-pyc - clean-build: rm -fr build/ rm -fr dist/ @@ -20,6 +9,8 @@ clean-pyc: find . -name '*~' -exec rm -f {} + find . -name '__pycache__' -exec rm -fr {} + +clean: clean-build clean-pyc + lint: flake8 phy @@ -29,9 +20,6 @@ test: lint coverage: coverage --html -test-quick: lint - python setup.py test -a "-m \"not long\" phy" - unit-tests: lint python setup.py test -a phy @@ -46,9 +34,3 @@ build: upload: python setup.py sdist --formats=zip upload - -release-test: - python tools/release.py release_test - -release: - python tools/release.py release diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index f078e34fe..ceaec42af 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -6,14 +6,11 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark, raises, yield_fixture +from pytest import raises, yield_fixture from ..actions import _show_shortcuts, Actions, Snippets, _parse_snippet from phy.utils.testing import captured_output, captured_logging -# Skip these tests in "make test-quick". -pytestmark = mark.long - #------------------------------------------------------------------------------ # Utilities and fixtures diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index ec6043028..1c702be5b 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -17,9 +17,6 @@ from phy.utils.plugin import IPlugin from .test_actions import actions, snippets # noqa -# Skip these tests in "make test-quick". -pytestmark = mark.long - # Skip some tests on OS X or on CI systems (Travis). skip_mac = mark.skipif(platform == "darwin", reason="Some tests don't work on OS X because of a bug " diff --git a/phy/plot/tests/test_ccg.py b/phy/plot/tests/test_ccg.py index cfeaf2eb2..18c793b6c 100644 --- a/phy/plot/tests/test_ccg.py +++ b/phy/plot/tests/test_ccg.py @@ -6,8 +6,6 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - import numpy as np from ..ccg import _plot_ccg_mpl, CorrelogramView, plot_correlograms @@ -16,10 +14,6 @@ from ...utils.testing import show_test -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - #------------------------------------------------------------------------------ # Tests matplotlib #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_features.py b/phy/plot/tests/test_features.py index e5665a0ac..ff24c5de7 100644 --- a/phy/plot/tests/test_features.py +++ b/phy/plot/tests/test_features.py @@ -6,8 +6,6 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - import numpy as np from ..features import FeatureView, plot_features @@ -19,10 +17,6 @@ from ...utils.testing import show_test -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_traces.py b/phy/plot/tests/test_traces.py index 5525b1512..b8a8df834 100644 --- a/phy/plot/tests/test_traces.py +++ b/phy/plot/tests/test_traces.py @@ -6,8 +6,6 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - import numpy as np from ..traces import TraceView, plot_traces @@ -19,10 +17,6 @@ from ...utils.testing import show_test -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - #------------------------------------------------------------------------------ # Tests VisPy #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index f6fe6c985..1cd174476 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -6,8 +6,6 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - from vispy import app from ...utils.testing import show_test @@ -15,10 +13,6 @@ from .._panzoom import PanZoom, PanZoomGrid -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - #------------------------------------------------------------------------------ # Tests VisPy #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_waveforms.py b/phy/plot/tests/test_waveforms.py index ed2dcf2b9..1a2df6ed7 100644 --- a/phy/plot/tests/test_waveforms.py +++ b/phy/plot/tests/test_waveforms.py @@ -6,8 +6,6 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - import numpy as np from ..waveforms import WaveformView, plot_waveforms @@ -18,10 +16,6 @@ from ...utils.testing import show_test -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_color.py b/phy/utils/tests/test_color.py index ddbf88fc2..781616d72 100644 --- a/phy/utils/tests/test_color.py +++ b/phy/utils/tests/test_color.py @@ -6,18 +6,12 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark - from .._color import (_random_color, _is_bright, _random_bright_color, _selected_clusters_colors, ) from ..testing import show_colored_canvas -# Skip these tests in "make test-quick". -pytestmark = mark.long - - #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ From 5024118771a6c26e21b67a6d63cfb8f9f7389182 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 17:48:55 +0200 Subject: [PATCH 0206/1059] Minor updates --- phy/cluster/manual/wizard.py | 6 +++++- phy/gui/gui.py | 8 ++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index fc4894b5d..b60324a67 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -76,7 +76,11 @@ def _next(items, current, filter=None): #------------------------------------------------------------------------------ class Wizard(EventEmitter): - """Propose a selection of high-quality clusters and merge candidates.""" + """Propose a selection of high-quality clusters and merge candidates. + + The wizard is responsible for the selected clusters. + + """ def __init__(self): super(Wizard, self).__init__() self._similarity = None diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 0dad8cd53..ff15e0d4b 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -10,6 +10,8 @@ from collections import defaultdict import logging +from six import string_types + from .qt import QtCore, QtGui from phy.utils.event import EventEmitter from phy.utils.plugin import get_plugin @@ -85,9 +87,11 @@ def __init__(self, self._status_bar = QtGui.QStatusBar() self.setStatusBar(self._status_bar) - def attach(self, plugin_name, *args, **kwargs): + def attach(self, plugin, *args, **kwargs): """Attach a plugin to the GUI.""" - plugin = get_plugin(plugin_name)() + if isinstance(plugin, string_types): + # Instantiate the plugin if the name is given. + plugin = get_plugin(plugin)() return plugin.attach_to_gui(self, *args, **kwargs) # Events From b584638f9ef6794726fcd9ba0276aaadc4276f24 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 17:56:50 +0200 Subject: [PATCH 0207/1059] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index fc5dcadfe..a7c023ad6 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ experimental htmlcov format wiki +.cache .idea .ipynb_checkpoints .*fuse* From 832d40f5c8b3453a8b4238d58dc1f96ae9eb3a0f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 18:18:52 +0200 Subject: [PATCH 0208/1059] WIP: replace Actions.shortcut() by Actions.add() --- phy/gui/actions.py | 38 ++++++++++++++++++++++------------- phy/gui/tests/test_actions.py | 14 ++++++------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index d5a120cde..6d5a5769a 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -7,6 +7,7 @@ # Imports # ----------------------------------------------------------------------------- +from functools import partial import logging from six import string_types, PY3 @@ -118,20 +119,32 @@ def attach(self, gui): @self.connect def on_reset(): # Default exit action. - @self.shortcut('ctrl+q') + @self.add(shortcut='ctrl+q') def exit(): gui.close() - def add(self, name, callback=None, shortcut=None, alias=None, + def add(self, callback=None, name=None, shortcut=None, alias=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" + if callback is None: + # Allow to use either add(func) or @add or @add(...). + return partial(self.add, name=name, shortcut=shortcut, + alias=alias, checkable=checkable, checked=checked) + # TODO: add menu_name option and create menu bar + + # Get the name from the callback function if needed. + assert callback + name = name or callback.__name__ + # Get the alias from the character after & if it exists. if alias is None: alias = name[name.index('&') + 1] if '&' in name else name name = name.replace('&', '') if name in self._actions: return + + # Create the QAction instance. action = QtGui.QAction(name, self._gui) action.triggered.connect(callback) action.setCheckable(checkable) @@ -209,13 +222,6 @@ def show_shortcuts(self): _show_shortcuts(self.shortcuts, self._gui.title() if self._gui else None) - def shortcut(self, key=None, name=None, **kwargs): - """Decorator to add a global keyboard shortcut.""" - def wrap(func): - self.add(name or func.__name__, shortcut=key, - callback=func, **kwargs) - return wrap - # ----------------------------------------------------------------------------- # Snippets @@ -265,7 +271,7 @@ def attach(self, gui, actions): # Register snippet mode shortcut. @actions.connect def on_reset(): - @actions.shortcut(':') + @actions.add(shortcut=':') def enable_snippet_mode(): self.mode_on() @@ -321,14 +327,18 @@ def callback(): self.command += char return callback - self._actions.add('_snippet_{}'.format(i), shortcut=char, + self._actions.add(name='_snippet_{}'.format(i), + shortcut=char, callback=_make_func(char)) - self._actions.add('_snippet_backspace', shortcut='backspace', + self._actions.add(name='_snippet_backspace', + shortcut='backspace', callback=self._backspace) - self._actions.add('_snippet_activate', shortcut=('enter', 'return'), + self._actions.add(name='_snippet_activate', + shortcut=('enter', 'return'), callback=self._enter) - self._actions.add('_snippet_disable', shortcut='escape', + self._actions.add(name='_snippet_disable', + shortcut='escape', callback=self.mode_off) def run(self, snippet): diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index ceaec42af..db95bfb4f 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -47,14 +47,14 @@ def test_actions_simple(actions): def _action(*args): _res.append(args) - actions.add('tes&t', _action) + actions.add(_action, 'tes&t') # Adding an action twice has no effect. - actions.add('test', _action) + actions.add(_action, 'test') # Create a shortcut and display it. _captured = [] - @actions.shortcut('h') + @actions.add(shortcut='h') def show_my_shortcuts(): with captured_output() as (stdout, stderr): actions.show_shortcuts() @@ -119,7 +119,7 @@ def test_snippets_errors(actions, snippets): @actions.connect def on_reset(): - @actions.shortcut(name='my_test', alias='t') + @actions.add(name='my_test', alias='t') def test(arg): # Enforce single-character argument. assert len(str(arg)) == 1 @@ -154,15 +154,15 @@ def test_snippets_actions(actions, snippets): @actions.connect def on_reset(): - @actions.shortcut(name='my_test_1') + @actions.add(name='my_test_1') def test_1(*args): _actions.append((1, args)) - @actions.shortcut(name='my_&test_2') + @actions.add(name='my_&test_2') def test_2(*args): _actions.append((2, args)) - @actions.shortcut(name='my_test_3', alias='t3') + @actions.add(name='my_test_3', alias='t3') def test_3(*args): _actions.append((3, args)) From 312a404a3245c0201f27d667df918bd37d03d38c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 18:19:35 +0200 Subject: [PATCH 0209/1059] Update GUI tests with Actions.add() --- phy/gui/tests/test_gui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 1c702be5b..25b1ac83e 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -68,7 +68,7 @@ def test_actions_gui(qtbot, gui, actions): _press = [] - @actions.shortcut('ctrl+g') + @actions.add(shortcut='ctrl+g') def press(): _press.append(0) @@ -92,7 +92,7 @@ def test_snippets_gui(qtbot, gui, actions, snippets): @actions.connect def on_reset(): - @actions.shortcut(name='my_test_1', alias='t1') + @actions.add(name='my_test_1', alias='t1') def test(*args): _actions.append(args) From 6f40f134e90df873d38ab3e9ef2aada3351bc323 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 18:34:21 +0200 Subject: [PATCH 0210/1059] WIP: add manual clustering plugin --- phy/cluster/manual/__init__.py | 1 + phy/cluster/manual/gui_plugins.py | 119 +++++++++++++++++++ phy/cluster/manual/tests/test_gui_plugins.py | 25 ++++ 3 files changed, 145 insertions(+) create mode 100644 phy/cluster/manual/gui_plugins.py create mode 100644 phy/cluster/manual/tests/test_gui_plugins.py diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index c6fa9b085..9dd61ab42 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,3 +5,4 @@ from .clustering import Clustering from .wizard import Wizard +from .gui_plugins import ManualClustering diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py new file mode 100644 index 000000000..aa935f568 --- /dev/null +++ b/phy/cluster/manual/gui_plugins.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- + +"""Manual clustering GUI plugins.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +import logging + +from ._utils import ClusterMetadataUpdater +from .clustering import Clustering +from .wizard import Wizard +from phy.gui.actions import Actions, Snippets +from phy.io.array import select_spikes +from phy.utils.plugin import IPlugin + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Clustering GUI plugins +# ----------------------------------------------------------------------------- + +class ManualClustering(IPlugin): + def attach_to_gui(self, gui, + spike_clusters=None, + cluster_metadata=None, + n_spikes_max_per_cluster=100, + ): + # Create Clustering and ClusterMetadataUpdater. + clustering = Clustering(spike_clusters) + cluster_meta_up = ClusterMetadataUpdater(cluster_metadata) + + # Create the wizard and attach it to Clustering/ClusterMetadataUpdater. + wizard = Wizard() + wizard.attach(clustering, cluster_meta_up) + + @wizard.connect + def on_select(cluster_ids): + """When the wizard selects clusters, choose a spikes subset + and emit the `select` event on the GUI. + + The wizard is responsible for the notion of "selected clusters". + + """ + spike_ids = select_spikes(cluster_ids, + n_spikes_max_per_cluster, + clustering.spikes_per_cluster) + gui.emit('select', cluster_ids, spike_ids) + + self.create_actions(gui) + + def create_actions(self, gui): + actions = Actions() + snippets = Snippets() + + # Create the default actions for the clustering GUI. + @actions.connect + def on_reset(): + actions.add(alias='s', callback=self.select) + # TODO: other actions + + # Attach the GUI and register the actions. + snippets.attach(gui, actions) + actions.attach(gui) + actions.reset() + + def toggle_correlogram_normalization(self): + pass + + def toggle_waveforms_mean(self): + pass + + def toggle_waveforms_overlap(self): + pass + + def show_features_time(self): + pass + + def select(self, cluster_ids): + pass + + def reset_wizard(self): + pass + + def first(self): + pass + + def last(self): + pass + + def next(self): + pass + + def previous(self): + pass + + def pin(self): + pass + + def unpin(self): + pass + + def merge(self, cluster_ids=None): + pass + + def split(self, spike_ids=None): + pass + + def move(self, clusters, group): + pass + + def undo(self): + pass + + def redo(self): + pass diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py new file mode 100644 index 000000000..ad3461f10 --- /dev/null +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- + +"""Test GUI plugins.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +# from pytest + +from .test_wizard import clustering, cluster_metadata, wizard # noqa +from phy.gui.tests.test_gui import gui # noqa + + +#------------------------------------------------------------------------------ +# Test GUI plugins +#------------------------------------------------------------------------------ + +def test_manual_clustering(qtbot, gui, clustering, cluster_metadata): # noqa + # TODO: refactor these fixtures + sc = clustering.spike_clusters + gui.attach('ManualClustering', + spike_clusters=sc, + cluster_metadata=cluster_metadata._cluster_metadata, + ) From 5629f3bdc5adc28b5d022afe2c9f584d112e1ef0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 18:49:36 +0200 Subject: [PATCH 0211/1059] Refactor cluster.manual.tests fixtures --- phy/cluster/manual/gui_plugins.py | 25 +++++- phy/cluster/manual/tests/conftest.py | 67 ++++++++++++++++ phy/cluster/manual/tests/test_gui_plugins.py | 13 +--- phy/cluster/manual/tests/test_wizard.py | 81 ++++---------------- 4 files changed, 110 insertions(+), 76 deletions(-) create mode 100644 phy/cluster/manual/tests/conftest.py diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index aa935f568..39786c86b 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -9,7 +9,7 @@ import logging -from ._utils import ClusterMetadataUpdater +from ._utils import ClusterMetadata, ClusterMetadataUpdater from .clustering import Clustering from .wizard import Wizard from phy.gui.actions import Actions, Snippets @@ -19,6 +19,29 @@ logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +# Clustering objects +# ----------------------------------------------------------------------------- + +def create_cluster_metadata(data): + """Return a ClusterMetadata instance with cluster group support.""" + meta = ClusterMetadata(data=data) + + @meta.default + def group(cluster, ascendant_values=None): + if not ascendant_values: + return 3 + s = list(set(ascendant_values) - set([None, 3])) + # Return the default value if all ascendant values are the default. + if not s: # pragma: no cover + return 3 + # Otherwise, return good (2) if it is present, or the largest value + # among those present. + return max(s) + + return meta + + # ----------------------------------------------------------------------------- # Clustering GUI plugins # ----------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py new file mode 100644 index 000000000..589f2cc7b --- /dev/null +++ b/phy/cluster/manual/tests/conftest.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- + +"""Test wizard.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture + +from ..clustering import Clustering +from ..gui_plugins import create_cluster_metadata +from ..wizard import Wizard + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def cluster_ids(): + yield [2, 3, 5, 7] + + +@yield_fixture +def spike_clusters(): + yield [2, 3, 5, 7] + + +@yield_fixture +def clustering(spike_clusters): + yield Clustering(spike_clusters) + + +@yield_fixture +def cluster_metadata(): + data = {2: {'group': 3}, + 3: {'group': 3}, + 5: {'group': 1}, + 7: {'group': 2}, + } + + yield create_cluster_metadata(data) + + +@yield_fixture +def wizard(): + + def get_cluster_ids(): + return [2, 3, 5, 7] + + wizard = Wizard() + wizard.set_cluster_ids_function(get_cluster_ids) + + @wizard.set_status_function + def cluster_status(cluster): + return {2: None, 3: None, 5: 'ignored', 7: 'good'}.get(cluster, None) + + @wizard.set_quality_function + def quality(cluster): + return cluster * .1 + + @wizard.set_similarity_function + def similarity(cluster, other): + return 1. + quality(cluster) - quality(other) + + yield wizard diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index ad3461f10..34fe69d1e 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -6,20 +6,15 @@ # Imports #------------------------------------------------------------------------------ -# from pytest - -from .test_wizard import clustering, cluster_metadata, wizard # noqa -from phy.gui.tests.test_gui import gui # noqa +from phy.gui.tests.test_gui import gui #------------------------------------------------------------------------------ # Test GUI plugins #------------------------------------------------------------------------------ -def test_manual_clustering(qtbot, gui, clustering, cluster_metadata): # noqa - # TODO: refactor these fixtures - sc = clustering.spike_clusters +def test_manual_clustering(qtbot, gui, spike_clusters, cluster_metadata): gui.attach('ManualClustering', - spike_clusters=sc, - cluster_metadata=cluster_metadata._cluster_metadata, + spike_clusters=spike_clusters, + cluster_metadata=cluster_metadata, ) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 54c751293..ecc792a9e 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -9,8 +9,7 @@ from pytest import yield_fixture from numpy.testing import assert_array_equal as ae -from ..clustering import Clustering -from .._utils import ClusterMetadata, ClusterMetadataUpdater +from .._utils import ClusterMetadataUpdater from ..wizard import (_previous, _next, _find_first, @@ -23,58 +22,8 @@ #------------------------------------------------------------------------------ @yield_fixture -def wizard(): - - def get_cluster_ids(): - return [2, 3, 5, 7] - - wizard = Wizard() - wizard.set_cluster_ids_function(get_cluster_ids) - - @wizard.set_status_function - def cluster_status(cluster): - return {2: None, 3: None, 5: 'ignored', 7: 'good'}.get(cluster, None) - - @wizard.set_quality_function - def quality(cluster): - return cluster * .1 - - @wizard.set_similarity_function - def similarity(cluster, other): - return 1. + quality(cluster) - quality(other) - - yield wizard - - -@yield_fixture -def cluster_metadata(): - data = {2: {'group': 3}, - 3: {'group': 3}, - 5: {'group': 1}, - 7: {'group': 2}, - } - - base_meta = ClusterMetadata(data=data) - - @base_meta.default - def group(cluster, ascendant_values=None): - if not ascendant_values: - return 3 - s = list(set(ascendant_values) - set([None, 3])) - # Return the default value if all ascendant values are the default. - if not s: # pragma: no cover - return 3 - # Otherwise, return good (2) if it is present, or the largest value - # among those present. - return max(s) - - meta = ClusterMetadataUpdater(base_meta) - yield meta - - -@yield_fixture -def clustering(): - yield Clustering([2, 3, 5, 7]) +def cluster_meta_up(cluster_metadata): + yield ClusterMetadataUpdater(cluster_metadata) #------------------------------------------------------------------------------ @@ -219,9 +168,9 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 -def test_wizard_update_simple(wizard, clustering, cluster_metadata): +def test_wizard_update_simple(wizard, clustering, cluster_meta_up): # 2: none, 3: none, 5: ignored, 7: good - wizard.attach(clustering, cluster_metadata) + wizard.attach(clustering, cluster_meta_up) wizard.first() wizard.last() @@ -241,8 +190,8 @@ def test_wizard_update_simple(wizard, clustering, cluster_metadata): wizard.next_best() -def test_wizard_update_group(wizard, clustering, cluster_metadata): - wizard.attach(clustering, cluster_metadata) +def test_wizard_update_group(wizard, clustering, cluster_meta_up): + wizard.attach(clustering, cluster_meta_up) wizard.start() @@ -255,25 +204,25 @@ def _check_best_match(b, m): _check_best_match(3, 2) # Ignore the currently-pinned cluster. - cluster_metadata.set_group(3, 0) + cluster_meta_up.set_group(3, 0) _check_best_match(5, 2) # 2: none, 3: ignored, 5: ignored, 7: good # Ignore the current match and move to next. - cluster_metadata.set_group(2, 1) + cluster_meta_up.set_group(2, 1) _check_best_match(5, 7) # 2: ignored, 3: ignored, 5: ignored, 7: good - cluster_metadata.undo() + cluster_meta_up.undo() _check_best_match(5, 2) - cluster_metadata.redo() + cluster_meta_up.redo() _check_best_match(5, 7) -def test_wizard_update_clustering(wizard, clustering, cluster_metadata): +def test_wizard_update_clustering(wizard, clustering, cluster_meta_up): # 2: none, 3: none, 5: ignored, 7: good - wizard.attach(clustering, cluster_metadata) + wizard.attach(clustering, cluster_meta_up) wizard.start() def _check_best_match(b, m): @@ -286,7 +235,7 @@ def _check_best_match(b, m): wizard.pin() _check_best_match(2, 3) - cluster_metadata.set_group(2, 2) + cluster_meta_up.set_group(2, 2) wizard.selection = [2, 3] ################################ @@ -326,7 +275,7 @@ def _check_best_match(b, m): assert wizard.cluster_status(9) is None # Ignore a cluster. - cluster_metadata.set_group(9, 1) + cluster_meta_up.set_group(9, 1) assert wizard.cluster_status(9) == 'ignored' # Undo split. From 50b02d704c144d6d5abaa2fa92cf4c0b700c78d6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 18:50:54 +0200 Subject: [PATCH 0212/1059] Flakify --- phy/cluster/manual/tests/conftest.py | 5 ----- phy/cluster/manual/tests/test_gui_plugins.py | 5 +++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 589f2cc7b..d0ffffb9f 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -17,11 +17,6 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture -def cluster_ids(): - yield [2, 3, 5, 7] - - @yield_fixture def spike_clusters(): yield [2, 3, 5, 7] diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 34fe69d1e..e01cd96e6 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -6,14 +6,15 @@ # Imports #------------------------------------------------------------------------------ -from phy.gui.tests.test_gui import gui +from phy.gui.tests.test_gui import gui # noqa #------------------------------------------------------------------------------ # Test GUI plugins #------------------------------------------------------------------------------ -def test_manual_clustering(qtbot, gui, spike_clusters, cluster_metadata): +def test_manual_clustering(qtbot, gui, spike_clusters, # noqa + cluster_metadata): gui.attach('ManualClustering', spike_clusters=spike_clusters, cluster_metadata=cluster_metadata, From 3e7491cdc89bd763e3eb88db797b575c257e5070 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 18:57:41 +0200 Subject: [PATCH 0213/1059] WIP: tests of the manual clustering plugin --- phy/cluster/manual/gui_plugins.py | 15 +++++++++++---- phy/cluster/manual/tests/test_gui_plugins.py | 11 +++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 39786c86b..dbeaacebc 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -53,12 +53,13 @@ def attach_to_gui(self, gui, n_spikes_max_per_cluster=100, ): # Create Clustering and ClusterMetadataUpdater. - clustering = Clustering(spike_clusters) + self.clustering = Clustering(spike_clusters) + self.cluster_metadata = cluster_metadata cluster_meta_up = ClusterMetadataUpdater(cluster_metadata) # Create the wizard and attach it to Clustering/ClusterMetadataUpdater. wizard = Wizard() - wizard.attach(clustering, cluster_meta_up) + wizard.attach(self.clustering, cluster_meta_up) @wizard.connect def on_select(cluster_ids): @@ -70,11 +71,17 @@ def on_select(cluster_ids): """ spike_ids = select_spikes(cluster_ids, n_spikes_max_per_cluster, - clustering.spikes_per_cluster) + self.clustering.spikes_per_cluster) gui.emit('select', cluster_ids, spike_ids) self.create_actions(gui) + return self + + @property + def cluster_ids(self): + return self.clustering.cluster_ids + def create_actions(self, gui): actions = Actions() snippets = Snippets() @@ -82,7 +89,7 @@ def create_actions(self, gui): # Create the default actions for the clustering GUI. @actions.connect def on_reset(): - actions.add(alias='s', callback=self.select) + actions.add(callback=self.select, alias='s') # TODO: other actions # Attach the GUI and register the actions. diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index e01cd96e6..ad94e51db 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +from numpy.testing import assert_array_equal as ae + from phy.gui.tests.test_gui import gui # noqa @@ -15,7 +17,8 @@ def test_manual_clustering(qtbot, gui, spike_clusters, # noqa cluster_metadata): - gui.attach('ManualClustering', - spike_clusters=spike_clusters, - cluster_metadata=cluster_metadata, - ) + mc = gui.attach('ManualClustering', + spike_clusters=spike_clusters, + cluster_metadata=cluster_metadata, + ) + ae(mc.cluster_ids, [2, 3, 5, 7]) From b85d9d645bf1cb2d24036e53d6e8fc06afa2a11b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 19:06:08 +0200 Subject: [PATCH 0214/1059] WIP: tests of the manual clustering plugin --- phy/cluster/manual/gui_plugins.py | 53 ++++++++++++++------ phy/cluster/manual/tests/test_gui_plugins.py | 11 ++++ 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index dbeaacebc..5ee1c131b 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -47,6 +47,20 @@ def group(cluster, ascendant_values=None): # ----------------------------------------------------------------------------- class ManualClustering(IPlugin): + """Plugin that brings manual clustering facilities to a GUI: + + * Clustering instance: merge, split, undo, redo + * ClusterMetadataUpdater instance: change cluster metadata (e.g. group) + * Wizard + * Selection + * Many manual clustering-related actions, snippets, shortcuts, etc. + + Bring the `select` event to the GUI. This is raised when clusters are + selected by the user or by the wizard. + + Other plugins can connect to that event. + + """ def attach_to_gui(self, gui, spike_clusters=None, cluster_metadata=None, @@ -58,10 +72,10 @@ def attach_to_gui(self, gui, cluster_meta_up = ClusterMetadataUpdater(cluster_metadata) # Create the wizard and attach it to Clustering/ClusterMetadataUpdater. - wizard = Wizard() - wizard.attach(self.clustering, cluster_meta_up) + self.wizard = Wizard() + self.wizard.attach(self.clustering, cluster_meta_up) - @wizard.connect + @self.wizard.connect def on_select(cluster_ids): """When the wizard selects clusters, choose a spikes subset and emit the `select` event on the GUI. @@ -97,20 +111,11 @@ def on_reset(): actions.attach(gui) actions.reset() - def toggle_correlogram_normalization(self): - pass - - def toggle_waveforms_mean(self): - pass - - def toggle_waveforms_overlap(self): - pass - - def show_features_time(self): - pass + # Wizard-related actions + # ------------------------------------------------------------------------- def select(self, cluster_ids): - pass + self.wizard.selection = cluster_ids def reset_wizard(self): pass @@ -133,6 +138,9 @@ def pin(self): def unpin(self): pass + # Clustering actions + # ------------------------------------------------------------------------- + def merge(self, cluster_ids=None): pass @@ -147,3 +155,18 @@ def undo(self): def redo(self): pass + + # View-related actions + # ------------------------------------------------------------------------- + + def toggle_correlogram_normalization(self): + pass + + def toggle_waveforms_mean(self): + pass + + def toggle_waveforms_overlap(self): + pass + + def show_features_time(self): + pass diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index ad94e51db..5b9a04673 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -22,3 +22,14 @@ def test_manual_clustering(qtbot, gui, spike_clusters, # noqa cluster_metadata=cluster_metadata, ) ae(mc.cluster_ids, [2, 3, 5, 7]) + + # Connect to the `select` event. + _s = [] + + @gui.connect_ + def on_select(cluster_ids, spike_ids): + _s.append((cluster_ids, spike_ids)) + + mc.select([]) + ae(_s[-1][0], []) + ae(_s[-1][1], []) From 239af845467c09570649109658669349a64a0f80 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 19:15:38 +0200 Subject: [PATCH 0215/1059] WIP: wizard actions in manual clustering GUI --- phy/cluster/manual/gui_plugins.py | 40 ++++++++++++-------- phy/cluster/manual/tests/conftest.py | 9 +++-- phy/cluster/manual/tests/test_gui_plugins.py | 11 +++++- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 5ee1c131b..9bff7e3d4 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -103,7 +103,14 @@ def create_actions(self, gui): # Create the default actions for the clustering GUI. @actions.connect def on_reset(): - actions.add(callback=self.select, alias='s') + actions.add(callback=self.select, alias='c') + actions.add(callback=self.wizard.start, name='reset_wizard') + actions.add(callback=self.wizard.first) + actions.add(callback=self.wizard.last) + actions.add(callback=self.wizard.previous) + actions.add(callback=self.wizard.next) + actions.add(callback=self.wizard.pin) + actions.add(callback=self.wizard.unpin) # TODO: other actions # Attach the GUI and register the actions. @@ -111,32 +118,35 @@ def on_reset(): actions.attach(gui) actions.reset() + self.actions = actions + self.snippets = snippets + # Wizard-related actions # ------------------------------------------------------------------------- def select(self, cluster_ids): self.wizard.selection = cluster_ids - def reset_wizard(self): - pass + # def reset_wizard(self): + # self.wizard.start() - def first(self): - pass + # def first(self): + # self.wizard.first() - def last(self): - pass + # def last(self): + # self.wizard.last() - def next(self): - pass + # def next(self): + # self.wizard.next() - def previous(self): - pass + # def previous(self): + # self.wizard.previous() - def pin(self): - pass + # def pin(self): + # self.wizard.pin() - def unpin(self): - pass + # def unpin(self): + # self.wizard.unpin() # Clustering actions # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index d0ffffb9f..eb7e3558d 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -38,13 +38,11 @@ def cluster_metadata(): yield create_cluster_metadata(data) -@yield_fixture -def wizard(): +def _set_test_wizard(wizard): def get_cluster_ids(): return [2, 3, 5, 7] - wizard = Wizard() wizard.set_cluster_ids_function(get_cluster_ids) @wizard.set_status_function @@ -59,4 +57,9 @@ def quality(cluster): def similarity(cluster, other): return 1. + quality(cluster) - quality(other) + +@yield_fixture +def wizard(): + wizard = Wizard() + _set_test_wizard(wizard) yield wizard diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 5b9a04673..1c315fe43 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -8,6 +8,7 @@ from numpy.testing import assert_array_equal as ae +from .conftest import _set_test_wizard from phy.gui.tests.test_gui import gui # noqa @@ -21,6 +22,10 @@ def test_manual_clustering(qtbot, gui, spike_clusters, # noqa spike_clusters=spike_clusters, cluster_metadata=cluster_metadata, ) + _set_test_wizard(mc.wizard) + actions = mc.actions + + # Test cluster ids. ae(mc.cluster_ids, [2, 3, 5, 7]) # Connect to the `select` event. @@ -30,6 +35,10 @@ def test_manual_clustering(qtbot, gui, spike_clusters, # noqa def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) - mc.select([]) + # Test select actions. + actions.select([]) ae(_s[-1][0], []) ae(_s[-1][1], []) + + # Test wizard actions. + actions.reset_wizard() From 662a17f8bbcf07a0684d06a0729d0f07133e0d5b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 19:22:47 +0200 Subject: [PATCH 0216/1059] WIP: wizard actions in manual clustering GUI --- phy/cluster/manual/gui_plugins.py | 2 + phy/cluster/manual/tests/test_gui_plugins.py | 39 +++++++++++++++++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 9bff7e3d4..a8442cd4f 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -66,6 +66,8 @@ def attach_to_gui(self, gui, cluster_metadata=None, n_spikes_max_per_cluster=100, ): + self.gui = gui + # Create Clustering and ClusterMetadataUpdater. self.clustering = Clustering(spike_clusters) self.cluster_metadata = cluster_metadata diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 1c315fe43..5b834e5fb 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +from pytest import yield_fixture from numpy.testing import assert_array_equal as ae from .conftest import _set_test_wizard @@ -16,22 +17,28 @@ # Test GUI plugins #------------------------------------------------------------------------------ -def test_manual_clustering(qtbot, gui, spike_clusters, # noqa - cluster_metadata): +@yield_fixture +def manual_clustering(qtbot, gui, spike_clusters, # noqa + cluster_metadata): mc = gui.attach('ManualClustering', spike_clusters=spike_clusters, cluster_metadata=cluster_metadata, ) _set_test_wizard(mc.wizard) - actions = mc.actions + yield mc + + +def test_manual_clustering(manual_clustering): + actions = manual_clustering.actions + wizard = manual_clustering.wizard # Test cluster ids. - ae(mc.cluster_ids, [2, 3, 5, 7]) + ae(manual_clustering.cluster_ids, [2, 3, 5, 7]) # Connect to the `select` event. _s = [] - @gui.connect_ + @manual_clustering.gui.connect_ def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) @@ -42,3 +49,25 @@ def on_select(cluster_ids, spike_ids): # Test wizard actions. actions.reset_wizard() + assert wizard.best_list == [3, 2, 7, 5] + assert wizard.best == 3 + + actions.next() + assert wizard.best == 2 + + actions.last() + assert wizard.best == 5 + + actions.previous() + assert wizard.best == 7 + + actions.first() + assert wizard.best == 3 + + # Test pinning. + actions.pin() + assert wizard.match_list == [2, 7, 5] + assert wizard.match == 2 + wizard.next() + assert wizard.match == 7 + assert len(_s) == 9 From 4b8fa58206f0c9a39f695df896eb7098b0e29a56 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 19:28:19 +0200 Subject: [PATCH 0217/1059] WIP: test wizard actions in manual clustering plugin --- phy/cluster/manual/gui_plugins.py | 25 +++----------------- phy/cluster/manual/tests/test_gui_plugins.py | 14 +++++++++-- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index a8442cd4f..b48cc2e82 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -129,32 +129,13 @@ def on_reset(): def select(self, cluster_ids): self.wizard.selection = cluster_ids - # def reset_wizard(self): - # self.wizard.start() - - # def first(self): - # self.wizard.first() - - # def last(self): - # self.wizard.last() - - # def next(self): - # self.wizard.next() - - # def previous(self): - # self.wizard.previous() - - # def pin(self): - # self.wizard.pin() - - # def unpin(self): - # self.wizard.unpin() - # Clustering actions # ------------------------------------------------------------------------- def merge(self, cluster_ids=None): - pass + if cluster_ids is None: + cluster_ids = self.wizard.selection + self.clustering.merge(cluster_ids) def split(self, spike_ids=None): pass diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 5b834e5fb..2cf03501c 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -42,32 +42,42 @@ def test_manual_clustering(manual_clustering): def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) + def _assert_selection(*cluster_ids): + assert _s[-1][0] == list(cluster_ids) + # Test select actions. actions.select([]) - ae(_s[-1][0], []) - ae(_s[-1][1], []) + _assert_selection() # Test wizard actions. actions.reset_wizard() assert wizard.best_list == [3, 2, 7, 5] assert wizard.best == 3 + _assert_selection(3) actions.next() assert wizard.best == 2 + _assert_selection(2) actions.last() assert wizard.best == 5 + _assert_selection(5) actions.previous() assert wizard.best == 7 + _assert_selection(7) actions.first() assert wizard.best == 3 + _assert_selection(3) # Test pinning. actions.pin() assert wizard.match_list == [2, 7, 5] assert wizard.match == 2 + _assert_selection(3, 2) + wizard.next() assert wizard.match == 7 assert len(_s) == 9 + _assert_selection(3, 7) From 8308cef557ecbe2d2a6cae9375439aa0d95c2089 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 21:41:35 +0200 Subject: [PATCH 0218/1059] WIP: add clustering action tests --- phy/cluster/manual/gui_plugins.py | 16 ++++-- phy/cluster/manual/tests/conftest.py | 20 ++++---- phy/cluster/manual/tests/test_gui_plugins.py | 54 ++++++++++++++------ 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index b48cc2e82..0d9c0672e 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -105,7 +105,10 @@ def create_actions(self, gui): # Create the default actions for the clustering GUI. @actions.connect def on_reset(): + # Selection. actions.add(callback=self.select, alias='c') + + # Wizard. actions.add(callback=self.wizard.start, name='reset_wizard') actions.add(callback=self.wizard.first) actions.add(callback=self.wizard.last) @@ -113,7 +116,13 @@ def on_reset(): actions.add(callback=self.wizard.next) actions.add(callback=self.wizard.pin) actions.add(callback=self.wizard.unpin) - # TODO: other actions + + # Clustering. + actions.add(callback=self.merge) + actions.add(callback=self.split) + actions.add(callback=self.move) + actions.add(callback=self.undo) + actions.add(callback=self.redo) # Attach the GUI and register the actions. snippets.attach(gui, actions) @@ -137,10 +146,11 @@ def merge(self, cluster_ids=None): cluster_ids = self.wizard.selection self.clustering.merge(cluster_ids) - def split(self, spike_ids=None): - pass + def split(self, spike_ids): + self.clustering.split(spike_ids) def move(self, clusters, group): + # TODO pass def undo(self): diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index eb7e3558d..697ff5ae5 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -40,15 +40,6 @@ def cluster_metadata(): def _set_test_wizard(wizard): - def get_cluster_ids(): - return [2, 3, 5, 7] - - wizard.set_cluster_ids_function(get_cluster_ids) - - @wizard.set_status_function - def cluster_status(cluster): - return {2: None, 3: None, 5: 'ignored', 7: 'good'}.get(cluster, None) - @wizard.set_quality_function def quality(cluster): return cluster * .1 @@ -61,5 +52,16 @@ def similarity(cluster, other): @yield_fixture def wizard(): wizard = Wizard() + + def get_cluster_ids(): + return [2, 3, 5, 7] + + wizard.set_cluster_ids_function(get_cluster_ids) + + @wizard.set_status_function + def cluster_status(cluster): + return {2: None, 3: None, 5: 'ignored', 7: 'good'}.get(cluster, None) + _set_test_wizard(wizard) + yield wizard diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 2cf03501c..534cd5f87 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -17,34 +17,37 @@ # Test GUI plugins #------------------------------------------------------------------------------ -@yield_fixture -def manual_clustering(qtbot, gui, spike_clusters, # noqa - cluster_metadata): +@yield_fixture # noqa +def manual_clustering(qtbot, gui, spike_clusters, cluster_metadata): mc = gui.attach('ManualClustering', spike_clusters=spike_clusters, cluster_metadata=cluster_metadata, ) _set_test_wizard(mc.wizard) - yield mc - - -def test_manual_clustering(manual_clustering): - actions = manual_clustering.actions - wizard = manual_clustering.wizard - # Test cluster ids. - ae(manual_clustering.cluster_ids, [2, 3, 5, 7]) - - # Connect to the `select` event. _s = [] - @manual_clustering.gui.connect_ + # Connect to the `select` event. + @mc.gui.connect_ def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) def _assert_selection(*cluster_ids): assert _s[-1][0] == list(cluster_ids) + mc._assert_selection = _assert_selection + + yield mc + + +def test_manual_clustering_wizard(manual_clustering): + actions = manual_clustering.actions + wizard = manual_clustering.wizard + _assert_selection = manual_clustering._assert_selection + + # Test cluster ids. + ae(manual_clustering.cluster_ids, [2, 3, 5, 7]) + # Test select actions. actions.select([]) _assert_selection() @@ -63,6 +66,10 @@ def _assert_selection(*cluster_ids): assert wizard.best == 5 _assert_selection(5) + actions.next() + assert wizard.best == 5 + _assert_selection(5) + actions.previous() assert wizard.best == 7 _assert_selection(7) @@ -71,6 +78,10 @@ def _assert_selection(*cluster_ids): assert wizard.best == 3 _assert_selection(3) + actions.previous() + assert wizard.best == 3 + _assert_selection(3) + # Test pinning. actions.pin() assert wizard.match_list == [2, 7, 5] @@ -79,5 +90,18 @@ def _assert_selection(*cluster_ids): wizard.next() assert wizard.match == 7 - assert len(_s) == 9 _assert_selection(3, 7) + + wizard.unpin() + _assert_selection(3) + + +def test_manual_clustering_actions(manual_clustering): + actions = manual_clustering.actions + # wizard = manual_clustering.wizard + _assert_selection = manual_clustering._assert_selection + + actions.reset_wizard() + actions.pin() + _assert_selection(3, 2) + actions.merge() From 209411750e62501086262b5086375e3f76c39323 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 21:53:17 +0200 Subject: [PATCH 0219/1059] More tests --- phy/cluster/manual/gui_plugins.py | 4 ++-- phy/cluster/manual/tests/test_gui_plugins.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 0d9c0672e..f96cc6bd2 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -154,10 +154,10 @@ def move(self, clusters, group): pass def undo(self): - pass + self.clustering.undo() def redo(self): - pass + self.clustering.redo() # View-related actions # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 534cd5f87..a4ffb1a6e 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -98,10 +98,23 @@ def test_manual_clustering_wizard(manual_clustering): def test_manual_clustering_actions(manual_clustering): actions = manual_clustering.actions - # wizard = manual_clustering.wizard + wizard = manual_clustering.wizard _assert_selection = manual_clustering._assert_selection + # [3 , 2 , 7 , 5] + # [None, None, 'ignored', 'good'] actions.reset_wizard() actions.pin() _assert_selection(3, 2) - actions.merge() + + actions.merge() # 3 + 2 => 8 + # [8, 7, 5] + _assert_selection(8, 7) + + wizard.next() + _assert_selection(8, 5) + + actions.undo() + _assert_selection(3, 2) + + # TODO: more tests, notably with group actions and wizard From efdfc400af6de343d32da845ca034becb97e1a60 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 5 Oct 2015 21:58:51 +0200 Subject: [PATCH 0220/1059] Increase coverage --- phy/cluster/manual/gui_plugins.py | 16 +--------------- phy/cluster/manual/tests/test_gui_plugins.py | 7 +++++++ 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index f96cc6bd2..eb158f231 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -147,6 +147,7 @@ def merge(self, cluster_ids=None): self.clustering.merge(cluster_ids) def split(self, spike_ids): + # TODO: connect to request_split emitted by view self.clustering.split(spike_ids) def move(self, clusters, group): @@ -158,18 +159,3 @@ def undo(self): def redo(self): self.clustering.redo() - - # View-related actions - # ------------------------------------------------------------------------- - - def toggle_correlogram_normalization(self): - pass - - def toggle_waveforms_mean(self): - pass - - def toggle_waveforms_overlap(self): - pass - - def show_features_time(self): - pass diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index a4ffb1a6e..2a8b3924e 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -117,4 +117,11 @@ def test_manual_clustering_actions(manual_clustering): actions.undo() _assert_selection(3, 2) + actions.redo() + _assert_selection(8, 7) + + actions.split([2, 3]) # => 9 + _assert_selection(9, 8) + # TODO: more tests, notably with group actions and wizard + actions.move([9], 'good') From e525bb919a4fa6ee586454e18b15bb26aad3ae5b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 15:59:12 +0200 Subject: [PATCH 0221/1059] Refactor ClusterMeta --- phy/cluster/manual/_utils.py | 186 ++++++------------- phy/cluster/manual/gui_plugins.py | 30 +-- phy/cluster/manual/tests/conftest.py | 6 +- phy/cluster/manual/tests/test_gui_plugins.py | 4 +- phy/cluster/manual/tests/test_utils.py | 43 ++--- phy/cluster/manual/tests/test_wizard.py | 35 ++-- phy/cluster/manual/wizard.py | 10 +- 7 files changed, 116 insertions(+), 198 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index c93742286..590b99586 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -81,88 +81,71 @@ def __repr__(self): # ClusterMetadataUpdater class #------------------------------------------------------------------------------ -class ClusterMetadata(object): - """Hold cluster metadata. +class ClusterMeta(EventEmitter): + """Handle cluster metadata changes.""" + def __init__(self): + super(ClusterMeta, self).__init__() + self._fields = {} + self._reset_data() - Features - -------- + def _reset_data(self): + self._data = {} + self._data_base = {} + # The stack contains (clusters, field, value, update_info, undo_state) + # tuples. + self._undo_stack = History((None, None, None, None, None)) - * New metadata fields can be easily registered - * Arbitrary functions can be used for default values + @property + def fields(self): + return sorted(self._fields.keys()) - Notes - ---- + def add_field(self, name, default_value=None, ascendants_func=None): + self._fields[name] = (default_value, ascendants_func) - If a metadata field `group` is registered, then two methods are - dynamically created: + def func(cluster): + return self.get(name, cluster) - * `group(cluster)` returns the group of a cluster, or the default value - if the cluster doesn't exist. - * `set_group(cluster, value)` sets a value for the `group` metadata field. + setattr(self, name, func) - """ - def __init__(self, data=None): - self._fields = {} - self._data = defaultdict(dict) - # Fill the existing values. - if data is not None: - self._data.update(data) + def from_dict(self, dic): + self._reset_data() + for cluster, vals in dic.items(): + for field, value in vals.items(): + self.set(field, [cluster], value, add_to_stack=False) + self._data_base = deepcopy(self._data) - @property - def data(self): - return self._data + def set(self, field, clusters, value, add_to_stack=True): + assert field in self._fields - @property - def fields(self): - return sorted(self._fields) - - def _get_one(self, cluster, field): - """Return the field value for a cluster, or the default value if it - doesn't exist.""" - if cluster in self._data: - if field in self._data[cluster]: - return self._data[cluster][field] - elif field in self._fields: - # Call the default field function. - return self._fields[field](cluster) - else: - if field in self._fields: - return self._fields[field](cluster) - - def _get(self, clusters, field): - if _is_list(clusters): - return [self._get_one(cluster, field) - for cluster in _as_list(clusters)] - else: - return self._get_one(clusters, field) - - def _set_one(self, cluster, field, value): - """Set a field value for a cluster.""" - self._data[cluster][field] = value - - def _set(self, clusters, field, value): clusters = _as_list(clusters) for cluster in clusters: - self._set_one(cluster, field, value) - - def default(self, func): - """Register a new metadata field with a function - returning the default value of a cluster.""" - field = func.__name__ - # Register the decorated function as the default field function. - self._fields[field] = func - # Create self.(clusters). - setattr(self, field, lambda clusters: self._get(clusters, field)) - # Create self.set_(clusters, value). - setattr(self, 'set_{0:s}'.format(field), - lambda clusters, value: self._set(clusters, field, value)) - return func + if cluster not in self._data: + self._data[cluster] = {} + self._data[cluster][field] = value + + up = UpdateInfo(description='metadata_' + field, + metadata_changed=clusters, + metadata_value=value, + ) + undo_state = self.emit('request_undo_state', up) + + if add_to_stack: + self._undo_stack.add((clusters, field, value, up, undo_state)) + self.emit('cluster', up) + + return up + + def get(self, field, cluster): + if _is_list(cluster): + return [self.get(field, c) for c in cluster] + assert field in self._fields + default = self._fields[field][0] + return self._data.get(cluster, {}).get(field, default) def set_from_descendants(self, descendants): """Update metadata of some clusters given the metadata of their ascendants.""" - fields = self.fields - for field in fields: + for field in self.fields: # For each new cluster, a set of metadata values of their # ascendants. candidates = defaultdict(set) @@ -171,66 +154,19 @@ def set_from_descendants(self, descendants): for new, vals in candidates.items(): # Skip that new cluster if its value is already non-default. - current_val = self._get_one(new, field) - default_val = self._fields[field](new) + current_val = self.get(field, new) + default_val = self._fields[field][0] if current_val != default_val: continue # Ask the field the value of the new cluster, # as a function of the values of its ascendants. This is # encoded in the specified default function. - new_val = self._fields[field](new, list(vals)) - if new_val is not None: - self._set_one(new, field, new_val) - - -class ClusterMetadataUpdater(EventEmitter): - """Handle cluster metadata changes.""" - def __init__(self, cluster_metadata): - super(ClusterMetadataUpdater, self).__init__() - self._cluster_metadata = cluster_metadata - # Keep a deep copy of the original structure for the undo stack. - self._data_base = deepcopy(cluster_metadata.data) - # The stack contains (clusters, field, value, update_info, undo_state) - # tuples. - self._undo_stack = History((None, None, None, None, None)) - - for field, func in self._cluster_metadata._fields.items(): - - # Create self.(clusters). - def _make_get(field): - def f(clusters): - return self._cluster_metadata._get(clusters, field) - return f - setattr(self, field, _make_get(field)) - - # Create self.set_(clusters, value). - def _make_set(field): - def f(clusters, value): - return self._set(clusters, field, value) - return f - setattr(self, 'set_{0:s}'.format(field), _make_set(field)) - - def _set(self, clusters, field, value, add_to_stack=True): - - self._cluster_metadata._set(clusters, field, value) - clusters = _as_list(clusters) - up = UpdateInfo(description='metadata_' + field, - metadata_changed=clusters, - metadata_value=value, - ) - undo_state = self.emit('request_undo_state', up) - - if add_to_stack: - self._undo_stack.add((clusters, field, value, up, undo_state)) - self.emit('cluster', up) - - return up - - def set_from_descendants(self, descendants): - """Update metadata of some clusters given the metadata of their - ascendants.""" - self._cluster_metadata.set_from_descendants(descendants) + func = self._fields[field][1] + if func: + new_val = func(list(vals)) + if new_val is not None: + self.set(field, [new], new_val) def undo(self): """Undo the last metadata change. @@ -244,10 +180,10 @@ def undo(self): args = self._undo_stack.back() if args is None: return - self._cluster_metadata._data = deepcopy(self._data_base) + self._data = deepcopy(self._data_base) for clusters, field, value, up, undo_state in self._undo_stack: if clusters is not None: - self._set(clusters, field, value, add_to_stack=False) + self.set(field, clusters, value, add_to_stack=False) # Return the UpdateInfo instance of the undo action. up, undo_state = args[-2:] @@ -269,7 +205,7 @@ def redo(self): if args is None: return clusters, field, value, up, undo_state = args - self._set(clusters, field, value, add_to_stack=False) + self.set(field, clusters, value, add_to_stack=False) # Return the UpdateInfo instance of the redo action. up.history = 'redo' diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index eb158f231..40eecf8b4 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -9,7 +9,7 @@ import logging -from ._utils import ClusterMetadata, ClusterMetadataUpdater +from ._utils import ClusterMeta from .clustering import Clustering from .wizard import Wizard from phy.gui.actions import Actions, Snippets @@ -23,13 +23,12 @@ # Clustering objects # ----------------------------------------------------------------------------- -def create_cluster_metadata(data): - """Return a ClusterMetadata instance with cluster group support.""" - meta = ClusterMetadata(data=data) +def create_cluster_meta(data): + """Return a ClusterMeta instance with cluster group support.""" + meta = ClusterMeta() - @meta.default - def group(cluster, ascendant_values=None): - if not ascendant_values: + def group(ascendant_values=None): + if not ascendant_values: # pragma: no cover return 3 s = list(set(ascendant_values) - set([None, 3])) # Return the default value if all ascendant values are the default. @@ -39,6 +38,10 @@ def group(cluster, ascendant_values=None): # among those present. return max(s) + meta.add_field('group', 3, group) + + meta.from_dict(data) + return meta @@ -50,7 +53,7 @@ class ManualClustering(IPlugin): """Plugin that brings manual clustering facilities to a GUI: * Clustering instance: merge, split, undo, redo - * ClusterMetadataUpdater instance: change cluster metadata (e.g. group) + * ClusterMeta instance: change cluster metadata (e.g. group) * Wizard * Selection * Many manual clustering-related actions, snippets, shortcuts, etc. @@ -63,19 +66,18 @@ class ManualClustering(IPlugin): """ def attach_to_gui(self, gui, spike_clusters=None, - cluster_metadata=None, + cluster_meta=None, n_spikes_max_per_cluster=100, ): self.gui = gui - # Create Clustering and ClusterMetadataUpdater. + # Create Clustering and ClusterMeta. self.clustering = Clustering(spike_clusters) - self.cluster_metadata = cluster_metadata - cluster_meta_up = ClusterMetadataUpdater(cluster_metadata) + self.cluster_meta = cluster_meta - # Create the wizard and attach it to Clustering/ClusterMetadataUpdater. + # Create the wizard and attach it to Clustering/ClusterMeta. self.wizard = Wizard() - self.wizard.attach(self.clustering, cluster_meta_up) + self.wizard.attach(self.clustering, cluster_meta) @self.wizard.connect def on_select(cluster_ids): diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 697ff5ae5..71b03093c 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -9,7 +9,7 @@ from pytest import yield_fixture from ..clustering import Clustering -from ..gui_plugins import create_cluster_metadata +from ..gui_plugins import create_cluster_meta from ..wizard import Wizard @@ -28,14 +28,14 @@ def clustering(spike_clusters): @yield_fixture -def cluster_metadata(): +def cluster_meta(): data = {2: {'group': 3}, 3: {'group': 3}, 5: {'group': 1}, 7: {'group': 2}, } - yield create_cluster_metadata(data) + yield create_cluster_meta(data) def _set_test_wizard(wizard): diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 2a8b3924e..8389f2dfa 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -18,10 +18,10 @@ #------------------------------------------------------------------------------ @yield_fixture # noqa -def manual_clustering(qtbot, gui, spike_clusters, cluster_metadata): +def manual_clustering(qtbot, gui, spike_clusters, cluster_meta): mc = gui.attach('ManualClustering', spike_clusters=spike_clusters, - cluster_metadata=cluster_metadata, + cluster_meta=cluster_meta, ) _set_test_wizard(mc.wizard) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 6c20c55a3..a6c4e1935 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -8,9 +8,7 @@ import logging -from .._utils import (ClusterMetadata, ClusterMetadataUpdater, UpdateInfo, - _update_cluster_selection, - ) +from .._utils import ClusterMeta, UpdateInfo, _update_cluster_selection logger = logging.getLogger(__name__) @@ -22,22 +20,15 @@ def test_metadata_history(): """Test ClusterMetadataUpdater history.""" - data = {2: {'group': 2, 'color': 7}, 4: {'group': 5}} - - base_meta = ClusterMetadata(data=data) - - @base_meta.default - def group(cluster): - return 3 - - @base_meta.default - def color(cluster): - return 0 + meta = ClusterMeta() + meta.add_field('group', 3) + meta.add_field('color', 0) - assert base_meta.group(2) == 2 - assert base_meta.group([4, 2]) == [5, 2] + data = {2: {'group': 2, 'color': 7}, 4: {'group': 5}} + meta.from_dict(data) - meta = ClusterMetadataUpdater(base_meta) + assert meta.group(2) == 2 + assert meta.group([4, 2]) == [5, 2] # Values set in 'data'. assert meta.group(2) == 2 @@ -56,19 +47,19 @@ def color(cluster): meta.redo() # Action 1. - info = meta.set_group(2, 20) + info = meta.set('group', 2, 20) assert meta.group(2) == 20 assert info.description == 'metadata_group' assert info.metadata_changed == [2] # Action 2. - info = meta.set_color(3, 30) + info = meta.set('color', 3, 30) assert meta.color(3) == 30 assert info.description == 'metadata_color' assert info.metadata_changed == [3] # Action 3. - info = meta.set_color(2, 40) + info = meta.set('color', 2, 40) assert meta.color(2) == 40 assert info.description == 'metadata_color' assert info.metadata_changed == [2] @@ -121,12 +112,9 @@ def test_metadata_descendants(): 3: {'group': 3}, } - meta = ClusterMetadata(data=data) + meta = ClusterMeta() - @meta.default - def group(cluster, ascendant_values=None): - if not ascendant_values: - return 3 + def group(ascendant_values): s = list(set(ascendant_values) - set([None, 3])) # Return the default value if all ascendant values are the default. if not s: # pragma: no cover @@ -135,6 +123,9 @@ def group(cluster, ascendant_values=None): # among those present. return max(s) + meta.add_field('group', 3, group) + meta.from_dict(data) + meta.set_from_descendants([]) assert meta.group(4) == 3 @@ -142,7 +133,7 @@ def group(cluster, ascendant_values=None): assert meta.group(4) == 0 # Reset to default. - meta.set_group(4, 3) + meta.set('group', 4, 3) meta.set_from_descendants([(1, 4)]) assert meta.group(4) == 1 diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index ecc792a9e..5bacc5384 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -6,10 +6,8 @@ # Imports #------------------------------------------------------------------------------ -from pytest import yield_fixture from numpy.testing import assert_array_equal as ae -from .._utils import ClusterMetadataUpdater from ..wizard import (_previous, _next, _find_first, @@ -17,15 +15,6 @@ ) -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -@yield_fixture -def cluster_meta_up(cluster_metadata): - yield ClusterMetadataUpdater(cluster_metadata) - - #------------------------------------------------------------------------------ # Test wizard #------------------------------------------------------------------------------ @@ -168,9 +157,9 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 -def test_wizard_update_simple(wizard, clustering, cluster_meta_up): +def test_wizard_update_simple(wizard, clustering, cluster_meta): # 2: none, 3: none, 5: ignored, 7: good - wizard.attach(clustering, cluster_meta_up) + wizard.attach(clustering, cluster_meta) wizard.first() wizard.last() @@ -190,8 +179,8 @@ def test_wizard_update_simple(wizard, clustering, cluster_meta_up): wizard.next_best() -def test_wizard_update_group(wizard, clustering, cluster_meta_up): - wizard.attach(clustering, cluster_meta_up) +def test_wizard_update_group(wizard, clustering, cluster_meta): + wizard.attach(clustering, cluster_meta) wizard.start() @@ -204,25 +193,25 @@ def _check_best_match(b, m): _check_best_match(3, 2) # Ignore the currently-pinned cluster. - cluster_meta_up.set_group(3, 0) + cluster_meta.set('group', 3, 0) _check_best_match(5, 2) # 2: none, 3: ignored, 5: ignored, 7: good # Ignore the current match and move to next. - cluster_meta_up.set_group(2, 1) + cluster_meta.set('group', 2, 1) _check_best_match(5, 7) # 2: ignored, 3: ignored, 5: ignored, 7: good - cluster_meta_up.undo() + cluster_meta.undo() _check_best_match(5, 2) - cluster_meta_up.redo() + cluster_meta.redo() _check_best_match(5, 7) -def test_wizard_update_clustering(wizard, clustering, cluster_meta_up): +def test_wizard_update_clustering(wizard, clustering, cluster_meta): # 2: none, 3: none, 5: ignored, 7: good - wizard.attach(clustering, cluster_meta_up) + wizard.attach(clustering, cluster_meta) wizard.start() def _check_best_match(b, m): @@ -235,7 +224,7 @@ def _check_best_match(b, m): wizard.pin() _check_best_match(2, 3) - cluster_meta_up.set_group(2, 2) + cluster_meta.set('group', 2, 2) wizard.selection = [2, 3] ################################ @@ -275,7 +264,7 @@ def _check_best_match(b, m): assert wizard.cluster_status(9) is None # Ignore a cluster. - cluster_meta_up.set_group(9, 1) + cluster_meta.set('group', 9, 1) assert wizard.cluster_status(9) == 'ignored' # Undo split. diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index b60324a67..2e854d943 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -448,7 +448,7 @@ def _select_after_update(self, up): elif cluster == self.match: self.next_match() - def attach(self, clustering, cluster_metadata): + def attach(self, clustering, cluster_meta): # TODO: might be better in an independent function in another module # The wizard gets the cluster ids from the Clustering instance @@ -457,7 +457,7 @@ def attach(self, clustering, cluster_metadata): @self.set_status_function def status(cluster): - group = cluster_metadata.group(cluster) + group = cluster_meta.group(cluster) if group is None: # pragma: no cover return None if group <= 1: @@ -471,7 +471,7 @@ def on_request_undo_state(up): def on_cluster(up): # Set the cluster metadata of new clusters. if up.added: - cluster_metadata.set_from_descendants(up.descendants) + cluster_meta.set_from_descendants(up.descendants) # Update the wizard state. if self._best_list or self._match_list: self._update_state(up) @@ -480,7 +480,7 @@ def on_cluster(up): self._select_after_update(up) clustering.connect(on_request_undo_state) - cluster_metadata.connect(on_request_undo_state) + cluster_meta.connect(on_request_undo_state) clustering.connect(on_cluster) - cluster_metadata.connect(on_cluster) + cluster_meta.connect(on_cluster) From 6249911c996e4ad5a42ab76139f027795d81538a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 17:24:32 +0200 Subject: [PATCH 0222/1059] WIP: refactor --- phy/cluster/manual/gui_plugins.py | 9 ++--- phy/cluster/manual/tests/conftest.py | 16 +++++---- phy/cluster/manual/tests/test_gui_plugins.py | 35 +++++++++++++++++--- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 40eecf8b4..68fb8e11a 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -23,7 +23,7 @@ # Clustering objects # ----------------------------------------------------------------------------- -def create_cluster_meta(data): +def create_cluster_meta(cluster_groups): """Return a ClusterMeta instance with cluster group support.""" meta = ClusterMeta() @@ -40,6 +40,7 @@ def group(ascendant_values=None): meta.add_field('group', 3, group) + data = {c: {'group': v} for c, v in cluster_groups.items()} meta.from_dict(data) return meta @@ -66,18 +67,18 @@ class ManualClustering(IPlugin): """ def attach_to_gui(self, gui, spike_clusters=None, - cluster_meta=None, + cluster_groups=None, n_spikes_max_per_cluster=100, ): self.gui = gui # Create Clustering and ClusterMeta. self.clustering = Clustering(spike_clusters) - self.cluster_meta = cluster_meta + self.cluster_meta = create_cluster_meta(cluster_groups) # Create the wizard and attach it to Clustering/ClusterMeta. self.wizard = Wizard() - self.wizard.attach(self.clustering, cluster_meta) + self.wizard.attach(self.clustering, self.cluster_meta) @self.wizard.connect def on_select(cluster_ids): diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 71b03093c..e35558657 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -28,14 +28,18 @@ def clustering(spike_clusters): @yield_fixture -def cluster_meta(): - data = {2: {'group': 3}, - 3: {'group': 3}, - 5: {'group': 1}, - 7: {'group': 2}, +def cluster_groups(): + data = {2: 3, + 3: 3, + 5: 1, + 7: 2, } + yield data - yield create_cluster_meta(data) + +@yield_fixture +def cluster_meta(cluster_groups): + yield create_cluster_meta(cluster_groups) def _set_test_wizard(wizard): diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 8389f2dfa..6d3388ff3 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -10,6 +10,7 @@ from numpy.testing import assert_array_equal as ae from .conftest import _set_test_wizard +from ..gui_plugins import create_cluster_meta from phy.gui.tests.test_gui import gui # noqa @@ -18,10 +19,10 @@ #------------------------------------------------------------------------------ @yield_fixture # noqa -def manual_clustering(qtbot, gui, spike_clusters, cluster_meta): +def manual_clustering(qtbot, gui, spike_clusters, cluster_groups): mc = gui.attach('ManualClustering', spike_clusters=spike_clusters, - cluster_meta=cluster_meta, + cluster_groups=cluster_groups, ) _set_test_wizard(mc.wizard) @@ -40,6 +41,20 @@ def _assert_selection(*cluster_ids): yield mc +def test_create_cluster_meta(): + cluster_groups = {2: 3, + 3: 3, + 5: 1, + 7: 2, + } + meta = create_cluster_meta(cluster_groups) + assert meta.group(2) == 3 + assert meta.group(3) == 3 + assert meta.group(5) == 1 + assert meta.group(7) == 2 + assert meta.group(8) == 3 + + def test_manual_clustering_wizard(manual_clustering): actions = manual_clustering.actions wizard = manual_clustering.wizard @@ -123,5 +138,17 @@ def test_manual_clustering_actions(manual_clustering): actions.split([2, 3]) # => 9 _assert_selection(9, 8) - # TODO: more tests, notably with group actions and wizard - actions.move([9], 'good') + +def test_manual_clustering_group(manual_clustering): + actions = manual_clustering.actions + wizard = manual_clustering.wizard + _assert_selection = manual_clustering._assert_selection + + actions.reset_wizard() + actions.pin() + _assert_selection(3, 2) + + # [3 , 2 , 7 , 5] + # [None, None, 'ignored', 'good'] + actions.move([3], 'good') + print(wizard.selection) From a287caea62be1430f95d0b24f93babbbf0819281 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 17:53:47 +0200 Subject: [PATCH 0223/1059] Update wizard fixture --- phy/cluster/manual/tests/conftest.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index e35558657..2cd109d0b 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -54,7 +54,7 @@ def similarity(cluster, other): @yield_fixture -def wizard(): +def wizard(cluster_groups): wizard = Wizard() def get_cluster_ids(): @@ -63,8 +63,13 @@ def get_cluster_ids(): wizard.set_cluster_ids_function(get_cluster_ids) @wizard.set_status_function - def cluster_status(cluster): - return {2: None, 3: None, 5: 'ignored', 7: 'good'}.get(cluster, None) + def status(cluster): + group = cluster_groups.get(cluster, 3) + if group <= 1: + return 'ignored' + elif group == 2: + return 'good' + return None _set_test_wizard(wizard) From edf7e464f919d968849030a8c478de0d3f4f26fa Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 17:54:10 +0200 Subject: [PATCH 0224/1059] Add wizard test --- phy/cluster/manual/tests/test_wizard.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 5bacc5384..1dd0e01ec 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -208,6 +208,14 @@ def _check_best_match(b, m): cluster_meta.redo() _check_best_match(5, 7) + # Now move 3 to good. + for _ in range(5): + cluster_meta.undo() + wizard.selection = (3, 2) + _check_best_match(3, 2) + cluster_meta.set('group', 3, 2) + _check_best_match(5, 2) + def test_wizard_update_clustering(wizard, clustering, cluster_meta): # 2: none, 3: none, 5: ignored, 7: good From d4c5d14fe7fb6ee4c831571684b6a7751a02ab4f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 19:26:15 +0200 Subject: [PATCH 0225/1059] Update ClusterMeta descendants --- phy/cluster/manual/_utils.py | 60 +++++++++++++++++--------- phy/cluster/manual/tests/test_utils.py | 32 ++++++++------ 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 590b99586..e2c7237eb 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -8,10 +8,13 @@ from copy import deepcopy from collections import defaultdict +import logging from ._history import History from phy.utils import Bunch, _as_list, _is_list, EventEmitter +logger = logging.getLogger(__name__) + #------------------------------------------------------------------------------ # Utility functions @@ -29,6 +32,26 @@ def _join(clusters): return '[{}]'.format(', '.join(map(str, clusters))) +def _wizard_group(group): + group = group.lower() if group else group + if group in ('mua', 'noise'): + return 'ignored' + elif group == 'good': + return 'good' + return None + + +def create_cluster_meta(cluster_groups): + """Return a ClusterMeta instance with cluster group support.""" + meta = ClusterMeta() + meta.add_field('group') + + data = {c: {'group': v} for c, v in cluster_groups.items()} + meta.from_dict(data) + + return meta + + #------------------------------------------------------------------------------ # UpdateInfo class #------------------------------------------------------------------------------ @@ -99,8 +122,8 @@ def _reset_data(self): def fields(self): return sorted(self._fields.keys()) - def add_field(self, name, default_value=None, ascendants_func=None): - self._fields[name] = (default_value, ascendants_func) + def add_field(self, name, default_value=None): + self._fields[name] = default_value def func(cluster): return self.get(name, cluster) @@ -121,6 +144,8 @@ def set(self, field, clusters, value, add_to_stack=True): for cluster in clusters: if cluster not in self._data: self._data[cluster] = {} + if value == 8: + raise RuntimeError() self._data[cluster][field] = value up = UpdateInfo(description='metadata_' + field, @@ -139,34 +164,27 @@ def get(self, field, cluster): if _is_list(cluster): return [self.get(field, c) for c in cluster] assert field in self._fields - default = self._fields[field][0] + default = self._fields[field] return self._data.get(cluster, {}).get(field, default) def set_from_descendants(self, descendants): """Update metadata of some clusters given the metadata of their ascendants.""" for field in self.fields: - # For each new cluster, a set of metadata values of their - # ascendants. + + # This gives a set of metadata values of all the parents + # of any new cluster. candidates = defaultdict(set) for old, new in descendants: - candidates[new].add(old) - for new, vals in candidates.items(): + candidates[new].add(self.get(field, old)) - # Skip that new cluster if its value is already non-default. - current_val = self.get(field, new) - default_val = self._fields[field][0] - if current_val != default_val: - continue - - # Ask the field the value of the new cluster, - # as a function of the values of its ascendants. This is - # encoded in the specified default function. - func = self._fields[field][1] - if func: - new_val = func(list(vals)) - if new_val is not None: - self.set(field, [new], new_val) + # Loop over all new clusters. + for new, vals in candidates.items(): + # If all the parents have the same value, assign it to + # the new cluster. + if len(vals) == 1: + self.set(field, new, list(vals)[0]) + # Otherwise, the default is assumed. def undo(self): """Undo the last metadata change. diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index a6c4e1935..3a0f47e33 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -8,7 +8,8 @@ import logging -from .._utils import ClusterMeta, UpdateInfo, _update_cluster_selection +from .._utils import (ClusterMeta, UpdateInfo, + _update_cluster_selection, create_cluster_meta) logger = logging.getLogger(__name__) @@ -17,6 +18,20 @@ # Tests #------------------------------------------------------------------------------ +def test_create_cluster_meta(): + cluster_groups = {2: 3, + 3: 3, + 5: 1, + 7: 2, + } + meta = create_cluster_meta(cluster_groups) + assert meta.group(2) == 3 + assert meta.group(3) == 3 + assert meta.group(5) == 1 + assert meta.group(7) == 2 + assert meta.group(8) is None + + def test_metadata_history(): """Test ClusterMetadataUpdater history.""" @@ -114,16 +129,7 @@ def test_metadata_descendants(): meta = ClusterMeta() - def group(ascendant_values): - s = list(set(ascendant_values) - set([None, 3])) - # Return the default value if all ascendant values are the default. - if not s: # pragma: no cover - return 3 - # Otherwise, return good (2) if it is present, or the largest value - # among those present. - return max(s) - - meta.add_field('group', 3, group) + meta.add_field('group', 3) meta.from_dict(data) meta.set_from_descendants([]) @@ -138,8 +144,10 @@ def group(ascendant_values): assert meta.group(4) == 1 meta.set_from_descendants([(1, 5), (2, 5)]) - assert meta.group(5) == 2 + # This is the default value because the parents have different values. + assert meta.group(5) == 3 + meta.set('group', 3, 2) meta.set_from_descendants([(2, 6), (3, 6), (10, 10)]) assert meta.group(6) == 2 From 03a58942fdf42c567be466b732e3cda9d0f42ab6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 19:28:39 +0200 Subject: [PATCH 0226/1059] Add test_wizard_group --- phy/cluster/manual/tests/test_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 3a0f47e33..a823ebfff 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -9,7 +9,8 @@ import logging from .._utils import (ClusterMeta, UpdateInfo, - _update_cluster_selection, create_cluster_meta) + _update_cluster_selection, create_cluster_meta, + _wizard_group) logger = logging.getLogger(__name__) @@ -18,6 +19,14 @@ # Tests #------------------------------------------------------------------------------ +def test_wizard_group(): + assert _wizard_group('noise') == 'ignored' + assert _wizard_group('mua') == 'ignored' + assert _wizard_group('good') == 'good' + assert _wizard_group('unknown') is None + assert _wizard_group(None) is None + + def test_create_cluster_meta(): cluster_groups = {2: 3, 3: 3, From e6119d28a62806fd20a117e88163e2d0a6729509 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 20:48:19 +0200 Subject: [PATCH 0227/1059] Move _wizard_group --- phy/cluster/manual/_utils.py | 9 --------- phy/cluster/manual/tests/conftest.py | 20 ++++++++------------ phy/cluster/manual/tests/test_utils.py | 11 +---------- 3 files changed, 9 insertions(+), 31 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index e2c7237eb..b296dfba6 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -32,15 +32,6 @@ def _join(clusters): return '[{}]'.format(', '.join(map(str, clusters))) -def _wizard_group(group): - group = group.lower() if group else group - if group in ('mua', 'noise'): - return 'ignored' - elif group == 'good': - return 'good' - return None - - def create_cluster_meta(cluster_groups): """Return a ClusterMeta instance with cluster group support.""" meta = ClusterMeta() diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 2cd109d0b..8cce1ad21 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -9,8 +9,8 @@ from pytest import yield_fixture from ..clustering import Clustering -from ..gui_plugins import create_cluster_meta -from ..wizard import Wizard +from ..wizard import Wizard, _wizard_group +from .._utils import create_cluster_meta #------------------------------------------------------------------------------ @@ -29,10 +29,10 @@ def clustering(spike_clusters): @yield_fixture def cluster_groups(): - data = {2: 3, - 3: 3, - 5: 1, - 7: 2, + data = {2: None, + 3: None, + 5: 'mua', + 7: 'good', } yield data @@ -64,12 +64,8 @@ def get_cluster_ids(): @wizard.set_status_function def status(cluster): - group = cluster_groups.get(cluster, 3) - if group <= 1: - return 'ignored' - elif group == 2: - return 'good' - return None + group = cluster_groups.get(cluster, None) + return _wizard_group(group) _set_test_wizard(wizard) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index a823ebfff..3a0f47e33 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -9,8 +9,7 @@ import logging from .._utils import (ClusterMeta, UpdateInfo, - _update_cluster_selection, create_cluster_meta, - _wizard_group) + _update_cluster_selection, create_cluster_meta) logger = logging.getLogger(__name__) @@ -19,14 +18,6 @@ # Tests #------------------------------------------------------------------------------ -def test_wizard_group(): - assert _wizard_group('noise') == 'ignored' - assert _wizard_group('mua') == 'ignored' - assert _wizard_group('good') == 'good' - assert _wizard_group('unknown') is None - assert _wizard_group(None) is None - - def test_create_cluster_meta(): cluster_groups = {2: 3, 3: 3, From 2a83e50cd9ee8acbdf6668d62e9b0ea257444dc6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 20:52:02 +0200 Subject: [PATCH 0228/1059] Add ClusterMeta test --- phy/cluster/manual/tests/test_utils.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 3a0f47e33..0a1fd6423 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -32,8 +32,24 @@ def test_create_cluster_meta(): assert meta.group(8) is None -def test_metadata_history(): - """Test ClusterMetadataUpdater history.""" +def test_metadata_history_simple(): + """Test ClusterMeta history.""" + + meta = ClusterMeta() + meta.add_field('group') + + meta.set('group', 2, 2) + assert meta.get('group', 2) == 2 + + meta.undo() + assert meta.get('group', 2) is None + + meta.redo() + assert meta.get('group', 2) == 2 + + +def test_metadata_history_complex(): + """Test ClusterMeta history.""" meta = ClusterMeta() meta.add_field('group', 3) @@ -119,7 +135,7 @@ def test_metadata_history(): def test_metadata_descendants(): - """Test ClusterMetadataUpdater history.""" + """Test ClusterMeta history.""" data = {0: {'group': 0}, 1: {'group': 1}, From a336b3a36d129aa691857c41bce413a8b8202e2c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 20:58:54 +0200 Subject: [PATCH 0229/1059] Fix wizard tests --- phy/cluster/manual/_utils.py | 8 ++++--- phy/cluster/manual/tests/test_wizard.py | 30 +++++++++++++++++-------- phy/cluster/manual/wizard.py | 16 ++++++++----- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index b296dfba6..6477b77c3 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -171,10 +171,12 @@ def set_from_descendants(self, descendants): # Loop over all new clusters. for new, vals in candidates.items(): + vals = list(vals) + default = self._fields[field] # If all the parents have the same value, assign it to - # the new cluster. - if len(vals) == 1: - self.set(field, new, list(vals)[0]) + # the new cluster if it is not the default. + if len(vals) == 1 and vals[0] != default: + self.set(field, new, vals[0]) # Otherwise, the default is assumed. def undo(self): diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 1dd0e01ec..049fc889c 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -12,6 +12,7 @@ _next, _find_first, Wizard, + _wizard_group, ) @@ -19,6 +20,14 @@ # Test wizard #------------------------------------------------------------------------------ +def test_wizard_group(): + assert _wizard_group('noise') == 'ignored' + assert _wizard_group('mua') == 'ignored' + assert _wizard_group('good') == 'good' + assert _wizard_group('unknown') is None + assert _wizard_group(None) is None + + def test_utils(): l = [2, 3, 5, 7, 11] @@ -191,14 +200,16 @@ def _check_best_match(b, m): wizard.pin() _check_best_match(3, 2) + # print(wizard.best_list) # Ignore the currently-pinned cluster. - cluster_meta.set('group', 3, 0) - _check_best_match(5, 2) + cluster_meta.set('group', 3, 'noise') # 2: none, 3: ignored, 5: ignored, 7: good + _check_best_match(5, 2) + return # Ignore the current match and move to next. - cluster_meta.set('group', 2, 1) + cluster_meta.set('group', 2, 'mua') _check_best_match(5, 7) # 2: ignored, 3: ignored, 5: ignored, 7: good @@ -213,7 +224,7 @@ def _check_best_match(b, m): cluster_meta.undo() wizard.selection = (3, 2) _check_best_match(3, 2) - cluster_meta.set('group', 3, 2) + cluster_meta.set('group', 3, 'good') _check_best_match(5, 2) @@ -232,7 +243,7 @@ def _check_best_match(b, m): wizard.pin() _check_best_match(2, 3) - cluster_meta.set('group', 2, 2) + cluster_meta.set('group', 2, 'good') wizard.selection = [2, 3] ################################ @@ -242,8 +253,9 @@ def _check_best_match(b, m): clustering.merge([2, 3]) # => 8 _check_best_match(8, 7) assert wizard.best_list == [8, 7, 5] - assert wizard.cluster_status(8) == 'good' + assert wizard.cluster_status(8) is None assert wizard.cluster_status(7) == 'good' + assert wizard.cluster_status(2) == 'good' # Undo merge. clustering.undo() @@ -258,7 +270,7 @@ def _check_best_match(b, m): # Redo merge. clustering.redo() _check_best_match(8, 7) - assert wizard.cluster_status(8) == 'good' + assert wizard.cluster_status(8) is None assert wizard.cluster_status(7) == 'good' ################################ @@ -272,7 +284,7 @@ def _check_best_match(b, m): assert wizard.cluster_status(9) is None # Ignore a cluster. - cluster_meta.set('group', 9, 1) + cluster_meta.set('group', 9, 'noise') assert wizard.cluster_status(9) == 'ignored' # Undo split. @@ -297,7 +309,7 @@ def _check_best_match(b, m): _check_best_match(11, 7) assert up.description == 'merge' assert up.history is None - assert wizard.cluster_status(11) is None + assert wizard.cluster_status(11) == 'ignored' # Undo split (=merge). up = clustering.undo() diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 2e854d943..106a841c9 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -71,6 +71,15 @@ def _next(items, current, filter=None): return current +def _wizard_group(group): + group = group.lower() if group else group + if group in ('mua', 'noise'): + return 'ignored' + elif group == 'good': + return 'good' + return None + + #------------------------------------------------------------------------------ # Wizard #------------------------------------------------------------------------------ @@ -458,12 +467,7 @@ def attach(self, clustering, cluster_meta): @self.set_status_function def status(cluster): group = cluster_meta.group(cluster) - if group is None: # pragma: no cover - return None - if group <= 1: - return 'ignored' - elif group == 2: - return 'good' + return _wizard_group(group) def on_request_undo_state(up): return {'selection': self.selection} From 00e83bded156719a6a28ce7c6b8f78c99a364a15 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 21:04:44 +0200 Subject: [PATCH 0230/1059] WIP: fix tests --- phy/cluster/manual/gui_plugins.py | 51 +++++++++----------- phy/cluster/manual/tests/test_gui_plugins.py | 37 +++++--------- phy/cluster/manual/wizard.py | 4 ++ 3 files changed, 39 insertions(+), 53 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 68fb8e11a..472dc8edb 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -9,7 +9,8 @@ import logging -from ._utils import ClusterMeta +from ._history import GlobalHistory +from ._utils import create_cluster_meta from .clustering import Clustering from .wizard import Wizard from phy.gui.actions import Actions, Snippets @@ -20,30 +21,23 @@ # ----------------------------------------------------------------------------- -# Clustering objects +# Utility functions # ----------------------------------------------------------------------------- -def create_cluster_meta(cluster_groups): - """Return a ClusterMeta instance with cluster group support.""" - meta = ClusterMeta() - - def group(ascendant_values=None): - if not ascendant_values: # pragma: no cover - return 3 - s = list(set(ascendant_values) - set([None, 3])) - # Return the default value if all ascendant values are the default. - if not s: # pragma: no cover - return 3 - # Otherwise, return good (2) if it is present, or the largest value - # among those present. - return max(s) - - meta.add_field('group', 3, group) - - data = {c: {'group': v} for c, v in cluster_groups.items()} - meta.from_dict(data) - - return meta +def _process_ups(ups): + """This function processes the UpdateInfo instances of the two + undo stacks (clustering and cluster metadata) and concatenates them + into a single UpdateInfo instance.""" + if len(ups) == 0: + return + elif len(ups) == 1: + return ups[0] + elif len(ups) == 2: + up = ups[0] + up.update(ups[1]) + return up + else: + raise NotImplementedError() # ----------------------------------------------------------------------------- @@ -75,6 +69,7 @@ def attach_to_gui(self, gui, # Create Clustering and ClusterMeta. self.clustering = Clustering(spike_clusters) self.cluster_meta = create_cluster_meta(cluster_groups) + self._global_history = GlobalHistory(process_ups=_process_ups) # Create the wizard and attach it to Clustering/ClusterMeta. self.wizard = Wizard() @@ -148,17 +143,19 @@ def merge(self, cluster_ids=None): if cluster_ids is None: cluster_ids = self.wizard.selection self.clustering.merge(cluster_ids) + self._global_history.action(self.clustering) def split(self, spike_ids): # TODO: connect to request_split emitted by view self.clustering.split(spike_ids) + self._global_history.action(self.clustering) def move(self, clusters, group): - # TODO - pass + self.cluster_meta.set('group', clusters, group) + self._global_history.action(self.cluster_meta) def undo(self): - self.clustering.undo() + self._global_history.undo() def redo(self): - self.clustering.redo() + self._global_history.redo() diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 6d3388ff3..96f165afc 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -10,7 +10,6 @@ from numpy.testing import assert_array_equal as ae from .conftest import _set_test_wizard -from ..gui_plugins import create_cluster_meta from phy.gui.tests.test_gui import gui # noqa @@ -35,26 +34,16 @@ def on_select(cluster_ids, spike_ids): def _assert_selection(*cluster_ids): assert _s[-1][0] == list(cluster_ids) + if len(cluster_ids) >= 1: + assert mc.wizard.best == cluster_ids[0] + elif len(cluster_ids) >= 2: + assert mc.wizard.match == cluster_ids[2] mc._assert_selection = _assert_selection yield mc -def test_create_cluster_meta(): - cluster_groups = {2: 3, - 3: 3, - 5: 1, - 7: 2, - } - meta = create_cluster_meta(cluster_groups) - assert meta.group(2) == 3 - assert meta.group(3) == 3 - assert meta.group(5) == 1 - assert meta.group(7) == 2 - assert meta.group(8) == 3 - - def test_manual_clustering_wizard(manual_clustering): actions = manual_clustering.actions wizard = manual_clustering.wizard @@ -70,41 +59,32 @@ def test_manual_clustering_wizard(manual_clustering): # Test wizard actions. actions.reset_wizard() assert wizard.best_list == [3, 2, 7, 5] - assert wizard.best == 3 _assert_selection(3) actions.next() - assert wizard.best == 2 _assert_selection(2) actions.last() - assert wizard.best == 5 _assert_selection(5) actions.next() - assert wizard.best == 5 _assert_selection(5) actions.previous() - assert wizard.best == 7 _assert_selection(7) actions.first() - assert wizard.best == 3 _assert_selection(3) actions.previous() - assert wizard.best == 3 _assert_selection(3) # Test pinning. actions.pin() assert wizard.match_list == [2, 7, 5] - assert wizard.match == 2 _assert_selection(3, 2) wizard.next() - assert wizard.match == 7 _assert_selection(3, 7) wizard.unpin() @@ -149,6 +129,11 @@ def test_manual_clustering_group(manual_clustering): _assert_selection(3, 2) # [3 , 2 , 7 , 5] - # [None, None, 'ignored', 'good'] + # [None, None, 'good', 'ignored'] actions.move([3], 'good') - print(wizard.selection) + + # ['good', None, 'good', 'ignored'] + _assert_selection(7, 2) + + actions.next() + _assert_selection(7, 3) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 106a841c9..5520d6357 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -9,6 +9,8 @@ import logging from operator import itemgetter +from six import string_types + from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -72,6 +74,8 @@ def _next(items, current, filter=None): def _wizard_group(group): + # The group should be None, 'mua', 'noise', or 'good'. + assert group is None or isinstance(group, string_types) group = group.lower() if group else group if group in ('mua', 'noise'): return 'ignored' From 48717fe677f3197f40e294f4fcfc8ce590e8b061 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 21:10:06 +0200 Subject: [PATCH 0231/1059] Increase coverage --- phy/cluster/manual/_utils.py | 2 -- phy/cluster/manual/gui_plugins.py | 2 +- phy/cluster/manual/tests/test_gui_plugins.py | 4 ++-- phy/cluster/manual/tests/test_wizard.py | 1 - 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 6477b77c3..b87547bc1 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -135,8 +135,6 @@ def set(self, field, clusters, value, add_to_stack=True): for cluster in clusters: if cluster not in self._data: self._data[cluster] = {} - if value == 8: - raise RuntimeError() self._data[cluster][field] = value up = UpdateInfo(description='metadata_' + field, diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 472dc8edb..50af39706 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -24,7 +24,7 @@ # Utility functions # ----------------------------------------------------------------------------- -def _process_ups(ups): +def _process_ups(ups): # pragma: no cover """This function processes the UpdateInfo instances of the two undo stacks (clustering and cluster metadata) and concatenates them into a single UpdateInfo instance.""" diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 96f165afc..b6699cad4 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -32,7 +32,7 @@ def manual_clustering(qtbot, gui, spike_clusters, cluster_groups): def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) - def _assert_selection(*cluster_ids): + def _assert_selection(*cluster_ids): # pragma: no cover assert _s[-1][0] == list(cluster_ids) if len(cluster_ids) >= 1: assert mc.wizard.best == cluster_ids[0] @@ -121,7 +121,7 @@ def test_manual_clustering_actions(manual_clustering): def test_manual_clustering_group(manual_clustering): actions = manual_clustering.actions - wizard = manual_clustering.wizard + # wizard = manual_clustering.wizard _assert_selection = manual_clustering._assert_selection actions.reset_wizard() diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 049fc889c..4662c2cb9 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -206,7 +206,6 @@ def _check_best_match(b, m): cluster_meta.set('group', 3, 'noise') # 2: none, 3: ignored, 5: ignored, 7: good _check_best_match(5, 2) - return # Ignore the current match and move to next. cluster_meta.set('group', 2, 'mua') From 3f7244e84377aa8c39c74830b38a788e2dd92192 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 21:17:10 +0200 Subject: [PATCH 0232/1059] Remove unused test --- phy/cluster/manual/tests/test_wizard.py | 26 +++---------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 4662c2cb9..81436a4fc 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -159,6 +159,9 @@ def test_wizard_nav(wizard): wizard.last() assert wizard.selection == [3, 5] + wizard.previous_best() + assert wizard.selection == [3, 2] + wizard.unpin() assert wizard.best == 3 assert wizard.match is None @@ -166,28 +169,6 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 -def test_wizard_update_simple(wizard, clustering, cluster_meta): - # 2: none, 3: none, 5: ignored, 7: good - wizard.attach(clustering, cluster_meta) - - wizard.first() - wizard.last() - - wizard.start() - - wizard.first() - wizard.last() - - wizard.pin() - - wizard.first() - wizard.last() - - wizard.pin() - wizard.previous_best() - wizard.next_best() - - def test_wizard_update_group(wizard, clustering, cluster_meta): wizard.attach(clustering, cluster_meta) @@ -200,7 +181,6 @@ def _check_best_match(b, m): wizard.pin() _check_best_match(3, 2) - # print(wizard.best_list) # Ignore the currently-pinned cluster. cluster_meta.set('group', 3, 'noise') From 9fa726c967ed15cdb4eb97c2b67d997a5f562fcf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 6 Oct 2015 21:30:51 +0200 Subject: [PATCH 0233/1059] Adding new wizard test --- phy/cluster/manual/tests/test_wizard.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 81436a4fc..f62e0a3ff 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -169,6 +169,16 @@ def test_wizard_nav(wizard): assert wizard.n_processed == 2 +def test_wizard_update_1(wizard, clustering, cluster_meta): + wizard.attach(clustering, cluster_meta) + wizard.start() + wizard.pin() + assert wizard.best_list == [3, 2, 7, 5] + assert wizard.selection == [3, 2] + cluster_meta.set('group', 3, 'noise') + # 2: None, 3: 'noise', 5: 'mua', 7: 'good' + + def test_wizard_update_group(wizard, clustering, cluster_meta): wizard.attach(clustering, cluster_meta) From d231f21e25079ef970408bf25589d0a72b4c69b3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 10:42:20 +0200 Subject: [PATCH 0234/1059] WIP: refactor wizard --- phy/cluster/manual/tests/test_wizard.py | 320 +++---------------- phy/cluster/manual/wizard.py | 406 +++++------------------- 2 files changed, 128 insertions(+), 598 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index f62e0a3ff..c47b6ce92 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -8,9 +8,8 @@ from numpy.testing import assert_array_equal as ae -from ..wizard import (_previous, - _next, - _find_first, +from ..wizard import (_argsort, + _best_clusters, Wizard, _wizard_group, ) @@ -20,294 +19,61 @@ # Test wizard #------------------------------------------------------------------------------ -def test_wizard_group(): - assert _wizard_group('noise') == 'ignored' - assert _wizard_group('mua') == 'ignored' - assert _wizard_group('good') == 'good' - assert _wizard_group('unknown') is None - assert _wizard_group(None) is None - +def test_argsort(): + l = [(1, .1), (2, .2), (3, .3), (4, .4)] + assert _argsort(l) == [4, 3, 2, 1] -def test_utils(): - l = [2, 3, 5, 7, 11] + assert _argsort(l, n_max=0) == [4, 3, 2, 1] + assert _argsort(l, n_max=1) == [4] + assert _argsort(l, n_max=2) == [4, 3] + assert _argsort(l, n_max=10) == [4, 3, 2, 1] - def func(x): - return x in (2, 5) + assert _argsort(l, reverse=False) == [1, 2, 3, 4] - _find_first([], None) - _previous([], None) - _previous([0, 1], 1, lambda x: x > 0) - # Error: log and do nothing. - _previous(l, 1, func) - _previous(l, 15, func) +def test_best_clusters(): + quality = lambda c: c * .1 + l = list(range(1, 5)) + assert _best_clusters(l, quality) == [4, 3, 2, 1] + assert _best_clusters(l, quality, n_max=0) == [4, 3, 2, 1] + assert _best_clusters(l, quality, n_max=1) == [4] + assert _best_clusters(l, quality, n_max=2) == [4, 3] + assert _best_clusters(l, quality, n_max=10) == [4, 3, 2, 1] - assert _previous(l, 2, func) == 2 - assert _previous(l, 3, func) == 2 - assert _previous(l, 5, func) == 2 - assert _previous(l, 7, func) == 5 - assert _previous(l, 11, func) == 5 - _next([], None) - # Error: log and do nothing. - _next(l, 1, func) - _next(l, 15, func) - - assert _next(l, 2, func) == 5 - assert _next(l, 3, func) == 5 - assert _next(l, 5, func) == 5 - assert _next(l, 7, func) == 7 - assert _next(l, 11, func) == 11 +def test_wizard_group(): + assert _wizard_group('noise') == 'ignored' + assert _wizard_group('mua') == 'ignored' + assert _wizard_group('good') == 'good' + assert _wizard_group('unknown') is None + assert _wizard_group(None) is None def test_wizard_core(): wizard = Wizard() - wizard.set_cluster_ids_function(lambda: [2, 3, 5]) + wizard.set_cluster_ids_function(lambda: [1, 2, 3]) @wizard.set_quality_function def quality(cluster): - return {2: .9, - 3: .3, - 5: .6, - }[cluster] + return cluster @wizard.set_similarity_function def similarity(cluster, other): - cluster, other = min((cluster, other)), max((cluster, other)) - return {(2, 3): 1, - (2, 5): 2, - (3, 5): 3}[cluster, other] - - assert wizard.best_clusters() == [2, 5, 3] - assert wizard.best_clusters(n_max=0) == [2, 5, 3] - assert wizard.best_clusters(n_max=None) == [2, 5, 3] - assert wizard.best_clusters(n_max=2) == [2, 5] - - assert wizard.best_clusters(n_max=1) == [2] - - assert wizard.most_similar_clusters() == [5, 3] - assert wizard.most_similar_clusters(2) == [5, 3] - - assert wizard.most_similar_clusters(n_max=0) == [5, 3] - assert wizard.most_similar_clusters(n_max=None) == [5, 3] - assert wizard.most_similar_clusters(n_max=1) == [5] - - -def test_wizard_nav(wizard): - - # Loop over the best clusters. - wizard.start() - - assert wizard.n_clusters == 4 - assert wizard.best_list == [3, 2, 7, 5] - - assert wizard.best == 3 - assert wizard.match is None - - wizard.next() - assert wizard.best == 2 - - wizard.previous() - assert wizard.best == 3 - - wizard.previous_best() - assert wizard.best == 3 - - wizard.next() - assert wizard.best == 2 - - wizard.next() - assert wizard.best == 7 - - wizard.next_best() - assert wizard.best == 5 - - wizard.next() - assert wizard.best == 5 - - # Now we start again. - wizard.start() - assert wizard.best == 3 - assert wizard.match is None - - # The match are sorted by group first (unsorted, good, and ignored), - # and similarity second. - wizard.pin() - assert wizard.best == 3 - assert wizard.match == 2 - assert wizard.match_list == [2, 7, 5] - - wizard.next() - assert wizard.best == 3 - assert wizard.match == 7 - - wizard.next_match() - assert wizard.best == 3 - assert wizard.match == 5 - - wizard.previous_match() - assert wizard.best == 3 - assert wizard.match == 7 - - wizard.previous() - assert wizard.best == 3 - assert wizard.match == 2 - - wizard.first() - assert wizard.selection == [3, 2] - wizard.last() - assert wizard.selection == [3, 5] - - wizard.previous_best() - assert wizard.selection == [3, 2] - - wizard.unpin() - assert wizard.best == 3 - assert wizard.match is None - - assert wizard.n_processed == 2 - - -def test_wizard_update_1(wizard, clustering, cluster_meta): - wizard.attach(clustering, cluster_meta) - wizard.start() - wizard.pin() - assert wizard.best_list == [3, 2, 7, 5] - assert wizard.selection == [3, 2] - cluster_meta.set('group', 3, 'noise') - # 2: None, 3: 'noise', 5: 'mua', 7: 'good' - - -def test_wizard_update_group(wizard, clustering, cluster_meta): - wizard.attach(clustering, cluster_meta) - - wizard.start() - - def _check_best_match(b, m): - assert wizard.selection == [b, m] - assert wizard.best == b - assert wizard.match == m - - wizard.pin() - _check_best_match(3, 2) - - # Ignore the currently-pinned cluster. - cluster_meta.set('group', 3, 'noise') - # 2: none, 3: ignored, 5: ignored, 7: good - _check_best_match(5, 2) - - # Ignore the current match and move to next. - cluster_meta.set('group', 2, 'mua') - _check_best_match(5, 7) - # 2: ignored, 3: ignored, 5: ignored, 7: good - - cluster_meta.undo() - _check_best_match(5, 2) - - cluster_meta.redo() - _check_best_match(5, 7) - - # Now move 3 to good. - for _ in range(5): - cluster_meta.undo() - wizard.selection = (3, 2) - _check_best_match(3, 2) - cluster_meta.set('group', 3, 'good') - _check_best_match(5, 2) - - -def test_wizard_update_clustering(wizard, clustering, cluster_meta): - # 2: none, 3: none, 5: ignored, 7: good - wizard.attach(clustering, cluster_meta) - wizard.start() - - def _check_best_match(b, m): - assert wizard.selection == [b, m] - assert wizard.best == b - assert wizard.match == m - - assert wizard.best_list == [3, 2, 7, 5] - wizard.next() - wizard.pin() - - _check_best_match(2, 3) - cluster_meta.set('group', 2, 'good') - wizard.selection = [2, 3] - - ################################ - - assert wizard.cluster_status(2) == 'good' - assert wizard.cluster_status(3) is None - clustering.merge([2, 3]) # => 8 - _check_best_match(8, 7) - assert wizard.best_list == [8, 7, 5] - assert wizard.cluster_status(8) is None - assert wizard.cluster_status(7) == 'good' - assert wizard.cluster_status(2) == 'good' - - # Undo merge. - clustering.undo() - _check_best_match(2, 3) - assert wizard.cluster_status(2) == 'good' - assert wizard.cluster_status(3) is None - - # Make a selection. - wizard.selection = [1, 5, 7, 8] - _check_best_match(5, 7) - - # Redo merge. - clustering.redo() - _check_best_match(8, 7) - assert wizard.cluster_status(8) is None - assert wizard.cluster_status(7) == 'good' - - ################################ - - # Split. - ae(clustering.spike_clusters, [8, 8, 5, 7]) - clustering.split([1, 2]) # ==> 9, 10 - ae(clustering.spike_clusters, [10, 9, 9, 7]) - _check_best_match(9, 10) - assert wizard.cluster_status(10) is None - assert wizard.cluster_status(9) is None - - # Ignore a cluster. - cluster_meta.set('group', 9, 'noise') - assert wizard.cluster_status(9) == 'ignored' - - # Undo split. - up = clustering.undo() - _check_best_match(8, 7) - assert up.description == 'assign' - assert up.history == 'undo' - - # Redo split. - up = clustering.redo() - _check_best_match(9, 10) - assert up.description == 'assign' - assert up.history == 'redo' - assert wizard.cluster_status(9) == 'ignored' - - ################################ - - # Split (=merge). - ae(clustering.spike_clusters, [10, 9, 9, 7]) - up = clustering.split([1, 2]) - ae(clustering.spike_clusters, [10, 11, 11, 7]) - _check_best_match(11, 7) - assert up.description == 'merge' - assert up.history is None - assert wizard.cluster_status(11) == 'ignored' - - # Undo split (=merge). - up = clustering.undo() - _check_best_match(9, 10) - assert up.description == 'merge' - assert up.history == 'undo' - - # Redo split (=merge). - up = clustering.redo() - _check_best_match(11, 7) - assert up.description == 'merge' - assert up.history == 'redo' + return cluster + other + + assert wizard.best_clusters() == [3, 2, 1] + assert wizard.best_clusters(n_max=0) == [3, 2, 1] + assert wizard.best_clusters(n_max=None) == [3, 2, 1] + assert wizard.best_clusters(n_max=2) == [3, 2] + assert wizard.best_clusters(n_max=1) == [3] + + assert wizard.most_similar_clusters(3) == [2, 1] + assert wizard.most_similar_clusters(2) == [3, 1] + assert wizard.most_similar_clusters(1) == [3, 2] + + assert wizard.most_similar_clusters(3, n_max=0) == [2, 1] + assert wizard.most_similar_clusters(3, n_max=None) == [2, 1] + assert wizard.most_similar_clusters(3, n_max=1) == [2] + assert wizard.most_similar_clusters(3, n_max=2) == [2, 1] + assert wizard.most_similar_clusters(3, n_max=10) == [2, 1] diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 5520d6357..c95abe3b7 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -11,6 +11,8 @@ from six import string_types +from ._history import History +from phy.utils._types import _as_list from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -37,42 +39,6 @@ def _best_clusters(clusters, quality, n_max=None): for cluster in clusters], n_max=n_max) -def _find_first(items, filter=None): - if not items: - return None - if filter is None: - return items[0] - return next(item for item in items if filter(item)) - - -def _previous(items, current, filter=None): - if current not in items: - logger.debug("%s is not in %s.", current, items) - return - i = items.index(current) - if i == 0: - return current - try: - return _find_first(items[:i][::-1], filter) - except StopIteration: - return current - - -def _next(items, current, filter=None): - if not items: - return current - if current not in items: - logger.debug("%d is not in %s.", current, items) - return - i = items.index(current) - if i == len(items) - 1: - return current - try: - return _find_first(items[i + 1:], filter) - except StopIteration: - return current - - def _wizard_group(group): # The group should be None, 'mua', 'noise', or 'good'. assert group is None or isinstance(group, string_types) @@ -91,7 +57,16 @@ def _wizard_group(group): class Wizard(EventEmitter): """Propose a selection of high-quality clusters and merge candidates. - The wizard is responsible for the selected clusters. + * The wizard is responsible for the selected clusters. + * The wizard keeps no state about the clusters: the state is entirely + provided by functions: cluster_ids, status (group), similarity, quality. + * The wizard keeps track of the history of the selected clusters, but this + history is cleared after every action that changes the state. + * The `next()` function proposes a new selection as a function of the + current selection only. + * There are two strategies: best-quality or best-similarity strategy. + + TODO: cache expensive functions. """ def __init__(self): @@ -100,14 +75,12 @@ def __init__(self): self._quality = None self._get_cluster_ids = None self._cluster_status = lambda cluster: None + self._next = None # Strategy function. self.reset() def reset(self): self._selection = [] - self._best_list = [] # This list is fixed (modulo clustering actions). - self._match_list = [] # This list may often change. - self._best = None - self._match = None + self._history = History(()) # Quality and status functions #-------------------------------------------------------------------------- @@ -144,36 +117,38 @@ def set_quality_function(self, func): self._quality = func return func + def set_strategy_function(self, func): + """Register a function returning a new selection after the current + selection, as a function of the quality and similarity of the clusters. + """ + # func(selection, cluster_ids=None, quality=None, similarity=None) + + def wrapped(sel): + return func(self._selection, + cluster_ids=self._get_cluster_ids(), + quality=self._quality, + similarity=self._similarity, + ) + + self._next = wrapped + # Internal methods #-------------------------------------------------------------------------- - def _with_status(self, items, status): - """Filter out ignored clusters or pairs of clusters.""" - if not isinstance(status, (list, tuple)): - status = [status] - return [item for item in items if self._cluster_status(item) in status] - - def _check(self): - clusters = set(self.cluster_ids) - assert set(self._best_list) <= clusters - assert set(self._match_list) <= clusters - if self._best is not None and len(self._best_list) >= 1: - assert self._best in self._best_list - if self._match is not None and len(self._match_list) >= 1: - assert self._match in self._match_list - if None not in (self.best, self.match): - assert self.best != self.match - - def _sort(self, items, mix_good_unsorted=False): - """Sort clusters according to their status: - unsorted, good, and ignored.""" - if mix_good_unsorted: - return (self._with_status(items, (None, 'good')) + - self._with_status(items, 'ignored')) - else: - return (self._with_status(items, None) + - self._with_status(items, 'good') + - self._with_status(items, 'ignored')) + def _sort_nomix(self, cluster): + # Sort by unsorted first, good second, ignored last. + _sort_map = {None: 0, 'good': 1, 'ignored': 2} + return _sort_map.get(self._cluster_status(cluster), 0) + + def _sort_mix(self, cluster): + # Sort by unsorted/good first, ignored last. + _sort_map = {None: 0, 'good': 0, 'ignored': 2} + return _sort_map.get(self._cluster_status(cluster), 0) + + def _sort(self, clusters, mix_good_unsorted=False): + """Sort clusters according to their status.""" + key = self._sort_mix if mix_good_unsorted else self._sort_nomix + return sorted(clusters, key=key) # Properties #-------------------------------------------------------------------------- @@ -183,6 +158,11 @@ def cluster_ids(self): """Array of cluster ids in the current clustering.""" return sorted(self._get_cluster_ids()) + @property + def n_clusters(self): + """Total number of clusters.""" + return len(self.cluster_ids) + # Core methods #-------------------------------------------------------------------------- @@ -195,300 +175,84 @@ def best_clusters(self, n_max=None, quality=None): The default quality function is the registered one. """ - if quality is None: - quality = self._quality + quality = quality or self._quality best = _best_clusters(self.cluster_ids, quality, n_max=n_max) return self._sort(best) - def most_similar_clusters(self, cluster=None, n_max=None, similarity=None): + def most_similar_clusters(self, cluster, n_max=None, similarity=None): """Return the `n_max` most similar clusters to a given cluster. The default similarity function is the registered one. """ - if cluster is None: - cluster = self.best - if cluster is None: - cluster = self.best_clusters(1)[0] - if similarity is None: - similarity = self._similarity + similarity = similarity or self._similarity s = [(other, similarity(cluster, other)) for other in self.cluster_ids if other != cluster] clusters = _argsort(s, n_max=n_max) return self._sort(clusters, mix_good_unsorted=True) - # List methods + # Selection methods #-------------------------------------------------------------------------- - def _set_best_list(self, cluster=None, clusters=None): - if cluster is None: - cluster = self.best - if clusters is None: - clusters = self.best_clusters() - self._best_list = clusters - if clusters: - self.best = clusters[0] - - def _set_match_list(self, cluster=None, clusters=None): - if cluster is None: - cluster = self.best - if clusters is None: - clusters = self.most_similar_clusters(cluster) - self._match_list = clusters - if clusters: - self.match = clusters[0] - - @property - def best(self): - """Currently-selected best cluster.""" - return self._best - - @best.setter - def best(self, value): - assert value in self._best_list - self.selection = [value] - - @property - def match(self): - """Currently-selected closest match.""" - return self._match - - @match.setter - def match(self, value): - if value is not None: - assert value in self._match_list - if len(self._selection) == 1: - self.selection = self.selection + [value] - elif len(self._selection) == 2: - self.selection = [self.selection[0], value] - @property def selection(self): - """Return the current best/match cluster selection.""" + """Return the current cluster selection.""" return self._selection @selection.setter def selection(self, value): - """Return the current best/match cluster selection.""" - assert isinstance(value, (tuple, list)) + value = _as_list(value) clusters = self.cluster_ids value = [cluster for cluster in value if cluster in clusters] self._selection = value - if len(self._selection) == 1: - self._match = None - if len(self._selection) >= 1: - self._best = self._selection[0] - if len(self._selection) >= 2: - self._match = self._selection[1] self.emit('select', self._selection) @property - def best_list(self): - """Current list of best clusters, by decreasing quality.""" - return self._best_list - - @property - def match_list(self): - """Current list of closest matches, by decreasing similarity.""" - return self._match_list - - @property - def n_processed(self): - """Numbered of processed clusters so far. - - A cluster is considered processed if its status is not `None`. - - """ - return len(self._with_status(self._best_list, ('good', 'ignored'))) + def best(self): + """Currently-selected best cluster.""" + return self._selection[0] if self._selection else None @property - def n_clusters(self): - """Total number of clusters.""" - return len(self.cluster_ids) + def match(self): + """Currently-selected closest match.""" + return self._selection[1] if len(self._selection) >= 2 else None # Navigation #-------------------------------------------------------------------------- - def next_best(self): - """Select the next best cluster.""" - boo_match = self.match is not None - self.best = _next(self._best_list, self._best) - if boo_match: - self._set_match_list() - - def previous_best(self): - """Select the previous best in cluster.""" - boo_match = self.match is not None - if self._best_list: - self.best = _previous(self._best_list, self._best) - if boo_match: - self._set_match_list() - - def next_match(self): - """Select the next match.""" - if self._match_list: - self.match = _next(self._match_list, self._match) - - def previous_match(self): - """Select the previous match.""" - if self._match_list: - self.match = _previous(self._match_list, self._match) + def previous(self): + sel = self._history.back() + if sel: + self._selection = sel def next(self): - """Next cluster proposition.""" - if self.match is None: - return self.next_best() + if not self._history.is_last(): + # Go forward after a previous. + sel = self._history.forward() + if sel: + self._selection = sel else: - return self.next_match() - - def previous(self): - """Previous cluster proposition.""" - if self.match is None: - return self.previous_best() - else: - return self.previous_match() - - def first(self): - """First match or first best.""" - if self.match is None and self._best_list: - self.best = self._best_list[0] - elif self._match_list: - self.match = self._match_list[0] - - def last(self): - """Last match or last best.""" - if self.match is None and self._best_list: - self.best = self._best_list[-1] - elif self.match_list: - self.match = self._match_list[-1] - - # Control - #-------------------------------------------------------------------------- + # Or compute the next selection. + self._selection = self._next(self._selection) + self._history.add(self._selection) - def start(self): - """Start the wizard by setting the list of best clusters.""" - self._set_best_list() - - def pin(self, cluster=None): - """Pin the current best cluster and set the list of closest matches.""" - if cluster is None: - cluster = self.best - logger.debug("Pin %d.", cluster) - self.best = cluster - self._set_match_list(cluster) - self._check() - - def unpin(self): - """Unpin the current cluster.""" - if self.match is not None: - logger.debug("Unpin.") - self.match = None - self._match_list = [] - - # Actions + # Attach #-------------------------------------------------------------------------- - def _delete(self, clusters): - for clu in clusters: - if clu in self._best_list: - self._best_list.remove(clu) - if clu in self._match_list: - self._match_list.remove(clu) - if clu == self._best: - self._best = self._best_list[0] if self._best_list else None - if clu == self._match: - self._match = None - - def _add(self, clusters, position=None): - for clu in clusters: - assert clu not in self._best_list - assert clu not in self._match_list - if self.best is not None: - if position is not None: - self._best_list.insert(position, clu) - else: # pragma: no cover - self._best_list.append(clu) - if self.match is not None: - self._match_list.append(clu) - - def _update_state(self, up): - # Update the cluster status. - if up.description == 'metadata_group': - cluster = up.metadata_changed[0] - # Reorder the best list, so that the clusters moved in different - # status go to their right place in the best list. - if (self._best is not None and self._best_list and - cluster == self._best): - # # Find the next best after the cluster has been moved. - # next_best = _next(self._best_list, self._best) - # Reorder the list. - self._best_list = self._sort(self._best_list) - # # Select the next best. - # self._best = next_best - # Update the wizard with new and old clusters. - for clu in up.added: - # Add the child at the parent's position. - parents = [x for (x, y) in up.descendants if y == clu] - parent = parents[0] - position = (self._best_list.index(parent) - if self._best_list else None) - self._add([clu], position) - # Delete old clusters. - self._delete(up.deleted) - # # Select the last added cluster. - # if self.best is not None and up.added: - # self._best = up.added[-1] - - def _select_after_update(self, up): - if up.history == 'undo': - self.selection = up.undo_state[0]['selection'] - return - # Make as few updates as possible in the views after clustering - # actions. This allows for better before/after comparisons. - if up.added: - self.selection = up.added - if up.description == 'merge': - self.pin(up.added[0]) - if up.description == 'metadata_group': - cluster = up.metadata_changed[0] - if cluster == self.best: - # Pin the next best if there was a match before. - match_before = self.match is not None - self.next_best() - if match_before: - self.pin() - elif cluster == self.match: - self.next_match() - - def attach(self, clustering, cluster_meta): - # TODO: might be better in an independent function in another module - - # The wizard gets the cluster ids from the Clustering instance - # and the status from ClusterMetadataUpdater. - self.set_cluster_ids_function(lambda: clustering.cluster_ids) - - @self.set_status_function - def status(cluster): - group = cluster_meta.group(cluster) - return _wizard_group(group) + def attach(self, obj): + """Attach an actioner to the wizard.""" + # Save the current selection when an action occurs. + @obj.connect def on_request_undo_state(up): - return {'selection': self.selection} + return {'selection': self._selection} + @obj.connect def on_cluster(up): - # Set the cluster metadata of new clusters. - if up.added: - cluster_meta.set_from_descendants(up.descendants) - # Update the wizard state. - if self._best_list or self._match_list: - self._update_state(up) - # Make a new selection. - if self._best is not None or self._match is not None: - self._select_after_update(up) - - clustering.connect(on_request_undo_state) - cluster_meta.connect(on_request_undo_state) - - clustering.connect(on_cluster) - cluster_meta.connect(on_cluster) + if up.history == 'undo': + # Revert to the given selection after an undo. + self._selection = up.undo_state[0]['selection'] + else: + # Or move to the next selection after any other action. + self.next() From 738593ea38ca71067054fa11bcc615d9768a5c37 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 10:46:35 +0200 Subject: [PATCH 0235/1059] WIP: wizard tests --- phy/cluster/manual/tests/conftest.py | 50 ++++++++++++------------- phy/cluster/manual/tests/test_wizard.py | 47 ++++++++++------------- 2 files changed, 43 insertions(+), 54 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 8cce1ad21..e6ad455fe 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,7 +8,6 @@ from pytest import yield_fixture -from ..clustering import Clustering from ..wizard import Wizard, _wizard_group from .._utils import create_cluster_meta @@ -17,22 +16,18 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture -def spike_clusters(): - yield [2, 3, 5, 7] - - -@yield_fixture -def clustering(spike_clusters): - yield Clustering(spike_clusters) - - @yield_fixture def cluster_groups(): - data = {2: None, - 3: None, - 5: 'mua', - 7: 'good', + data = {1: 'noise', + 2: 'mua', + 11: 'good', + 12: 'good', + 13: 'good', + 101: None, + 102: None, + 103: None, + 104: None, + 105: None, } yield data @@ -42,31 +37,34 @@ def cluster_meta(cluster_groups): yield create_cluster_meta(cluster_groups) -def _set_test_wizard(wizard): +@yield_fixture +def mock_wizard(): + + wizard = Wizard() + wizard.set_cluster_ids_function(lambda: [1, 2, 3]) @wizard.set_quality_function def quality(cluster): - return cluster * .1 + return cluster @wizard.set_similarity_function def similarity(cluster, other): - return 1. + quality(cluster) - quality(other) + return cluster + other + + yield wizard @yield_fixture -def wizard(cluster_groups): - wizard = Wizard() +def wizard_with_groups(mock_wizard, cluster_groups): def get_cluster_ids(): - return [2, 3, 5, 7] + return sorted(cluster_groups.keys()) - wizard.set_cluster_ids_function(get_cluster_ids) + mock_wizard.set_cluster_ids_function(get_cluster_ids) - @wizard.set_status_function + @mock_wizard.set_status_function def status(cluster): group = cluster_groups.get(cluster, None) return _wizard_group(group) - _set_test_wizard(wizard) - - yield wizard + yield mock_wizard diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index c47b6ce92..2c09a06a8 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -49,31 +49,22 @@ def test_wizard_group(): assert _wizard_group(None) is None -def test_wizard_core(): - - wizard = Wizard() - wizard.set_cluster_ids_function(lambda: [1, 2, 3]) - - @wizard.set_quality_function - def quality(cluster): - return cluster - - @wizard.set_similarity_function - def similarity(cluster, other): - return cluster + other - - assert wizard.best_clusters() == [3, 2, 1] - assert wizard.best_clusters(n_max=0) == [3, 2, 1] - assert wizard.best_clusters(n_max=None) == [3, 2, 1] - assert wizard.best_clusters(n_max=2) == [3, 2] - assert wizard.best_clusters(n_max=1) == [3] - - assert wizard.most_similar_clusters(3) == [2, 1] - assert wizard.most_similar_clusters(2) == [3, 1] - assert wizard.most_similar_clusters(1) == [3, 2] - - assert wizard.most_similar_clusters(3, n_max=0) == [2, 1] - assert wizard.most_similar_clusters(3, n_max=None) == [2, 1] - assert wizard.most_similar_clusters(3, n_max=1) == [2] - assert wizard.most_similar_clusters(3, n_max=2) == [2, 1] - assert wizard.most_similar_clusters(3, n_max=10) == [2, 1] +def test_wizard_core(mock_wizard): + + w = mock_wizard + + assert w.best_clusters() == [3, 2, 1] + assert w.best_clusters(n_max=0) == [3, 2, 1] + assert w.best_clusters(n_max=None) == [3, 2, 1] + assert w.best_clusters(n_max=2) == [3, 2] + assert w.best_clusters(n_max=1) == [3] + + assert w.most_similar_clusters(3) == [2, 1] + assert w.most_similar_clusters(2) == [3, 1] + assert w.most_similar_clusters(1) == [3, 2] + + assert w.most_similar_clusters(3, n_max=0) == [2, 1] + assert w.most_similar_clusters(3, n_max=None) == [2, 1] + assert w.most_similar_clusters(3, n_max=1) == [2] + assert w.most_similar_clusters(3, n_max=2) == [2, 1] + assert w.most_similar_clusters(3, n_max=10) == [2, 1] From d05b2b18a39b2c6cd4109e333db102f3cbe22839 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 10:51:30 +0200 Subject: [PATCH 0236/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 16 +++++++++++++++- phy/cluster/manual/wizard.py | 17 ++++++++--------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 2c09a06a8..be1d45f3b 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -49,10 +49,14 @@ def test_wizard_group(): assert _wizard_group(None) is None -def test_wizard_core(mock_wizard): +def test_wizard_basic(mock_wizard): w = mock_wizard + assert w.cluster_ids == [1, 2, 3] + assert w.n_clusters == 3 + assert w.cluster_status(1) is None + assert w.best_clusters() == [3, 2, 1] assert w.best_clusters(n_max=0) == [3, 2, 1] assert w.best_clusters(n_max=None) == [3, 2, 1] @@ -68,3 +72,13 @@ def test_wizard_core(mock_wizard): assert w.most_similar_clusters(3, n_max=1) == [2] assert w.most_similar_clusters(3, n_max=2) == [2, 1] assert w.most_similar_clusters(3, n_max=10) == [2, 1] + + +def test_wizard_nav(mock_wizard): + w = mock_wizard + + assert w.selection == () + + +def test_wizard_strategy(mock_wizard): + pass diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index c95abe3b7..f7aec84a4 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -12,7 +12,7 @@ from six import string_types from ._history import History -from phy.utils._types import _as_list +from phy.utils._types import _as_list, _as_tuple from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def __init__(self): self.reset() def reset(self): - self._selection = [] + self._selection = () self._history = History(()) # Quality and status functions @@ -198,13 +198,12 @@ def most_similar_clusters(self, cluster, n_max=None, similarity=None): @property def selection(self): """Return the current cluster selection.""" - return self._selection + return _as_tuple(self._selection) @selection.setter def selection(self, value): - value = _as_list(value) clusters = self.cluster_ids - value = [cluster for cluster in value if cluster in clusters] + value = tuple(cluster for cluster in value if cluster in clusters) self._selection = value self.emit('select', self._selection) @@ -224,17 +223,17 @@ def match(self): def previous(self): sel = self._history.back() if sel: - self._selection = sel + self._selection = tuple(sel) def next(self): if not self._history.is_last(): # Go forward after a previous. sel = self._history.forward() if sel: - self._selection = sel + self._selection = tuple(sel) else: # Or compute the next selection. - self._selection = self._next(self._selection) + self._selection = tuple(self._next(self._selection)) self._history.add(self._selection) # Attach @@ -252,7 +251,7 @@ def on_request_undo_state(up): def on_cluster(up): if up.history == 'undo': # Revert to the given selection after an undo. - self._selection = up.undo_state[0]['selection'] + self._selection = tuple(up.undo_state[0]['selection']) else: # Or move to the next selection after any other action. self.next() From c3de46703c3601be8d39cbff797e9f216ffa08e9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 10:56:42 +0200 Subject: [PATCH 0237/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 21 +++++++++++++++++++++ phy/cluster/manual/wizard.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index be1d45f3b..546b2c32b 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -79,6 +79,27 @@ def test_wizard_nav(mock_wizard): assert w.selection == () + ### + w.selection = [] + assert w.selection == () + + assert w.best is None + assert w.match is None + + ### + w.selection = [1] + assert w.selection == (1,) + + assert w.best == 1 + assert w.match is None + + ### + w.selection = [1, 2, 4] + assert w.selection == (1, 2) + + assert w.best == 1 + assert w.match == 2 + def test_wizard_strategy(mock_wizard): pass diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index f7aec84a4..2d206a357 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -240,7 +240,7 @@ def next(self): #-------------------------------------------------------------------------- def attach(self, obj): - """Attach an actioner to the wizard.""" + """Attach an effector to the wizard.""" # Save the current selection when an action occurs. @obj.connect From f5bdb85d0979663535418eee172c34daad445d0b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 11:17:56 +0200 Subject: [PATCH 0238/1059] Fix docstring in History --- phy/cluster/manual/_history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/_history.py b/phy/cluster/manual/_history.py index 81a0a9fd0..813bf957a 100644 --- a/phy/cluster/manual/_history.py +++ b/phy/cluster/manual/_history.py @@ -94,7 +94,7 @@ def add(self, item): def back(self): """Go back in history if possible. - Return the current item after going back. + Return the undone item. """ if self._index <= 0: From ccf64d2a4dd3c0c58420082b81bd1debfa73e196 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 11:21:22 +0200 Subject: [PATCH 0239/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 16 ++++++++++++++++ phy/cluster/manual/wizard.py | 25 +++++++++++++++++-------- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 546b2c32b..cde65fd9e 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -100,6 +100,22 @@ def test_wizard_nav(mock_wizard): assert w.best == 1 assert w.match == 2 + ### + w.previous() + assert w.selection == (1,) + + for _ in range(2): + w.previous() + assert w.selection == (1,) + + ### + w.next() + assert w.selection == (1, 2) + + for _ in range(2): + w.next() + assert w.selection == (1, 2) + def test_wizard_strategy(mock_wizard): pass diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 2d206a357..c5987dadc 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -12,7 +12,7 @@ from six import string_types from ._history import History -from phy.utils._types import _as_list, _as_tuple +from phy.utils._types import _as_tuple from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -205,6 +205,7 @@ def selection(self, value): clusters = self.cluster_ids value = tuple(cluster for cluster in value if cluster in clusters) self._selection = value + self._history.add(self._selection) self.emit('select', self._selection) @property @@ -221,20 +222,28 @@ def match(self): #-------------------------------------------------------------------------- def previous(self): - sel = self._history.back() + if self._history.current_position <= 2: + return self._selection + self._history.back() + sel = self._history.current_item if sel: - self._selection = tuple(sel) + self._selection = sel # Not add this action to the history. + return self._selection def next(self): if not self._history.is_last(): # Go forward after a previous. - sel = self._history.forward() + self._history.forward() + sel = self._history.current_item if sel: - self._selection = tuple(sel) + self._selection = sel # Not add this action to the history. else: - # Or compute the next selection. - self._selection = tuple(self._next(self._selection)) - self._history.add(self._selection) + if self._next: + # Or compute the next selection. + self.selection = tuple(self._next(self._selection)) + else: + logger.debug("No strategy selected in the wizard.") + return self._selection # Attach #-------------------------------------------------------------------------- From 404428edee1fcffb8eeb98671a38907dcad62d4f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 12:01:11 +0200 Subject: [PATCH 0240/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 23 ++++++++++++++++++++++- phy/cluster/manual/wizard.py | 8 +++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index cde65fd9e..221e61176 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -118,4 +118,25 @@ def test_wizard_nav(mock_wizard): def test_wizard_strategy(mock_wizard): - pass + w = mock_wizard + + w.set_status_function(lambda x: None) + + def strategy(selection, best_clusters=None, status=None, similarity=None): + """Return the next best cluster.""" + if not selection: + return best_clusters[0] + assert selection[0] in best_clusters + i = best_clusters.index(selection[0]) + if i < len(best_clusters) - 1: + return best_clusters[i + 1] + + w.set_strategy_function(strategy) + assert w.selection == () + + for i in range(3, 0, -1): + w.next() + assert w.selection == (i,) + + w.next() + assert w.selection == (1,) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index c5987dadc..36870a2e6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -125,8 +125,8 @@ def set_strategy_function(self, func): def wrapped(sel): return func(self._selection, - cluster_ids=self._get_cluster_ids(), - quality=self._quality, + best_clusters=self.best_clusters(), + status=self._cluster_status, similarity=self._similarity, ) @@ -202,6 +202,8 @@ def selection(self): @selection.setter def selection(self, value): + if value is None: + return clusters = self.cluster_ids value = tuple(cluster for cluster in value if cluster in clusters) self._selection = value @@ -240,7 +242,7 @@ def next(self): else: if self._next: # Or compute the next selection. - self.selection = tuple(self._next(self._selection)) + self.selection = _as_tuple(self._next(self._selection)) else: logger.debug("No strategy selected in the wizard.") return self._selection From e1734f6861b405008358d6c42a401d221880bf15 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 12:09:55 +0200 Subject: [PATCH 0241/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 19 ++++++++++++------- phy/cluster/manual/wizard.py | 6 ++++++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 221e61176..78219c0f8 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -6,11 +6,9 @@ # Imports #------------------------------------------------------------------------------ -from numpy.testing import assert_array_equal as ae - from ..wizard import (_argsort, + _next_in_list, _best_clusters, - Wizard, _wizard_group, ) @@ -41,6 +39,15 @@ def test_best_clusters(): assert _best_clusters(l, quality, n_max=10) == [4, 3, 2, 1] +def test_next_in_list(): + l = [1, 2, 3] + assert _next_in_list(l, 0) == 0 + assert _next_in_list(l, 1) == 2 + assert _next_in_list(l, 2) == 3 + assert _next_in_list(l, 3) == 3 + assert _next_in_list(l, 4) == 4 + + def test_wizard_group(): assert _wizard_group('noise') == 'ignored' assert _wizard_group('mua') == 'ignored' @@ -126,10 +133,8 @@ def strategy(selection, best_clusters=None, status=None, similarity=None): """Return the next best cluster.""" if not selection: return best_clusters[0] - assert selection[0] in best_clusters - i = best_clusters.index(selection[0]) - if i < len(best_clusters) - 1: - return best_clusters[i + 1] + assert len(selection) == 1 + return _next_in_list(best_clusters, selection[0]) w.set_strategy_function(strategy) assert w.selection == () diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 36870a2e6..0cae67860 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -50,6 +50,12 @@ def _wizard_group(group): return None +def _next_in_list(l, item): + if l and item in l and l.index(item) < len(l) - 1: + return l[l.index(item) + 1] + return item + + #------------------------------------------------------------------------------ # Wizard #------------------------------------------------------------------------------ From 5cecf6267439c232dbfd32492623084b61e95cb2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 12:26:36 +0200 Subject: [PATCH 0242/1059] WIP: refactor sort logic in wizard --- phy/cluster/manual/tests/test_wizard.py | 9 ++++ phy/cluster/manual/wizard.py | 60 ++++++++++++++++--------- 2 files changed, 49 insertions(+), 20 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 78219c0f8..955d0e9c3 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -7,6 +7,7 @@ #------------------------------------------------------------------------------ from ..wizard import (_argsort, + _sort, _next_in_list, _best_clusters, _wizard_group, @@ -29,6 +30,14 @@ def test_argsort(): assert _argsort(l, reverse=False) == [1, 2, 3, 4] +def test_sort(): + clusters = [10, 0, 1, 30, 2, 20] + # N, i, g, N, N, N + status = lambda c: ('ignored', 'good')[c] if c <= 1 else None + + assert _sort(clusters, status=status) == [10, 30, 2, 20, 1, 0] + + def test_best_clusters(): quality = lambda c: c * .1 l = list(range(1, 5)) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 0cae67860..e05b970b3 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -56,6 +56,43 @@ def _next_in_list(l, item): return item +def _sort(clusters, status=None, mix_good_unsorted=False): + """Sort clusters according to their status.""" + assert status + _sort_map = {None: 0, 'good': 1, 'ignored': 2} + if mix_good_unsorted: + _sort_map['good'] = 0 + # NOTE: sorted is "stable": it doesn't change the order of elements + # that compare equal, which ensures that the order of clusters is kept + # among any given status. + key = lambda cluster: _sort_map.get(status(cluster), 0) + return sorted(clusters, key=key) + + +#------------------------------------------------------------------------------ +# Strategy functions +#------------------------------------------------------------------------------ + +def best_quality_strategy(selection, best_clusters=None, status=None, + similarity=None): + """Two cases depending on the number of selected clusters: + + * 1: move to the next best cluster + * 2: move to the next most similar pair + * 3+: do nothing + + """ + n = len(selection) + if n == 0 or n >= 3: + return selection + if n == 1: + return _next_in_list(best_clusters, selection[0]) + elif n == 2: + best, match = selection + value = similarity(best, match) + sims = [similarity(best, other) for other in best_clusters] + + #------------------------------------------------------------------------------ # Wizard #------------------------------------------------------------------------------ @@ -138,24 +175,6 @@ def wrapped(sel): self._next = wrapped - # Internal methods - #-------------------------------------------------------------------------- - - def _sort_nomix(self, cluster): - # Sort by unsorted first, good second, ignored last. - _sort_map = {None: 0, 'good': 1, 'ignored': 2} - return _sort_map.get(self._cluster_status(cluster), 0) - - def _sort_mix(self, cluster): - # Sort by unsorted/good first, ignored last. - _sort_map = {None: 0, 'good': 0, 'ignored': 2} - return _sort_map.get(self._cluster_status(cluster), 0) - - def _sort(self, clusters, mix_good_unsorted=False): - """Sort clusters according to their status.""" - key = self._sort_mix if mix_good_unsorted else self._sort_nomix - return sorted(clusters, key=key) - # Properties #-------------------------------------------------------------------------- @@ -183,7 +202,7 @@ def best_clusters(self, n_max=None, quality=None): """ quality = quality or self._quality best = _best_clusters(self.cluster_ids, quality, n_max=n_max) - return self._sort(best) + return _sort(best, status=self._cluster_status) def most_similar_clusters(self, cluster, n_max=None, similarity=None): """Return the `n_max` most similar clusters to a given cluster. @@ -196,7 +215,8 @@ def most_similar_clusters(self, cluster, n_max=None, similarity=None): for other in self.cluster_ids if other != cluster] clusters = _argsort(s, n_max=n_max) - return self._sort(clusters, mix_good_unsorted=True) + return _sort(clusters, status=self._cluster_status, + mix_good_unsorted=True) # Selection methods #-------------------------------------------------------------------------- From daaa5700d7869c960d6c9dbd831d81e0b7e9a3c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 12:57:34 +0200 Subject: [PATCH 0243/1059] WIP: wizard strategy --- phy/cluster/manual/tests/test_wizard.py | 21 ++++++++++++++++++++- phy/cluster/manual/wizard.py | 24 ++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 955d0e9c3..325e95ba6 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -11,6 +11,7 @@ _next_in_list, _best_clusters, _wizard_group, + best_quality_strategy, ) @@ -57,6 +58,24 @@ def test_next_in_list(): assert _next_in_list(l, 4) == 4 +def test_best_quality_strategy(): + best_clusters = range(5, -1, -1) + status = lambda c: ('ignored', 'ignored', 'good')[c] if c <= 2 else None + similarity = lambda c, d: c + d + + def _next(selection): + return best_quality_strategy(selection, + best_clusters=best_clusters, + status=status, + similarity=similarity) + + assert not _next(None) + assert not _next(()) + + for i in range(5, -1, -1): + assert _next(i) == max(0, i - 1) + + def test_wizard_group(): assert _wizard_group('noise') == 'ignored' assert _wizard_group('mua') == 'ignored' @@ -110,7 +129,7 @@ def test_wizard_nav(mock_wizard): assert w.match is None ### - w.selection = [1, 2, 4] + w.select([1, 2, 4]) assert w.selection == (1, 2) assert w.best == 1 diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index e05b970b3..ffa76384a 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -82,15 +82,32 @@ def best_quality_strategy(selection, best_clusters=None, status=None, * 3+: do nothing """ + if selection is None: + return selection + selection = _as_tuple(selection) n = len(selection) if n == 0 or n >= 3: return selection if n == 1: + # Sort the best clusters according to their status. + best_clusters = _sort(best_clusters, status=status) return _next_in_list(best_clusters, selection[0]) elif n == 2: best, match = selection value = similarity(best, match) - sims = [similarity(best, other) for other in best_clusters] + # Find the similarity of the best cluster with every other one. + sims = [(other, similarity(best, other)) for other in best_clusters] + # Only keep the less similar clusters. + sims = [(other, s) for (other, s) in sims if s <= value] + # Sort the pairs by decreasing similarity. + sims = sorted(sims, key=itemgetter(1), reverse=True) + # Just keep the cluster ids. + sims = [c for (c, v) in sims] + # Sort the candidates according to their status. + _sort(sims, status=status, mix_good_unsorted=True) + if not sims: + return selection + return [best, sims[0][0]] #------------------------------------------------------------------------------ @@ -228,7 +245,7 @@ def selection(self): @selection.setter def selection(self, value): - if value is None: + if value is None: # pragma: no cover return clusters = self.cluster_ids value = tuple(cluster for cluster in value if cluster in clusters) @@ -236,6 +253,9 @@ def selection(self, value): self._history.add(self._selection) self.emit('select', self._selection) + def select(self, cluster_ids): + self.selection = cluster_ids + @property def best(self): """Currently-selected best cluster.""" From bf60f06b91bc0d785cfce912adab090985565caf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 13:40:12 +0200 Subject: [PATCH 0244/1059] WIP: wizard strategy --- phy/cluster/manual/tests/test_wizard.py | 66 +++++++------- phy/cluster/manual/wizard.py | 109 +++++++++++------------- 2 files changed, 89 insertions(+), 86 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 325e95ba6..d99e7f16f 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -10,6 +10,7 @@ _sort, _next_in_list, _best_clusters, + _most_similar_clusters, _wizard_group, best_quality_strategy, ) @@ -37,6 +38,8 @@ def test_sort(): status = lambda c: ('ignored', 'good')[c] if c <= 1 else None assert _sort(clusters, status=status) == [10, 30, 2, 20, 1, 0] + assert _sort(clusters, status=status, remove_ignored=True) == \ + [10, 30, 2, 20, 1] def test_best_clusters(): @@ -49,6 +52,25 @@ def test_best_clusters(): assert _best_clusters(l, quality, n_max=10) == [4, 3, 2, 1] +def test_most_similar_clusters(): + cluster_ids = [0, 1, 2, 3] + # i, g, N, i + similarity = lambda c, d: c + d + status = lambda c: ('ignored', 'good', None, 'ignored')[c] + + def _similar(cluster): + return _most_similar_clusters(cluster, + cluster_ids=cluster_ids, + similarity=similarity, + status=status) + + assert not _similar(None) + assert not _similar(10) + assert _similar(0) == [2, 1] + assert _similar(1) == [2] + assert _similar(2) == [1] + + def test_next_in_list(): l = [1, 2, 3] assert _next_in_list(l, 0) == 0 @@ -59,22 +81,31 @@ def test_next_in_list(): def test_best_quality_strategy(): - best_clusters = range(5, -1, -1) + cluster_ids = [0, 1, 2, 3, 4] + # i, i, g, N, N + quality = lambda c: c status = lambda c: ('ignored', 'ignored', 'good')[c] if c <= 2 else None similarity = lambda c, d: c + d def _next(selection): return best_quality_strategy(selection, - best_clusters=best_clusters, + cluster_ids=cluster_ids, + quality=quality, status=status, similarity=similarity) assert not _next(None) assert not _next(()) - for i in range(5, -1, -1): + for i in range(4, -1, -1): assert _next(i) == max(0, i - 1) + assert _next((4, 3)) == (4, 2) + assert _next((4, 2)) == (4, 2) # 1 is ignored, so it does not appear. + + assert _next((3, 2)) == (3, 2) + assert _next((2, 3)) == (2, 3) + def test_wizard_group(): assert _wizard_group('noise') == 'ignored' @@ -84,31 +115,6 @@ def test_wizard_group(): assert _wizard_group(None) is None -def test_wizard_basic(mock_wizard): - - w = mock_wizard - - assert w.cluster_ids == [1, 2, 3] - assert w.n_clusters == 3 - assert w.cluster_status(1) is None - - assert w.best_clusters() == [3, 2, 1] - assert w.best_clusters(n_max=0) == [3, 2, 1] - assert w.best_clusters(n_max=None) == [3, 2, 1] - assert w.best_clusters(n_max=2) == [3, 2] - assert w.best_clusters(n_max=1) == [3] - - assert w.most_similar_clusters(3) == [2, 1] - assert w.most_similar_clusters(2) == [3, 1] - assert w.most_similar_clusters(1) == [3, 2] - - assert w.most_similar_clusters(3, n_max=0) == [2, 1] - assert w.most_similar_clusters(3, n_max=None) == [2, 1] - assert w.most_similar_clusters(3, n_max=1) == [2] - assert w.most_similar_clusters(3, n_max=2) == [2, 1] - assert w.most_similar_clusters(3, n_max=10) == [2, 1] - - def test_wizard_nav(mock_wizard): w = mock_wizard @@ -157,8 +163,10 @@ def test_wizard_strategy(mock_wizard): w.set_status_function(lambda x: None) - def strategy(selection, best_clusters=None, status=None, similarity=None): + def strategy(selection, cluster_ids=None, quality=None, + status=None, similarity=None): """Return the next best cluster.""" + best_clusters = _best_clusters(cluster_ids, quality) if not selection: return best_clusters[0] assert len(selection) == 1 diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index ffa76384a..c514fbdc6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -34,34 +34,18 @@ def _argsort(seq, reverse=True, n_max=None): return out[:n_max] -def _best_clusters(clusters, quality, n_max=None): - return _argsort([(cluster, quality(cluster)) - for cluster in clusters], n_max=n_max) - - -def _wizard_group(group): - # The group should be None, 'mua', 'noise', or 'good'. - assert group is None or isinstance(group, string_types) - group = group.lower() if group else group - if group in ('mua', 'noise'): - return 'ignored' - elif group == 'good': - return 'good' - return None - - def _next_in_list(l, item): if l and item in l and l.index(item) < len(l) - 1: return l[l.index(item) + 1] return item -def _sort(clusters, status=None, mix_good_unsorted=False): +def _sort(clusters, status=None, remove_ignored=False): """Sort clusters according to their status.""" assert status _sort_map = {None: 0, 'good': 1, 'ignored': 2} - if mix_good_unsorted: - _sort_map['good'] = 0 + if remove_ignored: + clusters = [c for c in clusters if status(c) != 'ignored'] # NOTE: sorted is "stable": it doesn't change the order of elements # that compare equal, which ensures that the order of clusters is kept # among any given status. @@ -69,11 +53,45 @@ def _sort(clusters, status=None, mix_good_unsorted=False): return sorted(clusters, key=key) +def _best_clusters(clusters, quality, n_max=None): + return _argsort([(cluster, quality(cluster)) + for cluster in clusters], n_max=n_max) + + +def _most_similar_clusters(cluster, cluster_ids=None, n_max=None, + similarity=None, status=None, less_than=None): + """Return the `n_max` most similar clusters to a given cluster.""" + if cluster not in cluster_ids: + return [] + s = [(other, similarity(cluster, other)) + for other in cluster_ids + if other != cluster] + # Only keep values less than a threshold. + if less_than: + s = [(c, v) for (c, v) in s if v <= less_than] + clusters = _argsort(s, n_max=n_max) + return _sort(clusters, status=status, remove_ignored=True) + + +def _wizard_group(group): + # The group should be None, 'mua', 'noise', or 'good'. + assert group is None or isinstance(group, string_types) + group = group.lower() if group else group + if group in ('mua', 'noise'): + return 'ignored' + elif group == 'good': + return 'good' + return None + + #------------------------------------------------------------------------------ # Strategy functions #------------------------------------------------------------------------------ -def best_quality_strategy(selection, best_clusters=None, status=None, +def best_quality_strategy(selection, + cluster_ids=None, + quality=None, + status=None, similarity=None): """Two cases depending on the number of selected clusters: @@ -89,25 +107,25 @@ def best_quality_strategy(selection, best_clusters=None, status=None, if n == 0 or n >= 3: return selection if n == 1: + best_clusters = _best_clusters(cluster_ids, quality) # Sort the best clusters according to their status. best_clusters = _sort(best_clusters, status=status) return _next_in_list(best_clusters, selection[0]) elif n == 2: best, match = selection value = similarity(best, match) - # Find the similarity of the best cluster with every other one. - sims = [(other, similarity(best, other)) for other in best_clusters] - # Only keep the less similar clusters. - sims = [(other, s) for (other, s) in sims if s <= value] - # Sort the pairs by decreasing similarity. - sims = sorted(sims, key=itemgetter(1), reverse=True) - # Just keep the cluster ids. - sims = [c for (c, v) in sims] - # Sort the candidates according to their status. - _sort(sims, status=status, mix_good_unsorted=True) - if not sims: + candidates = _most_similar_clusters(best, + cluster_ids=cluster_ids, + similarity=similarity, + status=status, + less_than=value) + if best in candidates: + candidates.remove(best) + if match in candidates: + candidates.remove(match) + if not candidates: return selection - return [best, sims[0][0]] + return (best, candidates[0]) #------------------------------------------------------------------------------ @@ -185,7 +203,8 @@ def set_strategy_function(self, func): def wrapped(sel): return func(self._selection, - best_clusters=self.best_clusters(), + cluster_ids=self._get_cluster_ids(), + quality=self._quality, status=self._cluster_status, similarity=self._similarity, ) @@ -211,30 +230,6 @@ def n_clusters(self): def cluster_status(self, cluster): return self._cluster_status(cluster) - def best_clusters(self, n_max=None, quality=None): - """Return the list of best clusters sorted by decreasing quality. - - The default quality function is the registered one. - - """ - quality = quality or self._quality - best = _best_clusters(self.cluster_ids, quality, n_max=n_max) - return _sort(best, status=self._cluster_status) - - def most_similar_clusters(self, cluster, n_max=None, similarity=None): - """Return the `n_max` most similar clusters to a given cluster. - - The default similarity function is the registered one. - - """ - similarity = similarity or self._similarity - s = [(other, similarity(cluster, other)) - for other in self.cluster_ids - if other != cluster] - clusters = _argsort(s, n_max=n_max) - return _sort(clusters, status=self._cluster_status, - mix_good_unsorted=True) - # Selection methods #-------------------------------------------------------------------------- From 4e50c3733d8336dce14d34870201561013aece1b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 13:49:45 +0200 Subject: [PATCH 0245/1059] WIP: wizard strategy --- phy/cluster/manual/tests/test_wizard.py | 8 ++++++++ phy/cluster/manual/wizard.py | 22 +++++++++++++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index d99e7f16f..a0415409b 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -117,6 +117,8 @@ def test_wizard_group(): def test_wizard_nav(mock_wizard): w = mock_wizard + assert w.cluster_ids == [1, 2, 3] + assert w.n_clusters == 3 assert w.selection == () @@ -181,3 +183,9 @@ def strategy(selection, cluster_ids=None, quality=None, w.next() assert w.selection == (1,) + + +def test_wizard_groups(wizard_with_groups): + w = wizard_with_groups + w.next() + print(w.selection) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index c514fbdc6..b321ab9f6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -119,7 +119,7 @@ def best_quality_strategy(selection, similarity=similarity, status=status, less_than=value) - if best in candidates: + if best in candidates: # pragma: no cover candidates.remove(best) if match in candidates: candidates.remove(match) @@ -224,12 +224,6 @@ def n_clusters(self): """Total number of clusters.""" return len(self.cluster_ids) - # Core methods - #-------------------------------------------------------------------------- - - def cluster_status(self, cluster): - return self._cluster_status(cluster) - # Selection methods #-------------------------------------------------------------------------- @@ -261,6 +255,20 @@ def match(self): """Currently-selected closest match.""" return self._selection[1] if len(self._selection) >= 2 else None + def pin(self): + best = self.best + if best is None: + return + candidates = _most_similar_clusters(best, + cluster_ids=self.cluster_ids, + similarity=self._similarity, + status=self._cluster_status) + if not candidates: + return + if best in candidates: + candidates.remove(best) + self.select([self.best, candidates[0]]) + # Navigation #-------------------------------------------------------------------------- From 3754fdd0cd07675cb771493bd64acedd34e0b279 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 16:33:19 +0200 Subject: [PATCH 0246/1059] WIP: wizard strategy --- phy/cluster/manual/tests/test_wizard.py | 9 +++++---- phy/cluster/manual/wizard.py | 15 +++++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index a0415409b..68cef6ee6 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -95,7 +95,7 @@ def _next(selection): similarity=similarity) assert not _next(None) - assert not _next(()) + assert _next(()) == 4 for i in range(4, -1, -1): assert _next(i) == max(0, i - 1) @@ -185,7 +185,8 @@ def strategy(selection, cluster_ids=None, quality=None, assert w.selection == (1,) -def test_wizard_groups(wizard_with_groups): +def test_wizard_strategy_groups(wizard_with_groups): w = wizard_with_groups - w.next() - print(w.selection) + + for i in range(105, 100, -1): + assert w.next() == (i,) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index b321ab9f6..50b939881 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -104,13 +104,16 @@ def best_quality_strategy(selection, return selection selection = _as_tuple(selection) n = len(selection) - if n == 0 or n >= 3: - return selection - if n == 1: + if n <= 1: best_clusters = _best_clusters(cluster_ids, quality) # Sort the best clusters according to their status. best_clusters = _sort(best_clusters, status=status) - return _next_in_list(best_clusters, selection[0]) + if selection: + return _next_in_list(best_clusters, selection[0]) + elif best_clusters: + return best_clusters[0] + else: + return selection elif n == 2: best, match = selection value = similarity(best, match) @@ -269,6 +272,10 @@ def pin(self): candidates.remove(best) self.select([self.best, candidates[0]]) + def unpin(self): + if len(self._selection) == 2: + self.selection = self.selection[0] + # Navigation #-------------------------------------------------------------------------- From 9dd08b2df354a733da2a2d11184f780691bcbb47 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 20:29:04 +0200 Subject: [PATCH 0247/1059] WIP: refactor wizard --- phy/cluster/manual/tests/test_wizard.py | 24 ++++++++-- phy/cluster/manual/wizard.py | 61 ++++++++++++++----------- 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 68cef6ee6..699affe12 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -95,10 +95,10 @@ def _next(selection): similarity=similarity) assert not _next(None) - assert _next(()) == 4 + assert _next(()) == (4,) for i in range(4, -1, -1): - assert _next(i) == max(0, i - 1) + assert _next((i,)) == (max(0, i - 1),) assert _next((4, 3)) == (4, 2) assert _next((4, 2)) == (4, 2) # 1 is ignored, so it does not appear. @@ -170,9 +170,9 @@ def strategy(selection, cluster_ids=None, quality=None, """Return the next best cluster.""" best_clusters = _best_clusters(cluster_ids, quality) if not selection: - return best_clusters[0] + return (best_clusters[0],) assert len(selection) == 1 - return _next_in_list(best_clusters, selection[0]) + return (_next_in_list(best_clusters, selection[0]),) w.set_strategy_function(strategy) assert w.selection == () @@ -187,6 +187,22 @@ def strategy(selection, cluster_ids=None, quality=None, def test_wizard_strategy_groups(wizard_with_groups): w = wizard_with_groups + assert 101 in w.cluster_ids + assert 105 in w.cluster_ids for i in range(105, 100, -1): assert w.next() == (i,) + + w.select([105]) + + w.pin() + assert w.selection == (105, 104) + + w.next() + assert w.selection == (105, 103) + + w.previous() + assert w.selection == (105, 104) + + w.unpin() + assert w.selection == (105,) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 50b939881..371fda945 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- - """Wizard.""" #------------------------------------------------------------------------------ # Imports + #------------------------------------------------------------------------------ import logging @@ -12,7 +12,6 @@ from six import string_types from ._history import History -from phy.utils._types import _as_tuple from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -65,7 +64,7 @@ def _most_similar_clusters(cluster, cluster_ids=None, n_max=None, return [] s = [(other, similarity(cluster, other)) for other in cluster_ids - if other != cluster] + if other != cluster and status(other) != 'ignored'] # Only keep values less than a threshold. if less_than: s = [(c, v) for (c, v) in s if v <= less_than] @@ -102,16 +101,16 @@ def best_quality_strategy(selection, """ if selection is None: return selection - selection = _as_tuple(selection) + selection = tuple(selection) n = len(selection) if n <= 1: best_clusters = _best_clusters(cluster_ids, quality) # Sort the best clusters according to their status. best_clusters = _sort(best_clusters, status=status) if selection: - return _next_in_list(best_clusters, selection[0]) + return (_next_in_list(best_clusters, selection[0]),) elif best_clusters: - return best_clusters[0] + return (best_clusters[0],) else: return selection elif n == 2: @@ -230,23 +229,28 @@ def n_clusters(self): # Selection methods #-------------------------------------------------------------------------- + def _selection_changed(self, sel, add_to_history=True): + if sel is None: # pragma: no cover + return + assert hasattr(sel, '__len__') + clusters = self.cluster_ids + sel = tuple(cluster for cluster in sel if cluster in clusters) + self._selection = sel + if add_to_history: + self._history.add(self._selection) + self.emit('select', self._selection) + + def select(self, cluster_ids): + self._selection_changed(cluster_ids) + @property def selection(self): """Return the current cluster selection.""" - return _as_tuple(self._selection) + return self._selection @selection.setter def selection(self, value): - if value is None: # pragma: no cover - return - clusters = self.cluster_ids - value = tuple(cluster for cluster in value if cluster in clusters) - self._selection = value - self._history.add(self._selection) - self.emit('select', self._selection) - - def select(self, cluster_ids): - self.selection = cluster_ids + self.select(value) @property def best(self): @@ -270,35 +274,37 @@ def pin(self): return if best in candidates: candidates.remove(best) - self.select([self.best, candidates[0]]) + self.select((self.best, candidates[0])) def unpin(self): if len(self._selection) == 2: - self.selection = self.selection[0] + self.selection = (self.selection[0],) # Navigation #-------------------------------------------------------------------------- + def _set_selection_from_history(self): + sel = self._history.current_item + if not sel: + return + self._selection_changed(sel, add_to_history=False) + def previous(self): if self._history.current_position <= 2: return self._selection self._history.back() - sel = self._history.current_item - if sel: - self._selection = sel # Not add this action to the history. + self._set_selection_from_history() return self._selection def next(self): if not self._history.is_last(): # Go forward after a previous. self._history.forward() - sel = self._history.current_item - if sel: - self._selection = sel # Not add this action to the history. + self._set_selection_from_history() else: if self._next: # Or compute the next selection. - self.selection = _as_tuple(self._next(self._selection)) + self.selection = self._next(self._selection) else: logger.debug("No strategy selected in the wizard.") return self._selection @@ -316,6 +322,9 @@ def on_request_undo_state(up): @obj.connect def on_cluster(up): + if not up.history: + # Reset the history after every change. + self.reset() if up.history == 'undo': # Revert to the given selection after an undo. self._selection = tuple(up.undo_state[0]['selection']) From 1b7156dadaef9a03c9b5387a5cf41e638945f4bb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 21:40:25 +0200 Subject: [PATCH 0248/1059] WIP: wizard tests --- phy/cluster/manual/tests/conftest.py | 4 +++- phy/cluster/manual/tests/test_wizard.py | 11 ++++++++++- phy/cluster/manual/wizard.py | 7 +++---- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index e6ad455fe..e42ee686b 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,7 +8,7 @@ from pytest import yield_fixture -from ..wizard import Wizard, _wizard_group +from ..wizard import Wizard, _wizard_group, best_quality_strategy from .._utils import create_cluster_meta @@ -67,4 +67,6 @@ def status(cluster): group = cluster_groups.get(cluster, None) return _wizard_group(group) + mock_wizard.set_strategy_function(best_quality_strategy) + yield mock_wizard diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 699affe12..afad8aaf1 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -160,7 +160,7 @@ def test_wizard_nav(mock_wizard): assert w.selection == (1, 2) -def test_wizard_strategy(mock_wizard): +def test_wizard_strategy_1(mock_wizard): w = mock_wizard w.set_status_function(lambda x: None) @@ -190,6 +190,9 @@ def test_wizard_strategy_groups(wizard_with_groups): assert 101 in w.cluster_ids assert 105 in w.cluster_ids + w.pin() + assert w.selection == () + for i in range(105, 100, -1): assert w.next() == (i,) @@ -206,3 +209,9 @@ def test_wizard_strategy_groups(wizard_with_groups): w.unpin() assert w.selection == (105,) + + @w.set_status_function + def status(cluster): + return 'ignored' + + assert w.pin() is None diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 371fda945..02d9ec64a 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -111,7 +111,7 @@ def best_quality_strategy(selection, return (_next_in_list(best_clusters, selection[0]),) elif best_clusters: return (best_clusters[0],) - else: + else: # pragma: no cover return selection elif n == 2: best, match = selection @@ -270,10 +270,9 @@ def pin(self): cluster_ids=self.cluster_ids, similarity=self._similarity, status=self._cluster_status) + assert best not in candidates if not candidates: return - if best in candidates: - candidates.remove(best) self.select((self.best, candidates[0])) def unpin(self): @@ -285,7 +284,7 @@ def unpin(self): def _set_selection_from_history(self): sel = self._history.current_item - if not sel: + if not sel: # pragma: no cover return self._selection_changed(sel, add_to_history=False) From 15c20744cbe3177d65b93aac3088e78c4a7f9afb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 22:21:56 +0200 Subject: [PATCH 0249/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 32 +++++++++++++++++++++++++ phy/cluster/manual/wizard.py | 3 +++ 2 files changed, 35 insertions(+) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index afad8aaf1..b52b1ce3d 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -14,6 +14,8 @@ _wizard_group, best_quality_strategy, ) +from .._utils import UpdateInfo +from phy.utils import EventEmitter #------------------------------------------------------------------------------ @@ -215,3 +217,33 @@ def status(cluster): return 'ignored' assert w.pin() is None + + +def test_wizard_attach(mock_wizard): + w = mock_wizard + + def strategy(selection, **kwargs): + if not selection: + return (3,) + return (1 + (selection[0] % 3),) + + w.set_strategy_function(strategy) + w.select([3]) + + obj = EventEmitter() + w.attach(obj) + + def _action(**kwargs): + up = UpdateInfo(**kwargs) + return obj.emit('cluster', up) + + _action(description='merge', added=[3]) + assert w.selection == (3,) + + _action(history='undo', undo_state=[{'selection': w.selection}]) + assert w.selection == (3,) + + _action(history='redo') + assert w.selection == (1,) + + assert obj.emit('request_undo_state', {}) == [{'selection': (w.selection)}] diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 02d9ec64a..19b385c5b 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -263,6 +263,7 @@ def match(self): return self._selection[1] if len(self._selection) >= 2 else None def pin(self): + """Select the cluster the most similar cluster to the current best.""" best = self.best if best is None: return @@ -327,6 +328,8 @@ def on_cluster(up): if up.history == 'undo': # Revert to the given selection after an undo. self._selection = tuple(up.undo_state[0]['selection']) + elif up.added: + self.select((up.added[0],)) else: # Or move to the next selection after any other action. self.next() From 248ee089ba977da0ffbc1134ff493d73fe6d66ac Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 22:41:46 +0200 Subject: [PATCH 0250/1059] WIP: wizard tests --- phy/cluster/manual/tests/test_wizard.py | 7 +++++++ phy/cluster/manual/wizard.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index b52b1ce3d..d8f9b7005 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -247,3 +247,10 @@ def _action(**kwargs): assert w.selection == (1,) assert obj.emit('request_undo_state', {}) == [{'selection': (w.selection)}] + + w.select((1, 2)) + _action(description='metadata_group', metadata_changed=[1]) + assert w.selection == (1, 3) + + _action(description='metadata_group', metadata_changed=[3]) + assert w.selection == (2,) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 19b385c5b..69fcf1eb4 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -156,10 +156,10 @@ def __init__(self): self._get_cluster_ids = None self._cluster_status = lambda cluster: None self._next = None # Strategy function. + self._selection = () self.reset() def reset(self): - self._selection = () self._history = History(()) # Quality and status functions @@ -330,6 +330,12 @@ def on_cluster(up): self._selection = tuple(up.undo_state[0]['selection']) elif up.added: self.select((up.added[0],)) + elif up.description == 'metadata_group': + cluster = up.metadata_changed[0] + if cluster == self.best: + self.pin() + else: + self.next() else: # Or move to the next selection after any other action. self.next() From 096d4d7f466e504e27d23bb89e906cea85d91ac2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 7 Oct 2015 23:24:53 +0200 Subject: [PATCH 0251/1059] WIP: wizard --- phy/cluster/manual/tests/test_wizard.py | 7 +++++-- phy/cluster/manual/wizard.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index d8f9b7005..317e47dc4 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -187,8 +187,11 @@ def strategy(selection, cluster_ids=None, quality=None, assert w.selection == (1,) -def test_wizard_strategy_groups(wizard_with_groups): - w = wizard_with_groups +def test_wizard_strategy_groups(mock_wizard, cluster_groups): + w = mock_wizard + + w.attach_cluster_groups(cluster_groups) + assert 101 in w.cluster_ids assert 105 in w.cluster_ids diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 69fcf1eb4..a228aabcc 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -339,3 +339,17 @@ def on_cluster(up): else: # Or move to the next selection after any other action. self.next() + + def attach_cluster_groups(self, cluster_groups): + + def get_cluster_ids(): + return sorted(cluster_groups.keys()) + + self.set_cluster_ids_function(get_cluster_ids) + + @self.set_status_function + def status(cluster): + group = cluster_groups.get(cluster, None) + return _wizard_group(group) + + self.set_strategy_function(best_quality_strategy) From f800d3e270a36318092f98c974ba49dfc7111b8a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 09:47:31 +0200 Subject: [PATCH 0252/1059] Flakify --- phy/cluster/manual/tests/test_wizard.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 317e47dc4..6b17a14be 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -36,7 +36,7 @@ def test_argsort(): def test_sort(): clusters = [10, 0, 1, 30, 2, 20] - # N, i, g, N, N, N + # N, i, g, N, N, N status = lambda c: ('ignored', 'good')[c] if c <= 1 else None assert _sort(clusters, status=status) == [10, 30, 2, 20, 1, 0] @@ -56,7 +56,7 @@ def test_best_clusters(): def test_most_similar_clusters(): cluster_ids = [0, 1, 2, 3] - # i, g, N, i + # i, g, N, i similarity = lambda c, d: c + d status = lambda c: ('ignored', 'good', None, 'ignored')[c] @@ -84,7 +84,7 @@ def test_next_in_list(): def test_best_quality_strategy(): cluster_ids = [0, 1, 2, 3, 4] - # i, i, g, N, N + # i, i, g, N, N quality = lambda c: c status = lambda c: ('ignored', 'ignored', 'good')[c] if c <= 2 else None similarity = lambda c, d: c + d From 0091c7aa5c36cff0c1f01d13123257f3f68ec4c7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 10:26:27 +0200 Subject: [PATCH 0253/1059] WIP: refactoring wizard --- phy/cluster/manual/tests/test_wizard.py | 37 ++++++++++++++++--- phy/cluster/manual/wizard.py | 49 ++++++++++++++++++++++--- 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 6b17a14be..cfddc05cd 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -12,7 +12,8 @@ _best_clusters, _most_similar_clusters, _wizard_group, - best_quality_strategy, + _best_quality_strategy, + _best_similarity_strategy, ) from .._utils import UpdateInfo from phy.utils import EventEmitter @@ -90,11 +91,11 @@ def test_best_quality_strategy(): similarity = lambda c, d: c + d def _next(selection): - return best_quality_strategy(selection, - cluster_ids=cluster_ids, - quality=quality, - status=status, - similarity=similarity) + return _best_quality_strategy(selection, + cluster_ids=cluster_ids, + quality=quality, + status=status, + similarity=similarity) assert not _next(None) assert _next(()) == (4,) @@ -109,6 +110,30 @@ def _next(selection): assert _next((2, 3)) == (2, 3) +def test_best_similarity_strategy(): + cluster_ids = [0, 1, 2, 3, 4] + # i, i, g, N, N + quality = lambda c: c + status = lambda c: ('ignored', 'ignored', 'good')[c] if c <= 2 else None + similarity = lambda c, d: c * 1.1 + d + + def _next(selection): + return _best_similarity_strategy(selection, + cluster_ids=cluster_ids, + quality=quality, + status=status, + similarity=similarity) + + assert not _next(None) + assert _next(()) == (4, 3) + + assert _next((4, 3)) == (4, 2) + assert _next((4, 2)) == (3, 2) + + assert _next((3, 2)) == (3, 2) + assert _next((2, 3)) == (2, 3) + + def test_wizard_group(): assert _wizard_group('noise') == 'ignored' assert _wizard_group('mua') == 'ignored' diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index a228aabcc..6b435c338 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -6,6 +6,7 @@ #------------------------------------------------------------------------------ +from itertools import product import logging from operator import itemgetter @@ -87,11 +88,11 @@ def _wizard_group(group): # Strategy functions #------------------------------------------------------------------------------ -def best_quality_strategy(selection, - cluster_ids=None, - quality=None, - status=None, - similarity=None): +def _best_quality_strategy(selection, + cluster_ids=None, + quality=None, + status=None, + similarity=None): """Two cases depending on the number of selected clusters: * 1: move to the next best cluster @@ -130,6 +131,42 @@ def best_quality_strategy(selection, return (best, candidates[0]) +def _best_similarity_strategy(selection, + cluster_ids=None, + quality=None, + status=None, + similarity=None): + if selection is None: + return selection + selection = tuple(selection) + n = len(selection) + if n >= 2: + best, match = selection + value = similarity(best, match) + else: + best, match = None, None + value = None + # We remove the current pair, the (x, x) pairs, and we ensure that + # (d, c) doesn't appear if (c, d) does. We choose the pair where + # the first cluster of the pair has the highest quality. + # Finally we remove the ignored clusters. + s = [((c, d), similarity(c, d)) + for c, d in product(cluster_ids, repeat=2) + if c != d and (c, d) != (best, match) + and quality(c) >= quality(d) + and status(c) != 'ignored' + and status(d) != 'ignored' + ] + + if value is not None: + s = [((c, d), v) for ((c, d), v) in s if v <= value] + pairs = _argsort(s) + if pairs: + return pairs[0] + else: + return selection + + #------------------------------------------------------------------------------ # Wizard #------------------------------------------------------------------------------ @@ -352,4 +389,4 @@ def status(cluster): group = cluster_groups.get(cluster, None) return _wizard_group(group) - self.set_strategy_function(best_quality_strategy) + self.set_strategy_function(_best_quality_strategy) From 7c90e756712f2355ea2cc0eee3a42578201f6885 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 10:53:29 +0200 Subject: [PATCH 0254/1059] WIP: move wizard attach functions --- phy/cluster/manual/tests/test_wizard.py | 97 ---------------------- phy/cluster/manual/wizard.py | 103 ++++++------------------ 2 files changed, 25 insertions(+), 175 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index cfddc05cd..8d24c6a35 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -185,100 +185,3 @@ def test_wizard_nav(mock_wizard): for _ in range(2): w.next() assert w.selection == (1, 2) - - -def test_wizard_strategy_1(mock_wizard): - w = mock_wizard - - w.set_status_function(lambda x: None) - - def strategy(selection, cluster_ids=None, quality=None, - status=None, similarity=None): - """Return the next best cluster.""" - best_clusters = _best_clusters(cluster_ids, quality) - if not selection: - return (best_clusters[0],) - assert len(selection) == 1 - return (_next_in_list(best_clusters, selection[0]),) - - w.set_strategy_function(strategy) - assert w.selection == () - - for i in range(3, 0, -1): - w.next() - assert w.selection == (i,) - - w.next() - assert w.selection == (1,) - - -def test_wizard_strategy_groups(mock_wizard, cluster_groups): - w = mock_wizard - - w.attach_cluster_groups(cluster_groups) - - assert 101 in w.cluster_ids - assert 105 in w.cluster_ids - - w.pin() - assert w.selection == () - - for i in range(105, 100, -1): - assert w.next() == (i,) - - w.select([105]) - - w.pin() - assert w.selection == (105, 104) - - w.next() - assert w.selection == (105, 103) - - w.previous() - assert w.selection == (105, 104) - - w.unpin() - assert w.selection == (105,) - - @w.set_status_function - def status(cluster): - return 'ignored' - - assert w.pin() is None - - -def test_wizard_attach(mock_wizard): - w = mock_wizard - - def strategy(selection, **kwargs): - if not selection: - return (3,) - return (1 + (selection[0] % 3),) - - w.set_strategy_function(strategy) - w.select([3]) - - obj = EventEmitter() - w.attach(obj) - - def _action(**kwargs): - up = UpdateInfo(**kwargs) - return obj.emit('cluster', up) - - _action(description='merge', added=[3]) - assert w.selection == (3,) - - _action(history='undo', undo_state=[{'selection': w.selection}]) - assert w.selection == (3,) - - _action(history='redo') - assert w.selection == (1,) - - assert obj.emit('request_undo_state', {}) == [{'selection': (w.selection)}] - - w.select((1, 2)) - _action(description='metadata_group', metadata_changed=[1]) - assert w.selection == (1, 3) - - _action(description='metadata_group', metadata_changed=[3]) - assert w.selection == (2,) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 6b435c338..0034a2669 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -192,7 +192,6 @@ def __init__(self): self._quality = None self._get_cluster_ids = None self._cluster_status = lambda cluster: None - self._next = None # Strategy function. self._selection = () self.reset() @@ -234,22 +233,6 @@ def set_quality_function(self, func): self._quality = func return func - def set_strategy_function(self, func): - """Register a function returning a new selection after the current - selection, as a function of the quality and similarity of the clusters. - """ - # func(selection, cluster_ids=None, quality=None, similarity=None) - - def wrapped(sel): - return func(self._selection, - cluster_ids=self._get_cluster_ids(), - quality=self._quality, - status=self._cluster_status, - similarity=self._similarity, - ) - - self._next = wrapped - # Properties #-------------------------------------------------------------------------- @@ -266,20 +249,18 @@ def n_clusters(self): # Selection methods #-------------------------------------------------------------------------- - def _selection_changed(self, sel, add_to_history=True): - if sel is None: # pragma: no cover + def select(self, cluster_ids, add_to_history=True): + if cluster_ids is None: # pragma: no cover return - assert hasattr(sel, '__len__') + assert hasattr(cluster_ids, '__len__') clusters = self.cluster_ids - sel = tuple(cluster for cluster in sel if cluster in clusters) - self._selection = sel + cluster_ids = tuple(cluster for cluster in cluster_ids + if cluster in clusters) + self._selection = cluster_ids if add_to_history: self._history.add(self._selection) self.emit('select', self._selection) - def select(self, cluster_ids): - self._selection_changed(cluster_ids) - @property def selection(self): """Return the current cluster selection.""" @@ -321,10 +302,10 @@ def unpin(self): #-------------------------------------------------------------------------- def _set_selection_from_history(self): - sel = self._history.current_item - if not sel: # pragma: no cover + cluster_ids = self._history.current_item + if not cluster_ids: # pragma: no cover return - self._selection_changed(sel, add_to_history=False) + self.select(cluster_ids, add_to_history=False) def previous(self): if self._history.current_position <= 2: @@ -338,55 +319,21 @@ def next(self): # Go forward after a previous. self._history.forward() self._set_selection_from_history() - else: - if self._next: - # Or compute the next selection. - self.selection = self._next(self._selection) - else: - logger.debug("No strategy selected in the wizard.") - return self._selection - # Attach - #-------------------------------------------------------------------------- + def next_by_quality(self): + self.selection = _best_quality_strategy( + self._selection, + cluster_ids=self._get_cluster_ids(), + quality=self._quality, + status=self._cluster_status, + similarity=self._similarity) + return self._selection - def attach(self, obj): - """Attach an effector to the wizard.""" - - # Save the current selection when an action occurs. - @obj.connect - def on_request_undo_state(up): - return {'selection': self._selection} - - @obj.connect - def on_cluster(up): - if not up.history: - # Reset the history after every change. - self.reset() - if up.history == 'undo': - # Revert to the given selection after an undo. - self._selection = tuple(up.undo_state[0]['selection']) - elif up.added: - self.select((up.added[0],)) - elif up.description == 'metadata_group': - cluster = up.metadata_changed[0] - if cluster == self.best: - self.pin() - else: - self.next() - else: - # Or move to the next selection after any other action. - self.next() - - def attach_cluster_groups(self, cluster_groups): - - def get_cluster_ids(): - return sorted(cluster_groups.keys()) - - self.set_cluster_ids_function(get_cluster_ids) - - @self.set_status_function - def status(cluster): - group = cluster_groups.get(cluster, None) - return _wizard_group(group) - - self.set_strategy_function(_best_quality_strategy) + def next_by_similarity(self): + self.selection = _best_similarity_strategy( + self._selection, + cluster_ids=self._get_cluster_ids(), + quality=self._quality, + status=self._cluster_status, + similarity=self._similarity) + return self._selection From 8fed91a1be287984bfacfedb6b164cf12a5d8944 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 11:17:44 +0200 Subject: [PATCH 0255/1059] WIP: refactor wizard tests --- phy/cluster/manual/tests/test_wizard.py | 112 +++++++++++------------- phy/cluster/manual/wizard.py | 6 +- 2 files changed, 54 insertions(+), 64 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 8d24c6a35..cfebbc87e 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -7,7 +7,7 @@ #------------------------------------------------------------------------------ from ..wizard import (_argsort, - _sort, + _sort_by_status, _next_in_list, _best_clusters, _most_similar_clusters, @@ -15,12 +15,10 @@ _best_quality_strategy, _best_similarity_strategy, ) -from .._utils import UpdateInfo -from phy.utils import EventEmitter #------------------------------------------------------------------------------ -# Test wizard +# Test utility functions #------------------------------------------------------------------------------ def test_argsort(): @@ -35,18 +33,23 @@ def test_argsort(): assert _argsort(l, reverse=False) == [1, 2, 3, 4] -def test_sort(): - clusters = [10, 0, 1, 30, 2, 20] - # N, i, g, N, N, N - status = lambda c: ('ignored', 'good')[c] if c <= 1 else None +def test_sort_by_status(status): + cluster_ids = [10, 0, 1, 30, 2, 20] + assert _sort_by_status(cluster_ids, status=status) == [10, 30, 2, 20, 1, 0] + assert _sort_by_status(cluster_ids, status=status, + remove_ignored=True) == [10, 30, 2, 20, 1] + - assert _sort(clusters, status=status) == [10, 30, 2, 20, 1, 0] - assert _sort(clusters, status=status, remove_ignored=True) == \ - [10, 30, 2, 20, 1] +def test_next_in_list(): + l = [1, 2, 3] + assert _next_in_list(l, 0) == 0 + assert _next_in_list(l, 1) == 2 + assert _next_in_list(l, 2) == 3 + assert _next_in_list(l, 3) == 3 + assert _next_in_list(l, 4) == 4 -def test_best_clusters(): - quality = lambda c: c * .1 +def test_best_clusters(quality): l = list(range(1, 5)) assert _best_clusters(l, quality) == [4, 3, 2, 1] assert _best_clusters(l, quality, n_max=0) == [4, 3, 2, 1] @@ -55,11 +58,7 @@ def test_best_clusters(): assert _best_clusters(l, quality, n_max=10) == [4, 3, 2, 1] -def test_most_similar_clusters(): - cluster_ids = [0, 1, 2, 3] - # i, g, N, i - similarity = lambda c, d: c + d - status = lambda c: ('ignored', 'good', None, 'ignored')[c] +def test_most_similar_clusters(cluster_ids, similarity, status): def _similar(cluster): return _most_similar_clusters(cluster, @@ -68,27 +67,17 @@ def _similar(cluster): status=status) assert not _similar(None) - assert not _similar(10) - assert _similar(0) == [2, 1] - assert _similar(1) == [2] - assert _similar(2) == [1] + assert not _similar(100) + assert _similar(0) == [30, 20, 10, 2, 1] + assert _similar(1) == [30, 20, 10, 2] + assert _similar(2) == [30, 20, 10, 1] -def test_next_in_list(): - l = [1, 2, 3] - assert _next_in_list(l, 0) == 0 - assert _next_in_list(l, 1) == 2 - assert _next_in_list(l, 2) == 3 - assert _next_in_list(l, 3) == 3 - assert _next_in_list(l, 4) == 4 - +#------------------------------------------------------------------------------ +# Test strategy functions +#------------------------------------------------------------------------------ -def test_best_quality_strategy(): - cluster_ids = [0, 1, 2, 3, 4] - # i, i, g, N, N - quality = lambda c: c - status = lambda c: ('ignored', 'ignored', 'good')[c] if c <= 2 else None - similarity = lambda c, d: c + d +def test_best_quality_strategy(cluster_ids, quality, status, similarity): def _next(selection): return _best_quality_strategy(selection, @@ -98,24 +87,17 @@ def _next(selection): similarity=similarity) assert not _next(None) - assert _next(()) == (4,) - - for i in range(4, -1, -1): - assert _next((i,)) == (max(0, i - 1),) - - assert _next((4, 3)) == (4, 2) - assert _next((4, 2)) == (4, 2) # 1 is ignored, so it does not appear. + assert _next(()) == (30,) + assert _next((30,)) == (20,) + assert _next((20,)) == (10,) + assert _next((10,)) == (2,) - assert _next((3, 2)) == (3, 2) - assert _next((2, 3)) == (2, 3) + assert _next((30, 20)) == (30, 10) + assert _next((10, 2)) == (10, 1) + assert _next((10, 1)) == (10, 1) # 0 is ignored, so it does not appear. -def test_best_similarity_strategy(): - cluster_ids = [0, 1, 2, 3, 4] - # i, i, g, N, N - quality = lambda c: c - status = lambda c: ('ignored', 'ignored', 'good')[c] if c <= 2 else None - similarity = lambda c, d: c * 1.1 + d +def test_best_similarity_strategy(cluster_ids, quality, status, similarity): def _next(selection): return _best_similarity_strategy(selection, @@ -125,14 +107,18 @@ def _next(selection): similarity=similarity) assert not _next(None) - assert _next(()) == (4, 3) + assert _next(()) == (30, 20) + assert _next((30, 20)) == (30, 10) + assert _next((30, 10)) == (30, 2) + assert _next((20, 10)) == (20, 2) + assert _next((10, 2)) == (10, 1) + assert _next((10, 1)) == (2, 1) + assert _next((2, 1)) == (2, 1) # 0 is ignored, so it does not appear. - assert _next((4, 3)) == (4, 2) - assert _next((4, 2)) == (3, 2) - - assert _next((3, 2)) == (3, 2) - assert _next((2, 3)) == (2, 3) +#------------------------------------------------------------------------------ +# Test wizard +#------------------------------------------------------------------------------ def test_wizard_group(): assert _wizard_group('noise') == 'ignored' @@ -142,10 +128,10 @@ def test_wizard_group(): assert _wizard_group(None) is None -def test_wizard_nav(mock_wizard): - w = mock_wizard - assert w.cluster_ids == [1, 2, 3] - assert w.n_clusters == 3 +def test_wizard_nav(wizard): + w = wizard + assert w.cluster_ids == [0, 1, 2, 10, 20, 30] + assert w.n_clusters == 6 assert w.selection == () @@ -185,3 +171,7 @@ def test_wizard_nav(mock_wizard): for _ in range(2): w.next() assert w.selection == (1, 2) + + +def test_wizard_pin(wizard): + w = wizard diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 0034a2669..1c17878f6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -40,7 +40,7 @@ def _next_in_list(l, item): return item -def _sort(clusters, status=None, remove_ignored=False): +def _sort_by_status(clusters, status=None, remove_ignored=False): """Sort clusters according to their status.""" assert status _sort_map = {None: 0, 'good': 1, 'ignored': 2} @@ -70,7 +70,7 @@ def _most_similar_clusters(cluster, cluster_ids=None, n_max=None, if less_than: s = [(c, v) for (c, v) in s if v <= less_than] clusters = _argsort(s, n_max=n_max) - return _sort(clusters, status=status, remove_ignored=True) + return _sort_by_status(clusters, status=status, remove_ignored=True) def _wizard_group(group): @@ -107,7 +107,7 @@ def _best_quality_strategy(selection, if n <= 1: best_clusters = _best_clusters(cluster_ids, quality) # Sort the best clusters according to their status. - best_clusters = _sort(best_clusters, status=status) + best_clusters = _sort_by_status(best_clusters, status=status) if selection: return (_next_in_list(best_clusters, selection[0]),) elif best_clusters: From ddf7926dedaaeae1d8e279189ba1574704cc2cb6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 11:28:25 +0200 Subject: [PATCH 0256/1059] WIP: refactor wizard tests --- phy/cluster/manual/tests/conftest.py | 61 ++++++++------------- phy/cluster/manual/tests/test_wizard.py | 72 ++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 41 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index e42ee686b..22c800ff2 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,8 +8,7 @@ from pytest import yield_fixture -from ..wizard import Wizard, _wizard_group, best_quality_strategy -from .._utils import create_cluster_meta +from ..wizard import Wizard #------------------------------------------------------------------------------ @@ -17,56 +16,38 @@ #------------------------------------------------------------------------------ @yield_fixture -def cluster_groups(): - data = {1: 'noise', - 2: 'mua', - 11: 'good', - 12: 'good', - 13: 'good', - 101: None, - 102: None, - 103: None, - 104: None, - 105: None, - } - yield data +def cluster_ids(): + yield [0, 1, 2, 10, 20, 30] + # i, g, N, N, N, N @yield_fixture -def cluster_meta(cluster_groups): - yield create_cluster_meta(cluster_groups) +def get_cluster_ids(cluster_ids): + yield lambda: cluster_ids @yield_fixture -def mock_wizard(): +def status(): + yield lambda c: ('ignored', 'good')[c] if c <= 1 else None - wizard = Wizard() - wizard.set_cluster_ids_function(lambda: [1, 2, 3]) - - @wizard.set_quality_function - def quality(cluster): - return cluster - - @wizard.set_similarity_function - def similarity(cluster, other): - return cluster + other - yield wizard +@yield_fixture +def quality(): + yield lambda c: c @yield_fixture -def wizard_with_groups(mock_wizard, cluster_groups): +def similarity(): + yield lambda c, d: c * 1.01 + d - def get_cluster_ids(): - return sorted(cluster_groups.keys()) - mock_wizard.set_cluster_ids_function(get_cluster_ids) - - @mock_wizard.set_status_function - def status(cluster): - group = cluster_groups.get(cluster, None) - return _wizard_group(group) +@yield_fixture +def wizard(get_cluster_ids, status, quality, similarity): + wizard = Wizard() - mock_wizard.set_strategy_function(best_quality_strategy) + wizard.set_cluster_ids_function(get_cluster_ids) + wizard.set_status_function(status) + wizard.set_quality_function(quality) + wizard.set_similarity_function(similarity) - yield mock_wizard + yield wizard diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index cfebbc87e..ea02a6a5e 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -173,5 +173,75 @@ def test_wizard_nav(wizard): assert w.selection == (1, 2) -def test_wizard_pin(wizard): +def test_wizard_pin_by_quality(wizard): w = wizard + + w.pin() + assert w.selection == () + + w.unpin() + assert w.selection == () + + w.next_by_quality() + assert w.selection == (30,) + + w.next_by_quality() + assert w.selection == (20,) + + w.pin() + assert w.selection == (20, 30) + + w.next_by_quality() + assert w.selection == (20, 10) + + w.unpin() + assert w.selection == (20,) + + w.next_by_quality() + assert w.selection == (10,) + + w.pin() + assert w.selection == (10, 30) + + w.next_by_quality() + assert w.selection == (10, 20) + + w.next_by_quality() + assert w.selection == (10, 2) + + +def test_wizard_pin_by_similarity(wizard): + w = wizard + + w.pin() + assert w.selection == () + + w.unpin() + assert w.selection == () + + w.next_by_similarity() + assert w.selection == (30, 20) + + w.next_by_similarity() + assert w.selection == (30, 10) + + w.pin() + assert w.selection == (30, 20) + + w.next_by_similarity() + assert w.selection == (30, 10) + + w.unpin() + assert w.selection == (30,) + + w.select((20, 10)) + assert w.selection == (20, 10) + + w.next_by_similarity() + assert w.selection == (20, 2) + + w.next_by_similarity() + assert w.selection == (20, 1) + + w.next_by_similarity() + assert w.selection == (10, 2) From 015ad443aa188fe2e38ed9259798e41b76966c63 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 11:32:31 +0200 Subject: [PATCH 0257/1059] WIP: refactor wizard tests --- phy/cluster/manual/tests/conftest.py | 9 +++++++-- phy/cluster/manual/wizard.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 22c800ff2..167c08107 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -27,8 +27,13 @@ def get_cluster_ids(cluster_ids): @yield_fixture -def status(): - yield lambda c: ('ignored', 'good')[c] if c <= 1 else None +def cluster_groups(): + yield {0: 'ignored', 1: 'good'} + + +@yield_fixture +def status(cluster_groups): + yield lambda c: cluster_groups.get(c, None) @yield_fixture diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 1c17878f6..f03a99690 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -290,7 +290,7 @@ def pin(self): similarity=self._similarity, status=self._cluster_status) assert best not in candidates - if not candidates: + if not candidates: # pragma: no cover return self.select((self.best, candidates[0])) From e1ada4cd35d0edacdb7ea6ef45116b9be95080fa Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 11:42:50 +0200 Subject: [PATCH 0258/1059] WIP: manual plugin tests --- phy/cluster/manual/gui_plugins.py | 78 +++++++++++-- phy/cluster/manual/tests/test_gui_plugins.py | 110 ++----------------- phy/cluster/manual/wizard.py | 4 + 3 files changed, 80 insertions(+), 112 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 50af39706..f36716b39 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -12,7 +12,7 @@ from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering -from .wizard import Wizard +from .wizard import Wizard, _wizard_group from phy.gui.actions import Actions, Snippets from phy.io.array import select_spikes from phy.utils.plugin import IPlugin @@ -40,6 +40,65 @@ def _process_ups(ups): # pragma: no cover raise NotImplementedError() +# ----------------------------------------------------------------------------- +# Attach wizard to effectors (clustering and cluster_meta) +# ----------------------------------------------------------------------------- + +def _attach_wizard_to_effector(wizard, effector): + + # Save the current selection when an action occurs. + @effector.connect + def on_request_undo_state(up): + return {'selection': wizard._selection} + + @effector.connect + def on_cluster(up): + if not up.history: + # Reset the history after every change. + wizard.reset() + if up.history == 'undo': + # Revert to the given selection after an undo. + wizard.select(up.undo_state[0]['selection'], add_to_history=False) + + +def _attach_wizard_to_clustering(wizard, clustering): + _attach_wizard_to_effector(wizard, clustering) + + @wizard.set_cluster_ids_function + def get_cluster_ids(): + return clustering.cluster_ids + + @clustering.connect + def on_cluster(up): + if up.added and up.history != 'undo': + wizard.select((up.added[0],)) + wizard.pin() + + +def _attach_wizard_to_cluster_meta(wizard, cluster_meta): + _attach_wizard_to_effector(wizard, cluster_meta) + + @wizard.set_status_function + def status(cluster): + group = cluster_meta.get('group', cluster) + return _wizard_group(group) + + @cluster_meta.connect + def on_cluster(up): + if up.description == 'metadata_group' and up.history != 'undo': + cluster = up.metadata_changed[0] + wizard.select((cluster,)) + wizard.pin() + + +def _attach_wizard(wizard, clustering, cluster_meta): + @clustering.connect + def on_cluster(up): + # Set the cluster metadata of new clusters. + if up.added: + cluster_meta.set_from_descendants(up.descendants) + + # ----------------------------------------------------------------------------- # Clustering GUI plugins # ----------------------------------------------------------------------------- @@ -73,7 +132,7 @@ def attach_to_gui(self, gui, # Create the wizard and attach it to Clustering/ClusterMeta. self.wizard = Wizard() - self.wizard.attach(self.clustering, self.cluster_meta) + _attach_wizard(self.wizard, self.clustering, self.cluster_meta) @self.wizard.connect def on_select(cluster_ids): @@ -97,8 +156,8 @@ def cluster_ids(self): return self.clustering.cluster_ids def create_actions(self, gui): - actions = Actions() - snippets = Snippets() + self.actions = actions = Actions() + self.snippets = snippets = Snippets() # Create the default actions for the clustering GUI. @actions.connect @@ -107,11 +166,11 @@ def on_reset(): actions.add(callback=self.select, alias='c') # Wizard. - actions.add(callback=self.wizard.start, name='reset_wizard') - actions.add(callback=self.wizard.first) - actions.add(callback=self.wizard.last) + actions.add(callback=self.wizard.restart, name='reset_wizard') actions.add(callback=self.wizard.previous) actions.add(callback=self.wizard.next) + actions.add(callback=self.wizard.next_by_quality) + actions.add(callback=self.wizard.next_by_similarity) actions.add(callback=self.wizard.pin) actions.add(callback=self.wizard.unpin) @@ -127,14 +186,11 @@ def on_reset(): actions.attach(gui) actions.reset() - self.actions = actions - self.snippets = snippets - # Wizard-related actions # ------------------------------------------------------------------------- def select(self, cluster_ids): - self.wizard.selection = cluster_ids + self.wizard.select(cluster_ids) # Clustering actions # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index b6699cad4..89f1b2e5b 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -7,9 +7,8 @@ #------------------------------------------------------------------------------ from pytest import yield_fixture -from numpy.testing import assert_array_equal as ae +import numpy as np -from .conftest import _set_test_wizard from phy.gui.tests.test_gui import gui # noqa @@ -18,12 +17,13 @@ #------------------------------------------------------------------------------ @yield_fixture # noqa -def manual_clustering(qtbot, gui, spike_clusters, cluster_groups): +def manual_clustering(qtbot, gui, cluster_ids, cluster_groups): + spike_clusters = np.array(cluster_ids) + mc = gui.attach('ManualClustering', spike_clusters=spike_clusters, cluster_groups=cluster_groups, ) - _set_test_wizard(mc.wizard) _s = [] @@ -32,108 +32,16 @@ def manual_clustering(qtbot, gui, spike_clusters, cluster_groups): def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) - def _assert_selection(*cluster_ids): # pragma: no cover + def assert_selection(*cluster_ids): # pragma: no cover assert _s[-1][0] == list(cluster_ids) if len(cluster_ids) >= 1: assert mc.wizard.best == cluster_ids[0] elif len(cluster_ids) >= 2: assert mc.wizard.match == cluster_ids[2] - mc._assert_selection = _assert_selection - - yield mc - - -def test_manual_clustering_wizard(manual_clustering): - actions = manual_clustering.actions - wizard = manual_clustering.wizard - _assert_selection = manual_clustering._assert_selection - - # Test cluster ids. - ae(manual_clustering.cluster_ids, [2, 3, 5, 7]) - - # Test select actions. - actions.select([]) - _assert_selection() - - # Test wizard actions. - actions.reset_wizard() - assert wizard.best_list == [3, 2, 7, 5] - _assert_selection(3) - - actions.next() - _assert_selection(2) - - actions.last() - _assert_selection(5) - - actions.next() - _assert_selection(5) - - actions.previous() - _assert_selection(7) - - actions.first() - _assert_selection(3) - - actions.previous() - _assert_selection(3) - - # Test pinning. - actions.pin() - assert wizard.match_list == [2, 7, 5] - _assert_selection(3, 2) - - wizard.next() - _assert_selection(3, 7) - - wizard.unpin() - _assert_selection(3) - - -def test_manual_clustering_actions(manual_clustering): - actions = manual_clustering.actions - wizard = manual_clustering.wizard - _assert_selection = manual_clustering._assert_selection - - # [3 , 2 , 7 , 5] - # [None, None, 'ignored', 'good'] - actions.reset_wizard() - actions.pin() - _assert_selection(3, 2) - - actions.merge() # 3 + 2 => 8 - # [8, 7, 5] - _assert_selection(8, 7) - - wizard.next() - _assert_selection(8, 5) - - actions.undo() - _assert_selection(3, 2) - - actions.redo() - _assert_selection(8, 7) - - actions.split([2, 3]) # => 9 - _assert_selection(9, 8) - - -def test_manual_clustering_group(manual_clustering): - actions = manual_clustering.actions - # wizard = manual_clustering.wizard - _assert_selection = manual_clustering._assert_selection - - actions.reset_wizard() - actions.pin() - _assert_selection(3, 2) - - # [3 , 2 , 7 , 5] - # [None, None, 'good', 'ignored'] - actions.move([3], 'good') + yield mc, assert_selection - # ['good', None, 'good', 'ignored'] - _assert_selection(7, 2) - actions.next() - _assert_selection(7, 3) +def test_manual_clustering_1(manual_clustering): + mc, ae = manual_clustering + print(mc.wizard.selection) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index f03a99690..5fb8dd8cd 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -320,6 +320,10 @@ def next(self): self._history.forward() self._set_selection_from_history() + def restart(self): + self.select(()) + self.next_by_similarity() + def next_by_quality(self): self.selection = _best_quality_strategy( self._selection, From 2dbdcfc7a99bb4afff66cdc9596caa5f77165973 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 13:05:53 +0200 Subject: [PATCH 0259/1059] Remove wizard.selection setter --- phy/cluster/manual/tests/test_wizard.py | 4 ++-- phy/cluster/manual/wizard.py | 14 +++++--------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index ea02a6a5e..3b3e908b0 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -136,14 +136,14 @@ def test_wizard_nav(wizard): assert w.selection == () ### - w.selection = [] + w.select([]) assert w.selection == () assert w.best is None assert w.match is None ### - w.selection = [1] + w.select([1]) assert w.selection == (1,) assert w.best == 1 diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 5fb8dd8cd..c3f76bbfd 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -266,10 +266,6 @@ def selection(self): """Return the current cluster selection.""" return self._selection - @selection.setter - def selection(self, value): - self.select(value) - @property def best(self): """Currently-selected best cluster.""" @@ -296,7 +292,7 @@ def pin(self): def unpin(self): if len(self._selection) == 2: - self.selection = (self.selection[0],) + self.select((self.selection[0],)) # Navigation #-------------------------------------------------------------------------- @@ -325,19 +321,19 @@ def restart(self): self.next_by_similarity() def next_by_quality(self): - self.selection = _best_quality_strategy( + self.select(_best_quality_strategy( self._selection, cluster_ids=self._get_cluster_ids(), quality=self._quality, status=self._cluster_status, - similarity=self._similarity) + similarity=self._similarity)) return self._selection def next_by_similarity(self): - self.selection = _best_similarity_strategy( + self.select(_best_similarity_strategy( self._selection, cluster_ids=self._get_cluster_ids(), quality=self._quality, status=self._cluster_status, - similarity=self._similarity) + similarity=self._similarity)) return self._selection From 76cd44d8bc6440e2c3e3fcc8fb453030457a0463 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 13:42:31 +0200 Subject: [PATCH 0260/1059] Add save_requested event in manual clustering plugin --- phy/cluster/manual/gui_plugins.py | 19 +++++++++++++++++++ phy/cluster/manual/wizard.py | 5 ++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index f36716b39..01bf80b78 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -117,6 +117,20 @@ class ManualClustering(IPlugin): Other plugins can connect to that event. + Parameters + ---------- + + gui : GUI + spike_clusters : ndarray + cluster_groups : dictionary + n_spikes_max_per_cluster : int + + Events + ------ + + select(cluster_ids, spike_ids) + save_requested(spike_clusters, cluster_groups) + """ def attach_to_gui(self, gui, spike_clusters=None, @@ -215,3 +229,8 @@ def undo(self): def redo(self): self._global_history.redo() + + def save(self): + groups = {c: self.cluster_meta.get('group', c) + for c in self.cluster_ids} + self.emit('save_requested', self.clustering.spike_clusters, groups) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index c3f76bbfd..beed32491 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -179,9 +179,8 @@ class Wizard(EventEmitter): provided by functions: cluster_ids, status (group), similarity, quality. * The wizard keeps track of the history of the selected clusters, but this history is cleared after every action that changes the state. - * The `next()` function proposes a new selection as a function of the - current selection only. - * There are two strategies: best-quality or best-similarity strategy. + * The `next_*()` functions propose a new selection as a function of the + current selection. TODO: cache expensive functions. From b1a9702cfd85010afb22683d955590c39278fb63 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 13:56:54 +0200 Subject: [PATCH 0261/1059] WIP: test GUI plugin --- phy/cluster/manual/gui_plugins.py | 7 ++++++- phy/cluster/manual/tests/test_gui_plugins.py | 13 ++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 01bf80b78..2991ae9d2 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -9,6 +9,8 @@ import logging +import numpy as np + from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering @@ -98,6 +100,9 @@ def on_cluster(up): if up.added: cluster_meta.set_from_descendants(up.descendants) + _attach_wizard_to_clustering(wizard, clustering) + _attach_wizard_to_cluster_meta(wizard, cluster_meta) + # ----------------------------------------------------------------------------- # Clustering GUI plugins @@ -156,7 +161,7 @@ def on_select(cluster_ids): The wizard is responsible for the notion of "selected clusters". """ - spike_ids = select_spikes(cluster_ids, + spike_ids = select_spikes(np.array(cluster_ids), n_spikes_max_per_cluster, self.clustering.spikes_per_cluster) gui.emit('select', cluster_ids, spike_ids) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 89f1b2e5b..45eefae1a 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -8,6 +8,7 @@ from pytest import yield_fixture import numpy as np +from numpy.testing import assert_array_equal as ae from phy.gui.tests.test_gui import gui # noqa @@ -33,7 +34,9 @@ def on_select(cluster_ids, spike_ids): _s.append((cluster_ids, spike_ids)) def assert_selection(*cluster_ids): # pragma: no cover - assert _s[-1][0] == list(cluster_ids) + if not _s: + return + assert _s[-1][0] == tuple(cluster_ids) if len(cluster_ids) >= 1: assert mc.wizard.best == cluster_ids[0] elif len(cluster_ids) >= 2: @@ -43,5 +46,9 @@ def assert_selection(*cluster_ids): # pragma: no cover def test_manual_clustering_1(manual_clustering): - mc, ae = manual_clustering - print(mc.wizard.selection) + mc, assert_selection = manual_clustering + assert_selection() + ae(mc.cluster_ids, [0, 1, 2, 10, 20, 30]) + + mc.select([0]) + assert_selection(0) From 5040c338393199bdb6b5be98e5faf2fd0ee34735 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 14:25:17 +0200 Subject: [PATCH 0262/1059] WIP: test GUI plugin --- phy/cluster/manual/gui_plugins.py | 19 +++++++----- phy/cluster/manual/tests/test_gui_plugins.py | 32 ++++++++++++++++++-- phy/cluster/manual/wizard.py | 16 ++++++++++ 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 2991ae9d2..cfabca99f 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -170,10 +170,6 @@ def on_select(cluster_ids): return self - @property - def cluster_ids(self): - return self.clustering.cluster_ids - def create_actions(self, gui): self.actions = actions = Actions() self.snippets = snippets = Snippets() @@ -217,16 +213,22 @@ def select(self, cluster_ids): def merge(self, cluster_ids=None): if cluster_ids is None: cluster_ids = self.wizard.selection + if len(cluster_ids) <= 1: + return self.clustering.merge(cluster_ids) self._global_history.action(self.clustering) def split(self, spike_ids): + if len(spike_ids) == 0: + return # TODO: connect to request_split emitted by view self.clustering.split(spike_ids) self._global_history.action(self.clustering) - def move(self, clusters, group): - self.cluster_meta.set('group', clusters, group) + def move(self, cluster_ids, group): + if len(cluster_ids) == 0: + return + self.cluster_meta.set('group', cluster_ids, group) self._global_history.action(self.cluster_meta) def undo(self): @@ -236,6 +238,7 @@ def redo(self): self._global_history.redo() def save(self): + spike_clusters = self.clustering.spike_clusters groups = {c: self.cluster_meta.get('group', c) - for c in self.cluster_ids} - self.emit('save_requested', self.clustering.spike_clusters, groups) + for c in self.clustering.cluster_ids} + self.gui.emit('save_requested', spike_clusters, groups) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 45eefae1a..340c7c162 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -10,6 +10,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae +from ..gui_plugins import _attach_wizard from phy.gui.tests.test_gui import gui # noqa @@ -17,6 +18,11 @@ # Test GUI plugins #------------------------------------------------------------------------------ +def test_attach_wizard(): + # TODO + pass + + @yield_fixture # noqa def manual_clustering(qtbot, gui, cluster_ids, cluster_groups): spike_clusters = np.array(cluster_ids) @@ -45,10 +51,32 @@ def assert_selection(*cluster_ids): # pragma: no cover yield mc, assert_selection -def test_manual_clustering_1(manual_clustering): +def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering assert_selection() - ae(mc.cluster_ids, [0, 1, 2, 10, 20, 30]) + ae(mc.clustering.cluster_ids, [0, 1, 2, 10, 20, 30]) mc.select([0]) assert_selection(0) + + mc.undo() + mc.redo() + + # Merge. + mc.merge() + assert_selection(0) + + mc.merge([]) + assert_selection(0) + + mc.merge([10]) + assert_selection(0) + + # Split. + mc.split([]) + assert_selection(0) + + # Move. + mc.move([], 'ignored') + + mc.save() diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index beed32491..9395547cf 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -280,6 +280,7 @@ def pin(self): best = self.best if best is None: return + self._check_functions() candidates = _most_similar_clusters(best, cluster_ids=self.cluster_ids, similarity=self._similarity, @@ -319,7 +320,21 @@ def restart(self): self.select(()) self.next_by_similarity() + def _check_functions(self): + if not self._get_cluster_ids: + raise RuntimeError("The cluster_ids function must be set.") + if not self._cluster_status: + logger.warn("A cluster status function has not been set.") + self._cluster_status = lambda c: None + if not self._quality: + logger.warn("A cluster quality function has not been set.") + self._quality = lambda c: 0 + if not self._similarity: + logger.warn("A cluster similarity function has not been set.") + self._similarity = lambda c, d: 0 + def next_by_quality(self): + self._check_functions() self.select(_best_quality_strategy( self._selection, cluster_ids=self._get_cluster_ids(), @@ -329,6 +344,7 @@ def next_by_quality(self): return self._selection def next_by_similarity(self): + self._check_functions() self.select(_best_similarity_strategy( self._selection, cluster_ids=self._get_cluster_ids(), From 5231c3321c39239a85ea206a39e0e85d5b577fc0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 14:29:24 +0200 Subject: [PATCH 0263/1059] Increase coverage in wizard --- phy/cluster/manual/tests/test_wizard.py | 13 +++++++++++++ phy/cluster/manual/wizard.py | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 3b3e908b0..eb72a3281 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +from pytest import raises + from ..wizard import (_argsort, _sort_by_status, _next_in_list, @@ -14,6 +16,7 @@ _wizard_group, _best_quality_strategy, _best_similarity_strategy, + Wizard, ) @@ -120,6 +123,16 @@ def _next(selection): # Test wizard #------------------------------------------------------------------------------ +def test_wizard_empty(): + wizard = Wizard() + with raises(RuntimeError): + wizard.restart() + + wizard = Wizard() + wizard.set_cluster_ids_function(lambda: []) + wizard.restart() + + def test_wizard_group(): assert _wizard_group('noise') == 'ignored' assert _wizard_group('mua') == 'ignored' diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 9395547cf..fcdc61e4e 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -190,7 +190,7 @@ def __init__(self): self._similarity = None self._quality = None self._get_cluster_ids = None - self._cluster_status = lambda cluster: None + self._cluster_status = None self._selection = () self.reset() @@ -238,6 +238,8 @@ def set_quality_function(self, func): @property def cluster_ids(self): """Array of cluster ids in the current clustering.""" + if not self._get_cluster_ids: + return [] return sorted(self._get_cluster_ids()) @property From e361fd7179f3d2fb10185a4cb51aad7476b4f103 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 14:47:52 +0200 Subject: [PATCH 0264/1059] WIP: test GUI plugin --- phy/cluster/manual/gui_plugins.py | 3 ++- phy/cluster/manual/tests/test_gui_plugins.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index cfabca99f..efb22ee46 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -57,6 +57,7 @@ def on_request_undo_state(up): def on_cluster(up): if not up.history: # Reset the history after every change. + # That's because the history contains references to dead clusters. wizard.reset() if up.history == 'undo': # Revert to the given selection after an undo. @@ -215,7 +216,7 @@ def merge(self, cluster_ids=None): cluster_ids = self.wizard.selection if len(cluster_ids) <= 1: return - self.clustering.merge(cluster_ids) + self.clustering.merge(list(cluster_ids)) self._global_history.action(self.clustering) def split(self, spike_ids): diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 340c7c162..ed6b11085 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -53,6 +53,8 @@ def assert_selection(*cluster_ids): # pragma: no cover def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering + + # Empty selection at first. assert_selection() ae(mc.clustering.cluster_ids, [0, 1, 2, 10, 20, 30]) @@ -80,3 +82,11 @@ def test_manual_clustering_edge_cases(manual_clustering): mc.move([], 'ignored') mc.save() + + +def test_manual_clustering_merge(manual_clustering): + mc, assert_selection = manual_clustering + + mc.actions.select([30, 20]) + mc.actions.merge() + # assert_selection(31, 10) From 931ef72d847dc4c1bbc98ceeb6ccac4052a88b68 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 14:56:55 +0200 Subject: [PATCH 0265/1059] WIP: test attach wizard --- phy/cluster/manual/tests/test_gui_plugins.py | 27 +++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index ed6b11085..b9ea61958 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -10,19 +10,18 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from ..gui_plugins import _attach_wizard +from ..clustering import Clustering +from ..gui_plugins import (_attach_wizard, + _attach_wizard_to_clustering, + _attach_wizard_to_cluster_meta, + ) from phy.gui.tests.test_gui import gui # noqa #------------------------------------------------------------------------------ -# Test GUI plugins +# Fixtures #------------------------------------------------------------------------------ -def test_attach_wizard(): - # TODO - pass - - @yield_fixture # noqa def manual_clustering(qtbot, gui, cluster_ids, cluster_groups): spike_clusters = np.array(cluster_ids) @@ -51,6 +50,20 @@ def assert_selection(*cluster_ids): # pragma: no cover yield mc, assert_selection +#------------------------------------------------------------------------------ +# Test GUI plugins +#------------------------------------------------------------------------------ + +def test_attach_wizard_to_clustering(wizard, cluster_ids): + clustering = Clustering(np.array(cluster_ids)) + _attach_wizard_to_clustering(wizard, clustering) + + assert wizard.selection == () + + clustering.merge([30, 20]) + assert wizard.selection == (31, 10) + + def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering From 49b2be25822b61eccf9eb5b67224d58176893df3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 15:05:07 +0200 Subject: [PATCH 0266/1059] Move from tuples to lists in wizard.selection --- phy/cluster/manual/gui_plugins.py | 6 +- phy/cluster/manual/tests/test_gui_plugins.py | 6 +- phy/cluster/manual/tests/test_wizard.py | 90 ++++++++++---------- phy/cluster/manual/wizard.py | 28 +++--- 4 files changed, 64 insertions(+), 66 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index efb22ee46..168c13736 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -74,7 +74,7 @@ def get_cluster_ids(): @clustering.connect def on_cluster(up): if up.added and up.history != 'undo': - wizard.select((up.added[0],)) + wizard.select([up.added[0]]) wizard.pin() @@ -90,7 +90,7 @@ def status(cluster): def on_cluster(up): if up.description == 'metadata_group' and up.history != 'undo': cluster = up.metadata_changed[0] - wizard.select((cluster,)) + wizard.select([cluster]) wizard.pin() @@ -216,7 +216,7 @@ def merge(self, cluster_ids=None): cluster_ids = self.wizard.selection if len(cluster_ids) <= 1: return - self.clustering.merge(list(cluster_ids)) + self.clustering.merge(cluster_ids) self._global_history.action(self.clustering) def split(self, spike_ids): diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index b9ea61958..34c610fcf 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -41,7 +41,7 @@ def on_select(cluster_ids, spike_ids): def assert_selection(*cluster_ids): # pragma: no cover if not _s: return - assert _s[-1][0] == tuple(cluster_ids) + assert _s[-1][0] == list(cluster_ids) if len(cluster_ids) >= 1: assert mc.wizard.best == cluster_ids[0] elif len(cluster_ids) >= 2: @@ -58,10 +58,10 @@ def test_attach_wizard_to_clustering(wizard, cluster_ids): clustering = Clustering(np.array(cluster_ids)) _attach_wizard_to_clustering(wizard, clustering) - assert wizard.selection == () + assert wizard.selection == [] clustering.merge([30, 20]) - assert wizard.selection == (31, 10) + assert wizard.selection == [31, 10] def test_manual_clustering_edge_cases(manual_clustering): diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index eb72a3281..26ca48588 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -90,14 +90,14 @@ def _next(selection): similarity=similarity) assert not _next(None) - assert _next(()) == (30,) - assert _next((30,)) == (20,) - assert _next((20,)) == (10,) - assert _next((10,)) == (2,) + assert _next([]) == [30] + assert _next([30]) == [20] + assert _next([20]) == [10] + assert _next([10]) == [2] - assert _next((30, 20)) == (30, 10) - assert _next((10, 2)) == (10, 1) - assert _next((10, 1)) == (10, 1) # 0 is ignored, so it does not appear. + assert _next([30, 20]) == [30, 10] + assert _next([10, 2]) == [10, 1] + assert _next([10, 1]) == [10, 1] # 0 is ignored, so it does not appear. def test_best_similarity_strategy(cluster_ids, quality, status, similarity): @@ -110,13 +110,13 @@ def _next(selection): similarity=similarity) assert not _next(None) - assert _next(()) == (30, 20) - assert _next((30, 20)) == (30, 10) - assert _next((30, 10)) == (30, 2) - assert _next((20, 10)) == (20, 2) - assert _next((10, 2)) == (10, 1) - assert _next((10, 1)) == (2, 1) - assert _next((2, 1)) == (2, 1) # 0 is ignored, so it does not appear. + assert _next([]) == [30, 20] + assert _next([30, 20]) == [30, 10] + assert _next([30, 10]) == [30, 2] + assert _next([20, 10]) == [20, 2] + assert _next([10, 2]) == [10, 1] + assert _next([10, 1]) == [2, 1] + assert _next([2, 1]) == [2, 1] # 0 is ignored, so it does not appear. #------------------------------------------------------------------------------ @@ -146,115 +146,115 @@ def test_wizard_nav(wizard): assert w.cluster_ids == [0, 1, 2, 10, 20, 30] assert w.n_clusters == 6 - assert w.selection == () + assert w.selection == [] ### w.select([]) - assert w.selection == () + assert w.selection == [] assert w.best is None assert w.match is None ### w.select([1]) - assert w.selection == (1,) + assert w.selection == [1] assert w.best == 1 assert w.match is None ### w.select([1, 2, 4]) - assert w.selection == (1, 2) + assert w.selection == [1, 2] assert w.best == 1 assert w.match == 2 ### w.previous() - assert w.selection == (1,) + assert w.selection == [1] for _ in range(2): w.previous() - assert w.selection == (1,) + assert w.selection == [1] ### w.next() - assert w.selection == (1, 2) + assert w.selection == [1, 2] for _ in range(2): w.next() - assert w.selection == (1, 2) + assert w.selection == [1, 2] def test_wizard_pin_by_quality(wizard): w = wizard w.pin() - assert w.selection == () + assert w.selection == [] w.unpin() - assert w.selection == () + assert w.selection == [] w.next_by_quality() - assert w.selection == (30,) + assert w.selection == [30] w.next_by_quality() - assert w.selection == (20,) + assert w.selection == [20] w.pin() - assert w.selection == (20, 30) + assert w.selection == [20, 30] w.next_by_quality() - assert w.selection == (20, 10) + assert w.selection == [20, 10] w.unpin() - assert w.selection == (20,) + assert w.selection == [20] w.next_by_quality() - assert w.selection == (10,) + assert w.selection == [10] w.pin() - assert w.selection == (10, 30) + assert w.selection == [10, 30] w.next_by_quality() - assert w.selection == (10, 20) + assert w.selection == [10, 20] w.next_by_quality() - assert w.selection == (10, 2) + assert w.selection == [10, 2] def test_wizard_pin_by_similarity(wizard): w = wizard w.pin() - assert w.selection == () + assert w.selection == [] w.unpin() - assert w.selection == () + assert w.selection == [] w.next_by_similarity() - assert w.selection == (30, 20) + assert w.selection == [30, 20] w.next_by_similarity() - assert w.selection == (30, 10) + assert w.selection == [30, 10] w.pin() - assert w.selection == (30, 20) + assert w.selection == [30, 20] w.next_by_similarity() - assert w.selection == (30, 10) + assert w.selection == [30, 10] w.unpin() - assert w.selection == (30,) + assert w.selection == [30] - w.select((20, 10)) - assert w.selection == (20, 10) + w.select([20, 10]) + assert w.selection == [20, 10] w.next_by_similarity() - assert w.selection == (20, 2) + assert w.selection == [20, 2] w.next_by_similarity() - assert w.selection == (20, 1) + assert w.selection == [20, 1] w.next_by_similarity() - assert w.selection == (10, 2) + assert w.selection == [10, 2] diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index fcdc61e4e..7ad126cfe 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -13,7 +13,7 @@ from six import string_types from ._history import History -from phy.utils import EventEmitter +from phy.utils import EventEmitter, _is_array_like logger = logging.getLogger(__name__) @@ -102,16 +102,15 @@ def _best_quality_strategy(selection, """ if selection is None: return selection - selection = tuple(selection) n = len(selection) if n <= 1: best_clusters = _best_clusters(cluster_ids, quality) # Sort the best clusters according to their status. best_clusters = _sort_by_status(best_clusters, status=status) if selection: - return (_next_in_list(best_clusters, selection[0]),) + return [_next_in_list(best_clusters, selection[0])] elif best_clusters: - return (best_clusters[0],) + return [best_clusters[0]] else: # pragma: no cover return selection elif n == 2: @@ -128,7 +127,7 @@ def _best_quality_strategy(selection, candidates.remove(match) if not candidates: return selection - return (best, candidates[0]) + return [best, candidates[0]] def _best_similarity_strategy(selection, @@ -138,7 +137,6 @@ def _best_similarity_strategy(selection, similarity=None): if selection is None: return selection - selection = tuple(selection) n = len(selection) if n >= 2: best, match = selection @@ -162,7 +160,7 @@ def _best_similarity_strategy(selection, s = [((c, d), v) for ((c, d), v) in s if v <= value] pairs = _argsort(s) if pairs: - return pairs[0] + return list(pairs[0]) else: return selection @@ -191,11 +189,11 @@ def __init__(self): self._quality = None self._get_cluster_ids = None self._cluster_status = None - self._selection = () + self._selection = [] self.reset() def reset(self): - self._history = History(()) + self._history = History([]) # Quality and status functions #-------------------------------------------------------------------------- @@ -253,10 +251,10 @@ def n_clusters(self): def select(self, cluster_ids, add_to_history=True): if cluster_ids is None: # pragma: no cover return - assert hasattr(cluster_ids, '__len__') + assert _is_array_like(cluster_ids) clusters = self.cluster_ids - cluster_ids = tuple(cluster for cluster in cluster_ids - if cluster in clusters) + cluster_ids = [cluster for cluster in cluster_ids + if cluster in clusters] self._selection = cluster_ids if add_to_history: self._history.add(self._selection) @@ -290,11 +288,11 @@ def pin(self): assert best not in candidates if not candidates: # pragma: no cover return - self.select((self.best, candidates[0])) + self.select([self.best, candidates[0]]) def unpin(self): if len(self._selection) == 2: - self.select((self.selection[0],)) + self.select([self.selection[0]]) # Navigation #-------------------------------------------------------------------------- @@ -319,7 +317,7 @@ def next(self): self._set_selection_from_history() def restart(self): - self.select(()) + self.select([]) self.next_by_similarity() def _check_functions(self): From db711f50ff7e5240ce44bd00dee96d90f9c3524d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 15:13:21 +0200 Subject: [PATCH 0267/1059] WIP: test attach wizard --- phy/cluster/manual/clustering.py | 2 +- phy/cluster/manual/gui_plugins.py | 6 ++++ phy/cluster/manual/tests/test_gui_plugins.py | 33 +++++++++++++++++++- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 851901546..883f69f3e 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -404,7 +404,7 @@ def assign(self, spike_ids, spike_clusters_rel=0): return UpdateInfo() assert len(spike_ids) == len(spike_clusters_rel) assert spike_ids.min() >= 0 - assert spike_ids.max() < self._n_spikes + assert spike_ids.max() < self._n_spikes, "Some spikes don't exist." # Normalize the spike-cluster assignment such that # there are only new or dead clusters, not modified clusters. diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 168c13736..a7e23ee28 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -75,6 +75,12 @@ def get_cluster_ids(): def on_cluster(up): if up.added and up.history != 'undo': wizard.select([up.added[0]]) + # NOTE: after a merge, select the merged one AND the most similar. + # There is an ambiguity after a merge: does the merge occurs during + # a wizard session, in which case we want to pin the merged + # cluster? If it is just a "cold" merge, then we might not want + # to pin the merged cluster. But cold merges are supposed to be + # less frequent than wizard merges. wizard.pin() diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 34c610fcf..7e469fd7c 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -54,16 +54,47 @@ def assert_selection(*cluster_ids): # pragma: no cover # Test GUI plugins #------------------------------------------------------------------------------ -def test_attach_wizard_to_clustering(wizard, cluster_ids): +def test_attach_wizard_to_clustering_merge(wizard, cluster_ids): clustering = Clustering(np.array(cluster_ids)) _attach_wizard_to_clustering(wizard, clustering) assert wizard.selection == [] + wizard.select([30, 20, 10]) + assert wizard.selection == [30, 20, 10] + clustering.merge([30, 20]) + # Select the merged cluster along with its most similar one (=pin merged). + assert wizard.selection == [31, 10] + + # Undo: the previous selection reappears. + clustering.undo() + assert wizard.selection == [30, 20, 10] + + # Redo. + clustering.redo() assert wizard.selection == [31, 10] +def test_attach_wizard_to_clustering_split(wizard, cluster_ids): + clustering = Clustering(np.array(cluster_ids)) + _attach_wizard_to_clustering(wizard, clustering) + + wizard.select([30, 20, 10]) + assert wizard.selection == [30, 20, 10] + + clustering.split([5, 3]) + assert wizard.selection == [31, 20] + + # Undo: the previous selection reappears. + clustering.undo() + assert wizard.selection == [30, 20, 10] + + # Redo. + clustering.redo() + assert wizard.selection == [31, 20] + + def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering From 95c2f0b40a86e9dfff7e791abecdf5d44385bb4d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 15:35:42 +0200 Subject: [PATCH 0268/1059] WIP: test wizard attach to cluster meta --- phy/cluster/manual/tests/test_gui_plugins.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 7e469fd7c..63d169e28 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -11,6 +11,7 @@ from numpy.testing import assert_array_equal as ae from ..clustering import Clustering +from .._utils import create_cluster_meta from ..gui_plugins import (_attach_wizard, _attach_wizard_to_clustering, _attach_wizard_to_cluster_meta, @@ -51,7 +52,7 @@ def assert_selection(*cluster_ids): # pragma: no cover #------------------------------------------------------------------------------ -# Test GUI plugins +# Test wizard attach #------------------------------------------------------------------------------ def test_attach_wizard_to_clustering_merge(wizard, cluster_ids): @@ -95,6 +96,23 @@ def test_attach_wizard_to_clustering_split(wizard, cluster_ids): assert wizard.selection == [31, 20] +def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): + cluster_meta = create_cluster_meta(cluster_groups) + _attach_wizard_to_cluster_meta(wizard, cluster_meta) + + wizard.select([30]) + + wizard.select([20]) + assert wizard.selection == [20] + + cluster_meta.set('group', [20], 'noise') + # assert wizard.selection == [10] + + +#------------------------------------------------------------------------------ +# Test GUI plugins +#------------------------------------------------------------------------------ + def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering From c4e0ca2bc610570bf5ada5abacb4d3bbdc24ca0b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 15:36:50 +0200 Subject: [PATCH 0269/1059] WIP: test attach wizard --- phy/cluster/manual/tests/test_gui_plugins.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 63d169e28..b1d323a5e 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -109,6 +109,12 @@ def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): # assert wizard.selection == [10] +def test_attach_wizard(wizard, cluster_ids, cluster_groups): + clustering = Clustering(np.array(cluster_ids)) + cluster_meta = create_cluster_meta(cluster_groups) + _attach_wizard(wizard, clustering, cluster_meta) + + #------------------------------------------------------------------------------ # Test GUI plugins #------------------------------------------------------------------------------ From b96c2ea88c6288f02b630fcb3451764b15973579 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 16:00:32 +0200 Subject: [PATCH 0270/1059] Refactor wizard next methods --- phy/cluster/manual/tests/test_wizard.py | 19 +++++++++++++ phy/cluster/manual/wizard.py | 37 +++++++++++++++---------- 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 26ca48588..10fe37eae 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -186,6 +186,25 @@ def test_wizard_nav(wizard): assert w.selection == [1, 2] +def test_wizard_next(wizard, status): + w = wizard + + assert w.next_selection([30]) == [20] + assert w.next_selection([30], ignore_group=True) == [20] + + assert w.next_selection([1]) == [0] + assert w.next_selection([1], ignore_group=True) == [0] + + @w.set_status_function + def status_bis(cluster): + if cluster == 30: + return 'ignored' + return status(cluster) + + assert w.next_selection([30]) == [0] + assert w.next_selection([30], ignore_group=True) == [20] + + def test_wizard_pin_by_quality(wizard): w = wizard diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 7ad126cfe..4e129ded0 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -333,22 +333,29 @@ def _check_functions(self): logger.warn("A cluster similarity function has not been set.") self._similarity = lambda c, d: 0 - def next_by_quality(self): + def next_selection(self, cluster_ids=None, + strategy=None, + ignore_group=False): + cluster_ids = cluster_ids or self._selection + strategy = strategy or _best_quality_strategy + if ignore_group: + # Ignore the status of the selected clusters. + def status(cluster): + if cluster in cluster_ids: + return None + return self._cluster_status(cluster) + else: + status = self._cluster_status self._check_functions() - self.select(_best_quality_strategy( - self._selection, - cluster_ids=self._get_cluster_ids(), - quality=self._quality, - status=self._cluster_status, - similarity=self._similarity)) + self.select(strategy(cluster_ids, + cluster_ids=self._get_cluster_ids(), + quality=self._quality, + status=status, + similarity=self._similarity)) return self._selection + def next_by_quality(self): + return self.next_selection(strategy=_best_quality_strategy) + def next_by_similarity(self): - self._check_functions() - self.select(_best_similarity_strategy( - self._selection, - cluster_ids=self._get_cluster_ids(), - quality=self._quality, - status=self._cluster_status, - similarity=self._similarity)) - return self._selection + return self.next_selection(strategy=_best_similarity_strategy) From 21f4ca0902558abe232fa0c6883433e885c34611 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 16:17:45 +0200 Subject: [PATCH 0271/1059] WIP: test attach wizard --- phy/cluster/manual/gui_plugins.py | 3 +-- phy/cluster/manual/tests/test_gui_plugins.py | 14 +++++++++++++- phy/cluster/manual/wizard.py | 4 ++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index a7e23ee28..080ecb8e7 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -96,8 +96,7 @@ def status(cluster): def on_cluster(up): if up.description == 'metadata_group' and up.history != 'undo': cluster = up.metadata_changed[0] - wizard.select([cluster]) - wizard.pin() + wizard.next_selection([cluster], ignore_group=True) def _attach_wizard(wizard, clustering, cluster_meta): diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index b1d323a5e..32ef3b06c 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -106,7 +106,19 @@ def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): assert wizard.selection == [20] cluster_meta.set('group', [20], 'noise') - # assert wizard.selection == [10] + assert wizard.selection == [10] + + cluster_meta.set('group', [10], 'good') + assert wizard.selection == [2] + + # Restart. + wizard.restart() + assert wizard.selection == [30] + + # 30, 20, 10, 2, 1, 0 + # N, i, g, N, g, i + assert wizard.next_by_quality() == [2] + # assert wizard.next_by_quality() == [10] def test_attach_wizard(wizard, cluster_ids, cluster_groups): diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 4e129ded0..b3bdb5c5c 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -318,7 +318,7 @@ def next(self): def restart(self): self.select([]) - self.next_by_similarity() + self.next_by_quality() def _check_functions(self): if not self._get_cluster_ids: @@ -336,6 +336,7 @@ def _check_functions(self): def next_selection(self, cluster_ids=None, strategy=None, ignore_group=False): + self._check_functions() cluster_ids = cluster_ids or self._selection strategy = strategy or _best_quality_strategy if ignore_group: @@ -346,7 +347,6 @@ def status(cluster): return self._cluster_status(cluster) else: status = self._cluster_status - self._check_functions() self.select(strategy(cluster_ids, cluster_ids=self._get_cluster_ids(), quality=self._quality, From f129efdf919de8044ef4b87b370217884ff98e70 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 16:39:28 +0200 Subject: [PATCH 0272/1059] Fixture cluster groups are a bit more complex --- phy/cluster/manual/tests/conftest.py | 6 +- phy/cluster/manual/tests/test_gui_plugins.py | 10 ++-- phy/cluster/manual/tests/test_wizard.py | 63 ++++++++++++-------- 3 files changed, 46 insertions(+), 33 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 167c08107..ebc775394 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -17,8 +17,8 @@ @yield_fixture def cluster_ids(): - yield [0, 1, 2, 10, 20, 30] - # i, g, N, N, N, N + yield [0, 1, 2, 10, 11, 20, 30] + # i, g, N, i, g, N, N @yield_fixture @@ -28,7 +28,7 @@ def get_cluster_ids(cluster_ids): @yield_fixture def cluster_groups(): - yield {0: 'ignored', 1: 'good'} + yield {0: 'ignored', 1: 'good', 10: 'ignored', 11: 'good'} @yield_fixture diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 32ef3b06c..473523ce0 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -66,7 +66,7 @@ def test_attach_wizard_to_clustering_merge(wizard, cluster_ids): clustering.merge([30, 20]) # Select the merged cluster along with its most similar one (=pin merged). - assert wizard.selection == [31, 10] + assert wizard.selection == [31, 2] # Undo: the previous selection reappears. clustering.undo() @@ -74,7 +74,7 @@ def test_attach_wizard_to_clustering_merge(wizard, cluster_ids): # Redo. clustering.redo() - assert wizard.selection == [31, 10] + assert wizard.selection == [31, 2] def test_attach_wizard_to_clustering_split(wizard, cluster_ids): @@ -85,7 +85,7 @@ def test_attach_wizard_to_clustering_split(wizard, cluster_ids): assert wizard.selection == [30, 20, 10] clustering.split([5, 3]) - assert wizard.selection == [31, 20] + assert wizard.selection == [31, 30] # Undo: the previous selection reappears. clustering.undo() @@ -93,7 +93,7 @@ def test_attach_wizard_to_clustering_split(wizard, cluster_ids): # Redo. clustering.redo() - assert wizard.selection == [31, 20] + assert wizard.selection == [31, 30] def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): @@ -136,7 +136,7 @@ def test_manual_clustering_edge_cases(manual_clustering): # Empty selection at first. assert_selection() - ae(mc.clustering.cluster_ids, [0, 1, 2, 10, 20, 30]) + ae(mc.clustering.cluster_ids, [0, 1, 2, 10, 11, 20, 30]) mc.select([0]) assert_selection(0) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 10fe37eae..bbddd0da7 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -38,9 +38,10 @@ def test_argsort(): def test_sort_by_status(status): cluster_ids = [10, 0, 1, 30, 2, 20] - assert _sort_by_status(cluster_ids, status=status) == [10, 30, 2, 20, 1, 0] + assert _sort_by_status(cluster_ids, status=status) == \ + [30, 2, 20, 1, 10, 0] assert _sort_by_status(cluster_ids, status=status, - remove_ignored=True) == [10, 30, 2, 20, 1] + remove_ignored=True) == [30, 2, 20, 1] def test_next_in_list(): @@ -71,9 +72,9 @@ def _similar(cluster): assert not _similar(None) assert not _similar(100) - assert _similar(0) == [30, 20, 10, 2, 1] - assert _similar(1) == [30, 20, 10, 2] - assert _similar(2) == [30, 20, 10, 1] + assert _similar(0) == [30, 20, 2, 11, 1] + assert _similar(1) == [30, 20, 2, 11] + assert _similar(2) == [30, 20, 11, 1] #------------------------------------------------------------------------------ @@ -92,10 +93,10 @@ def _next(selection): assert not _next(None) assert _next([]) == [30] assert _next([30]) == [20] - assert _next([20]) == [10] - assert _next([10]) == [2] + assert _next([20]) == [2] + assert _next([2]) == [11] - assert _next([30, 20]) == [30, 10] + assert _next([30, 20]) == [30, 2] assert _next([10, 2]) == [10, 1] assert _next([10, 1]) == [10, 1] # 0 is ignored, so it does not appear. @@ -111,11 +112,10 @@ def _next(selection): assert not _next(None) assert _next([]) == [30, 20] - assert _next([30, 20]) == [30, 10] - assert _next([30, 10]) == [30, 2] + assert _next([30, 20]) == [30, 11] + assert _next([30, 11]) == [30, 2] assert _next([20, 10]) == [20, 2] - assert _next([10, 2]) == [10, 1] - assert _next([10, 1]) == [2, 1] + assert _next([10, 2]) == [2, 1] assert _next([2, 1]) == [2, 1] # 0 is ignored, so it does not appear. @@ -143,8 +143,8 @@ def test_wizard_group(): def test_wizard_nav(wizard): w = wizard - assert w.cluster_ids == [0, 1, 2, 10, 20, 30] - assert w.n_clusters == 6 + assert w.cluster_ids == [0, 1, 2, 10, 11, 20, 30] + assert w.n_clusters == 7 assert w.selection == [] @@ -192,8 +192,10 @@ def test_wizard_next(wizard, status): assert w.next_selection([30]) == [20] assert w.next_selection([30], ignore_group=True) == [20] - assert w.next_selection([1]) == [0] - assert w.next_selection([1], ignore_group=True) == [0] + # After the last good, the best ignored. + assert w.next_selection([1]) == [10] + # After the last unsorted (1's group is ignored), the best good. + assert w.next_selection([1], ignore_group=True) == [11] @w.set_status_function def status_bis(cluster): @@ -201,7 +203,7 @@ def status_bis(cluster): return 'ignored' return status(cluster) - assert w.next_selection([30]) == [0] + assert w.next_selection([30]) == [10] assert w.next_selection([30], ignore_group=True) == [20] @@ -220,26 +222,37 @@ def test_wizard_pin_by_quality(wizard): w.next_by_quality() assert w.selection == [20] + # Pin. w.pin() assert w.selection == [20, 30] w.next_by_quality() - assert w.selection == [20, 10] + assert w.selection == [20, 2] + # Unpin. w.unpin() assert w.selection == [20] w.next_by_quality() - assert w.selection == [10] + assert w.selection == [2] + # Pin. w.pin() - assert w.selection == [10, 30] + assert w.selection == [2, 30] + + w.next_by_quality() + assert w.selection == [2, 20] + + # Candidate is best among good. + w.next_by_quality() + assert w.selection == [2, 11] + # Candidate is last among good, ignored are completely ignored. w.next_by_quality() - assert w.selection == [10, 20] + assert w.selection == [2, 1] w.next_by_quality() - assert w.selection == [10, 2] + assert w.selection == [2, 1] def test_wizard_pin_by_similarity(wizard): @@ -255,13 +268,13 @@ def test_wizard_pin_by_similarity(wizard): assert w.selection == [30, 20] w.next_by_similarity() - assert w.selection == [30, 10] + assert w.selection == [30, 11] w.pin() assert w.selection == [30, 20] w.next_by_similarity() - assert w.selection == [30, 10] + assert w.selection == [30, 11] w.unpin() assert w.selection == [30] @@ -276,4 +289,4 @@ def test_wizard_pin_by_similarity(wizard): assert w.selection == [20, 1] w.next_by_similarity() - assert w.selection == [10, 2] + assert w.selection == [11, 2] From 61a4cca75986e6507a71c82ecc8653ff3f5e1aba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 16:49:04 +0200 Subject: [PATCH 0273/1059] Add ClusterMeta.to_dict() --- phy/cluster/manual/_utils.py | 5 +++++ phy/cluster/manual/tests/test_utils.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index b87547bc1..984d970b3 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -128,6 +128,11 @@ def from_dict(self, dic): self.set(field, [cluster], value, add_to_stack=False) self._data_base = deepcopy(self._data) + def to_dict(self, field): + assert field in self._fields, "This field doesn't exist" + return {cluster: self.get(field, cluster) + for cluster in self._data.keys()} + def set(self, field, clusters, value, add_to_stack=True): assert field in self._fields diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 0a1fd6423..651ffe3b0 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -8,6 +8,8 @@ import logging +from pytest import raises + from .._utils import (ClusterMeta, UpdateInfo, _update_cluster_selection, create_cluster_meta) @@ -47,6 +49,10 @@ def test_metadata_history_simple(): meta.redo() assert meta.get('group', 2) == 2 + with raises(AssertionError): + assert meta.to_dict('grou') is None + assert meta.to_dict('group') == {2: 2} + def test_metadata_history_complex(): """Test ClusterMeta history.""" From f9bd8738ed2e72cbb2f9cb3522cfe3eef241e3d5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 17:02:03 +0200 Subject: [PATCH 0274/1059] Add test --- phy/cluster/manual/tests/test_wizard.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index bbddd0da7..2b0bc64b1 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -207,6 +207,28 @@ def status_bis(cluster): assert w.next_selection([30], ignore_group=True) == [20] +def test_wizard_next_bis(wizard): + w = wizard + + # 30, 20, 11, 10, 2, 1, 0 + # N, i, g, g, N, g, i + + @w.set_status_function + def status_bis(cluster): + return {0: 'ignored', + 1: 'good', + 2: None, + 10: 'good', + 11: 'good', + 20: 'ignored', + 30: None, + }[cluster] + + wizard.select([30]) + assert wizard.next_by_quality() == [2] + assert wizard.next_by_quality() == [11] + + def test_wizard_pin_by_quality(wizard): w = wizard From e5e12f99d207d87e7ab993d275eb0c6785bb8d0c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 17:03:51 +0200 Subject: [PATCH 0275/1059] Move _wizard_group() --- phy/cluster/manual/gui_plugins.py | 14 +++++++++++++- phy/cluster/manual/tests/test_gui_plugins.py | 20 ++++++++++++++++---- phy/cluster/manual/tests/test_wizard.py | 9 --------- phy/cluster/manual/wizard.py | 11 ----------- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 080ecb8e7..e004f0d04 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -10,11 +10,12 @@ import logging import numpy as np +from six import string_types from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering -from .wizard import Wizard, _wizard_group +from .wizard import Wizard from phy.gui.actions import Actions, Snippets from phy.io.array import select_spikes from phy.utils.plugin import IPlugin @@ -46,6 +47,17 @@ def _process_ups(ups): # pragma: no cover # Attach wizard to effectors (clustering and cluster_meta) # ----------------------------------------------------------------------------- +def _wizard_group(group): + # The group should be None, 'mua', 'noise', or 'good'. + assert group is None or isinstance(group, string_types) + group = group.lower() if group else group + if group in ('mua', 'noise'): + return 'ignored' + elif group == 'good': + return 'good' + return None + + def _attach_wizard_to_effector(wizard, effector): # Save the current selection when an action occurs. diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 473523ce0..164b682cd 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -12,7 +12,8 @@ from ..clustering import Clustering from .._utils import create_cluster_meta -from ..gui_plugins import (_attach_wizard, +from ..gui_plugins import (_wizard_group, + _attach_wizard, _attach_wizard_to_clustering, _attach_wizard_to_cluster_meta, ) @@ -55,6 +56,14 @@ def assert_selection(*cluster_ids): # pragma: no cover # Test wizard attach #------------------------------------------------------------------------------ +def test_wizard_group(): + assert _wizard_group('noise') == 'ignored' + assert _wizard_group('mua') == 'ignored' + assert _wizard_group('good') == 'good' + assert _wizard_group('unknown') is None + assert _wizard_group(None) is None + + def test_attach_wizard_to_clustering_merge(wizard, cluster_ids): clustering = Clustering(np.array(cluster_ids)) _attach_wizard_to_clustering(wizard, clustering) @@ -106,6 +115,7 @@ def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): assert wizard.selection == [20] cluster_meta.set('group', [20], 'noise') + assert cluster_meta.get('group', 20) == 'noise' assert wizard.selection == [10] cluster_meta.set('group', [10], 'good') @@ -115,10 +125,12 @@ def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): wizard.restart() assert wizard.selection == [30] - # 30, 20, 10, 2, 1, 0 - # N, i, g, N, g, i + # 30, 20, 11, 10, 2, 1, 0 + # N, i, g, g, N, g, i assert wizard.next_by_quality() == [2] - # assert wizard.next_by_quality() == [10] + print(cluster_meta.to_dict('group')) + # TODO + # assert wizard.next_by_quality() == [11] def test_attach_wizard(wizard, cluster_ids, cluster_groups): diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 2b0bc64b1..cab23db33 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -13,7 +13,6 @@ _next_in_list, _best_clusters, _most_similar_clusters, - _wizard_group, _best_quality_strategy, _best_similarity_strategy, Wizard, @@ -133,14 +132,6 @@ def test_wizard_empty(): wizard.restart() -def test_wizard_group(): - assert _wizard_group('noise') == 'ignored' - assert _wizard_group('mua') == 'ignored' - assert _wizard_group('good') == 'good' - assert _wizard_group('unknown') is None - assert _wizard_group(None) is None - - def test_wizard_nav(wizard): w = wizard assert w.cluster_ids == [0, 1, 2, 10, 11, 20, 30] diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index b3bdb5c5c..108e0aa0c 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -73,17 +73,6 @@ def _most_similar_clusters(cluster, cluster_ids=None, n_max=None, return _sort_by_status(clusters, status=status, remove_ignored=True) -def _wizard_group(group): - # The group should be None, 'mua', 'noise', or 'good'. - assert group is None or isinstance(group, string_types) - group = group.lower() if group else group - if group in ('mua', 'noise'): - return 'ignored' - elif group == 'good': - return 'good' - return None - - #------------------------------------------------------------------------------ # Strategy functions #------------------------------------------------------------------------------ From 6f1a312ff101ba922c6ca50110c1628a2ff27629 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 17:28:24 +0200 Subject: [PATCH 0276/1059] Fix bugs --- phy/cluster/manual/gui_plugins.py | 18 +++++++++++------- phy/cluster/manual/tests/conftest.py | 5 +++-- phy/cluster/manual/tests/test_gui_plugins.py | 16 +++++++++------- phy/cluster/manual/wizard.py | 2 -- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index e004f0d04..10a7cd691 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -10,7 +10,6 @@ import logging import numpy as np -from six import string_types from ._history import GlobalHistory from ._utils import create_cluster_meta @@ -47,15 +46,17 @@ def _process_ups(ups): # pragma: no cover # Attach wizard to effectors (clustering and cluster_meta) # ----------------------------------------------------------------------------- +_wizard_group_mapping = { + 'noise': 'ignored', + 'mua': 'ignored', + 'good': 'good', +} + + def _wizard_group(group): # The group should be None, 'mua', 'noise', or 'good'. - assert group is None or isinstance(group, string_types) group = group.lower() if group else group - if group in ('mua', 'noise'): - return 'ignored' - elif group == 'good': - return 'good' - return None + return _wizard_group_mapping.get(group, None) def _attach_wizard_to_effector(wizard, effector): @@ -102,6 +103,9 @@ def _attach_wizard_to_cluster_meta(wizard, cluster_meta): @wizard.set_status_function def status(cluster): group = cluster_meta.get('group', cluster) + # TODO: remove this in order to allow for custom groups. + # For now, it serves as temporary check. + assert group is None or group in _wizard_group_mapping return _wizard_group(group) @cluster_meta.connect diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index ebc775394..033046411 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -9,6 +9,7 @@ from pytest import yield_fixture from ..wizard import Wizard +from ..gui_plugins import _wizard_group #------------------------------------------------------------------------------ @@ -28,12 +29,12 @@ def get_cluster_ids(cluster_ids): @yield_fixture def cluster_groups(): - yield {0: 'ignored', 1: 'good', 10: 'ignored', 11: 'good'} + yield {0: 'noise', 1: 'good', 10: 'mua', 11: 'good'} @yield_fixture def status(cluster_groups): - yield lambda c: cluster_groups.get(c, None) + yield lambda c: _wizard_group(cluster_groups.get(c, None)) @yield_fixture diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 164b682cd..a7f40c512 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -116,21 +116,23 @@ def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): cluster_meta.set('group', [20], 'noise') assert cluster_meta.get('group', 20) == 'noise' - assert wizard.selection == [10] - - cluster_meta.set('group', [10], 'good') assert wizard.selection == [2] + cluster_meta.set('group', [2], 'good') + assert wizard.selection == [11] + # Restart. wizard.restart() assert wizard.selection == [30] # 30, 20, 11, 10, 2, 1, 0 - # N, i, g, g, N, g, i + # N, i, g, i, g, g, i + assert wizard.next_by_quality() == [11] assert wizard.next_by_quality() == [2] - print(cluster_meta.to_dict('group')) - # TODO - # assert wizard.next_by_quality() == [11] + assert wizard.next_by_quality() == [1] + assert wizard.next_by_quality() == [20] + assert wizard.next_by_quality() == [10] + assert wizard.next_by_quality() == [0] def test_attach_wizard(wizard, cluster_ids, cluster_groups): diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 108e0aa0c..f281a29f8 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -10,8 +10,6 @@ import logging from operator import itemgetter -from six import string_types - from ._history import History from phy.utils import EventEmitter, _is_array_like From f00ce0797df0bb41fef2ab6a1503da55dcb2142c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 17:33:32 +0200 Subject: [PATCH 0277/1059] Add attach wizard test --- phy/cluster/manual/tests/test_gui_plugins.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index a7f40c512..ef6aa1ff4 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -135,10 +135,29 @@ def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): assert wizard.next_by_quality() == [0] +def test_attach_wizard_to_cluster_meta_undo(wizard, cluster_groups): + cluster_meta = create_cluster_meta(cluster_groups) + _attach_wizard_to_cluster_meta(wizard, cluster_meta) + + wizard.select([20]) + + cluster_meta.set('group', [20], 'noise') + assert wizard.selection == [2] + + wizard.select([30]) + + cluster_meta.undo() + assert wizard.selection == [20] + + cluster_meta.redo() + assert wizard.selection == [2] + + def test_attach_wizard(wizard, cluster_ids, cluster_groups): clustering = Clustering(np.array(cluster_ids)) cluster_meta = create_cluster_meta(cluster_groups) _attach_wizard(wizard, clustering, cluster_meta) + # TODO #------------------------------------------------------------------------------ From a7c0df5add151646785373324c2837248e45f6d4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 17:45:06 +0200 Subject: [PATCH 0278/1059] WIP: wizard attach tests --- phy/cluster/manual/tests/test_gui_plugins.py | 18 ++++++++++++++++-- phy/cluster/manual/tests/test_wizard.py | 6 ++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index ef6aa1ff4..b970349c7 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -144,7 +144,8 @@ def test_attach_wizard_to_cluster_meta_undo(wizard, cluster_groups): cluster_meta.set('group', [20], 'noise') assert wizard.selection == [2] - wizard.select([30]) + wizard.next_by_quality() + assert wizard.selection == [11] cluster_meta.undo() assert wizard.selection == [20] @@ -153,11 +154,24 @@ def test_attach_wizard_to_cluster_meta_undo(wizard, cluster_groups): assert wizard.selection == [2] -def test_attach_wizard(wizard, cluster_ids, cluster_groups): +def test_attach_wizard_1(wizard, cluster_ids, cluster_groups): clustering = Clustering(np.array(cluster_ids)) cluster_meta = create_cluster_meta(cluster_groups) _attach_wizard(wizard, clustering, cluster_meta) + + wizard.restart() + assert wizard.selection == [30] + + wizard.pin() + assert wizard.selection == [30, 20] + + clustering.merge(wizard.selection) + assert wizard.selection == [31, 2] + assert cluster_meta.get('group', 31) is None + + wizard.next_by_quality() # TODO + # assert wizard.selection == [31, 11] #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index cab23db33..759650bc4 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -71,10 +71,16 @@ def _similar(cluster): assert not _similar(None) assert not _similar(100) + assert _similar(0) == [30, 20, 2, 11, 1] assert _similar(1) == [30, 20, 2, 11] assert _similar(2) == [30, 20, 11, 1] + assert _similar(10) == [30, 20, 2, 11, 1] + assert _similar(11) == [30, 20, 2, 1] + assert _similar(20) == [30, 2, 11, 1] + assert _similar(30) == [20, 2, 11, 1] + #------------------------------------------------------------------------------ # Test strategy functions From 3ebda85ca33c8a0d2561340d6d997c4d8d661cda Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 18:14:05 +0200 Subject: [PATCH 0279/1059] Fix bug --- phy/cluster/manual/tests/test_gui_plugins.py | 4 +- phy/cluster/manual/tests/test_wizard.py | 39 ++++++++++++++++++-- phy/cluster/manual/wizard.py | 22 ++++------- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index b970349c7..23404ed03 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -170,8 +170,8 @@ def test_attach_wizard_1(wizard, cluster_ids, cluster_groups): assert cluster_meta.get('group', 31) is None wizard.next_by_quality() - # TODO - # assert wizard.selection == [31, 11] + print(cluster_meta.to_dict('group')) + assert wizard.selection == [31, 11] #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 759650bc4..00b2767f1 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -86,7 +86,7 @@ def _similar(cluster): # Test strategy functions #------------------------------------------------------------------------------ -def test_best_quality_strategy(cluster_ids, quality, status, similarity): +def test_best_quality_strategy_1(cluster_ids, quality, status, similarity): def _next(selection): return _best_quality_strategy(selection, @@ -102,10 +102,26 @@ def _next(selection): assert _next([2]) == [11] assert _next([30, 20]) == [30, 2] - assert _next([10, 2]) == [10, 1] + assert _next([10, 2]) == [10, 11] + assert _next([10, 11]) == [10, 1] assert _next([10, 1]) == [10, 1] # 0 is ignored, so it does not appear. +def test_best_quality_strategy_2(quality, similarity): + + def status(cluster): + return {0: 'ignored', 1: None, 2: 'good', 3: None}[cluster] + + def _next(selection): + return _best_quality_strategy(selection, + cluster_ids=list(range(4)), + quality=quality, + status=status, + similarity=similarity) + + assert _next([3, 1]) == [3, 2] + + def test_best_similarity_strategy(cluster_ids, quality, status, similarity): def _next(selection): @@ -183,7 +199,7 @@ def test_wizard_nav(wizard): assert w.selection == [1, 2] -def test_wizard_next(wizard, status): +def test_wizard_next_1(wizard, status): w = wizard assert w.next_selection([30]) == [20] @@ -204,7 +220,7 @@ def status_bis(cluster): assert w.next_selection([30], ignore_group=True) == [20] -def test_wizard_next_bis(wizard): +def test_wizard_next_2(wizard): w = wizard # 30, 20, 11, 10, 2, 1, 0 @@ -226,6 +242,21 @@ def status_bis(cluster): assert wizard.next_by_quality() == [11] +def test_wizard_next_3(wizard): + w = wizard + + @w.set_cluster_ids_function + def cluster_ids(): + return [0, 1, 2, 3] + + @w.set_status_function + def status_bis(cluster): + return {0: 'ignored', 1: None, 2: 'good', 3: None}[cluster] + + wizard.select([3, 1]) + assert wizard.next_by_quality() == [3, 2] + + def test_wizard_pin_by_quality(wizard): w = wizard diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index f281a29f8..591326a49 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -47,7 +47,7 @@ def _sort_by_status(clusters, status=None, remove_ignored=False): # NOTE: sorted is "stable": it doesn't change the order of elements # that compare equal, which ensures that the order of clusters is kept # among any given status. - key = lambda cluster: _sort_map.get(status(cluster), 0) + key = lambda cluster: _sort_map[status(cluster)] return sorted(clusters, key=key) @@ -57,18 +57,16 @@ def _best_clusters(clusters, quality, n_max=None): def _most_similar_clusters(cluster, cluster_ids=None, n_max=None, - similarity=None, status=None, less_than=None): + similarity=None, status=None): """Return the `n_max` most similar clusters to a given cluster.""" if cluster not in cluster_ids: return [] s = [(other, similarity(cluster, other)) for other in cluster_ids if other != cluster and status(other) != 'ignored'] - # Only keep values less than a threshold. - if less_than: - s = [(c, v) for (c, v) in s if v <= less_than] clusters = _argsort(s, n_max=n_max) - return _sort_by_status(clusters, status=status, remove_ignored=True) + out = _sort_by_status(clusters, status=status) + return out #------------------------------------------------------------------------------ @@ -102,19 +100,15 @@ def _best_quality_strategy(selection, return selection elif n == 2: best, match = selection - value = similarity(best, match) candidates = _most_similar_clusters(best, cluster_ids=cluster_ids, similarity=similarity, status=status, - less_than=value) - if best in candidates: # pragma: no cover - candidates.remove(best) - if match in candidates: - candidates.remove(match) - if not candidates: + ) + if not candidates: # pragma: no cover return selection - return [best, candidates[0]] + candidate = _next_in_list(candidates, match) + return [best, candidate] def _best_similarity_strategy(selection, From 4424ed0d3eefb32a63bcfc1a35f9a7638d592699 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 8 Oct 2015 18:34:15 +0200 Subject: [PATCH 0280/1059] WIP: more tests --- phy/cluster/manual/gui_plugins.py | 12 +--- phy/cluster/manual/tests/test_gui_plugins.py | 59 +++++++++++++++++++- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugins.py index 10a7cd691..e76813156 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugins.py @@ -103,9 +103,6 @@ def _attach_wizard_to_cluster_meta(wizard, cluster_meta): @wizard.set_status_function def status(cluster): group = cluster_meta.get('group', cluster) - # TODO: remove this in order to allow for custom groups. - # For now, it serves as temporary check. - assert group is None or group in _wizard_group_mapping return _wizard_group(group) @cluster_meta.connect @@ -113,15 +110,12 @@ def on_cluster(up): if up.description == 'metadata_group' and up.history != 'undo': cluster = up.metadata_changed[0] wizard.next_selection([cluster], ignore_group=True) + # TODO: pin after a move? Yes if the previous selection >= 2, no + # otherwise. See similar note above. + # wizard.pin() def _attach_wizard(wizard, clustering, cluster_meta): - @clustering.connect - def on_cluster(up): - # Set the cluster metadata of new clusters. - if up.added: - cluster_meta.set_from_descendants(up.descendants) - _attach_wizard_to_clustering(wizard, clustering) _attach_wizard_to_cluster_meta(wizard, cluster_meta) diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugins.py index 23404ed03..e231ab4ac 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugins.py @@ -170,9 +170,42 @@ def test_attach_wizard_1(wizard, cluster_ids, cluster_groups): assert cluster_meta.get('group', 31) is None wizard.next_by_quality() - print(cluster_meta.to_dict('group')) assert wizard.selection == [31, 11] + clustering.undo() + assert wizard.selection == [30, 20] + + +def test_attach_wizard_2(wizard, cluster_ids, cluster_groups): + clustering = Clustering(np.array(cluster_ids)) + cluster_meta = create_cluster_meta(cluster_groups) + _attach_wizard(wizard, clustering, cluster_meta) + + wizard.select([30, 20]) + assert wizard.selection == [30, 20] + + clustering.split([1]) + assert wizard.selection == [31, 30] + assert cluster_meta.get('group', 31) is None + + wizard.next_by_quality() + assert wizard.selection == [31, 20] + + clustering.undo() + assert wizard.selection == [30, 20] + + +def test_attach_wizard_3(wizard, cluster_ids, cluster_groups): + clustering = Clustering(np.array(cluster_ids)) + cluster_meta = create_cluster_meta(cluster_groups) + _attach_wizard(wizard, clustering, cluster_meta) + + wizard.select([30, 20]) + assert wizard.selection == [30, 20] + + cluster_meta.set('group', 30, 'noise') + assert wizard.selection == [20] + #------------------------------------------------------------------------------ # Test GUI plugins @@ -216,4 +249,26 @@ def test_manual_clustering_merge(manual_clustering): mc.actions.select([30, 20]) mc.actions.merge() - # assert_selection(31, 10) + assert_selection(31, 2) + + +def test_manual_clustering_split(manual_clustering): + mc, assert_selection = manual_clustering + + mc.actions.select([1, 2]) + mc.actions.split([1, 2]) + assert_selection(31, 20) + + +def test_manual_clustering_move(manual_clustering): + mc, assert_selection = manual_clustering + + mc.actions.select([30]) + assert_selection(30) + + # TODO: set quality and similarity functions + # mc.actions.next_by_quality() + # assert_selection(20) + + # mc.actions.move([20], 'noise') + # assert_selection(2) From 7aacfa0c2bd32b0a0d4ab841f516c628f8627913 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 9 Oct 2015 12:53:06 +0200 Subject: [PATCH 0281/1059] Rename gui_plugins to gui_plugin --- phy/cluster/manual/__init__.py | 2 +- phy/cluster/manual/{gui_plugins.py => gui_plugin.py} | 4 ++-- phy/cluster/manual/tests/conftest.py | 2 +- .../{test_gui_plugins.py => test_gui_plugin.py} | 12 ++++++------ 4 files changed, 10 insertions(+), 10 deletions(-) rename phy/cluster/manual/{gui_plugins.py => gui_plugin.py} (99%) rename phy/cluster/manual/tests/{test_gui_plugins.py => test_gui_plugin.py} (96%) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 9dd61ab42..8cc68c734 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,4 +5,4 @@ from .clustering import Clustering from .wizard import Wizard -from .gui_plugins import ManualClustering +from .gui_plugin import ManualClustering diff --git a/phy/cluster/manual/gui_plugins.py b/phy/cluster/manual/gui_plugin.py similarity index 99% rename from phy/cluster/manual/gui_plugins.py rename to phy/cluster/manual/gui_plugin.py index e76813156..29b189040 100644 --- a/phy/cluster/manual/gui_plugins.py +++ b/phy/cluster/manual/gui_plugin.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Manual clustering GUI plugins.""" +"""Manual clustering GUI plugin.""" # ----------------------------------------------------------------------------- @@ -121,7 +121,7 @@ def _attach_wizard(wizard, clustering, cluster_meta): # ----------------------------------------------------------------------------- -# Clustering GUI plugins +# Clustering GUI plugin # ----------------------------------------------------------------------------- class ManualClustering(IPlugin): diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 033046411..c4e3f0d5f 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -9,7 +9,7 @@ from pytest import yield_fixture from ..wizard import Wizard -from ..gui_plugins import _wizard_group +from ..gui_plugin import _wizard_group #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/tests/test_gui_plugins.py b/phy/cluster/manual/tests/test_gui_plugin.py similarity index 96% rename from phy/cluster/manual/tests/test_gui_plugins.py rename to phy/cluster/manual/tests/test_gui_plugin.py index e231ab4ac..ddd1e24d0 100644 --- a/phy/cluster/manual/tests/test_gui_plugins.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Test GUI plugins.""" +"""Test GUI plugin.""" #------------------------------------------------------------------------------ # Imports @@ -12,11 +12,11 @@ from ..clustering import Clustering from .._utils import create_cluster_meta -from ..gui_plugins import (_wizard_group, - _attach_wizard, - _attach_wizard_to_clustering, - _attach_wizard_to_cluster_meta, - ) +from ..gui_plugin import (_wizard_group, + _attach_wizard, + _attach_wizard_to_clustering, + _attach_wizard_to_cluster_meta, + ) from phy.gui.tests.test_gui import gui # noqa From 52f4aab62736f7623bfd564d13ab0ec62fb9de03 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 9 Oct 2015 19:02:24 +0200 Subject: [PATCH 0282/1059] WIP: refactor actions --- phy/gui/actions.py | 71 +++++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 6d5a5769a..3b77dc6e4 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -13,6 +13,7 @@ from six import string_types, PY3 from .qt import QtGui +from phy.utils import Bunch from phy.utils.event import EventEmitter logger = logging.getLogger(__name__) @@ -79,6 +80,13 @@ def _show_shortcuts(shortcuts, name=None): # Actions # ----------------------------------------------------------------------------- +def _alias_name(name): + # Get the alias from the character after & if it exists. + alias = name[name.index('&') + 1] if '&' in name else name + name = name.replace('&', '') + return alias, name + + class Actions(EventEmitter): """Handle GUI actions. @@ -123,6 +131,28 @@ def on_reset(): def exit(): gui.close() + def _create_action_bunch(self, callback=None, name=None, shortcut=None, + alias=None, checkable=False, checked=False): + + # Create the QAction instance. + action = QtGui.QAction(name, self._gui) + action.triggered.connect(callback) + action.setCheckable(checkable) + action.setChecked(checked) + if shortcut: + if not isinstance(shortcut, (tuple, list)): + shortcut = [shortcut] + for key in shortcut: + action.setShortcut(key) + + # HACK: add the shortcut string to the QAction object so that + # it can be shown in show_shortcuts(). I don't manage to recover + # the key sequence string from a QAction using Qt. + shortcut = shortcut or '' + + return Bunch(qaction=action, name=name, alias=alias, + shortcut=shortcut, callback=callback) + def add(self, callback=None, name=None, shortcut=None, alias=None, checkable=False, checked=False): """Add an action with a keyboard shortcut.""" @@ -137,37 +167,20 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, assert callback name = name or callback.__name__ - # Get the alias from the character after & if it exists. if alias is None: - alias = name[name.index('&') + 1] if '&' in name else name - name = name.replace('&', '') + alias, name = _alias_name(name) + if name in self._actions: return - # Create the QAction instance. - action = QtGui.QAction(name, self._gui) - action.triggered.connect(callback) - action.setCheckable(checkable) - action.setChecked(checked) - if shortcut: - if not isinstance(shortcut, (tuple, list)): - shortcut = [shortcut] - for key in shortcut: - action.setShortcut(key) - - # Add some attributes to the QAction instance. - # The alias is used in snippets. - action._alias = alias - action._callback = callback - action._name = name - # HACK: add the shortcut string to the QAction object so that - # it can be shown in show_shortcuts(). I don't manage to recover - # the key sequence string from a QAction using Qt. - action._shortcut_string = shortcut or '' + action = self._create_action_bunch(name=name, + alias=alias, + shortcut=shortcut, + callback=callback) # Register the action. if self._gui: - self._gui.addAction(action) + self._gui.addAction(action.qaction) self._actions[name] = action # Log the creation of the action. @@ -182,7 +195,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, def get_name(self, alias_or_name): """Return an action name from its alias or name.""" for name, action in self._actions.items(): - if alias_or_name in (action._alias, name): + if alias_or_name in (action.alias, name): return name raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) @@ -193,15 +206,15 @@ def run(self, action, *args): assert name in self._actions action = self._actions[name] else: - name = action._name + name = action.name if not name.startswith('_'): logger.debug("Execute action `%s`.", name) - return action._callback(*args) + return action.callback(*args) def remove(self, name): """Remove an action.""" if self._gui: - self._gui.removeAction(self._actions[name]) + self._gui.removeAction(self._actions[name].qaction) del self._actions[name] delattr(self, name) @@ -214,7 +227,7 @@ def remove_all(self): @property def shortcuts(self): """A dictionary of action shortcuts.""" - return {name: action._shortcut_string + return {name: action.shortcut for name, action in self._actions.items()} def show_shortcuts(self): From ad2e2bb5e210e808ede0d197442d396b916f8a25 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 9 Oct 2015 19:14:48 +0200 Subject: [PATCH 0283/1059] Add Actions.change_shortcut() --- phy/gui/actions.py | 21 ++++++++++++++++----- phy/gui/tests/test_actions.py | 5 +++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 3b77dc6e4..6f33b091b 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -87,6 +87,15 @@ def _alias_name(name): return alias, name +def _set_shortcut(action, shortcut): + if not shortcut: + return + if not isinstance(shortcut, (tuple, list)): + shortcut = [shortcut] + for key in shortcut: + action.setShortcut(key) + + class Actions(EventEmitter): """Handle GUI actions. @@ -139,11 +148,7 @@ def _create_action_bunch(self, callback=None, name=None, shortcut=None, action.triggered.connect(callback) action.setCheckable(checkable) action.setChecked(checked) - if shortcut: - if not isinstance(shortcut, (tuple, list)): - shortcut = [shortcut] - for key in shortcut: - action.setShortcut(key) + _set_shortcut(action, shortcut) # HACK: add the shortcut string to the QAction object so that # it can be shown in show_shortcuts(). I don't manage to recover @@ -199,6 +204,12 @@ def get_name(self, alias_or_name): return name raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) + def change_shortcut(self, name, shortcut): + assert name in self._actions, "This action doesn't exist." + action = self._actions[name] + action.shortcut = shortcut + _set_shortcut(action.qaction, shortcut) + def run(self, action, *args): """Run an action, specified by its name or object.""" if isinstance(action, string_types): diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index db95bfb4f..e3efd9115 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -64,6 +64,11 @@ def show_my_shortcuts(): assert 'show_my_shortcuts' in _captured[0] assert ': h' in _captured[0] + actions.change_shortcut('show_my_shortcuts', 'l') + actions.show_my_shortcuts() + assert 'show_my_shortcuts' in _captured[0] + assert ': l' in _captured[-1] + with raises(ValueError): assert actions.get_name('e') assert actions.get_name('t') == 'test' From 96fa459b4762a8e29e38ff7b4dee04b1203a6a67 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 11 Oct 2015 17:57:20 +0200 Subject: [PATCH 0284/1059] Increase coverage in utils --- phy/utils/tests/test_types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/phy/utils/tests/test_types.py b/phy/utils/tests/test_types.py index 517e7a211..115785e4b 100644 --- a/phy/utils/tests/test_types.py +++ b/phy/utils/tests/test_types.py @@ -7,6 +7,7 @@ #------------------------------------------------------------------------------ import numpy as np +from pytest import raises from .._types import (Bunch, _is_integer, _is_list, _is_float, _as_list, _is_array_like, _as_array, _as_tuple, @@ -74,6 +75,10 @@ def _check(arr): _check(_as_array(3, np.float)) _check(_as_array(3., np.float)) _check(_as_array([3], np.float)) + _check(_as_array(np.array([3]))) + with raises(ValueError): + _check(_as_array(np.array([3]), dtype=np.object)) + _check(_as_array(np.array([3]), np.float)) assert _as_array(None) is None assert not _is_array_like(None) From e3bd88361f5ec74a9eece9f1abae609b52cc529c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 11 Oct 2015 17:57:56 +0200 Subject: [PATCH 0285/1059] Increase coverage in plugin --- phy/utils/plugin.py | 2 +- phy/utils/tests/test_plugin.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index f77ce1fa6..c5a915127 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -37,7 +37,7 @@ def __init__(cls, name, bases, attrs): class IPlugin(with_metaclass(IPluginRegistry)): - def attach_to_gui(self, gui, *args, **kwargs): # pragma: no cover + def attach_to_gui(self, gui, *args, **kwargs): pass diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 15efb397f..ec1e08f74 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -43,6 +43,8 @@ class MyPlugin(IPlugin): with raises(ValueError): get_plugin('unknown') + get_plugin('myplugin')().attach_to_gui(None) + def test_discover_plugins(tempdir, no_native_plugins): path = op.join(tempdir, 'my_plugin.py') From 22d9ebfc461ab7d7f67f8c9662b0a3cda0298ecc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 11 Oct 2015 18:08:30 +0200 Subject: [PATCH 0286/1059] Modified plugin interface --- phy/cluster/manual/gui_plugin.py | 59 +++++++++++++++++--------------- phy/gui/gui.py | 4 +-- phy/gui/tests/test_gui.py | 9 +++-- phy/utils/plugin.py | 2 +- 4 files changed, 41 insertions(+), 33 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index 29b189040..760769e7d 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -153,12 +153,12 @@ class ManualClustering(IPlugin): save_requested(spike_clusters, cluster_groups) """ - def attach_to_gui(self, gui, - spike_clusters=None, - cluster_groups=None, - n_spikes_max_per_cluster=100, - ): - self.gui = gui + def __init__(self, spike_clusters=None, + cluster_groups=None, + n_spikes_max_per_cluster=100, + ): + + self.n_spikes_max_per_cluster = n_spikes_max_per_cluster # Create Clustering and ClusterMeta. self.clustering = Clustering(spike_clusters) @@ -169,26 +169,12 @@ def attach_to_gui(self, gui, self.wizard = Wizard() _attach_wizard(self.wizard, self.clustering, self.cluster_meta) - @self.wizard.connect - def on_select(cluster_ids): - """When the wizard selects clusters, choose a spikes subset - and emit the `select` event on the GUI. - - The wizard is responsible for the notion of "selected clusters". - - """ - spike_ids = select_spikes(np.array(cluster_ids), - n_spikes_max_per_cluster, - self.clustering.spikes_per_cluster) - gui.emit('select', cluster_ids, spike_ids) - - self.create_actions(gui) + # Create the actions. + self._create_actions() - return self - - def create_actions(self, gui): + def _create_actions(self): self.actions = actions = Actions() - self.snippets = snippets = Snippets() + self.snippets = Snippets() # Create the default actions for the clustering GUI. @actions.connect @@ -212,11 +198,30 @@ def on_reset(): actions.add(callback=self.undo) actions.add(callback=self.redo) - # Attach the GUI and register the actions. - snippets.attach(gui, actions) - actions.attach(gui) actions.reset() + def attach_to_gui(self, gui): + self.gui = gui + + @self.wizard.connect + def on_select(cluster_ids): + """When the wizard selects clusters, choose a spikes subset + and emit the `select` event on the GUI. + + The wizard is responsible for the notion of "selected clusters". + + """ + spike_ids = select_spikes(np.array(cluster_ids), + self.n_spikes_max_per_cluster, + self.clustering.spikes_per_cluster) + gui.emit('select', cluster_ids, spike_ids) + + # Attach the GUI and register the actions. + self.snippets.attach(gui, self.actions) + self.actions.attach(gui) + + return self + # Wizard-related actions # ------------------------------------------------------------------------- diff --git a/phy/gui/gui.py b/phy/gui/gui.py index ff15e0d4b..9431e6b60 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -91,8 +91,8 @@ def attach(self, plugin, *args, **kwargs): """Attach a plugin to the GUI.""" if isinstance(plugin, string_types): # Instantiate the plugin if the name is given. - plugin = get_plugin(plugin)() - return plugin.attach_to_gui(self, *args, **kwargs) + plugin = get_plugin(plugin)(*args, **kwargs) + return plugin.attach_to_gui(self) # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 25b1ac83e..da85e23de 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -157,12 +157,15 @@ def on_close_widget(): def test_gui_plugin(qtbot, gui): class TestPlugin(IPlugin): + def __init__(self, arg): + self._arg = arg + def attach_to_gui(self, gui): - gui._attached = True + gui._attached = self._arg return 'attached' - assert gui.attach('testplugin') == 'attached' - assert gui._attached + assert gui.attach('testplugin', 3) == 'attached' + assert gui._attached == 3 def test_gui_status_message(qtbot): diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index c5a915127..7e2eb8628 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -37,7 +37,7 @@ def __init__(cls, name, bases, attrs): class IPlugin(with_metaclass(IPluginRegistry)): - def attach_to_gui(self, gui, *args, **kwargs): + def attach_to_gui(self, gui): pass From 31de2bc80d2abc5cfc653c215084259523a95069 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 11 Oct 2015 18:46:18 +0200 Subject: [PATCH 0287/1059] Try to change NumPy version on travis --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 32b798762..7dc2cf958 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,7 +21,7 @@ install: # Create the environment. - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION - source activate testenv - - conda install pip numpy vispy matplotlib scipy h5py pyqt ipython requests six dill ipyparallel joblib dask + - conda install pip numpy=1.9 vispy matplotlib scipy h5py pyqt ipython requests six dill ipyparallel joblib dask # NOTE: cython is only temporarily needed for building KK2. # Once we have KK2 binary builds on binstar we can remove this dependency. - conda install cython && pip install klustakwik2 From d46f93d19a7d68307d0b53b78d0704dc1467bdcb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 10:45:33 +0200 Subject: [PATCH 0288/1059] Rename channel_mapping to site_label_to_traces_row --- phy/traces/spike_detect.py | 20 +++++++++++--------- phy/traces/tests/test_spike_detect.py | 5 +++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 0ce1f20db..2929a4d99 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -95,7 +95,7 @@ class SpikeDetector(Task): extract_s_after = Int(10) weight_power = Float(2) - def set_metadata(self, probe, channel_mapping=None, + def set_metadata(self, probe, site_label_to_traces_row=None, sample_rate=None): assert isinstance(probe, MEA) self.probe = probe @@ -104,24 +104,26 @@ def set_metadata(self, probe, channel_mapping=None, self.sample_rate = sample_rate # Channel mapping. - if channel_mapping is None: - channel_mapping = {c: c for c in probe.channels} + if site_label_to_traces_row is None: + site_label_to_traces_row = {c: c for c in probe.channels} # Remove channels mapped to None or a negative value: they are dead. - channel_mapping = {k: v for (k, v) in channel_mapping.items() - if v is not None and v >= 0} + site_label_to_traces_row = {k: v for (k, v) in + site_label_to_traces_row.items() + if v is not None and v >= 0} # channel mappings is {trace_col: channel_id}. # Trace columns and channel ids to keep. - self.trace_cols = sorted(channel_mapping.keys()) - self.channel_ids = sorted(channel_mapping.values()) + self.trace_cols = sorted(site_label_to_traces_row.keys()) + self.channel_ids = sorted(site_label_to_traces_row.values()) # The key is the col in traces, the val is the channel id. adj = self.probe.adjacency # Numbers are all channel ids. # First, we subset the adjacency list with the kept channel ids. adj = _adjacency_subset(adj, self.channel_ids) # Then, we remap to convert from channel ids to trace columns. # We need to inverse the mapping. - channel_mapping_inv = {v: c for (c, v) in channel_mapping.items()} + site_label_to_traces_row_inv = {v: c for (c, v) in + site_label_to_traces_row.items()} # Now, the adjacency list contains trace column numbers. - adj = _remap_adjacency(adj, channel_mapping_inv) + adj = _remap_adjacency(adj, site_label_to_traces_row_inv) assert set(adj) <= set(self.trace_cols) # Finally, we need to remap with relative column indices. rel_mapping = {c: i for (i, c) in enumerate(self.trace_cols)} diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 1b9ec92fe..1424fe21b 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -40,13 +40,14 @@ def spike_detector(request): remap = request.param[0] probe = load_probe('1x32_buzsaki') - channel_mapping = {i: i for i in range(1, 21, 2)} if remap else None + site_label_to_traces_row = ({i: i for i in range(1, 21, 2)} + if remap else None) sd = SpikeDetector() sd.use_single_threshold = False sample_rate = 20000 sd.set_metadata(probe, - channel_mapping=channel_mapping, + site_label_to_traces_row=site_label_to_traces_row, sample_rate=sample_rate) yield sd From 48aa73912107813fed340935f81acf83ddd37d3b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 11:42:21 +0200 Subject: [PATCH 0289/1059] Add some cluster stats --- phy/stats/clusters.py | 69 +++++++++++++++++ phy/stats/tests/test_clusters.py | 124 +++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 phy/stats/clusters.py create mode 100644 phy/stats/tests/test_clusters.py diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py new file mode 100644 index 000000000..2f731bb6c --- /dev/null +++ b/phy/stats/clusters.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +"""Cluster statistics.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + + +#------------------------------------------------------------------------------ +# Cluster statistics +#------------------------------------------------------------------------------ + +def mean(x): + return x.mean(axis=0) + + +def unmasked_channels(mean_masks, min_mask=.1): + return np.nonzero(mean_masks > min_mask)[0] + + +def mean_probe_position(mean_masks, site_positions): + return (np.sum(site_positions * mean_masks[:, np.newaxis], axis=0) / + max(1, np.sum(mean_masks))) + + +def sorted_main_channels(mean_masks, unmasked_channels): + # Weighted mean of the channels, weighted by the mean masks. + main_channels = np.argsort(mean_masks)[::-1] + main_channels = np.array([c for c in main_channels + if c in unmasked_channels]) + return main_channels + + +#------------------------------------------------------------------------------ +# Wizard measures +#------------------------------------------------------------------------------ + +def mean_masked_features_distance(mean_features_0, + mean_features_1, + mean_masks_0, + mean_masks_1, + n_features_per_channel=None, + ): + """Compute the distance between the mean masked features.""" + + assert n_features_per_channel > 0 + + mu_0 = mean_features_0.ravel() + mu_1 = mean_features_1.ravel() + + omeg_0 = mean_masks_0 + omeg_1 = mean_masks_1 + + omeg_0 = np.repeat(omeg_0, n_features_per_channel) + omeg_1 = np.repeat(omeg_1, n_features_per_channel) + + d_0 = mu_0 * omeg_0 + d_1 = mu_1 * omeg_1 + + return np.linalg.norm(d_0 - d_1) + + +def max_mean_masks(mean_masks): + """Return the maximum mean_masks across all channels + for a given cluster.""" + return mean_masks.max() diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py new file mode 100644 index 000000000..9de8a6eab --- /dev/null +++ b/phy/stats/tests/test_clusters.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +"""Tests of cluster statistics.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_array_equal as ae +from numpy.testing import assert_allclose as ac +from pytest import yield_fixture + +from ..clusters import (mean, + unmasked_channels, + mean_probe_position, + sorted_main_channels, + mean_masked_features_distance, + max_mean_masks, + ) +from phy.electrode.mea import staggered_positions +from phy.io.mock import (artificial_features, + artificial_masks, + ) + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def n_channels(): + yield 28 + + +@yield_fixture +def n_spikes(): + yield 50 + + +@yield_fixture +def n_features_per_channel(): + yield 4 + + +@yield_fixture +def features(n_spikes, n_channels, n_features_per_channel): + yield artificial_features(n_spikes, n_channels, n_features_per_channel) + + +@yield_fixture +def masks(n_spikes, n_channels): + yield artificial_masks(n_spikes, n_channels) + + +@yield_fixture +def site_positions(n_channels): + yield staggered_positions(n_channels) + + +#------------------------------------------------------------------------------ +# Tests +#------------------------------------------------------------------------------ + +def test_mean(features, n_channels, n_features_per_channel): + mf = mean(features) + assert mf.shape == (n_channels, n_features_per_channel) + ae(mf, features.mean(axis=0)) + + +def test_unmasked_channels(masks, n_channels): + # Mask many values in the masks array. + threshold = .05 + masks[:, 1::2] *= threshold + # Compute the mean masks. + mean_masks = mean(masks) + # Find the unmasked channels. + channels = unmasked_channels(mean_masks, threshold) + # These are 0, 2, 4, etc. + ae(channels, np.arange(0, n_channels, 2)) + + +def test_mean_probe_position(masks, site_positions): + masks[:, ::2] *= .05 + mean_masks = mean(masks) + mean_pos = mean_probe_position(mean_masks, site_positions) + assert mean_pos.shape == (2,) + assert mean_pos[0] < 0 + assert mean_pos[1] > 0 + + +def test_sorted_main_channels(masks): + masks *= .05 + masks[:, [5, 7]] *= 20 + mean_masks = mean(masks) + channels = sorted_main_channels(mean_masks, unmasked_channels(mean_masks)) + assert np.all(np.in1d(channels, [5, 7])) + + +def test_mean_masked_features_distance(features, + n_channels, + n_features_per_channel, + ): + + # Shifted feature vectors. + shift = 10. + f0 = mean(features) + f1 = mean(features) + shift + + # Only one channel is unmasked. + m0 = m1 = np.zeros(n_channels) + m0[n_channels // 2] = 1 + + # Check the distance. + d_expected = np.sqrt(n_features_per_channel) * shift + d_computed = mean_masked_features_distance(f0, f1, m0, m1, + n_features_per_channel) + ac(d_expected, d_computed) + + +def test_max_mean_masks(masks): + mean_masks = mean(masks) + mmm = max_mean_masks(mean_masks) + assert mmm > .4 From 4784959c1942f7aa2024e14b5762b23fe7bebc38 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 11:49:22 +0200 Subject: [PATCH 0290/1059] Fix setup --- Makefile | 8 +------- setup.cfg | 3 ++- setup.py | 26 -------------------------- 3 files changed, 3 insertions(+), 34 deletions(-) diff --git a/Makefile b/Makefile index 5059b040c..57f9de57a 100644 --- a/Makefile +++ b/Makefile @@ -15,17 +15,11 @@ lint: flake8 phy test: lint - python setup.py test + py.test coverage: coverage --html -unit-tests: lint - python setup.py test -a phy - -integration-tests: lint - python setup.py test -a tests - apidoc: python tools/api.py diff --git a/setup.cfg b/setup.cfg index a0583bf00..4c0a80757 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,8 @@ universal = 1 [pytest] -norecursedirs = plot experimental +addopts = --cov-report term-missing --cov=phy phy +norecursedirs = plot experimental _* [flake8] ignore=E265 diff --git a/setup.py b/setup.py index 28b2b3bd9..588e2950f 100644 --- a/setup.py +++ b/setup.py @@ -10,39 +10,15 @@ import os import os.path as op -import sys import re from setuptools import setup -from setuptools.command.test import test as TestCommand #------------------------------------------------------------------------------ # Setup #------------------------------------------------------------------------------ -class PyTest(TestCommand): - user_options = [('pytest-args=', 'a', - "String of arguments to pass to py.test")] - - def initialize_options(self): - TestCommand.initialize_options(self) - self.pytest_args = '--cov-report term-missing --cov=phy phy' - - def finalize_options(self): - TestCommand.finalize_options(self) - self.test_args = [] - self.test_suite = True - - def run_tests(self): - #import here, cause outside the eggs aren't loaded - import pytest - pytest_string = self.pytest_args - print("Running: py.test " + pytest_string) - errno = pytest.main(pytest_string) - sys.exit(errno) - - def _package_tree(pkgroot): path = op.dirname(__file__) subdirs = [op.relpath(i[0], path).replace(op.sep, '.') @@ -81,7 +57,6 @@ def _package_tree(pkgroot): ], }, include_package_data=True, - # zip_safe=False, keywords='phy,data analysis,electrophysiology,neuroscience', classifiers=[ 'Development Status :: 4 - Beta', @@ -94,5 +69,4 @@ def _package_tree(pkgroot): 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', ], - cmdclass={'test': PyTest}, ) From 221948f9a9621bb961255316351bfd5f58521a25 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 12:05:58 +0200 Subject: [PATCH 0291/1059] Move py.test options to Makefile --- Makefile | 2 +- setup.cfg | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 57f9de57a..1fa8cea14 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ lint: flake8 phy test: lint - py.test + py.test --cov-report term-missing --cov=phy phy coverage: coverage --html diff --git a/setup.cfg b/setup.cfg index 4c0a80757..c138ac336 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,6 @@ universal = 1 [pytest] -addopts = --cov-report term-missing --cov=phy phy norecursedirs = plot experimental _* [flake8] From cbec3e17c29bc1b0c8a91a7d3df22cde55c5bdc9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 12:06:10 +0200 Subject: [PATCH 0292/1059] Add waveform-based cluster quality measure --- phy/stats/clusters.py | 21 ++++++++++++++------ phy/stats/tests/test_clusters.py | 33 +++++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index 2f731bb6c..9df52acea 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -38,6 +38,21 @@ def sorted_main_channels(mean_masks, unmasked_channels): # Wizard measures #------------------------------------------------------------------------------ +def max_waveform_amplitude(mean_masks, mean_waveforms): + + assert mean_waveforms.ndim == 2 + n_samples, n_channels = mean_waveforms.shape + + assert mean_masks.ndim == 1 + assert mean_masks.shape == (n_channels,) + + mean_waveforms = mean_masks * mean_waveforms + + # Amplitudes. + m, M = mean_waveforms.min(axis=1), mean_waveforms.max(axis=1) + return np.max(M - m) + + def mean_masked_features_distance(mean_features_0, mean_features_1, mean_masks_0, @@ -61,9 +76,3 @@ def mean_masked_features_distance(mean_features_0, d_1 = mu_1 * omeg_1 return np.linalg.norm(d_0 - d_1) - - -def max_mean_masks(mean_masks): - """Return the maximum mean_masks across all channels - for a given cluster.""" - return mean_masks.max() diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py index 9de8a6eab..f9de58370 100644 --- a/phy/stats/tests/test_clusters.py +++ b/phy/stats/tests/test_clusters.py @@ -16,11 +16,12 @@ mean_probe_position, sorted_main_channels, mean_masked_features_distance, - max_mean_masks, + max_waveform_amplitude, ) from phy.electrode.mea import staggered_positions from phy.io.mock import (artificial_features, artificial_masks, + artificial_waveforms, ) @@ -38,6 +39,11 @@ def n_spikes(): yield 50 +@yield_fixture +def n_samples(): + yield 40 + + @yield_fixture def n_features_per_channel(): yield 4 @@ -53,6 +59,11 @@ def masks(n_spikes, n_channels): yield artificial_masks(n_spikes, n_channels) +@yield_fixture +def waveforms(n_spikes, n_samples, n_channels): + yield artificial_waveforms(n_spikes, n_samples, n_channels) + + @yield_fixture def site_positions(n_channels): yield staggered_positions(n_channels) @@ -97,6 +108,20 @@ def test_sorted_main_channels(masks): assert np.all(np.in1d(channels, [5, 7])) +def test_max_waveform_amplitude(masks, waveforms): + waveforms *= .1 + masks *= .1 + + waveforms[:, 10, :] *= 10 + masks[:, 10] *= 10 + + mean_waveforms = mean(waveforms) + mean_masks = mean(masks) + + amplitude = max_waveform_amplitude(mean_masks, mean_waveforms) + assert amplitude > 0 + + def test_mean_masked_features_distance(features, n_channels, n_features_per_channel, @@ -116,9 +141,3 @@ def test_mean_masked_features_distance(features, d_computed = mean_masked_features_distance(f0, f1, m0, m1, n_features_per_channel) ac(d_expected, d_computed) - - -def test_max_mean_masks(masks): - mean_masks = mean(masks) - mmm = max_mean_masks(mean_masks) - assert mmm > .4 From 5b0bde41eac7b3c7d14ce765c230d321d452a797 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 12:11:33 +0200 Subject: [PATCH 0293/1059] Increase coverage in gui_plugin --- phy/cluster/manual/tests/test_gui_plugin.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index ddd1e24d0..6957da531 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -260,15 +260,17 @@ def test_manual_clustering_split(manual_clustering): assert_selection(31, 20) -def test_manual_clustering_move(manual_clustering): +def test_manual_clustering_move(manual_clustering, quality, similarity): mc, assert_selection = manual_clustering mc.actions.select([30]) assert_selection(30) - # TODO: set quality and similarity functions - # mc.actions.next_by_quality() - # assert_selection(20) + mc.wizard.set_quality_function(quality) + mc.wizard.set_similarity_function(similarity) - # mc.actions.move([20], 'noise') - # assert_selection(2) + mc.actions.next_by_quality() + assert_selection(20) + + mc.actions.move([20], 'noise') + assert_selection(2) From a9b6ab59ad7a8c85ed05cf8ed71c5a105d2d03fd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 13:49:54 +0200 Subject: [PATCH 0294/1059] Add empty phy CLI tool --- phy/cli.py | 23 +++++++++++++++++++++++ phy/tests/__init__.py | 0 phy/tests/test_cli.py | 32 ++++++++++++++++++++++++++++++++ phy/utils/plugin.py | 3 +++ phy/utils/tests/test_plugin.py | 1 + setup.py | 1 + 6 files changed, 60 insertions(+) create mode 100644 phy/cli.py create mode 100644 phy/tests/__init__.py create mode 100644 phy/tests/test_cli.py diff --git a/phy/cli.py b/phy/cli.py new file mode 100644 index 000000000..0d67496f4 --- /dev/null +++ b/phy/cli.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# flake8: noqa + +"""CLI tool.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import click +import phy + + +#------------------------------------------------------------------------------ +# CLI tool +#------------------------------------------------------------------------------ + +@click.command() +@click.version_option(version=phy.__version_git__) +@click.help_option() +def phy(): + return 0 diff --git a/phy/tests/__init__.py b/phy/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/phy/tests/test_cli.py b/phy/tests/test_cli.py new file mode 100644 index 000000000..72d81b04a --- /dev/null +++ b/phy/tests/test_cli.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# flake8: noqa + +"""Test CLI tool.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from click.testing import CliRunner + +from ..cli import phy + + +#------------------------------------------------------------------------------ +# Test CLI tool +#------------------------------------------------------------------------------ + +def test_cli(): + runner = CliRunner() + + result = runner.invoke(phy, []) + assert result.exit_code == 0 + + result = runner.invoke(phy, ['--version']) + assert result.exit_code == 0 + assert result.output.startswith('phy,') + + result = runner.invoke(phy, ['--help']) + assert result.exit_code == 0 + assert result.output.startswith('Usage: phy') diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 7e2eb8628..71c3acc02 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -40,6 +40,9 @@ class IPlugin(with_metaclass(IPluginRegistry)): def attach_to_gui(self, gui): pass + def attach_to_cli(self, cli): + pass + def get_plugin(name): """Get a plugin class from its name.""" diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index ec1e08f74..6e0e4954b 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -43,6 +43,7 @@ class MyPlugin(IPlugin): with raises(ValueError): get_plugin('unknown') + get_plugin('myplugin')().attach_to_cli(None) get_plugin('myplugin')().attach_to_gui(None) diff --git a/setup.py b/setup.py index 588e2950f..22bdbc7d5 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ def _package_tree(pkgroot): }, entry_points={ 'console_scripts': [ + 'phy = phy.cli:phy' ], }, include_package_data=True, From 7cad552d5fa96878626e3809d8b573a835055565 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 14:56:57 +0200 Subject: [PATCH 0295/1059] Add traitlets config utilities --- phy/utils/_misc.py | 7 +++++++ phy/utils/tests/test_misc.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index a2c4f075c..10354e9a2 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -14,6 +14,7 @@ import sys import subprocess +from traitlets.config import PyFileConfigLoader import numpy as np from six import string_types, exec_ from six.moves import builtins @@ -117,6 +118,12 @@ def _read_python(path): return metadata +def _load_config(path): + dirpath, filename = op.split(path) + config = PyFileConfigLoader(filename, dirpath).load_config() + return config + + def _is_interactive(): # pragma: no cover """Determine whether the user has requested interactive mode.""" # The Python interpreter sets sys.flags correctly, so use them! diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index fdde63296..71c89f66e 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -9,13 +9,18 @@ import os import os.path as op import subprocess +from textwrap import dedent import numpy as np from numpy.testing import assert_array_equal as ae from pytest import raises from six import string_types +from traitlets import Float +from traitlets.config import Configurable + from .._misc import (_git_version, _load_json, _save_json, + _load_config, _encode_qbytearray, _decode_qbytearray, ) @@ -77,6 +82,36 @@ def test_json_numpy(tempdir): assert d['b'] == d_bis['b'] +def test_load_config(tempdir): + path = op.join(tempdir, 'config.py') + + class MyConfigurable(Configurable): + my_var = Float(0.0, config=True) + + assert MyConfigurable().my_var == 0.0 + + # Create and load a config file. + config_contents = dedent(""" + c = get_config() + + c.MyConfigurable.my_var = 1.0 + """) + + with open(path, 'w') as f: + f.write(config_contents) + + c = _load_config(path) + assert c.MyConfigurable.my_var == 1.0 + + # Create a new MyConfigurable instance. + configurable = MyConfigurable() + assert configurable.my_var == 0.0 + + # Load the config object. + configurable.update_config(c) + assert configurable.my_var == 1.0 + + def test_git_version(): v = _git_version() From 4f4599d8ba590a41ee06e43fc4d0064b28b0423d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 15:21:03 +0200 Subject: [PATCH 0296/1059] Load master config --- phy/utils/_misc.py | 30 +++++++++++---- phy/utils/tests/test_misc.py | 75 +++++++++++++++++++++++++++--------- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 10354e9a2..4727b7945 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -14,13 +14,15 @@ import sys import subprocess -from traitlets.config import PyFileConfigLoader +from traitlets.config import Config, PyFileConfigLoader import numpy as np from six import string_types, exec_ from six.moves import builtins from ._types import _is_integer +PHY_USER_DIR = op.expanduser('~/.phy/') + #------------------------------------------------------------------------------ # JSON utility functions @@ -103,6 +105,26 @@ def _save_json(path, data): json.dump(data, f, cls=_CustomEncoder, indent=2) +#------------------------------------------------------------------------------ +# traitlets config +#------------------------------------------------------------------------------ + +def _load_config(path): + path = op.realpath(path) + dirpath, filename = op.split(path) + config = PyFileConfigLoader(filename, dirpath).load_config() + return config + + +def load_master_config(): + """Load a master Config file from `~/.phy/phy_config.py`.""" + c = Config() + paths = [op.join(PHY_USER_DIR, 'phy_config.py')] + for path in paths: + c.update(_load_config(path)) + return c + + #------------------------------------------------------------------------------ # Various Python utility functions #------------------------------------------------------------------------------ @@ -118,12 +140,6 @@ def _read_python(path): return metadata -def _load_config(path): - dirpath, filename = op.split(path) - config = PyFileConfigLoader(filename, dirpath).load_config() - return config - - def _is_interactive(): # pragma: no cover """Determine whether the user has requested interactive mode.""" # The Python interpreter sets sys.flags correctly, so use them! diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 71c89f66e..aa6ec0680 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -13,20 +13,32 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import raises +from pytest import raises, yield_fixture from six import string_types from traitlets import Float from traitlets.config import Configurable - -from .._misc import (_git_version, _load_json, _save_json, - _load_config, +from .._misc import (_git_version, _load_json, _save_json, _read_python, + _load_config, load_master_config, _encode_qbytearray, _decode_qbytearray, ) +from .. import _misc + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def temp_user_dir(tempdir): + user_dir = _misc.PHY_USER_DIR + _misc.PHY_USER_DIR = tempdir + yield tempdir + _misc.PHY_USER_DIR = user_dir #------------------------------------------------------------------------------ -# Tests +# Misc tests #------------------------------------------------------------------------------ def test_qbytearray(tempdir): @@ -82,8 +94,34 @@ def test_json_numpy(tempdir): assert d['b'] == d_bis['b'] +def test_read_python(tempdir): + path = op.join(tempdir, 'mock.py') + with open(path, 'w') as f: + f.write("""a = {'b': 1}""") + + assert _read_python(path) == {'a': {'b': 1}} + + +def test_git_version(): + v = _git_version() + + # If this test file is tracked by git, then _git_version() should succeed + filedir, _ = op.split(__file__) + try: + fnull = open(os.devnull, 'w') + subprocess.check_output(['git', '-C', filedir, 'status'], + stderr=fnull) + assert v is not "", "git_version failed to return" + assert v[:5] == "-git-", "Git version does not begin in -git-" + except (OSError, subprocess.CalledProcessError): # pragma: no cover + assert v == "" + + +#------------------------------------------------------------------------------ +# Config tests +#------------------------------------------------------------------------------ + def test_load_config(tempdir): - path = op.join(tempdir, 'config.py') class MyConfigurable(Configurable): my_var = Float(0.0, config=True) @@ -97,6 +135,7 @@ class MyConfigurable(Configurable): c.MyConfigurable.my_var = 1.0 """) + path = op.join(tempdir, 'config.py') with open(path, 'w') as f: f.write(config_contents) @@ -112,16 +151,16 @@ class MyConfigurable(Configurable): assert configurable.my_var == 1.0 -def test_git_version(): - v = _git_version() +def test_load_master_config(temp_user_dir): + # Create a config file in the temporary user directory. + config_contents = dedent(""" + c = get_config() - # If this test file is tracked by git, then _git_version() should succeed - filedir, _ = op.split(__file__) - try: - fnull = open(os.devnull, 'w') - subprocess.check_output(['git', '-C', filedir, 'status'], - stderr=fnull) - assert v is not "", "git_version failed to return" - assert v[:5] == "-git-", "Git version does not begin in -git-" - except (OSError, subprocess.CalledProcessError): # pragma: no cover - assert v == "" + c.MyConfigurable.my_var = 1.0 + """) + with open(op.join(temp_user_dir, 'phy_config.py'), 'w') as f: + f.write(config_contents) + + # Load the master config file. + c = load_master_config() + assert c.MyConfigurable.my_var == 1. From 7aa32e2b5299beb558f71f763f2f234b437a690c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 15:24:42 +0200 Subject: [PATCH 0297/1059] Move phy user dir fixture to conftest --- phy/utils/tests/conftest.py | 23 +++++++++++++++++++++++ phy/utils/tests/test_misc.py | 13 ------------- 2 files changed, 23 insertions(+), 13 deletions(-) create mode 100644 phy/utils/tests/conftest.py diff --git a/phy/utils/tests/conftest.py b/phy/utils/tests/conftest.py new file mode 100644 index 000000000..897c047dc --- /dev/null +++ b/phy/utils/tests/conftest.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +"""py.test fixtures.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture + +from phy.utils import _misc + + +#------------------------------------------------------------------------------ +# Common fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def temp_user_dir(tempdir): + user_dir = _misc.PHY_USER_DIR + _misc.PHY_USER_DIR = tempdir + yield tempdir + _misc.PHY_USER_DIR = user_dir diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index aa6ec0680..d377dd958 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -22,19 +22,6 @@ _load_config, load_master_config, _encode_qbytearray, _decode_qbytearray, ) -from .. import _misc - - -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -@yield_fixture -def temp_user_dir(tempdir): - user_dir = _misc.PHY_USER_DIR - _misc.PHY_USER_DIR = tempdir - yield tempdir - _misc.PHY_USER_DIR = user_dir #------------------------------------------------------------------------------ From fc7e3939f26727266798f5c6fa19be0b91a6032d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 15:44:03 +0200 Subject: [PATCH 0298/1059] WIP: plugins and config --- phy/utils/_misc.py | 2 ++ phy/{ => utils}/cli.py | 0 phy/utils/plugin.py | 44 +++++++++++++++++++++++-- phy/{ => utils}/tests/test_cli.py | 0 phy/utils/tests/test_misc.py | 1 - phy/utils/tests/test_plugin.py | 53 +++++++++++++++++++++++++++++-- 6 files changed, 94 insertions(+), 6 deletions(-) rename phy/{ => utils}/cli.py (100%) rename phy/{ => utils}/tests/test_cli.py (100%) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 4727b7945..c30ff859d 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -110,6 +110,8 @@ def _save_json(path, data): #------------------------------------------------------------------------------ def _load_config(path): + if not op.exists(path): + return {} path = op.realpath(path) dirpath, filename = op.split(path) config = PyFileConfigLoader(filename, dirpath).load_config() diff --git a/phy/cli.py b/phy/utils/cli.py similarity index 100% rename from phy/cli.py rename to phy/utils/cli.py diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 71c3acc02..28f88ddfe 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -17,6 +17,10 @@ import os.path as op from six import with_metaclass +from traitlets import List, Unicode +from traitlets.config import Configurable + +from ._misc import load_master_config, PHY_USER_DIR logger = logging.getLogger(__name__) @@ -92,6 +96,42 @@ def discover_plugins(dirs): file, path, descr = imp.find_module(modname, [subdir]) if file: # Loading the module registers the plugin in - # IPluginRegistry - mod = imp.load_module(modname, file, path, descr) # noqa + # IPluginRegistry. + try: + mod = imp.load_module(modname, file, + path, descr) # noqa + except Exception as e: + logger.exception(e) return IPluginRegistry.plugins + + +class Plugins(Configurable): + """Configure the list of user plugin directories. + + By default, there is only `~/.phy/plugins/`. + + """ + dirs = List(Unicode, + default_value=[op.expanduser(op.join(PHY_USER_DIR, + 'plugins/'))], + config=True, + ) + + +def get_all_plugins(): + """Load all builtin and user plugins.""" + + # Builtin plugins. + builtin_plugins_dir = [op.realpath(op.join(op.dirname(__file__), + '../plugins/'))] + + # Load the plugin dirs from all config files. + plugins_config = Plugins() + c = load_master_config() + plugins_config.update_config(c) + + # Add the builtin dirs. + dirs = builtin_plugins_dir + plugins_config.dirs + + # Return all loaded plugins. + return [plugin for (plugin,) in discover_plugins(dirs)] diff --git a/phy/tests/test_cli.py b/phy/utils/tests/test_cli.py similarity index 100% rename from phy/tests/test_cli.py rename to phy/utils/tests/test_cli.py diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index d377dd958..7ba965623 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -142,7 +142,6 @@ def test_load_master_config(temp_user_dir): # Create a config file in the temporary user directory. config_contents = dedent(""" c = get_config() - c.MyConfigurable.my_var = 1.0 """) with open(op.join(temp_user_dir, 'phy_config.py'), 'w') as f: diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 6e0e4954b..da4e0d9b8 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -7,14 +7,21 @@ # Imports #------------------------------------------------------------------------------ +import os import os.path as op +from textwrap import dedent -from ..plugin import (IPluginRegistry, IPlugin, get_plugin, +from traitlets import List, Unicode +from pytest import yield_fixture, raises + +from ..plugin import (IPluginRegistry, + IPlugin, + Plugins, + get_plugin, discover_plugins, + get_all_plugins, ) -from pytest import yield_fixture, raises - #------------------------------------------------------------------------------ # Fixtures @@ -56,3 +63,43 @@ def test_discover_plugins(tempdir, no_native_plugins): plugins = discover_plugins([tempdir]) assert plugins assert plugins[0][0].__name__ == 'MyPlugin' + + +def test_get_all_plugins(temp_user_dir): + + n_builtin_plugins = 0 + + plugins = get_all_plugins() + assert len(plugins) == n_builtin_plugins + + plugin_contents = dedent(""" + from phy import IPlugin + class MyPlugin(IPlugin): + pass + """) + + # Create a plugin in some directory. + os.mkdir(op.join(temp_user_dir, 'myplugins/')) + with open(op.join(temp_user_dir, 'myplugins/myplugin.py'), 'w') as f: + f.write(plugin_contents) + + # By default, this directory has no reason to be scanned, and the + # plugin is not loaded. + plugins = get_all_plugins() + assert len(plugins) == n_builtin_plugins + + # Specify the path to the plugin in the phy config file.. + config_contents = dedent(""" + c = get_config() + c.Plugins.dirs = ['%s'] + """ % op.join(temp_user_dir, 'myplugins/')) + with open(op.join(temp_user_dir, 'phy_config.py'), 'w') as f: + f.write(config_contents) + + # Now, reload all plugins. + plugins = get_all_plugins() + + # This time, the plugin will be found. + assert len(plugins) == n_builtin_plugins + 1 + p = plugins[-1] + assert p.__name__ == 'MyPlugin' From 15b12f2bbeb50b2c21c798c8c1aeb318ed261969 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 15:44:19 +0200 Subject: [PATCH 0299/1059] WIP --- phy/tests/__init__.py | 0 phy/utils/cli.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+) delete mode 100644 phy/tests/__init__.py diff --git a/phy/tests/__init__.py b/phy/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 0d67496f4..287dac90b 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -9,7 +9,9 @@ #------------------------------------------------------------------------------ import click + import phy +from phy.plugins import get_all_plugins #------------------------------------------------------------------------------ @@ -21,3 +23,20 @@ @click.help_option() def phy(): return 0 + + +#------------------------------------------------------------------------------ +# CLI plugins +#------------------------------------------------------------------------------ + +def load_cli_plugins(cli): + """Load all plugins and attach them to a CLI object.""" + plugins = get_all_plugins() + # TODO: try/except to avoid crashing if a plugin is broken. + for plugin in plugins: + # NOTE: plugin is a class, so we need to instantiate it. + plugin().attach_to_cli(cli) + + +# Load all plugins for the phy CLI. +load_cli_plugins(phy) From c907792cd0a94998ac85d27da18c99558ab4ffb4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 17:01:27 +0200 Subject: [PATCH 0300/1059] Add _write_text function --- phy/utils/_misc.py | 20 +++++++++++++++++--- phy/utils/tests/test_misc.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index c30ff859d..2bcc1d907 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -13,6 +13,7 @@ import os import sys import subprocess +from textwrap import dedent from traitlets.config import Config, PyFileConfigLoader import numpy as np @@ -21,8 +22,6 @@ from ._types import _is_integer -PHY_USER_DIR = op.expanduser('~/.phy/') - #------------------------------------------------------------------------------ # JSON utility functions @@ -121,7 +120,7 @@ def _load_config(path): def load_master_config(): """Load a master Config file from `~/.phy/phy_config.py`.""" c = Config() - paths = [op.join(PHY_USER_DIR, 'phy_config.py')] + paths = [op.join(phy_user_dir(), 'phy_config.py')] for path in paths: c.update(_load_config(path)) return c @@ -131,6 +130,10 @@ def load_master_config(): # Various Python utility functions #------------------------------------------------------------------------------ +def phy_user_dir(): + return op.expanduser('~/.phy/') + + def _read_python(path): path = op.realpath(op.expanduser(path)) assert op.exists(path) @@ -142,6 +145,17 @@ def _read_python(path): return metadata +def _write_text(path, contents, *args, **kwargs): + contents = dedent(contents.format(*args, **kwargs)) + dir_path = op.dirname(path) + if not op.exists(dir_path): + os.mkdir(dir_path) + assert op.isdir(dir_path) + assert not op.exists(path) + with open(path, 'w') as f: + f.write(contents) + + def _is_interactive(): # pragma: no cover """Determine whether the user has requested interactive mode.""" # The Python interpreter sets sys.flags correctly, so use them! diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 7ba965623..8b91b1318 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -13,12 +13,13 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import raises, yield_fixture +from pytest import raises from six import string_types from traitlets import Float from traitlets.config import Configurable from .._misc import (_git_version, _load_json, _save_json, _read_python, + _write_text, _load_config, load_master_config, _encode_qbytearray, _decode_qbytearray, ) @@ -89,6 +90,15 @@ def test_read_python(tempdir): assert _read_python(path) == {'a': {'b': 1}} +def test_write_text(tempdir): + for path in (op.join(tempdir, 'test_1'), + op.join(tempdir, 'test_dir/test_2.txt'), + ): + _write_text(path, 'hello world') + with open(path, 'r') as f: + assert f.read() == 'hello world' + + def test_git_version(): v = _git_version() From 649ece8bd79cb37f5450e33cb6ba1c8a983e9888 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 17:08:09 +0200 Subject: [PATCH 0301/1059] WIP: phy user dir --- phy/utils/tests/conftest.py | 26 +++++++++++++++++++++----- phy/utils/tests/test_misc.py | 5 +++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/phy/utils/tests/conftest.py b/phy/utils/tests/conftest.py index 897c047dc..22111e355 100644 --- a/phy/utils/tests/conftest.py +++ b/phy/utils/tests/conftest.py @@ -8,8 +8,6 @@ from pytest import yield_fixture -from phy.utils import _misc - #------------------------------------------------------------------------------ # Common fixtures @@ -17,7 +15,25 @@ @yield_fixture def temp_user_dir(tempdir): - user_dir = _misc.PHY_USER_DIR - _misc.PHY_USER_DIR = tempdir + """NOTE: the user directory should be loaded with: + + ```python + from .. import _misc + _misc.phy_user_dir() + ``` + + and not: + + ```python + from _misc import phy_user_dir + ``` + + Otherwise, the monkey patching hack in tests won't work. + + """ + from phy.utils import _misc + + user_dir = _misc.phy_user_dir + _misc.phy_user_dir = lambda: tempdir yield tempdir - _misc.PHY_USER_DIR = user_dir + _misc.phy_user_dir = user_dir diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 8b91b1318..0fd5a53b9 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -23,6 +23,7 @@ _load_config, load_master_config, _encode_qbytearray, _decode_qbytearray, ) +from .. import _misc #------------------------------------------------------------------------------ @@ -99,6 +100,10 @@ def test_write_text(tempdir): assert f.read() == 'hello world' +def test_temp_user_dir(temp_user_dir): + assert _misc.phy_user_dir() == temp_user_dir + + def test_git_version(): v = _git_version() From b3767cebb9b1d41d7366fe596d5141e7d30f5472 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 17:41:56 +0200 Subject: [PATCH 0302/1059] WIP: config and plugins --- phy/utils/_misc.py | 5 ++- phy/utils/plugin.py | 38 +++++----------- phy/utils/tests/test_plugin.py | 80 +++++++++++++++++++--------------- 3 files changed, 60 insertions(+), 63 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 2bcc1d907..7f973a0aa 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -117,10 +117,11 @@ def _load_config(path): return config -def load_master_config(): +def load_master_config(user_dir=None): """Load a master Config file from `~/.phy/phy_config.py`.""" + user_dir = user_dir or phy_user_dir() c = Config() - paths = [op.join(phy_user_dir(), 'phy_config.py')] + paths = [op.join(user_dir, 'phy_config.py')] for path in paths: c.update(_load_config(path)) return c diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 28f88ddfe..46b3e6584 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -17,10 +17,8 @@ import os.path as op from six import with_metaclass -from traitlets import List, Unicode -from traitlets.config import Configurable -from ._misc import load_master_config, PHY_USER_DIR +from . import _misc logger = logging.getLogger(__name__) @@ -105,33 +103,19 @@ def discover_plugins(dirs): return IPluginRegistry.plugins -class Plugins(Configurable): - """Configure the list of user plugin directories. +def _builtin_plugins_dir(): + return op.realpath(op.join(op.dirname(__file__), '../plugins/')) - By default, there is only `~/.phy/plugins/`. - """ - dirs = List(Unicode, - default_value=[op.expanduser(op.join(PHY_USER_DIR, - 'plugins/'))], - config=True, - ) +def _user_plugins_dir(): + return op.expanduser(op.join(_misc.phy_user_dir(), 'plugins/')) -def get_all_plugins(): +def get_all_plugins(config=None): """Load all builtin and user plugins.""" - - # Builtin plugins. - builtin_plugins_dir = [op.realpath(op.join(op.dirname(__file__), - '../plugins/'))] - - # Load the plugin dirs from all config files. - plugins_config = Plugins() - c = load_master_config() - plugins_config.update_config(c) - - # Add the builtin dirs. - dirs = builtin_plugins_dir + plugins_config.dirs - - # Return all loaded plugins. + # By default, builtin and default user plugin. + dirs = [_builtin_plugins_dir(), _user_plugins_dir()] + # Add Plugins.dir from the optionally-passed config object. + if config: + dirs += config.Plugins.dirs return [plugin for (plugin,) in discover_plugins(dirs)] diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index da4e0d9b8..637385381 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -11,16 +11,15 @@ import os.path as op from textwrap import dedent -from traitlets import List, Unicode from pytest import yield_fixture, raises from ..plugin import (IPluginRegistry, IPlugin, - Plugins, get_plugin, discover_plugins, get_all_plugins, ) +from .._misc import _write_text, load_master_config #------------------------------------------------------------------------------ @@ -36,6 +35,31 @@ def no_native_plugins(): IPluginRegistry.plugins = plugins +@yield_fixture(params=[(False, 'my_plugins/plugin.py'), + (True, 'plugins/plugin.py'), + ]) +def plugin(no_native_plugins, temp_user_dir, request): + path = op.join(temp_user_dir, request.param[1]) + contents = """ + from phy import IPlugin + class MyPlugin(IPlugin): + pass + """ + _write_text(path, contents) + yield temp_user_dir, request.param[0], request.param[1] + + +def _write_my_plugins_dir_in_config(temp_user_dir): + # Now, we specify the path to the plugin in the phy config file. + config_contents = """ + c = get_config() + c.Plugins.dirs = ['{}'] + """ + _write_text(op.join(temp_user_dir, 'phy_config.py'), + config_contents, + op.join(temp_user_dir, 'my_plugins/')) + + #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ @@ -57,49 +81,37 @@ class MyPlugin(IPlugin): def test_discover_plugins(tempdir, no_native_plugins): path = op.join(tempdir, 'my_plugin.py') contents = '''from phy import IPlugin\nclass MyPlugin(IPlugin): pass''' - with open(path, 'w') as f: - f.write(contents) + _write_text(path, contents) plugins = discover_plugins([tempdir]) assert plugins assert plugins[0][0].__name__ == 'MyPlugin' -def test_get_all_plugins(temp_user_dir): - +def test_get_all_plugins(plugin): + temp_user_dir, in_default_dir, path = plugin n_builtin_plugins = 0 plugins = get_all_plugins() - assert len(plugins) == n_builtin_plugins - plugin_contents = dedent(""" - from phy import IPlugin - class MyPlugin(IPlugin): - pass - """) - - # Create a plugin in some directory. - os.mkdir(op.join(temp_user_dir, 'myplugins/')) - with open(op.join(temp_user_dir, 'myplugins/myplugin.py'), 'w') as f: - f.write(plugin_contents) + def _assert_loaded(): + assert len(plugins) == n_builtin_plugins + 1 + p = plugins[-1] + assert p.__name__ == 'MyPlugin' - # By default, this directory has no reason to be scanned, and the - # plugin is not loaded. - plugins = get_all_plugins() - assert len(plugins) == n_builtin_plugins + if in_default_dir: + # Create a plugin in the default plugins directory: it will be + # discovered and automatically loaded by get_all_plugins(). + _assert_loaded() + else: + assert len(plugins) == n_builtin_plugins - # Specify the path to the plugin in the phy config file.. - config_contents = dedent(""" - c = get_config() - c.Plugins.dirs = ['%s'] - """ % op.join(temp_user_dir, 'myplugins/')) - with open(op.join(temp_user_dir, 'phy_config.py'), 'w') as f: - f.write(config_contents) + # This time, we write the custom plugins path in the config file. + _write_my_plugins_dir_in_config(temp_user_dir) - # Now, reload all plugins. - plugins = get_all_plugins() + # We reload all plugins with the master config object. + config = load_master_config() + plugins = get_all_plugins(config) - # This time, the plugin will be found. - assert len(plugins) == n_builtin_plugins + 1 - p = plugins[-1] - assert p.__name__ == 'MyPlugin' + # This time, the plugin should be found. + _assert_loaded() From c0d9a58cac8090132518a22e5d80ee72cc021946 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 18:48:22 +0200 Subject: [PATCH 0303/1059] WIP: CLI plugins --- phy/utils/cli.py | 22 +++++++++++++++------- phy/utils/plugin.py | 2 +- phy/utils/tests/test_cli.py | 37 +++++++++++++++++++++++++++++++++++-- setup.py | 2 +- 4 files changed, 52 insertions(+), 11 deletions(-) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 287dac90b..37c95d69b 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -8,21 +8,25 @@ # Imports #------------------------------------------------------------------------------ +import logging + import click import phy -from phy.plugins import get_all_plugins + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ # CLI tool #------------------------------------------------------------------------------ -@click.command() +@click.group() @click.version_option(version=phy.__version_git__) @click.help_option() -def phy(): - return 0 +@click.pass_context +def phy(ctx): + pass #------------------------------------------------------------------------------ @@ -31,12 +35,16 @@ def phy(): def load_cli_plugins(cli): """Load all plugins and attach them to a CLI object.""" - plugins = get_all_plugins() + from ._misc import load_master_config + from .plugin import get_all_plugins + + config = load_master_config() + plugins = get_all_plugins(config) + # TODO: try/except to avoid crashing if a plugin is broken. for plugin in plugins: + logger.info("Attach plugin `%s`.", plugin.__name__) # NOTE: plugin is a class, so we need to instantiate it. plugin().attach_to_cli(cli) - -# Load all plugins for the phy CLI. load_cli_plugins(phy) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 46b3e6584..be90178ff 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -88,7 +88,7 @@ def discover_plugins(dirs): if (filename.startswith('__') or not filename.endswith('.py')): continue # pragma: no cover - logger.debug(" Found %s.", filename) + logger.debug("Found %s.", filename) path = os.path.join(subdir, filename) modname, ext = op.splitext(filename) file, path, descr = imp.find_module(modname, [subdir]) diff --git a/phy/utils/tests/test_cli.py b/phy/utils/tests/test_cli.py index 72d81b04a..7a1e2b338 100644 --- a/phy/utils/tests/test_cli.py +++ b/phy/utils/tests/test_cli.py @@ -8,16 +8,22 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op + from click.testing import CliRunner -from ..cli import phy +from .._misc import _write_text #------------------------------------------------------------------------------ # Test CLI tool #------------------------------------------------------------------------------ -def test_cli(): +def test_cli_empty(temp_user_dir): + # NOTE: make the import after the temp_user_dir fixture, to avoid + # loading any user plugin affecting the CLI. + from ..cli import phy + runner = CliRunner() result = runner.invoke(phy, []) @@ -30,3 +36,30 @@ def test_cli(): result = runner.invoke(phy, ['--help']) assert result.exit_code == 0 assert result.output.startswith('Usage: phy') + + +def test_cli_plugins(temp_user_dir): + + # Write a CLI plugin. + cli_plugin = """ + import click + from phy import IPlugin + + class MyPlugin(IPlugin): + def attach_to_cli(self, cli): + @cli.command() + def hello(): + click.echo("hello world") + """ + path = op.join(temp_user_dir, 'plugins/hello.py') + _write_text(path, cli_plugin) + + runner = CliRunner() + + # NOTE: make the import after the temp_user_dir fixture, to avoid + # loading any user plugin affecting the CLI. + from ..cli import phy + + result = runner.invoke(phy, ['--help']) + assert result.exit_code == 0 + assert 'hello' in result.output diff --git a/setup.py b/setup.py index 22bdbc7d5..0dd232d4f 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ def _package_tree(pkgroot): }, entry_points={ 'console_scripts': [ - 'phy = phy.cli:phy' + 'phy = phy.utils.cli:phy' ], }, include_package_data=True, From ef0511918edb411d0043d4de8712c5f5fa3435be Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 19:05:56 +0200 Subject: [PATCH 0304/1059] Updated CLI tests --- phy/utils/tests/test_cli.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/phy/utils/tests/test_cli.py b/phy/utils/tests/test_cli.py index 7a1e2b338..a648f286a 100644 --- a/phy/utils/tests/test_cli.py +++ b/phy/utils/tests/test_cli.py @@ -11,6 +11,7 @@ import os.path as op from click.testing import CliRunner +from pytest import yield_fixture from .._misc import _write_text @@ -19,12 +20,17 @@ # Test CLI tool #------------------------------------------------------------------------------ -def test_cli_empty(temp_user_dir): +@yield_fixture +def runner(): + yield CliRunner() + + +def test_cli_empty(temp_user_dir, runner): + # NOTE: make the import after the temp_user_dir fixture, to avoid # loading any user plugin affecting the CLI. - from ..cli import phy - - runner = CliRunner() + from ..cli import phy, load_cli_plugins + load_cli_plugins(phy) result = runner.invoke(phy, []) assert result.exit_code == 0 @@ -38,7 +44,7 @@ def test_cli_empty(temp_user_dir): assert result.output.startswith('Usage: phy') -def test_cli_plugins(temp_user_dir): +def test_cli_plugins(temp_user_dir, runner): # Write a CLI plugin. cli_plugin = """ @@ -54,12 +60,17 @@ def hello(): path = op.join(temp_user_dir, 'plugins/hello.py') _write_text(path, cli_plugin) - runner = CliRunner() - # NOTE: make the import after the temp_user_dir fixture, to avoid # loading any user plugin affecting the CLI. - from ..cli import phy + from ..cli import phy, load_cli_plugins + load_cli_plugins(phy) + # The plugin should have added a new command. result = runner.invoke(phy, ['--help']) assert result.exit_code == 0 assert 'hello' in result.output + + # The plugin should have added a new command. + result = runner.invoke(phy, ['hello']) + assert result.exit_code == 0 + assert result.output == 'hello world\n' From 364aa003a86969b3a78ed0570d068b556976d2ec Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 19:27:59 +0200 Subject: [PATCH 0305/1059] Bug fixes --- phy/gui/gui.py | 3 ++- phy/utils/cli.py | 2 ++ phy/utils/plugin.py | 16 +++++++++------- phy/utils/tests/test_misc.py | 4 ++++ phy/utils/tests/test_plugin.py | 5 ----- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 9431e6b60..eee58898f 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -92,7 +92,8 @@ def attach(self, plugin, *args, **kwargs): if isinstance(plugin, string_types): # Instantiate the plugin if the name is given. plugin = get_plugin(plugin)(*args, **kwargs) - return plugin.attach_to_gui(self) + if hasattr(plugin, 'attach_to_gui'): + return plugin.attach_to_gui(self) # Events # ------------------------------------------------------------------------- diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 37c95d69b..b7c2aecf4 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -43,6 +43,8 @@ def load_cli_plugins(cli): # TODO: try/except to avoid crashing if a plugin is broken. for plugin in plugins: + if not hasattr(plugin, 'attach_to_cli'): + continue logger.info("Attach plugin `%s`.", plugin.__name__) # NOTE: plugin is a class, so we need to instantiate it. plugin().attach_to_cli(cli) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index be90178ff..b2d40e4d9 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -39,11 +39,13 @@ def __init__(cls, name, bases, attrs): class IPlugin(with_metaclass(IPluginRegistry)): - def attach_to_gui(self, gui): - pass + """A class deriving from IPlugin can implement the following methods: - def attach_to_cli(self, cli): - pass + * `attach_to_gui(gui)`: called when the plugin is attached to a GUI. + * `attach_to_cli(cli)`: called when the CLI is created. + + """ + pass def get_plugin(name): @@ -96,9 +98,9 @@ def discover_plugins(dirs): # Loading the module registers the plugin in # IPluginRegistry. try: - mod = imp.load_module(modname, file, - path, descr) # noqa - except Exception as e: + mod = imp.load_module(modname, file, # noqa + path, descr) + except Exception as e: # pragma: no cover logger.exception(e) return IPluginRegistry.plugins diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 0fd5a53b9..72be5fb57 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -100,6 +100,10 @@ def test_write_text(tempdir): assert f.read() == 'hello world' +def test_phy_user_dir(): + assert _misc.phy_user_dir().endswith('.phy/') + + def test_temp_user_dir(temp_user_dir): assert _misc.phy_user_dir() == temp_user_dir diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 637385381..a14a97aa2 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -7,9 +7,7 @@ # Imports #------------------------------------------------------------------------------ -import os import os.path as op -from textwrap import dedent from pytest import yield_fixture, raises @@ -74,9 +72,6 @@ class MyPlugin(IPlugin): with raises(ValueError): get_plugin('unknown') - get_plugin('myplugin')().attach_to_cli(None) - get_plugin('myplugin')().attach_to_gui(None) - def test_discover_plugins(tempdir, no_native_plugins): path = op.join(tempdir, 'my_plugin.py') From 74eaaed0635df64c86f09f3a7e7f4dc4ee455e1f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 22:44:27 +0200 Subject: [PATCH 0306/1059] Fixing travis --- .travis.yml | 2 +- requirements.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) delete mode 100644 requirements.txt diff --git a/.travis.yml b/.travis.yml index 7dc2cf958..6a39c9c27 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,7 +21,7 @@ install: # Create the environment. - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION - source activate testenv - - conda install pip numpy=1.9 vispy matplotlib scipy h5py pyqt ipython requests six dill ipyparallel joblib dask + - conda install pip numpy=1.9 vispy matplotlib scipy h5py pyqt ipython requests six dill ipyparallel joblib dask click # NOTE: cython is only temporarily needed for building KK2. # Once we have KK2 binary builds on binstar we can remove this dependency. - conda install cython && pip install klustakwik2 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ddd42871c..000000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -wheel==0.23.0 From 840874118e41d233a61fffbfa333e0fa0161e747 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 23:30:02 +0200 Subject: [PATCH 0307/1059] Emit wizard_start event --- phy/cluster/manual/gui_plugin.py | 10 +++++----- phy/cluster/manual/tests/test_gui_plugin.py | 10 ++++++++++ phy/cluster/manual/wizard.py | 1 + 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index 760769e7d..a321c8cbf 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -206,16 +206,16 @@ def attach_to_gui(self, gui): @self.wizard.connect def on_select(cluster_ids): """When the wizard selects clusters, choose a spikes subset - and emit the `select` event on the GUI. - - The wizard is responsible for the notion of "selected clusters". - - """ + and emit the `select` event on the GUI.""" spike_ids = select_spikes(np.array(cluster_ids), self.n_spikes_max_per_cluster, self.clustering.spikes_per_cluster) gui.emit('select', cluster_ids, spike_ids) + @self.wizard.connect + def on_start(): + gui.emit('wizard_start') + # Attach the GUI and register the actions. self.snippets.attach(gui, self.actions) self.actions.attach(gui) diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index 6957da531..042c12cc1 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -269,6 +269,16 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.wizard.set_quality_function(quality) mc.wizard.set_similarity_function(similarity) + # Check that the wizard_start event is fired. + _check = [] + + @mc.gui.connect_ + def on_wizard_start(): + _check.append('wizard') + + mc.wizard.restart() + assert _check == ['wizard'] + mc.actions.next_by_quality() assert_selection(20) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 591326a49..f05422fcd 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -298,6 +298,7 @@ def next(self): self._set_selection_from_history() def restart(self): + self.emit('start') self.select([]) self.next_by_quality() From f59fa4ca818a652fe1f1393a160fbfc59ba651a3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 15 Oct 2015 23:43:22 +0200 Subject: [PATCH 0308/1059] Add wizard event tests --- phy/cluster/manual/tests/test_gui_plugin.py | 38 +++++++++++++++------ phy/cluster/manual/wizard.py | 3 +- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index 042c12cc1..4613d1820 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -211,6 +211,34 @@ def test_attach_wizard_3(wizard, cluster_ids, cluster_groups): # Test GUI plugins #------------------------------------------------------------------------------ +def test_wizard_start_1(manual_clustering): + mc, assert_selection = manual_clustering + + # Check that the wizard_start event is fired. + _check = [] + + @mc.gui.connect_ + def on_wizard_start(): + _check.append('wizard') + + mc.wizard.restart() + assert _check == ['wizard'] + + +def test_wizard_start_2(manual_clustering): + mc, assert_selection = manual_clustering + + # Check that the wizard_start event is fired. + _check = [] + + @mc.gui.connect_ + def on_wizard_start(): + _check.append('wizard') + + mc.wizard.select([1]) + assert _check == ['wizard'] + + def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering @@ -269,16 +297,6 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.wizard.set_quality_function(quality) mc.wizard.set_similarity_function(similarity) - # Check that the wizard_start event is fired. - _check = [] - - @mc.gui.connect_ - def on_wizard_start(): - _check.append('wizard') - - mc.wizard.restart() - assert _check == ['wizard'] - mc.actions.next_by_quality() assert_selection(20) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index f05422fcd..a0086ddc6 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -236,6 +236,8 @@ def select(self, cluster_ids, add_to_history=True): clusters = self.cluster_ids cluster_ids = [cluster for cluster in cluster_ids if cluster in clusters] + if not self._selection and cluster_ids: + self.emit('start') self._selection = cluster_ids if add_to_history: self._history.add(self._selection) @@ -298,7 +300,6 @@ def next(self): self._set_selection_from_history() def restart(self): - self.emit('start') self.select([]) self.next_by_quality() From 15c84162ad8f1a2c214a38f85dbd93569ab4c3a4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 07:42:07 +0200 Subject: [PATCH 0309/1059] WIP: move config from _misc to settings --- phy/io/datasets.py | 8 ++--- phy/utils/_misc.py | 28 --------------- phy/utils/cli.py | 2 +- phy/utils/plugin.py | 4 +-- phy/utils/settings.py | 29 +++++++++++---- phy/utils/tests/conftest.py | 14 ++++---- phy/utils/tests/test_misc.py | 61 -------------------------------- phy/utils/tests/test_plugin.py | 3 +- phy/utils/tests/test_settings.py | 61 ++++++++++++++++++++++++++++++-- 9 files changed, 97 insertions(+), 113 deletions(-) diff --git a/phy/io/datasets.py b/phy/io/datasets.py index 04fd9b0fd..3b9a47918 100644 --- a/phy/io/datasets.py +++ b/phy/io/datasets.py @@ -12,7 +12,7 @@ import os.path as op from phy.utils.event import ProgressReporter -from phy.utils.settings import _phy_user_dir, _ensure_dir_exists +from phy.utils.settings import phy_user_dir, _ensure_dir_exists logger = logging.getLogger(__name__) @@ -147,10 +147,10 @@ def download_file(url, output_path): return -def download_test_data(name, phy_user_dir=None, force=False): +def download_test_data(name, user_dir=None, force=False): """Download a test file.""" - phy_user_dir = phy_user_dir or _phy_user_dir() - dir = op.join(phy_user_dir, 'test_data') + user_dir = user_dir or phy_user_dir() + dir = op.join(user_dir, 'test_data') _ensure_dir_exists(dir) path = op.join(dir, name) if not force and op.exists(path): diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 7f973a0aa..514817fe3 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -15,7 +15,6 @@ import subprocess from textwrap import dedent -from traitlets.config import Config, PyFileConfigLoader import numpy as np from six import string_types, exec_ from six.moves import builtins @@ -104,37 +103,10 @@ def _save_json(path, data): json.dump(data, f, cls=_CustomEncoder, indent=2) -#------------------------------------------------------------------------------ -# traitlets config -#------------------------------------------------------------------------------ - -def _load_config(path): - if not op.exists(path): - return {} - path = op.realpath(path) - dirpath, filename = op.split(path) - config = PyFileConfigLoader(filename, dirpath).load_config() - return config - - -def load_master_config(user_dir=None): - """Load a master Config file from `~/.phy/phy_config.py`.""" - user_dir = user_dir or phy_user_dir() - c = Config() - paths = [op.join(user_dir, 'phy_config.py')] - for path in paths: - c.update(_load_config(path)) - return c - - #------------------------------------------------------------------------------ # Various Python utility functions #------------------------------------------------------------------------------ -def phy_user_dir(): - return op.expanduser('~/.phy/') - - def _read_python(path): path = op.realpath(op.expanduser(path)) assert op.exists(path) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index b7c2aecf4..947addec1 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -35,7 +35,7 @@ def phy(ctx): def load_cli_plugins(cli): """Load all plugins and attach them to a CLI object.""" - from ._misc import load_master_config + from .settings import load_master_config from .plugin import get_all_plugins config = load_master_config() diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index b2d40e4d9..c121628cc 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -18,7 +18,7 @@ from six import with_metaclass -from . import _misc +from . import settings logger = logging.getLogger(__name__) @@ -110,7 +110,7 @@ def _builtin_plugins_dir(): def _user_plugins_dir(): - return op.expanduser(op.join(_misc.phy_user_dir(), 'plugins/')) + return op.expanduser(op.join(settings.phy_user_dir(), 'plugins/')) def get_all_plugins(config=None): diff --git a/phy/utils/settings.py b/phy/utils/settings.py index c7e743cd6..2caf77a4e 100644 --- a/phy/utils/settings.py +++ b/phy/utils/settings.py @@ -11,6 +11,7 @@ import os.path as op from six import string_types +from traitlets.config import Config, PyFileConfigLoader from ._misc import _load_json, _save_json, _read_python @@ -192,16 +193,30 @@ def keys(self): # Config #------------------------------------------------------------------------------ -_PHY_USER_DIR_NAME = '.phy' - - -def _phy_user_dir(): +def phy_user_dir(): """Return the absolute path to the phy user directory.""" - home = op.expanduser("~") - path = op.realpath(op.join(home, _PHY_USER_DIR_NAME)) - return path + return op.expanduser('~/.phy/') def _ensure_dir_exists(path): if not op.exists(path): os.makedirs(path) + + +def _load_config(path): + if not op.exists(path): + return {} + path = op.realpath(path) + dirpath, filename = op.split(path) + config = PyFileConfigLoader(filename, dirpath).load_config() + return config + + +def load_master_config(user_dir=None): + """Load a master Config file from `~/.phy/phy_config.py`.""" + user_dir = user_dir or phy_user_dir() + c = Config() + paths = [op.join(user_dir, 'phy_config.py')] + for path in paths: + c.update(_load_config(path)) + return c diff --git a/phy/utils/tests/conftest.py b/phy/utils/tests/conftest.py index 22111e355..79ef0eba9 100644 --- a/phy/utils/tests/conftest.py +++ b/phy/utils/tests/conftest.py @@ -18,22 +18,22 @@ def temp_user_dir(tempdir): """NOTE: the user directory should be loaded with: ```python - from .. import _misc - _misc.phy_user_dir() + from .. import settings + settings.phy_user_dir() ``` and not: ```python - from _misc import phy_user_dir + from settings import phy_user_dir ``` Otherwise, the monkey patching hack in tests won't work. """ - from phy.utils import _misc + from phy.utils import settings - user_dir = _misc.phy_user_dir - _misc.phy_user_dir = lambda: tempdir + user_dir = settings.phy_user_dir + settings.phy_user_dir = lambda: tempdir yield tempdir - _misc.phy_user_dir = user_dir + settings.phy_user_dir = user_dir diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 72be5fb57..25a56bdcd 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -9,21 +9,16 @@ import os import os.path as op import subprocess -from textwrap import dedent import numpy as np from numpy.testing import assert_array_equal as ae from pytest import raises from six import string_types -from traitlets import Float -from traitlets.config import Configurable from .._misc import (_git_version, _load_json, _save_json, _read_python, _write_text, - _load_config, load_master_config, _encode_qbytearray, _decode_qbytearray, ) -from .. import _misc #------------------------------------------------------------------------------ @@ -100,14 +95,6 @@ def test_write_text(tempdir): assert f.read() == 'hello world' -def test_phy_user_dir(): - assert _misc.phy_user_dir().endswith('.phy/') - - -def test_temp_user_dir(temp_user_dir): - assert _misc.phy_user_dir() == temp_user_dir - - def test_git_version(): v = _git_version() @@ -121,51 +108,3 @@ def test_git_version(): assert v[:5] == "-git-", "Git version does not begin in -git-" except (OSError, subprocess.CalledProcessError): # pragma: no cover assert v == "" - - -#------------------------------------------------------------------------------ -# Config tests -#------------------------------------------------------------------------------ - -def test_load_config(tempdir): - - class MyConfigurable(Configurable): - my_var = Float(0.0, config=True) - - assert MyConfigurable().my_var == 0.0 - - # Create and load a config file. - config_contents = dedent(""" - c = get_config() - - c.MyConfigurable.my_var = 1.0 - """) - - path = op.join(tempdir, 'config.py') - with open(path, 'w') as f: - f.write(config_contents) - - c = _load_config(path) - assert c.MyConfigurable.my_var == 1.0 - - # Create a new MyConfigurable instance. - configurable = MyConfigurable() - assert configurable.my_var == 0.0 - - # Load the config object. - configurable.update_config(c) - assert configurable.my_var == 1.0 - - -def test_load_master_config(temp_user_dir): - # Create a config file in the temporary user directory. - config_contents = dedent(""" - c = get_config() - c.MyConfigurable.my_var = 1.0 - """) - with open(op.join(temp_user_dir, 'phy_config.py'), 'w') as f: - f.write(config_contents) - - # Load the master config file. - c = load_master_config() - assert c.MyConfigurable.my_var == 1. diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index a14a97aa2..8f15d3036 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -17,7 +17,8 @@ discover_plugins, get_all_plugins, ) -from .._misc import _write_text, load_master_config +from .._misc import _write_text +from ..settings import load_master_config #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 37b49df7f..814f45a9c 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -7,13 +7,18 @@ #------------------------------------------------------------------------------ import os.path as op +from textwrap import dedent from pytest import raises, yield_fixture +from traitlets import Float +from traitlets.config import Configurable +from .. import settings as _settings from ..settings import (BaseSettings, Settings, _recursive_dirs, - _phy_user_dir, + _load_config, + load_master_config, ) @@ -34,7 +39,11 @@ def settings(request): #------------------------------------------------------------------------------ def test_phy_user_dir(): - assert op.exists(_phy_user_dir()) + assert _settings.phy_user_dir().endswith('.phy/') + + +def test_temp_user_dir(temp_user_dir): + assert _settings.phy_user_dir() == temp_user_dir def test_recursive_dirs(): @@ -186,3 +195,51 @@ def test_settings_manager(tempdir, tempdir_bis): assert str(sm).startswith(' Date: Fri, 16 Oct 2015 08:14:09 +0200 Subject: [PATCH 0310/1059] WIP: remove settings --- phy/utils/__init__.py | 2 +- phy/utils/settings.py | 171 ------------------------------- phy/utils/tests/test_settings.py | 168 +----------------------------- 3 files changed, 2 insertions(+), 339 deletions(-) diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index f68479984..4bad955ab 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -6,4 +6,4 @@ from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, Bunch, _is_list) from .event import EventEmitter, ProgressReporter -from .settings import Settings, _ensure_dir_exists +from .settings import _ensure_dir_exists diff --git a/phy/utils/settings.py b/phy/utils/settings.py index 2caf77a4e..462ab4489 100644 --- a/phy/utils/settings.py +++ b/phy/utils/settings.py @@ -18,177 +18,6 @@ logger = logging.getLogger(__name__) -#------------------------------------------------------------------------------ -# Settings -#------------------------------------------------------------------------------ - -def _create_empty_settings(path): - """Create an empty Python file if the path doesn't exist.""" - # If the file exists of if it is an internal settings file, skip. - if op.exists(path) or op.splitext(path)[1] == '': - return - logger.debug("Creating empty settings file: %s.", path) - with open(path, 'a') as f: - f.write("# Settings file. Refer to phy's documentation " - "for more details.\n") - - -def _recursive_dirs(): - """Yield all subdirectories paths in phy's package.""" - phy_root = op.join(op.realpath(op.dirname(__file__)), '../') - for root, dirs, files in os.walk(phy_root): - root = op.realpath(root) - root = op.relpath(root, phy_root) - if ('.' in root or '_' in root or 'tests' in root or - 'static' in root or 'glsl' in root): - continue - yield op.realpath(op.join(phy_root, root)) - - -class BaseSettings(object): - """Store key-value pairs.""" - def __init__(self): - self._store = {} - self._to_save = {} - - def __getitem__(self, key): - return self._store[key] - - def __setitem__(self, key, value): - self._store[key] = value - self._to_save[key] = value - - def __contains__(self, key): - return key in self._store - - def __repr__(self): - return self._store.__repr__() - - def keys(self): - """List of settings keys.""" - return self._store.keys() - - def _update(self, d): - for k, v in d.items(): - if isinstance(v, dict) and k in self._store: - # Update instead of overwrite settings dictionaries. - self._store[k].update(v) - else: - self._store[k] = v - - def load(self, path): - """Load a settings file.""" - if not isinstance(path, string_types): - logger.warn("The settings file `%s` is invalid.", path) - return - path = op.realpath(path) - if not op.exists(path): - logger.debug("The settings file `%s` doesn't exist.", path) - return - try: - if op.splitext(path)[1] == '.py': - self._update(_read_python(path)) - logger.debug("Read settings file %s.", path) - elif op.splitext(path)[1] == '.json': - self._update(_load_json(path)) - logger.debug("Read settings file %s.", path) - else: - logger.warn("The settings file %s must have the extension " - "'.py' or '.json'.", path) - except Exception as e: - logger.warn("Unable to read %s. " - "Please try to delete this file. %s", path, str(e)) - - def save(self, path): - """Save the settings to a JSON file.""" - path = op.realpath(path) - try: - _save_json(path, self._to_save) - logger.debug("Saved internal settings file to `%s`.", path) - except Exception as e: # pragma: no cover - logger.warn("Unable to save the internal settings file " - "to `%s`:\n%s", path, str(e)) - self._to_save = {} - - -class Settings(object): - """Manage user-wide, and experiment-wide settings.""" - - def __init__(self, phy_user_dir=None): - self.phy_user_dir = phy_user_dir - if self.phy_user_dir: - _ensure_dir_exists(self.phy_user_dir) - self._bs = BaseSettings() - self._load_user_settings() - - def _load_user_settings(self): - - if not self.phy_user_dir: - return - - # User settings. - self.user_settings_path = op.join(self.phy_user_dir, - 'user_settings.py') - - # Create empty settings path if necessary. - _create_empty_settings(self.user_settings_path) - - self._bs.load(self.user_settings_path) - - # Load the user's internal settings. - self.internal_settings_path = op.join(self.phy_user_dir, - 'internal_settings.json') - self._bs.load(self.internal_settings_path) - - def on_open(self, path): - """Initialize settings when loading an experiment.""" - assert path is not None - # Get the experiment settings path. - path = op.realpath(op.expanduser(path)) - self.exp_path = path - self.exp_name = op.splitext(op.basename(path))[0] - self.exp_dir = op.dirname(path) - self.exp_settings_dir = op.join(self.exp_dir, self.exp_name + '.phy') - - self.exp_settings_path = op.join(self.exp_settings_dir, - 'user_settings.py') - _ensure_dir_exists(self.exp_settings_dir) - - # Create empty settings path if necessary. - _create_empty_settings(self.exp_settings_path) - - # Load experiment-wide settings. - self._load_user_settings() - self._bs.load(self.exp_settings_path) - - def save(self): - """Save settings to an internal settings file.""" - self._bs.save(self.internal_settings_path) - - def get(self, key, default=None): - """Return a settings value.""" - if key in self: - return self[key] - else: - return default - - def __getitem__(self, key): - return self._bs[key] - - def __setitem__(self, key, value): - self._bs[key] = value - - def __contains__(self, key): - return key in self._bs - - def __repr__(self): - return "".format(self._bs.__repr__()) - - def keys(self): - """Return the list of settings keys.""" - return self._bs.keys() - - #------------------------------------------------------------------------------ # Config #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py index 814f45a9c..ee3b46168 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_settings.py @@ -14,26 +14,11 @@ from traitlets.config import Configurable from .. import settings as _settings -from ..settings import (BaseSettings, - Settings, - _recursive_dirs, - _load_config, +from ..settings import (_load_config, load_master_config, ) -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -@yield_fixture(params=['py', 'json']) -def settings(request): - if request.param == 'py': - yield ('py', '''a = 4\nb = 5\nd = {'k1': 2, 'k2': '3'}''') - elif request.param == 'json': - yield ('json', '''{"a": 4, "b": 5, "d": {"k1": 2, "k2": "3"}}''') - - #------------------------------------------------------------------------------ # Test settings #------------------------------------------------------------------------------ @@ -46,157 +31,6 @@ def test_temp_user_dir(temp_user_dir): assert _settings.phy_user_dir() == temp_user_dir -def test_recursive_dirs(): - dirs = list(_recursive_dirs()) - assert len(dirs) >= 5 - root = op.join(op.realpath(op.dirname(__file__)), '../../') - for dir in dirs: - dir = op.relpath(dir, root) - assert '.' not in dir - assert '_' not in dir - - -def test_base_settings(): - s = BaseSettings() - - # Namespaces are mandatory. - with raises(KeyError): - s['a'] - - s['a'] = 3 - assert s['a'] == 3 - - -def test_base_settings_wrong_extension(tempdir): - path = op.join(tempdir, 'test') - with open(path, 'w'): - pass - s = BaseSettings() - s.load(path=path) - - -def test_base_settings_file(tempdir, settings): - ext, settings = settings - path = op.join(tempdir, 'test.' + ext) - with open(path, 'w') as f: - f.write(settings) - - s = BaseSettings() - - s['a'] = 3 - s['c'] = 6 - assert s['a'] == 3 - - # Warning: wrong path. - s.load(path=None) - - # Now, load the settings file. - s.load(path=path) - assert s['a'] == 4 - assert s['b'] == 5 - assert s['c'] == 6 - assert s['d'] == {'k1': 2, 'k2': '3'} - - s = BaseSettings() - s['d'] = {'k2': 30, 'k3': 40} - s.load(path=path) - assert s['d'] == {'k1': 2, 'k2': '3', 'k3': 40} - - -def test_base_settings_invalid(tempdir, settings): - ext, settings = settings - settings = settings[:-2] - path = op.join(tempdir, 'test.' + ext) - with open(path, 'w') as f: - f.write(settings) - - s = BaseSettings() - s.load(path) - assert 'a' not in s - - -def test_internal_settings(tempdir): - path = op.join(tempdir, 'test.json') - - s = BaseSettings() - - # Set the 'test' namespace. - s['a'] = 3 - s['c'] = 6 - assert s['a'] == 3 - assert s['c'] == 6 - - s.save(path) - assert s['a'] == 3 - assert s['c'] == 6 - - s = BaseSettings() - with raises(KeyError): - s['a'] - - s.load(path) - assert s['a'] == 3 - assert s['c'] == 6 - - -def test_settings_nodir(): - Settings() - - -def test_settings_manager(tempdir, tempdir_bis): - tempdir_exp = tempdir_bis - sm = Settings(tempdir) - - # Check paths. - assert sm.phy_user_dir == tempdir - assert sm.internal_settings_path == op.join(tempdir, - 'internal_settings.json') - assert sm.user_settings_path == op.join(tempdir, 'user_settings.py') - - # User settings. - with raises(KeyError): - sm['a'] - assert sm.get('a', None) is None - # Artificially populate the user settings. - sm._bs._store['a'] = 3 - assert sm['a'] == 3 - assert sm.get('a') == 3 - - # Internal settings. - sm['c'] = 5 - assert sm['c'] == 5 - - # Set an experiment path. - path = op.join(tempdir_exp, 'myexperiment.dat') - sm.on_open(path) - assert op.realpath(sm.exp_path) == op.realpath(path) - assert sm.exp_name == 'myexperiment' - assert (op.realpath(sm.exp_settings_dir) == - op.realpath(op.join(tempdir_exp, 'myexperiment.phy'))) - assert (op.realpath(sm.exp_settings_path) == - op.realpath(op.join(tempdir_exp, 'myexperiment.phy/' - 'user_settings.py'))) - - # User settings. - assert sm['a'] == 3 - sm._bs._store['a'] = 30 - assert sm['a'] == 30 - - # Internal settings. - sm['c'] = 50 - assert sm['c'] == 50 - - # Check persistence. - sm.save() - sm = Settings(tempdir) - sm.on_open(path) - assert sm['c'] == 50 - assert 'a' not in sm - - assert str(sm).startswith(' Date: Fri, 16 Oct 2015 08:16:58 +0200 Subject: [PATCH 0311/1059] Move settings module to config --- phy/io/datasets.py | 2 +- phy/utils/__init__.py | 2 +- phy/utils/cli.py | 2 +- phy/utils/{settings.py => config.py} | 5 +---- phy/utils/plugin.py | 4 ++-- phy/utils/testing.py | 2 +- phy/utils/tests/conftest.py | 14 +++++++------- .../tests/{test_settings.py => test_config.py} | 17 ++++++++--------- phy/utils/tests/test_plugin.py | 2 +- 9 files changed, 23 insertions(+), 27 deletions(-) rename phy/utils/{settings.py => config.py} (92%) rename phy/utils/tests/{test_settings.py => test_config.py} (84%) diff --git a/phy/io/datasets.py b/phy/io/datasets.py index 3b9a47918..4b9e8f36b 100644 --- a/phy/io/datasets.py +++ b/phy/io/datasets.py @@ -12,7 +12,7 @@ import os.path as op from phy.utils.event import ProgressReporter -from phy.utils.settings import phy_user_dir, _ensure_dir_exists +from phy.utils.config import phy_user_dir, _ensure_dir_exists logger = logging.getLogger(__name__) diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 4bad955ab..5355a3aad 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -6,4 +6,4 @@ from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, Bunch, _is_list) from .event import EventEmitter, ProgressReporter -from .settings import _ensure_dir_exists +from .config import _ensure_dir_exists diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 947addec1..0b05c52ee 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -35,7 +35,7 @@ def phy(ctx): def load_cli_plugins(cli): """Load all plugins and attach them to a CLI object.""" - from .settings import load_master_config + from .config import load_master_config from .plugin import get_all_plugins config = load_master_config() diff --git a/phy/utils/settings.py b/phy/utils/config.py similarity index 92% rename from phy/utils/settings.py rename to phy/utils/config.py index 462ab4489..66c48a0da 100644 --- a/phy/utils/settings.py +++ b/phy/utils/config.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Settings.""" +"""Config.""" #------------------------------------------------------------------------------ # Imports @@ -10,11 +10,8 @@ import os import os.path as op -from six import string_types from traitlets.config import Config, PyFileConfigLoader -from ._misc import _load_json, _save_json, _read_python - logger = logging.getLogger(__name__) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index c121628cc..e6c994cc8 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -18,7 +18,7 @@ from six import with_metaclass -from . import settings +from . import config logger = logging.getLogger(__name__) @@ -110,7 +110,7 @@ def _builtin_plugins_dir(): def _user_plugins_dir(): - return op.expanduser(op.join(settings.phy_user_dir(), 'plugins/')) + return op.expanduser(op.join(config.phy_user_dir(), 'plugins/')) def get_all_plugins(config=None): diff --git a/phy/utils/testing.py b/phy/utils/testing.py index cb53064e1..020b4d275 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -22,7 +22,7 @@ from six.moves import builtins from ._types import _is_array_like -from .settings import _ensure_dir_exists +from .config import _ensure_dir_exists logger = logging.getLogger(__name__) diff --git a/phy/utils/tests/conftest.py b/phy/utils/tests/conftest.py index 79ef0eba9..7026bc466 100644 --- a/phy/utils/tests/conftest.py +++ b/phy/utils/tests/conftest.py @@ -18,22 +18,22 @@ def temp_user_dir(tempdir): """NOTE: the user directory should be loaded with: ```python - from .. import settings - settings.phy_user_dir() + from .. import config + config.phy_user_dir() ``` and not: ```python - from settings import phy_user_dir + from config import phy_user_dir ``` Otherwise, the monkey patching hack in tests won't work. """ - from phy.utils import settings + from phy.utils import config - user_dir = settings.phy_user_dir - settings.phy_user_dir = lambda: tempdir + user_dir = config.phy_user_dir + config.phy_user_dir = lambda: tempdir yield tempdir - settings.phy_user_dir = user_dir + config.phy_user_dir = user_dir diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_config.py similarity index 84% rename from phy/utils/tests/test_settings.py rename to phy/utils/tests/test_config.py index ee3b46168..02edcbc36 100644 --- a/phy/utils/tests/test_settings.py +++ b/phy/utils/tests/test_config.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Test settings.""" +"""Test config.""" #------------------------------------------------------------------------------ # Imports @@ -9,26 +9,25 @@ import os.path as op from textwrap import dedent -from pytest import raises, yield_fixture from traitlets import Float from traitlets.config import Configurable -from .. import settings as _settings -from ..settings import (_load_config, - load_master_config, - ) +from .. import config as _config +from ..config import (_load_config, + load_master_config, + ) #------------------------------------------------------------------------------ -# Test settings +# Test config #------------------------------------------------------------------------------ def test_phy_user_dir(): - assert _settings.phy_user_dir().endswith('.phy/') + assert _config.phy_user_dir().endswith('.phy/') def test_temp_user_dir(temp_user_dir): - assert _settings.phy_user_dir() == temp_user_dir + assert _config.phy_user_dir() == temp_user_dir #------------------------------------------------------------------------------ diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 8f15d3036..13855b0ec 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -18,7 +18,7 @@ get_all_plugins, ) from .._misc import _write_text -from ..settings import load_master_config +from ..config import load_master_config #------------------------------------------------------------------------------ From 5b6a378b58958c1eaaa93aecd58a77bc19efaea1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 08:28:05 +0200 Subject: [PATCH 0312/1059] WIP: test JSON config --- phy/utils/_misc.py | 4 +-- phy/utils/config.py | 11 +++++-- phy/utils/tests/test_config.py | 52 +++++++++++++++++++++++++--------- phy/utils/tests/test_plugin.py | 5 ++-- 4 files changed, 52 insertions(+), 20 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 514817fe3..3423c639b 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -118,8 +118,8 @@ def _read_python(path): return metadata -def _write_text(path, contents, *args, **kwargs): - contents = dedent(contents.format(*args, **kwargs)) +def _write_text(path, contents): + contents = dedent(contents) dir_path = op.dirname(path) if not op.exists(dir_path): os.mkdir(dir_path) diff --git a/phy/utils/config.py b/phy/utils/config.py index 66c48a0da..1f76b808a 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -10,7 +10,10 @@ import os import os.path as op -from traitlets.config import Config, PyFileConfigLoader +from traitlets.config import (Config, + PyFileConfigLoader, + JSONFileConfigLoader, + ) logger = logging.getLogger(__name__) @@ -34,7 +37,11 @@ def _load_config(path): return {} path = op.realpath(path) dirpath, filename = op.split(path) - config = PyFileConfigLoader(filename, dirpath).load_config() + file_ext = op.splitext(path)[1] + if file_ext == '.py': + config = PyFileConfigLoader(filename, dirpath).load_config() + elif file_ext == '.json': + config = JSONFileConfigLoader(filename, dirpath).load_config() return config diff --git a/phy/utils/tests/test_config.py b/phy/utils/tests/test_config.py index 02edcbc36..c81935b64 100644 --- a/phy/utils/tests/test_config.py +++ b/phy/utils/tests/test_config.py @@ -9,10 +9,12 @@ import os.path as op from textwrap import dedent +from pytest import yield_fixture from traitlets import Float from traitlets.config import Configurable from .. import config as _config +from .._misc import _write_text from ..config import (_load_config, load_master_config, ) @@ -34,25 +36,49 @@ def test_temp_user_dir(temp_user_dir): # Config tests #------------------------------------------------------------------------------ -def test_load_config(tempdir): - - class MyConfigurable(Configurable): - my_var = Float(0.0, config=True) +@yield_fixture +def py_config(tempdir): + # Create and load a config file. + config_contents = """ + c = get_config() + c.MyConfigurable.my_var = 1.0 + """ + path = op.join(tempdir, 'config.py') + _write_text(path, config_contents) + yield path - assert MyConfigurable().my_var == 0.0 +@yield_fixture +def json_config(tempdir): # Create and load a config file. - config_contents = dedent(""" - c = get_config() + config_contents = """ + { + "MyConfigurable": { + "my_var": 1.0 + } + } + """ + path = op.join(tempdir, 'config.json') + _write_text(path, config_contents) + yield path - c.MyConfigurable.my_var = 1.0 - """) - path = op.join(tempdir, 'config.py') - with open(path, 'w') as f: - f.write(config_contents) +@yield_fixture(params=['python', 'json']) +def config(py_config, json_config, request): + if request.param == 'python': + yield py_config + elif request.param == 'json': + yield json_config + + +def test_load_config(config): + + class MyConfigurable(Configurable): + my_var = Float(0.0, config=True) + + assert MyConfigurable().my_var == 0.0 - c = _load_config(path) + c = _load_config(config) assert c.MyConfigurable.my_var == 1.0 # Create a new MyConfigurable instance. diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 13855b0ec..96abdfd1c 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -52,11 +52,10 @@ def _write_my_plugins_dir_in_config(temp_user_dir): # Now, we specify the path to the plugin in the phy config file. config_contents = """ c = get_config() - c.Plugins.dirs = ['{}'] + c.Plugins.dirs = ['%s'] """ _write_text(op.join(temp_user_dir, 'phy_config.py'), - config_contents, - op.join(temp_user_dir, 'my_plugins/')) + config_contents % op.join(temp_user_dir, 'my_plugins/')) #------------------------------------------------------------------------------ From e308b5cb2661efa8eef7ec0e68e788bf7252409b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 08:51:32 +0200 Subject: [PATCH 0313/1059] Save config in JSON file --- phy/utils/cli.py | 2 +- phy/utils/config.py | 18 +++++++++++++++--- phy/utils/tests/test_config.py | 19 ++++++++++++++++++- 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 0b05c52ee..20e6d7299 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -43,7 +43,7 @@ def load_cli_plugins(cli): # TODO: try/except to avoid crashing if a plugin is broken. for plugin in plugins: - if not hasattr(plugin, 'attach_to_cli'): + if not hasattr(plugin, 'attach_to_cli'): # pragma: no cover continue logger.info("Attach plugin `%s`.", plugin.__name__) # NOTE: plugin is a class, so we need to instantiate it. diff --git a/phy/utils/config.py b/phy/utils/config.py index 1f76b808a..61863b06c 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -28,13 +28,16 @@ def phy_user_dir(): def _ensure_dir_exists(path): + """Ensure a directory exists.""" if not op.exists(path): os.makedirs(path) + assert op.exists(path) and op.isdir(path) def _load_config(path): + """Load a Python or JSON config file.""" if not op.exists(path): - return {} + return Config() path = op.realpath(path) dirpath, filename = op.split(path) file_ext = op.splitext(path)[1] @@ -46,10 +49,19 @@ def _load_config(path): def load_master_config(user_dir=None): - """Load a master Config file from `~/.phy/phy_config.py`.""" + """Load a master Config file from `~/.phy/phy_config.py|json`.""" user_dir = user_dir or phy_user_dir() c = Config() - paths = [op.join(user_dir, 'phy_config.py')] + paths = [op.join(user_dir, 'phy_config.json'), + op.join(user_dir, 'phy_config.py')] for path in paths: c.update(_load_config(path)) return c + + +def save_config(path, config): + """Save a config object to a JSON file.""" + import json + config['version'] = 1 + with open(path, 'w') as f: + json.dump(config, f) diff --git a/phy/utils/tests/test_config.py b/phy/utils/tests/test_config.py index c81935b64..6d8626938 100644 --- a/phy/utils/tests/test_config.py +++ b/phy/utils/tests/test_config.py @@ -15,8 +15,10 @@ from .. import config as _config from .._misc import _write_text -from ..config import (_load_config, +from ..config import (_ensure_dir_exists, + _load_config, load_master_config, + save_config, ) @@ -28,6 +30,12 @@ def test_phy_user_dir(): assert _config.phy_user_dir().endswith('.phy/') +def test_ensure_dir_exists(tempdir): + path = op.join(tempdir, 'a/b/c') + _ensure_dir_exists(path) + assert op.isdir(path) + + def test_temp_user_dir(temp_user_dir): assert _config.phy_user_dir() == temp_user_dir @@ -102,3 +110,12 @@ def test_load_master_config(temp_user_dir): # Load the master config file. c = load_master_config() assert c.MyConfigurable.my_var == 1. + + +def test_save_config(tempdir): + c = {'A': {'b': 3.}} + path = op.join(tempdir, 'config.json') + save_config(path, c) + + c1 = _load_config(path) + assert c1.A.b == 3. From fb11b36fc1bd8bc6aad3793510b41bf9b87b1e41 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 08:53:21 +0200 Subject: [PATCH 0314/1059] Typo --- phy/utils/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index e6c994cc8..a6190bbdd 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -117,7 +117,7 @@ def get_all_plugins(config=None): """Load all builtin and user plugins.""" # By default, builtin and default user plugin. dirs = [_builtin_plugins_dir(), _user_plugins_dir()] - # Add Plugins.dir from the optionally-passed config object. + # Add Plugins.dirs from the optionally-passed config object. if config: dirs += config.Plugins.dirs return [plugin for (plugin,) in discover_plugins(dirs)] From 779e6f40bde3b1172df0c6b49f6d1926e1133fad Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 08:59:02 +0200 Subject: [PATCH 0315/1059] Try adding OS X on travis --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 6a39c9c27..9a2c04cd3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,7 @@ language: python +os: + - linux + - osx sudo: false python: - "2.7" From 82a5c64a22ee9a059253620112e456c81ef52136 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 09:02:41 +0200 Subject: [PATCH 0316/1059] Remove OS X --- .travis.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 9a2c04cd3..6a39c9c27 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,4 @@ language: python -os: - - linux - - osx sudo: false python: - "2.7" From 6e3526a6edfa05ca1be239b492e8f6bd02ef83ae Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 09:05:13 +0200 Subject: [PATCH 0317/1059] Rename GUI events --- phy/gui/gui.py | 8 ++++---- phy/gui/tests/test_gui.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index eee58898f..77dda7469 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -53,8 +53,8 @@ class GUI(QtGui.QMainWindow): Events ------ - close_gui - show_gui + close + show Note ---- @@ -109,7 +109,7 @@ def unconnect_(self, *args, **kwargs): def closeEvent(self, e): """Qt slot when the window is closed.""" - res = self.emit('close_gui') + res = self.emit('close') # Discard the close event if False is returned by one of the callback # functions. if False in res: # pragma: no cover @@ -119,7 +119,7 @@ def closeEvent(self, e): def show(self): """Show the window.""" - self.emit('show_gui') + self.emit('show') super(GUI, self).show() # Views diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index da85e23de..d96bc2e48 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -127,9 +127,9 @@ def test_gui_1(qtbot): # Increase coverage. @gui.connect_ - def on_show_gui(): + def on_show(): pass - gui.unconnect_(on_show_gui) + gui.unconnect_(on_show) qtbot.keyPress(gui, Qt.Key_Control) qtbot.keyRelease(gui, Qt.Key_Control) @@ -186,7 +186,7 @@ def test_gui_state(qtbot): gui.add_view(_create_canvas(), 'view2') @gui.connect_ - def on_close_gui(): + def on_close(): _gs.append(gui.save_geometry_state()) gui.show() @@ -208,7 +208,7 @@ def on_close_gui(): gui.add_view(_create_canvas(), 'view2') @gui.connect_ - def on_show_gui(): + def on_show(): gui.restore_geometry_state(_gs[0]) qtbot.addWidget(gui) From 99bd874a5de106a7caea4ac05e36ba83565218f7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 11:17:07 +0200 Subject: [PATCH 0318/1059] Fix bug --- .gitignore | 1 + phy/utils/cli.py | 8 +++++--- phy/utils/plugin.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a7c023ad6..922883500 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +contrib data doc docker diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 20e6d7299..9236eb411 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -23,10 +23,12 @@ @click.group() @click.version_option(version=phy.__version_git__) -@click.help_option() +@click.help_option('-h', '--help') +@click.option('-d', '--debug', is_flag=True) @click.pass_context -def phy(ctx): - pass +def phy(ctx, debug): + if debug: + logging.setLevel('DEBUG') #------------------------------------------------------------------------------ diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index a6190bbdd..5479cb450 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -79,7 +79,7 @@ def discover_plugins(dirs): """ # Scan all subdirectories recursively. for plugin_dir in dirs: - plugin_dir = op.realpath(plugin_dir) + plugin_dir = op.realpath(op.expanduser(plugin_dir)) for subdir, dirs, files in os.walk(plugin_dir): # Skip test folders. base = op.basename(subdir) From 40164453f2a110a2fe9f95078d50d46d453c1d31 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 11:35:07 +0200 Subject: [PATCH 0319/1059] Fix DEBUG --- phy/__init__.py | 5 +++-- phy/utils/cli.py | 10 ++++++---- phy/utils/config.py | 6 ++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index e3813bb9b..cf65f316b 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -59,9 +59,10 @@ def add_default_handler(level='INFO'): logger.addHandler(handler) +DEBUG = False if '--debug' in sys.argv: # pragma: no cover - add_default_handler('DEBUG') - logger.info("Activate DEBUG level.") + DEBUG = True + sys.argv.remove('--debug') # Force dask to use the synchronous scheduler: we'll use ipyparallel diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 9236eb411..2f06bf4fe 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -13,6 +13,7 @@ import click import phy +from phy import add_default_handler, DEBUG logger = logging.getLogger(__name__) @@ -21,14 +22,15 @@ # CLI tool #------------------------------------------------------------------------------ +add_default_handler('DEBUG' if DEBUG else 'INFO') + + @click.group() @click.version_option(version=phy.__version_git__) @click.help_option('-h', '--help') -@click.option('-d', '--debug', is_flag=True) @click.pass_context -def phy(ctx, debug): - if debug: - logging.setLevel('DEBUG') +def phy(ctx): + pass #------------------------------------------------------------------------------ diff --git a/phy/utils/config.py b/phy/utils/config.py index 61863b06c..0e312cf7f 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -42,9 +42,11 @@ def _load_config(path): dirpath, filename = op.split(path) file_ext = op.splitext(path)[1] if file_ext == '.py': - config = PyFileConfigLoader(filename, dirpath).load_config() + config = PyFileConfigLoader(filename, dirpath, + log=logger).load_config() elif file_ext == '.json': - config = JSONFileConfigLoader(filename, dirpath).load_config() + config = JSONFileConfigLoader(filename, dirpath, + log=logger).load_config() return config From 3880a2a8d6eb9d3de239fbbf53c02a0396a30a08 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 12:06:21 +0200 Subject: [PATCH 0320/1059] WIP --- phy/cluster/__init__.py | 3 +++ phy/cluster/manual/_utils.py | 1 + phy/cluster/manual/gui_plugin.py | 2 -- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/phy/cluster/__init__.py b/phy/cluster/__init__.py index ff6b165f7..78194577b 100644 --- a/phy/cluster/__init__.py +++ b/phy/cluster/__init__.py @@ -2,3 +2,6 @@ # flake8: noqa """Automatic and manual clustering facilities.""" + +from . import algorithms +from . import manual diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 984d970b3..98ec04a1d 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -37,6 +37,7 @@ def create_cluster_meta(cluster_groups): meta = ClusterMeta() meta.add_field('group') + cluster_groups = cluster_groups or {} data = {c: {'group': v} for c, v in cluster_groups.items()} meta.from_dict(data) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index a321c8cbf..e33fbab65 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -198,8 +198,6 @@ def on_reset(): actions.add(callback=self.undo) actions.add(callback=self.redo) - actions.reset() - def attach_to_gui(self, gui): self.gui = gui From 8608a34c10b6d2642ca455f6ad81f9f5af4acf4f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 12:06:44 +0200 Subject: [PATCH 0321/1059] WIP --- phy/utils/plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 5479cb450..035c231a2 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -32,7 +32,7 @@ class IPluginRegistry(type): def __init__(cls, name, bases, attrs): if name != 'IPlugin': - logger.debug("Register plugin %s.", name) + logger.debug("Register plugin `%s`.", name) plugin_tuple = (cls,) if plugin_tuple not in IPluginRegistry.plugins: IPluginRegistry.plugins.append(plugin_tuple) @@ -85,12 +85,12 @@ def discover_plugins(dirs): base = op.basename(subdir) if 'test' in base or '__' in base: # pragma: no cover continue - logger.debug("Scanning %s.", subdir) + logger.debug("Scanning `%s`.", subdir) for filename in files: if (filename.startswith('__') or not filename.endswith('.py')): continue # pragma: no cover - logger.debug("Found %s.", filename) + logger.debug("Found plugin module `%s`.", filename) path = os.path.join(subdir, filename) modname, ext = op.splitext(filename) file, path, descr = imp.find_module(modname, [subdir]) From dfe9f8de2e0c8335e9501ae66553720802fc4b8f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 12:15:13 +0200 Subject: [PATCH 0322/1059] Update logging in actions --- phy/gui/actions.py | 16 ++++++++++++---- phy/gui/tests/test_actions.py | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 6f33b091b..afde99d7c 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -190,7 +190,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, # Log the creation of the action. if not name.startswith('_'): - logger.debug("Add action `%s`, alias `%s`, shortcut %s.", + logger.debug("Add action `%s`, alias `%s`, shortcut `%s`.", name, alias, shortcut or '') if callback: @@ -208,6 +208,8 @@ def change_shortcut(self, name, shortcut): assert name in self._actions, "This action doesn't exist." action = self._actions[name] action.shortcut = shortcut + logger.debug("Change shortcut of action `%s` to shortcut `%s`.", + name, shortcut or '') _set_shortcut(action.qaction, shortcut) def run(self, action, *args): @@ -329,7 +331,9 @@ def _backspace(self): def _enter(self): """Disable the snippet mode and execute the command.""" command = self.command - logger.debug("Snippet keystroke `Enter`.") + logger.log(5, "Snippet keystroke `Enter`.") + # NOTE: we need to set back the actions (mode_off) before running + # the command. self.mode_off() self.run(command) @@ -347,7 +351,7 @@ def _create_snippet_actions(self): def _make_func(char): def callback(): - logger.debug("Snippet keystroke `%s`.", char) + logger.log(5, "Snippet keystroke `%s`.", char) self.command += char return callback @@ -375,7 +379,11 @@ def run(self, snippet): snippet = snippet[1:] snippet_args = _parse_snippet(snippet) alias = snippet_args[0] - name = self._actions.get_name(alias) + try: + name = self._actions.get_name(alias) + except ValueError: + logger.warn("Snippet `%s` cannot be found.", alias) + return assert name func = getattr(self._actions, name) try: diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index e3efd9115..d23c7fdca 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -134,8 +134,9 @@ def test(arg): snippets.attach(None, actions) actions.reset() - with raises(ValueError): + with captured_logging() as buf: snippets.run(':t1') + assert 'cannot be found' in buf.getvalue().lower() with captured_logging() as buf: snippets.run(':t') From 96d5177642f2d97fa7d9998571e82e071cbbcb91 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 12:29:02 +0200 Subject: [PATCH 0323/1059] WIP: remove change_shortcut --- phy/gui/actions.py | 13 +++++-------- phy/gui/tests/test_actions.py | 5 ----- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index afde99d7c..b263f74f6 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -140,6 +140,11 @@ def on_reset(): def exit(): gui.close() + # Reset the actions when the GUI is first shown. + @gui.connect_ + def on_show(): + self.reset() + def _create_action_bunch(self, callback=None, name=None, shortcut=None, alias=None, checkable=False, checked=False): @@ -204,14 +209,6 @@ def get_name(self, alias_or_name): return name raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) - def change_shortcut(self, name, shortcut): - assert name in self._actions, "This action doesn't exist." - action = self._actions[name] - action.shortcut = shortcut - logger.debug("Change shortcut of action `%s` to shortcut `%s`.", - name, shortcut or '') - _set_shortcut(action.qaction, shortcut) - def run(self, action, *args): """Run an action, specified by its name or object.""" if isinstance(action, string_types): diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index d23c7fdca..ed449a456 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -64,11 +64,6 @@ def show_my_shortcuts(): assert 'show_my_shortcuts' in _captured[0] assert ': h' in _captured[0] - actions.change_shortcut('show_my_shortcuts', 'l') - actions.show_my_shortcuts() - assert 'show_my_shortcuts' in _captured[0] - assert ': l' in _captured[-1] - with raises(ValueError): assert actions.get_name('e') assert actions.get_name('t') == 'test' From c604c7d6c73ec675a4948ec50f746aa98cca34b7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 12:45:38 +0200 Subject: [PATCH 0324/1059] WIP: default shortcuts --- phy/cluster/manual/gui_plugin.py | 51 +++++++++++++++------ phy/cluster/manual/tests/test_gui_plugin.py | 2 + phy/gui/actions.py | 2 +- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index e33fbab65..f219ba979 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -16,6 +16,7 @@ from .clustering import Clustering from .wizard import Wizard from phy.gui.actions import Actions, Snippets +from phy.gui.qt import QtGui from phy.io.array import select_spikes from phy.utils.plugin import IPlugin @@ -153,6 +154,25 @@ class ManualClustering(IPlugin): save_requested(spike_clusters, cluster_groups) """ + + default_shortcuts = { + 'save': QtGui.QKeySequence.Save, + # Wizard actions. + 'reset_wizard': 'ctrl+w', + 'next': 'space', + 'previous': 'shift+space', + 'reset_wizard': 'ctrl+alt+space', + 'first': QtGui.QKeySequence.MoveToStartOfLine, + 'last': QtGui.QKeySequence.MoveToEndOfLine, + 'pin': 'return', + 'unpin': QtGui.QKeySequence.Back, + # Clustering actions. + 'merge': 'g', + 'split': 'k', + 'undo': QtGui.QKeySequence.Undo, + 'redo': QtGui.QKeySequence.Redo, + } + def __init__(self, spike_clusters=None, cluster_groups=None, n_spikes_max_per_cluster=100, @@ -172,6 +192,11 @@ def __init__(self, spike_clusters=None, # Create the actions. self._create_actions() + def _add_action(self, callback, name=None, alias=None): + name = name or callback.__name__ + shortcut = self.default_shortcuts.get(name, None) + self.actions.add(callback=callback, name=name, shortcut=shortcut) + def _create_actions(self): self.actions = actions = Actions() self.snippets = Snippets() @@ -180,23 +205,23 @@ def _create_actions(self): @actions.connect def on_reset(): # Selection. - actions.add(callback=self.select, alias='c') + self._add_action(self.select, alias='c') # Wizard. - actions.add(callback=self.wizard.restart, name='reset_wizard') - actions.add(callback=self.wizard.previous) - actions.add(callback=self.wizard.next) - actions.add(callback=self.wizard.next_by_quality) - actions.add(callback=self.wizard.next_by_similarity) - actions.add(callback=self.wizard.pin) - actions.add(callback=self.wizard.unpin) + self._add_action(self.wizard.restart, name='reset_wizard') + self._add_action(self.wizard.previous) + self._add_action(self.wizard.next) + self._add_action(self.wizard.next_by_quality) + self._add_action(self.wizard.next_by_similarity) + self._add_action(self.wizard.pin) + self._add_action(self.wizard.unpin) # Clustering. - actions.add(callback=self.merge) - actions.add(callback=self.split) - actions.add(callback=self.move) - actions.add(callback=self.undo) - actions.add(callback=self.redo) + self._add_action(self.merge) + self._add_action(self.split) + self._add_action(self.move) + self._add_action(self.undo) + self._add_action(self.redo) def attach_to_gui(self, gui): self.gui = gui diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index 4613d1820..0bd56f11b 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -49,6 +49,8 @@ def assert_selection(*cluster_ids): # pragma: no cover elif len(cluster_ids) >= 2: assert mc.wizard.match == cluster_ids[2] + mc.actions.reset() + yield mc, assert_selection diff --git a/phy/gui/actions.py b/phy/gui/actions.py index b263f74f6..c509fe463 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -136,7 +136,7 @@ def attach(self, gui): @self.connect def on_reset(): # Default exit action. - @self.add(shortcut='ctrl+q') + @self.add(shortcut=QtGui.QKeySequence.Quit) def exit(): gui.close() From 78e21db35b6df3f8e0bb58296d6ba9f938005efd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 12:47:10 +0200 Subject: [PATCH 0325/1059] WIP: shortcuts --- phy/cluster/manual/gui_plugin.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index f219ba979..a0e30ecbc 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -176,10 +176,16 @@ class ManualClustering(IPlugin): def __init__(self, spike_clusters=None, cluster_groups=None, n_spikes_max_per_cluster=100, + shortcuts=None, ): self.n_spikes_max_per_cluster = n_spikes_max_per_cluster + # Load default shortcuts, and override any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + if shortcuts: + self.shortcuts.update(shortcuts) + # Create Clustering and ClusterMeta. self.clustering = Clustering(spike_clusters) self.cluster_meta = create_cluster_meta(cluster_groups) @@ -194,7 +200,7 @@ def __init__(self, spike_clusters=None, def _add_action(self, callback, name=None, alias=None): name = name or callback.__name__ - shortcut = self.default_shortcuts.get(name, None) + shortcut = self.shortcuts.get(name, None) self.actions.add(callback=callback, name=name, shortcut=shortcut) def _create_actions(self): From 1b156f6e9c3a3b895a5c7f7507890992ab0fe58a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 13:24:20 +0200 Subject: [PATCH 0326/1059] Refactor show shortcut --- phy/gui/actions.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index c509fe463..358c51e9f 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -58,8 +58,10 @@ def _parse_snippet(s): # Show shortcut utility functions # ----------------------------------------------------------------------------- -def _show_shortcut(shortcut): - if isinstance(shortcut, string_types): +def _shortcut_string(shortcut): + if isinstance(shortcut, QtGui.QKeySequence.StandardKey): + return QtGui.QKeySequence(shortcut).toString().lower() + elif isinstance(shortcut, string_types): return shortcut elif isinstance(shortcut, (tuple, list)): return ', '.join(shortcut) @@ -72,7 +74,7 @@ def _show_shortcuts(shortcuts, name=None): name = ' for ' + name print('Keyboard shortcuts' + name) for name in sorted(shortcuts): - print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) + print('{0:<40}: {1:s}'.format(name, _shortcut_string(shortcuts[name]))) print() @@ -195,9 +197,13 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, # Log the creation of the action. if not name.startswith('_'): - logger.debug("Add action `%s`, alias `%s`, shortcut `%s`.", - name, alias, shortcut or '') - + if isinstance(shortcut, QtGui.QKeySequence.StandardKey): + shortcut = QtGui.QKeySequence(shortcut).toString().lower() + elif shortcut is None: + shortcut = '' + msg = "Add action `%s`, alias `%s`" % (name, alias) + msg += (", shortcut `%s`." % shortcut) if shortcut else '.' + logger.debug(msg) if callback: setattr(self, name, callback) return action From 12414085d3dfb12ce35fafe4c0b019c9a1cacf49 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 13:29:37 +0200 Subject: [PATCH 0327/1059] Increase coverage --- phy/cluster/manual/tests/test_gui_plugin.py | 1 + phy/gui/tests/test_actions.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index 0bd56f11b..f0022fda0 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -31,6 +31,7 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups): mc = gui.attach('ManualClustering', spike_clusters=spike_clusters, cluster_groups=cluster_groups, + shortcuts={'undo': 'ctrl+z'}, ) _s = [] diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index ed449a456..98634d1e9 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -9,6 +9,7 @@ from pytest import raises, yield_fixture from ..actions import _show_shortcuts, Actions, Snippets, _parse_snippet +from phy.gui.qt import QtGui from phy.utils.testing import captured_output, captured_logging @@ -30,14 +31,18 @@ def snippets(): # Test actions #------------------------------------------------------------------------------ -def test_shortcuts(): +def test_shortcuts(qtbot): + # NOTE: a Qt application needs to be running so that we can use the + # KeySequence. shortcuts = { 'test_1': 'ctrl+t', 'test_2': ('ctrl+a', 'shift+b'), + 'test_3': QtGui.QKeySequence.Undo, } with captured_output() as (stdout, stderr): _show_shortcuts(shortcuts, 'test') assert 'ctrl+a, shift+b' in stdout.getvalue() + assert 'ctrl+z' in stdout.getvalue() def test_actions_simple(actions): From b8f7a5cf1be2653fd3d95fe27fff55b4603c8fa3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 13:34:33 +0200 Subject: [PATCH 0328/1059] Fix Python 2 bug --- phy/gui/actions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 358c51e9f..fdad07128 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -60,7 +60,7 @@ def _parse_snippet(s): def _shortcut_string(shortcut): if isinstance(shortcut, QtGui.QKeySequence.StandardKey): - return QtGui.QKeySequence(shortcut).toString().lower() + return str(QtGui.QKeySequence(shortcut).toString()).lower() elif isinstance(shortcut, string_types): return shortcut elif isinstance(shortcut, (tuple, list)): @@ -197,10 +197,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, # Log the creation of the action. if not name.startswith('_'): - if isinstance(shortcut, QtGui.QKeySequence.StandardKey): - shortcut = QtGui.QKeySequence(shortcut).toString().lower() - elif shortcut is None: - shortcut = '' + shortcut = _shortcut_string(shortcut) msg = "Add action `%s`, alias `%s`" % (name, alias) msg += (", shortcut `%s`." % shortcut) if shortcut else '.' logger.debug(msg) From 804ba09956b9b6193d61797caabab68d28a6ea19 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 13:39:29 +0200 Subject: [PATCH 0329/1059] Bug fixes --- phy/cluster/manual/gui_plugin.py | 6 +++++- phy/gui/actions.py | 3 ++- phy/utils/cli.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index a0e30ecbc..9635c7ffc 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -10,6 +10,7 @@ import logging import numpy as np +from six import integer_types from ._history import GlobalHistory from ._utils import create_cluster_meta @@ -201,7 +202,8 @@ def __init__(self, spike_clusters=None, def _add_action(self, callback, name=None, alias=None): name = name or callback.__name__ shortcut = self.shortcuts.get(name, None) - self.actions.add(callback=callback, name=name, shortcut=shortcut) + self.actions.add(callback=callback, name=name, + shortcut=shortcut, alias=alias) def _create_actions(self): self.actions = actions = Actions() @@ -255,6 +257,8 @@ def on_start(): # ------------------------------------------------------------------------- def select(self, cluster_ids): + if isinstance(cluster_ids, integer_types): + cluster_ids = [cluster_ids] self.wizard.select(cluster_ids) # Clustering actions diff --git a/phy/gui/actions.py b/phy/gui/actions.py index fdad07128..ba76c6c79 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -200,7 +200,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, shortcut = _shortcut_string(shortcut) msg = "Add action `%s`, alias `%s`" % (name, alias) msg += (", shortcut `%s`." % shortcut) if shortcut else '.' - logger.debug(msg) + logger.log(5, msg) if callback: setattr(self, name, callback) return action @@ -391,6 +391,7 @@ def run(self, snippet): func(*snippet_args[1:]) except Exception as e: logger.warn("Error when executing snippet: %s.", str(e)) + logger.exception(e) def is_mode_on(self): return self.command.startswith(':') diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 2f06bf4fe..aeb9cc87c 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -49,7 +49,7 @@ def load_cli_plugins(cli): for plugin in plugins: if not hasattr(plugin, 'attach_to_cli'): # pragma: no cover continue - logger.info("Attach plugin `%s`.", plugin.__name__) + logger.info("Attach plugin `%s` to CLI.", plugin.__name__) # NOTE: plugin is a class, so we need to instantiate it. plugin().attach_to_cli(cli) From edc5e7e948d61520d9f486d5605dab67cc693903 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 15:25:38 +0200 Subject: [PATCH 0330/1059] WIP: refactor actions --- phy/gui/actions.py | 86 ++++++++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 49 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index ba76c6c79..ea55a4c51 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -65,6 +65,7 @@ def _shortcut_string(shortcut): return shortcut elif isinstance(shortcut, (tuple, list)): return ', '.join(shortcut) + return '' def _show_shortcuts(shortcuts, name=None): @@ -82,20 +83,34 @@ def _show_shortcuts(shortcuts, name=None): # Actions # ----------------------------------------------------------------------------- -def _alias_name(name): +def _alias(name): # Get the alias from the character after & if it exists. alias = name[name.index('&') + 1] if '&' in name else name - name = name.replace('&', '') - return alias, name + return alias -def _set_shortcut(action, shortcut): - if not shortcut: - return +def _get_qt_shortcut(single_shortcut): + if (isinstance(single_shortcut, string_types) and + hasattr(QtGui.QKeySequence, single_shortcut)): + return getattr(QtGui.QKeySequence, single_shortcut) + return single_shortcut + + +def _get_qt_shortcuts(shortcut): + if shortcut is None: + return [] if not isinstance(shortcut, (tuple, list)): shortcut = [shortcut] - for key in shortcut: + return [_get_qt_shortcut(s) for s in shortcut] + + +def _create_qaction(gui, name, callback, qt_shortcuts): + # Create the QAction instance. + action = QtGui.QAction(name, gui) + action.triggered.connect(callback) + for key in qt_shortcuts: action.setShortcut(key) + return action class Actions(EventEmitter): @@ -147,63 +162,36 @@ def exit(): def on_show(): self.reset() - def _create_action_bunch(self, callback=None, name=None, shortcut=None, - alias=None, checkable=False, checked=False): - - # Create the QAction instance. - action = QtGui.QAction(name, self._gui) - action.triggered.connect(callback) - action.setCheckable(checkable) - action.setChecked(checked) - _set_shortcut(action, shortcut) - - # HACK: add the shortcut string to the QAction object so that - # it can be shown in show_shortcuts(). I don't manage to recover - # the key sequence string from a QAction using Qt. - shortcut = shortcut or '' - - return Bunch(qaction=action, name=name, alias=alias, - shortcut=shortcut, callback=callback) - - def add(self, callback=None, name=None, shortcut=None, alias=None, - checkable=False, checked=False): + def add(self, callback=None, name=None, shortcut=None, alias=None): """Add an action with a keyboard shortcut.""" + # TODO: add menu_name option and create menu bar if callback is None: # Allow to use either add(func) or @add or @add(...). return partial(self.add, name=name, shortcut=shortcut, - alias=alias, checkable=checkable, checked=checked) - - # TODO: add menu_name option and create menu bar + alias=alias) + assert callback # Get the name from the callback function if needed. - assert callback name = name or callback.__name__ + alias = alias or _alias(name) + name = name.replace('&', '') + qt_shortcuts = _get_qt_shortcuts(shortcut) - if alias is None: - alias, name = _alias_name(name) - + # Skip existing action. if name in self._actions: return - action = self._create_action_bunch(name=name, - alias=alias, - shortcut=shortcut, - callback=callback) - - # Register the action. + # Create and register the action. + action = _create_qaction(self._gui, name, callback, qt_shortcuts) + action_obj = Bunch(qaction=action, name=name, alias=alias, + shortcut=shortcut, callback=callback) if self._gui: - self._gui.addAction(action.qaction) - self._actions[name] = action + self._gui.addAction(action) + self._actions[name] = action_obj - # Log the creation of the action. - if not name.startswith('_'): - shortcut = _shortcut_string(shortcut) - msg = "Add action `%s`, alias `%s`" % (name, alias) - msg += (", shortcut `%s`." % shortcut) if shortcut else '.' - logger.log(5, msg) + # Set the callback method. if callback: setattr(self, name, callback) - return action def get_name(self, alias_or_name): """Return an action name from its alias or name.""" From 9db89e0e667d724cbfede3513438f0990b050f8f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 15:36:18 +0200 Subject: [PATCH 0331/1059] WIP --- phy/cluster/manual/gui_plugin.py | 33 +++++++++++++++++++++++++++++++- phy/cluster/manual/wizard.py | 3 +-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index 9635c7ffc..e5be67aae 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -194,6 +194,36 @@ def __init__(self, spike_clusters=None, # Create the wizard and attach it to Clustering/ClusterMeta. self.wizard = Wizard() + + # Log the actions. + @self.clustering.connect + def on_cluster(up): + if up.history: + logger.info(up.history.title() + " cluster assign.") + elif up.description == 'merge': + logger.info("Merge clusters %s to %s.", + ', '.join(map(str, up.deleted)), + up.added[0]) + else: + # TODO: how many spikes? + logger.info("Assigned spikes.") + + @self.cluster_meta.connect # noqa + def on_cluster(up): + if up.history: + logger.info(up.history.title() + " move.") + else: + logger.info("Move clusters %s to %s.", + ', '.join(map(str, up.metadata_changed)), + up.metadata_value) + + @self.wizard.connect + def on_select(cluster_ids): + """When the wizard selects clusters, choose a spikes subset + and emit the `select` event on the GUI.""" + logger.debug("Select clusters %s.", + ', '.join(map(str, cluster_ids))) + _attach_wizard(self.wizard, self.clustering, self.cluster_meta) # Create the actions. @@ -202,7 +232,8 @@ def __init__(self, spike_clusters=None, def _add_action(self, callback, name=None, alias=None): name = name or callback.__name__ shortcut = self.shortcuts.get(name, None) - self.actions.add(callback=callback, name=name, + self.actions.add(callback=callback, + name=name, shortcut=shortcut, alias=alias) def _create_actions(self): diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index a0086ddc6..19674b25f 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -11,7 +11,7 @@ from operator import itemgetter from ._history import History -from phy.utils import EventEmitter, _is_array_like +from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -232,7 +232,6 @@ def n_clusters(self): def select(self, cluster_ids, add_to_history=True): if cluster_ids is None: # pragma: no cover return - assert _is_array_like(cluster_ids) clusters = self.cluster_ids cluster_ids = [cluster for cluster in cluster_ids if cluster in clusters] From cd7cc93c466bd76c4bad358b28099c3209b7e69d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 17:08:20 +0200 Subject: [PATCH 0332/1059] WIP: refactor actions and snippets --- phy/gui/actions.py | 142 +++++++++++++++------------------- phy/gui/tests/test_actions.py | 103 ++++++++++++++++-------- phy/gui/tests/test_gui.py | 22 +++--- 3 files changed, 144 insertions(+), 123 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index ea55a4c51..b9f687de5 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -14,7 +14,6 @@ from .qt import QtGui from phy.utils import Bunch -from phy.utils.event import EventEmitter logger = logging.getLogger(__name__) @@ -41,10 +40,10 @@ def _parse_list(s): # Range: 'x-y' if '-' in s: m, M = map(_parse_arg, s.split('-')) - return tuple(range(m, M + 1)) + return list(range(m, M + 1)) # List of ids: 'x,y,z' elif ',' in s: - return tuple(map(_parse_arg, s.split(','))) + return list(map(_parse_arg, s.split(','))) else: return _parse_arg(s) @@ -58,24 +57,43 @@ def _parse_snippet(s): # Show shortcut utility functions # ----------------------------------------------------------------------------- -def _shortcut_string(shortcut): - if isinstance(shortcut, QtGui.QKeySequence.StandardKey): - return str(QtGui.QKeySequence(shortcut).toString()).lower() - elif isinstance(shortcut, string_types): - return shortcut - elif isinstance(shortcut, (tuple, list)): - return ', '.join(shortcut) - return '' +def _get_shortcut_string(shortcut): + """Return a string representation of a shortcut.""" + if shortcut is None: + return '' + if isinstance(shortcut, (tuple, list)): + return ', '.join([_get_shortcut_string(s) for s in shortcut]) + if isinstance(shortcut, string_types): + return shortcut.lower() + assert isinstance(shortcut, QtGui.QKeySequence) + s = shortcut.toString() or '' + return str(s).lower() + + +def _get_qkeysequence(shortcut): + """Return a QKeySequence or list of QKeySequence from a shortcut string.""" + if shortcut is None: + return [] + if isinstance(shortcut, (tuple, list)): + return [_get_qkeysequence(s) for s in shortcut] + assert isinstance(shortcut, string_types) + if hasattr(QtGui.QKeySequence, shortcut): + return QtGui.QKeySequence(getattr(QtGui.QKeySequence, shortcut)) + sequence = QtGui.QKeySequence.fromString(shortcut) + assert not sequence.isEmpty() + return sequence def _show_shortcuts(shortcuts, name=None): + """Display shortcuts.""" name = name or '' print() if name: name = ' for ' + name print('Keyboard shortcuts' + name) for name in sorted(shortcuts): - print('{0:<40}: {1:s}'.format(name, _shortcut_string(shortcuts[name]))) + shortcut = _get_shortcut_string(shortcuts[name]) + print('{0:<40}: {1:s}'.format(name, shortcut)) print() @@ -89,31 +107,19 @@ def _alias(name): return alias -def _get_qt_shortcut(single_shortcut): - if (isinstance(single_shortcut, string_types) and - hasattr(QtGui.QKeySequence, single_shortcut)): - return getattr(QtGui.QKeySequence, single_shortcut) - return single_shortcut - - -def _get_qt_shortcuts(shortcut): - if shortcut is None: - return [] - if not isinstance(shortcut, (tuple, list)): - shortcut = [shortcut] - return [_get_qt_shortcut(s) for s in shortcut] - - -def _create_qaction(gui, name, callback, qt_shortcuts): +def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. action = QtGui.QAction(name, gui) action.triggered.connect(callback) - for key in qt_shortcuts: - action.setShortcut(key) + sequence = _get_qkeysequence(shortcut) + if not isinstance(sequence, (tuple, list)): + sequence = [sequence] + for s in sequence: + action.setShortcut(s) return action -class Actions(EventEmitter): +class Actions(object): """Handle GUI actions. This class attaches to a GUI and implements the following features: @@ -124,65 +130,40 @@ class Actions(EventEmitter): """ def __init__(self): - super(Actions, self).__init__() self._gui = None self._actions = {} - def reset(self): - """Reset the actions. - - All actions should be registered here, as follows: - - ```python - @actions.connect - def on_reset(): - actions.add(...) - actions.add(...) - ... - ``` - - """ - self.remove_all() - self.emit('reset') + def get_action_dict(self): + return self._actions.copy() def attach(self, gui): """Attach a GUI.""" self._gui = gui - # Register default actions. - @self.connect - def on_reset(): - # Default exit action. - @self.add(shortcut=QtGui.QKeySequence.Quit) - def exit(): - gui.close() - - # Reset the actions when the GUI is first shown. - @gui.connect_ - def on_show(): - self.reset() + # Default exit action. + @self.add(shortcut='Quit') + def exit(): + gui.close() def add(self, callback=None, name=None, shortcut=None, alias=None): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar if callback is None: # Allow to use either add(func) or @add or @add(...). - return partial(self.add, name=name, shortcut=shortcut, - alias=alias) + return partial(self.add, name=name, shortcut=shortcut, alias=alias) assert callback # Get the name from the callback function if needed. name = name or callback.__name__ alias = alias or _alias(name) name = name.replace('&', '') - qt_shortcuts = _get_qt_shortcuts(shortcut) # Skip existing action. if name in self._actions: return # Create and register the action. - action = _create_qaction(self._gui, name, callback, qt_shortcuts) + action = _create_qaction(self._gui, name, callback, shortcut) action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback) if self._gui: @@ -234,7 +215,7 @@ def shortcuts(self): def show_shortcuts(self): """Print all shortcuts.""" _show_shortcuts(self.shortcuts, - self._gui.title() if self._gui else None) + self._gui.windowTitle() if self._gui else None) # ----------------------------------------------------------------------------- @@ -266,7 +247,7 @@ class Snippets(object): """ - # HACK: Unicode characters do not appear to work on Python 2 + # HACK: Unicode characters do not seem to work on Python 2 cursor = '\u200A\u258C' if PY3 else '' # Allowed characters in snippet mode. @@ -281,13 +262,14 @@ def __init__(self): def attach(self, gui, actions): self._gui = gui self._actions = actions + # We will keep a backup of all actions so that we can switch + # safely to the set of shortcut actions when snippet mode is on. + self._actions_backup = {} # Register snippet mode shortcut. - @actions.connect - def on_reset(): - @actions.add(shortcut=':') - def enable_snippet_mode(): - self.mode_on() + @actions.add(shortcut=':') + def enable_snippet_mode(): + self.mode_on() @property def command(self): @@ -326,14 +308,11 @@ def _enter(self): self.run(command) def _create_snippet_actions(self): - """Delete all existing actions, and add mock ones for snippet - keystrokes. + """Add mock Qt actions for snippet keystrokes. Used to enable snippet mode. """ - self._actions.remove_all() - # One action per allowed character. for i, char in enumerate(self._snippet_chars): @@ -386,8 +365,10 @@ def is_mode_on(self): def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") - # Remove all existing actions, and replace them by snippet keystroke - # actions. + self._actions_backup = self._actions.get_action_dict() + # Remove all existing actions. + self._actions.remove_all() + # Add snippet keystroke actions. self._create_snippet_actions() self.command = ':' @@ -396,4 +377,9 @@ def mode_off(self): self._gui.status_message = '' logger.info("Snippet mode disabled.") # Reestablishes the shortcuts. - self._actions.reset() + for action_obj in self._actions_backup.values(): + self._actions.add(callback=action_obj.callback, + name=action_obj.name, + shortcut=action_obj.shortcut, + alias=action_obj.alias, + ) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 98634d1e9..fdde958a2 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -8,8 +8,13 @@ from pytest import raises, yield_fixture -from ..actions import _show_shortcuts, Actions, Snippets, _parse_snippet -from phy.gui.qt import QtGui +from ..actions import (_show_shortcuts, + _get_shortcut_string, + _get_qkeysequence, + _parse_snippet, + Actions, + Snippets, + ) from phy.utils.testing import captured_output, captured_logging @@ -32,12 +37,29 @@ def snippets(): #------------------------------------------------------------------------------ def test_shortcuts(qtbot): + def _assert_shortcut(name, key=None): + shortcut = _get_qkeysequence(name) + s = _get_shortcut_string(shortcut) + if key is None: + assert s == s + else: + assert key in s + + _assert_shortcut('Undo', 'z') + _assert_shortcut('Save', 's') + _assert_shortcut('q') + _assert_shortcut('ctrl+q') + _assert_shortcut(':') + _assert_shortcut(['ctrl+a', 'shift+b']) + + +def test_show_shortcuts(qtbot): # NOTE: a Qt application needs to be running so that we can use the # KeySequence. shortcuts = { 'test_1': 'ctrl+t', 'test_2': ('ctrl+a', 'shift+b'), - 'test_3': QtGui.QKeySequence.Undo, + 'test_3': 'ctrl+z', } with captured_output() as (stdout, stderr): _show_shortcuts(shortcuts, 'test') @@ -94,8 +116,8 @@ def _check(args, expected): _check('a', ['a']) _check('abc', ['abc']) - _check('a,b,c', [('a', 'b', 'c')]) - _check('a b,c', ['a', ('b', 'c')]) + _check('a,b,c', [['a', 'b', 'c']]) + _check('a b,c', ['a', ['b', 'c']]) _check('1', [1]) _check('10', [10]) @@ -108,31 +130,28 @@ def _check(args, expected): _check('0 1.', [0, 1.]) _check('0 1.0', [0, 1.]) - _check('0,1', [(0, 1)]) - _check('0,10.', [(0, 10.)]) - _check('0. 1,10.', [0., (1, 10.)]) + _check('0,1', [[0, 1]]) + _check('0,10.', [[0, 10.]]) + _check('0. 1,10.', [0., [1, 10.]]) - _check('2-7', [(2, 3, 4, 5, 6, 7)]) - _check('2 3-5', [2, (3, 4, 5)]) + _check('2-7', [[2, 3, 4, 5, 6, 7]]) + _check('2 3-5', [2, [3, 4, 5]]) - _check('a b,c d,2 3-5', ['a', ('b', 'c'), ('d', 2), (3, 4, 5)]) + _check('a b,c d,2 3-5', ['a', ['b', 'c'], ['d', 2], [3, 4, 5]]) def test_snippets_errors(actions, snippets): _actions = [] - @actions.connect - def on_reset(): - @actions.add(name='my_test', alias='t') - def test(arg): - # Enforce single-character argument. - assert len(str(arg)) == 1 - _actions.append(arg) + @actions.add(name='my_test', alias='t') + def test(arg): + # Enforce single-character argument. + assert len(str(arg)) == 1 + _actions.append(arg) # Attach the GUI and register the actions. snippets.attach(None, actions) - actions.reset() with captured_logging() as buf: snippets.run(':t1') @@ -154,27 +173,24 @@ def test(arg): assert _actions == ['a'] -def test_snippets_actions(actions, snippets): +def test_snippets_actions_1(actions, snippets): _actions = [] - @actions.connect - def on_reset(): - @actions.add(name='my_test_1') - def test_1(*args): - _actions.append((1, args)) + @actions.add(name='my_test_1') + def test_1(*args): + _actions.append((1, args)) - @actions.add(name='my_&test_2') - def test_2(*args): - _actions.append((2, args)) + @actions.add(name='my_&test_2') + def test_2(*args): + _actions.append((2, args)) - @actions.add(name='my_test_3', alias='t3') - def test_3(*args): - _actions.append((3, args)) + @actions.add(name='my_test_3', alias='t3') + def test_3(*args): + _actions.append((3, args)) # Attach the GUI and register the actions. snippets.attach(None, actions) - actions.reset() assert snippets.command == '' @@ -184,7 +200,7 @@ def test_3(*args): # Action 2. snippets.run(':t 1.5 a 2-4 5,7') - assert _actions[-1] == (2, (1.5, 'a', (2, 3, 4), (5, 7))) + assert _actions[-1] == (2, (1.5, 'a', [2, 3, 4], [5, 7])) def _run(cmd): """Simulate keystrokes.""" @@ -203,3 +219,24 @@ def _run(cmd): actions._snippet_activate() # 'Enter' assert _actions[-1] == (3, ('hello',)) snippets.mode_off() + + +def test_snippets_actions_2(actions, snippets): + + _actions = [] + + @actions.add + def test(arg): + _actions.append(arg) + + # Attach the GUI and register the actions. + snippets.attach(None, actions) + + actions.test(1) + assert _actions == [1] + + snippets.mode_on() + snippets.mode_off() + + actions.test(2) + assert _actions == [1, 2] diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index d96bc2e48..0256fb7bb 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -59,9 +59,6 @@ def gui(): def test_actions_gui(qtbot, gui, actions): actions.attach(gui) - # Set the default actions. - actions.reset() - qtbot.addWidget(gui) gui.show() qtbot.waitForWindowShown(gui) @@ -90,30 +87,31 @@ def test_snippets_gui(qtbot, gui, actions, snippets): _actions = [] - @actions.connect - def on_reset(): - @actions.add(name='my_test_1', alias='t1') - def test(*args): - _actions.append(args) + @actions.add(name='my_test_1', alias='t1') + def test(*args): + _actions.append(args) # Attach the GUI and register the actions. - snippets.attach(gui, actions) actions.attach(gui) - actions.reset() + snippets.attach(gui, actions) # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` assert not snippets.is_mode_on() + # print(gui.actions()[0].shortcut().toString()) + # actions.show_shortcuts() qtbot.keyClicks(gui, ':t2 ') - qtbot.waitForWindowShown(gui) + # qtbot.stop() + # return assert snippets.is_mode_on() qtbot.keyPress(gui, Qt.Key_Backspace) qtbot.keyPress(gui, Qt.Key_Backspace) qtbot.keyClicks(gui, '1 3-5 ab,c') + # qtbot.stop() qtbot.keyPress(gui, Qt.Key_Return) qtbot.waitForWindowShown(gui) - assert _actions == [((3, 4, 5), ('ab', 'c'))] + assert _actions == [([3, 4, 5], ['ab', 'c'])] #------------------------------------------------------------------------------ From f6a74916fe0af67407a27b5cd554b9f907543b62 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 17:41:03 +0200 Subject: [PATCH 0333/1059] WIP: refactor actions --- phy/gui/actions.py | 42 ++++++++++++++--------------------- phy/gui/tests/test_actions.py | 10 +-------- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index b9f687de5..c02b89808 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -132,6 +132,7 @@ class Actions(object): def __init__(self): self._gui = None self._actions = {} + self._aliases = {} def get_action_dict(self): return self._actions.copy() @@ -169,26 +170,22 @@ def add(self, callback=None, name=None, shortcut=None, alias=None): if self._gui: self._gui.addAction(action) self._actions[name] = action_obj + # Register the alias -> name mapping. + self._aliases[alias] = name # Set the callback method. if callback: setattr(self, name, callback) - def get_name(self, alias_or_name): - """Return an action name from its alias or name.""" - for name, action in self._actions.items(): - if alias_or_name in (action.alias, name): - return name - raise ValueError("Action `{}` doesn't exist.".format(alias_or_name)) - - def run(self, action, *args): - """Run an action, specified by its name or object.""" - if isinstance(action, string_types): - name = self.get_name(action) - assert name in self._actions - action = self._actions[name] - else: - name = action.name + def run(self, name, *args): + """Run an action as specified by its name.""" + assert isinstance(name, string_types) + # Resolve the alias if it is an alias. + name = self._aliases[name] + # Get the action. + action = self._actions.get(name, None) + if not action: + raise ValueError("Action `{}` doesn't exist.".format(name)) if not name.startswith('_'): logger.debug("Execute action `%s`.", name) return action.callback(*args) @@ -345,17 +342,12 @@ def run(self, snippet): assert snippet[0] == ':' snippet = snippet[1:] snippet_args = _parse_snippet(snippet) - alias = snippet_args[0] - try: - name = self._actions.get_name(alias) - except ValueError: - logger.warn("Snippet `%s` cannot be found.", alias) - return - assert name - func = getattr(self._actions, name) + name = snippet_args[0] + + logger.info("Processing snippet `%s`.", snippet) try: - logger.info("Processing snippet `%s`.", snippet) - func(*snippet_args[1:]) + # func(*snippet_args[1:]) + self._actions.run(name, *snippet_args[1:]) except Exception as e: logger.warn("Error when executing snippet: %s.", str(e)) logger.exception(e) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index fdde958a2..d08fc6e93 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -91,17 +91,9 @@ def show_my_shortcuts(): assert 'show_my_shortcuts' in _captured[0] assert ': h' in _captured[0] - with raises(ValueError): - assert actions.get_name('e') - assert actions.get_name('t') == 'test' - assert actions.get_name('test') == 'test' - actions.run('t', 1) assert _res == [(1,)] - # Run an action instance. - actions.run(actions._actions['test'], 1) - actions.remove_all() @@ -155,7 +147,7 @@ def test(arg): with captured_logging() as buf: snippets.run(':t1') - assert 'cannot be found' in buf.getvalue().lower() + assert 'error' in buf.getvalue().lower() with captured_logging() as buf: snippets.run(':t') From 23e69a4df836bfaae76fed4d8c1f864fbd03b7fa Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 17:46:45 +0200 Subject: [PATCH 0334/1059] Automatically attach snippets when attaching GUI to actions --- phy/gui/actions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index c02b89808..0c5e45189 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -137,7 +137,7 @@ def __init__(self): def get_action_dict(self): return self._actions.copy() - def attach(self, gui): + def attach(self, gui, enable_snippets=True): """Attach a GUI.""" self._gui = gui @@ -146,6 +146,11 @@ def attach(self, gui): def exit(): gui.close() + # Create and attach snippets. + if enable_snippets: + self.snippets = Snippets() + self.snippets.attach(gui, self) + def add(self, callback=None, name=None, shortcut=None, alias=None): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar @@ -346,7 +351,6 @@ def run(self, snippet): logger.info("Processing snippet `%s`.", snippet) try: - # func(*snippet_args[1:]) self._actions.run(name, *snippet_args[1:]) except Exception as e: logger.warn("Error when executing snippet: %s.", str(e)) From 46bfc5237a5a6783b0a56628f422b273600c0cbf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 17:47:43 +0200 Subject: [PATCH 0335/1059] Update tests --- phy/gui/tests/test_gui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 0256fb7bb..c03962e73 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -15,7 +15,7 @@ from ..gui import GUI from phy.utils._color import _random_color from phy.utils.plugin import IPlugin -from .test_actions import actions, snippets # noqa +from .test_actions import actions # noqa # Skip some tests on OS X or on CI systems (Travis). skip_mac = mark.skipif(platform == "darwin", @@ -79,7 +79,7 @@ def press(): @skip_mac # noqa @skip_ci -def test_snippets_gui(qtbot, gui, actions, snippets): +def test_snippets_gui(qtbot, gui, actions): qtbot.addWidget(gui) gui.show() @@ -93,7 +93,7 @@ def test(*args): # Attach the GUI and register the actions. actions.attach(gui) - snippets.attach(gui, actions) + snippets = actions.snippets # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` assert not snippets.is_mode_on() From 8a608072092757d137bf5c78586535cf81896a27 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 17:53:11 +0200 Subject: [PATCH 0336/1059] Fix tests --- phy/cluster/manual/gui_plugin.py | 61 +++++++++------------ phy/cluster/manual/tests/test_gui_plugin.py | 2 - 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index e5be67aae..f9a604854 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -16,8 +16,7 @@ from ._utils import create_cluster_meta from .clustering import Clustering from .wizard import Wizard -from phy.gui.actions import Actions, Snippets -from phy.gui.qt import QtGui +from phy.gui.actions import Actions from phy.io.array import select_spikes from phy.utils.plugin import IPlugin @@ -157,21 +156,20 @@ class ManualClustering(IPlugin): """ default_shortcuts = { - 'save': QtGui.QKeySequence.Save, + 'save': 'Save', # Wizard actions. - 'reset_wizard': 'ctrl+w', 'next': 'space', 'previous': 'shift+space', 'reset_wizard': 'ctrl+alt+space', - 'first': QtGui.QKeySequence.MoveToStartOfLine, - 'last': QtGui.QKeySequence.MoveToEndOfLine, + 'first': 'MoveToStartOfLine', + 'last': 'MoveToEndOfLine', 'pin': 'return', - 'unpin': QtGui.QKeySequence.Back, + 'unpin': 'Back', # Clustering actions. 'merge': 'g', 'split': 'k', - 'undo': QtGui.QKeySequence.Undo, - 'redo': QtGui.QKeySequence.Redo, + 'undo': 'Undo', + 'redo': 'Redo', } def __init__(self, spike_clusters=None, @@ -237,30 +235,26 @@ def _add_action(self, callback, name=None, alias=None): shortcut=shortcut, alias=alias) def _create_actions(self): - self.actions = actions = Actions() - self.snippets = Snippets() - - # Create the default actions for the clustering GUI. - @actions.connect - def on_reset(): - # Selection. - self._add_action(self.select, alias='c') - - # Wizard. - self._add_action(self.wizard.restart, name='reset_wizard') - self._add_action(self.wizard.previous) - self._add_action(self.wizard.next) - self._add_action(self.wizard.next_by_quality) - self._add_action(self.wizard.next_by_similarity) - self._add_action(self.wizard.pin) - self._add_action(self.wizard.unpin) - - # Clustering. - self._add_action(self.merge) - self._add_action(self.split) - self._add_action(self.move) - self._add_action(self.undo) - self._add_action(self.redo) + self.actions = Actions() + + # Selection. + self._add_action(self.select, alias='c') + + # Wizard. + self._add_action(self.wizard.restart, name='reset_wizard') + self._add_action(self.wizard.previous) + self._add_action(self.wizard.next) + self._add_action(self.wizard.next_by_quality) + self._add_action(self.wizard.next_by_similarity) + self._add_action(self.wizard.pin) + self._add_action(self.wizard.unpin) + + # Clustering. + self._add_action(self.merge) + self._add_action(self.split) + self._add_action(self.move) + self._add_action(self.undo) + self._add_action(self.redo) def attach_to_gui(self, gui): self.gui = gui @@ -279,7 +273,6 @@ def on_start(): gui.emit('wizard_start') # Attach the GUI and register the actions. - self.snippets.attach(gui, self.actions) self.actions.attach(gui) return self diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index f0022fda0..cbb1d16eb 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -50,8 +50,6 @@ def assert_selection(*cluster_ids): # pragma: no cover elif len(cluster_ids) >= 2: assert mc.wizard.match == cluster_ids[2] - mc.actions.reset() - yield mc, assert_selection From 5ae0d9323b7165b3bd657cac0c49db9212f69915 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 18:15:27 +0200 Subject: [PATCH 0337/1059] WIP --- phy/gui/actions.py | 8 +++++++- phy/gui/tests/test_gui.py | 5 ----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 0c5e45189..0959a803d 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -110,7 +110,11 @@ def _alias(name): def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. action = QtGui.QAction(name, gui) - action.triggered.connect(callback) + + def wrapped(checked, *args, **kwargs): + return callback(*args, **kwargs) + + action.triggered.connect(wrapped) sequence = _get_qkeysequence(shortcut) if not isinstance(sequence, (tuple, list)): sequence = [sequence] @@ -371,6 +375,8 @@ def mode_on(self): def mode_off(self): if self._gui: self._gui.status_message = '' + # Remove all existing actions. + self._actions.remove_all() logger.info("Snippet mode disabled.") # Reestablishes the shortcuts. for action_obj in self._actions_backup.values(): diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index c03962e73..717433a01 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -97,17 +97,12 @@ def test(*args): # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` assert not snippets.is_mode_on() - # print(gui.actions()[0].shortcut().toString()) - # actions.show_shortcuts() qtbot.keyClicks(gui, ':t2 ') - # qtbot.stop() - # return assert snippets.is_mode_on() qtbot.keyPress(gui, Qt.Key_Backspace) qtbot.keyPress(gui, Qt.Key_Backspace) qtbot.keyClicks(gui, '1 3-5 ab,c') - # qtbot.stop() qtbot.keyPress(gui, Qt.Key_Return) qtbot.waitForWindowShown(gui) From 0a5133e182f2dc461195f0fda8a49764c8922865 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 18:19:51 +0200 Subject: [PATCH 0338/1059] WIP --- phy/gui/actions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 0959a803d..7699e2a2e 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -190,7 +190,7 @@ def run(self, name, *args): """Run an action as specified by its name.""" assert isinstance(name, string_types) # Resolve the alias if it is an alias. - name = self._aliases[name] + name = self._aliases.get(name, name) # Get the action. action = self._actions.get(name, None) if not action: @@ -357,8 +357,7 @@ def run(self, snippet): try: self._actions.run(name, *snippet_args[1:]) except Exception as e: - logger.warn("Error when executing snippet: %s.", str(e)) - logger.exception(e) + logger.warn("Error when executing snippet: \"%s\".", str(e)) def is_mode_on(self): return self.command.startswith(':') From 2a4f42c11579af26913dd0e577f46ae7a5f0fff3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 21:49:33 +0200 Subject: [PATCH 0339/1059] WIP: increase coverage in actions --- phy/gui/actions.py | 2 +- phy/gui/tests/conftest.py | 31 ++++++++++++++ phy/gui/tests/test_actions.py | 78 ++++++++++++++++++++++++++-------- phy/gui/tests/test_gui.py | 80 ----------------------------------- 4 files changed, 93 insertions(+), 98 deletions(-) create mode 100644 phy/gui/tests/conftest.py diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 7699e2a2e..bf76ac33d 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -111,7 +111,7 @@ def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. action = QtGui.QAction(name, gui) - def wrapped(checked, *args, **kwargs): + def wrapped(checked, *args, **kwargs): # pragma: no cover return callback(*args, **kwargs) action.triggered.connect(wrapped) diff --git a/phy/gui/tests/conftest.py b/phy/gui/tests/conftest.py new file mode 100644 index 000000000..482ee6cab --- /dev/null +++ b/phy/gui/tests/conftest.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +"""Test gui.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture + +from ..actions import Actions, Snippets +from ..gui import GUI + + +#------------------------------------------------------------------------------ +# Utilities and fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def gui(): + yield GUI(position=(200, 100), size=(100, 100)) + + +@yield_fixture +def actions(): + yield Actions() + + +@yield_fixture +def snippets(): + yield Snippets() diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index d08fc6e93..c6e9dee4d 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -6,32 +6,17 @@ # Imports #------------------------------------------------------------------------------ -from pytest import raises, yield_fixture +from pytest import raises +from ..qt import Qt from ..actions import (_show_shortcuts, _get_shortcut_string, _get_qkeysequence, _parse_snippet, - Actions, - Snippets, ) from phy.utils.testing import captured_output, captured_logging -#------------------------------------------------------------------------------ -# Utilities and fixtures -#------------------------------------------------------------------------------ - -@yield_fixture -def actions(): - yield Actions() - - -@yield_fixture -def snippets(): - yield Snippets() - - #------------------------------------------------------------------------------ # Test actions #------------------------------------------------------------------------------ @@ -97,6 +82,65 @@ def show_my_shortcuts(): actions.remove_all() +#------------------------------------------------------------------------------ +# Test actions and snippet +#------------------------------------------------------------------------------ + +def test_actions_gui(qtbot, gui, actions): + actions.attach(gui) + + qtbot.addWidget(gui) + gui.show() + qtbot.waitForWindowShown(gui) + + _press = [] + + @actions.add(shortcut='g') + def press(): + _press.append(0) + + actions.press() + assert _press == [0] + + actions.exit() + + +def test_snippets_gui(qtbot, gui, actions): + + qtbot.addWidget(gui) + gui.show() + qtbot.waitForWindowShown(gui) + + _actions = [] + + @actions.add(name='my_test_1', alias='t1') + def test(*args): + _actions.append(args) + + # Attach the GUI and register the actions. + actions.attach(gui) + snippets = actions.snippets + + # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` + assert not snippets.is_mode_on() + + def _run(cmd): + """Simulate keystrokes.""" + for char in cmd: + i = snippets._snippet_chars.index(char) + actions.run('_snippet_{}'.format(i)) + + actions.enable_snippet_mode() + _run('t2 ') + assert snippets.is_mode_on() + actions._snippet_backspace() + actions._snippet_backspace() + _run('1 3-5 ab,c') + actions._snippet_activate() + + assert _actions == [([3, 4, 5], ['ab', 'c'])] + + #------------------------------------------------------------------------------ # Test snippets #------------------------------------------------------------------------------ diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 717433a01..66f0c1603 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,26 +6,10 @@ # Imports #------------------------------------------------------------------------------ -import os -from sys import platform - -from pytest import mark, yield_fixture - from ..qt import Qt from ..gui import GUI from phy.utils._color import _random_color from phy.utils.plugin import IPlugin -from .test_actions import actions # noqa - -# Skip some tests on OS X or on CI systems (Travis). -skip_mac = mark.skipif(platform == "darwin", - reason="Some tests don't work on OS X because of a bug " - "with QTest (qtbot) keyboard events that don't " - "trigger QAction shortcuts. On CI these tests " - "fail because the GUI is not displayed.") - -skip_ci = mark.skipif(os.environ.get('CI', None) is not None, - reason="Some shortcut-related Qt tests fail on CI.") #------------------------------------------------------------------------------ @@ -45,70 +29,6 @@ def on_draw(e): # pragma: no cover return c -@yield_fixture -def gui(): - yield GUI(position=(200, 100), size=(100, 100)) - - -#------------------------------------------------------------------------------ -# Test actions and snippet -#------------------------------------------------------------------------------ - -@skip_mac # noqa -@skip_ci -def test_actions_gui(qtbot, gui, actions): - actions.attach(gui) - - qtbot.addWidget(gui) - gui.show() - qtbot.waitForWindowShown(gui) - - _press = [] - - @actions.add(shortcut='ctrl+g') - def press(): - _press.append(0) - - qtbot.keyPress(gui, Qt.Key_G, Qt.ControlModifier) - qtbot.waitForWindowShown(gui) - assert _press == [0] - - # Quit the GUI. - qtbot.keyPress(gui, Qt.Key_Q, Qt.ControlModifier) - - -@skip_mac # noqa -@skip_ci -def test_snippets_gui(qtbot, gui, actions): - - qtbot.addWidget(gui) - gui.show() - qtbot.waitForWindowShown(gui) - - _actions = [] - - @actions.add(name='my_test_1', alias='t1') - def test(*args): - _actions.append(args) - - # Attach the GUI and register the actions. - actions.attach(gui) - snippets = actions.snippets - - # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` - assert not snippets.is_mode_on() - qtbot.keyClicks(gui, ':t2 ') - - assert snippets.is_mode_on() - qtbot.keyPress(gui, Qt.Key_Backspace) - qtbot.keyPress(gui, Qt.Key_Backspace) - qtbot.keyClicks(gui, '1 3-5 ab,c') - qtbot.keyPress(gui, Qt.Key_Return) - qtbot.waitForWindowShown(gui) - - assert _actions == [([3, 4, 5], ['ab', 'c'])] - - #------------------------------------------------------------------------------ # Test gui #------------------------------------------------------------------------------ From 374f00373225749904274451cc1fcacb3ab53b34 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 16 Oct 2015 22:06:39 +0200 Subject: [PATCH 0340/1059] WIP --- phy/cluster/manual/tests/test_gui_plugin.py | 4 ++-- phy/gui/tests/conftest.py | 6 ++++-- phy/gui/tests/test_actions.py | 1 - 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index cbb1d16eb..7f41d9b9e 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -17,7 +17,7 @@ _attach_wizard_to_clustering, _attach_wizard_to_cluster_meta, ) -from phy.gui.tests.test_gui import gui # noqa +from phy.gui.tests.conftest import gui # noqa #------------------------------------------------------------------------------ @@ -25,7 +25,7 @@ #------------------------------------------------------------------------------ @yield_fixture # noqa -def manual_clustering(qtbot, gui, cluster_ids, cluster_groups): +def manual_clustering(gui, cluster_ids, cluster_groups): spike_clusters = np.array(cluster_ids) mc = gui.attach('ManualClustering', diff --git a/phy/gui/tests/conftest.py b/phy/gui/tests/conftest.py index 482ee6cab..fd491a72b 100644 --- a/phy/gui/tests/conftest.py +++ b/phy/gui/tests/conftest.py @@ -17,8 +17,10 @@ #------------------------------------------------------------------------------ @yield_fixture -def gui(): - yield GUI(position=(200, 100), size=(100, 100)) +def gui(qapp): + gui = GUI(position=(200, 100), size=(100, 100)) + yield gui + gui.close() @yield_fixture diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index c6e9dee4d..a4da51350 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -8,7 +8,6 @@ from pytest import raises -from ..qt import Qt from ..actions import (_show_shortcuts, _get_shortcut_string, _get_qkeysequence, From 1859877c01891882a19639b216da41750be44ebe Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 14:28:23 +0200 Subject: [PATCH 0341/1059] Refactor plugin discovery --- phy/utils/plugin.py | 59 +++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 035c231a2..c369839e4 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -61,6 +61,29 @@ def get_plugin(name): # Plugins discovery #------------------------------------------------------------------------------ +def _iter_plugin_files(dirs): + for plugin_dir in dirs: + plugin_dir = op.realpath(op.expanduser(plugin_dir)) + if not op.exists(plugin_dir): + continue + # for filename in os.listdir(plugin_dir): + # path = op.join(plugin_dir, filename) + # if not op.isdir(path): + # yield path + for subdir, dirs, files in os.walk(plugin_dir): + # Skip test folders. + base = op.basename(subdir) + if 'test' in base or '__' in base: # pragma: no cover + continue + logger.debug("Scanning `%s`.", subdir) + for filename in files: + if (filename.startswith('__') or + not filename.endswith('.py')): + continue # pragma: no cover + logger.debug("Found plugin module `%s`.", filename) + yield op.join(subdir, filename) + + def discover_plugins(dirs): """Discover the plugin classes contained in Python files. @@ -78,30 +101,18 @@ def discover_plugins(dirs): """ # Scan all subdirectories recursively. - for plugin_dir in dirs: - plugin_dir = op.realpath(op.expanduser(plugin_dir)) - for subdir, dirs, files in os.walk(plugin_dir): - # Skip test folders. - base = op.basename(subdir) - if 'test' in base or '__' in base: # pragma: no cover - continue - logger.debug("Scanning `%s`.", subdir) - for filename in files: - if (filename.startswith('__') or - not filename.endswith('.py')): - continue # pragma: no cover - logger.debug("Found plugin module `%s`.", filename) - path = os.path.join(subdir, filename) - modname, ext = op.splitext(filename) - file, path, descr = imp.find_module(modname, [subdir]) - if file: - # Loading the module registers the plugin in - # IPluginRegistry. - try: - mod = imp.load_module(modname, file, # noqa - path, descr) - except Exception as e: # pragma: no cover - logger.exception(e) + for path in _iter_plugin_files(dirs): + filename = op.basename(path) + subdir = op.dirname(path) + modname, ext = op.splitext(filename) + file, path, descr = imp.find_module(modname, [subdir]) + if file: + # Loading the module registers the plugin in + # IPluginRegistry. + try: + mod = imp.load_module(modname, file, path, descr) # noqa + except Exception as e: # pragma: no cover + logger.exception(e) return IPluginRegistry.plugins From fad14d4f9bcc5f5c2fa3232ca456b6ae842d63ba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 16:52:46 +0200 Subject: [PATCH 0342/1059] Flakify --- phy/utils/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index c369839e4..f2a792043 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -110,7 +110,7 @@ def discover_plugins(dirs): # Loading the module registers the plugin in # IPluginRegistry. try: - mod = imp.load_module(modname, file, path, descr) # noqa + mod = imp.load_module(modname, file, path, descr) # noqa except Exception as e: # pragma: no cover logger.exception(e) return IPluginRegistry.plugins From a28465a01a74aa708cdebe35a6fa6073be5ddba4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 17:03:41 +0200 Subject: [PATCH 0343/1059] Use GUI components instead of plugins --- phy/cluster/manual/gui_plugin.py | 13 +++++-------- phy/cluster/manual/tests/test_gui_plugin.py | 17 +++++++++-------- phy/gui/gui.py | 11 ----------- phy/gui/tests/test_gui.py | 11 ++++++----- phy/utils/plugin.py | 1 - 5 files changed, 20 insertions(+), 33 deletions(-) diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_plugin.py index f9a604854..8cecce272 100644 --- a/phy/cluster/manual/gui_plugin.py +++ b/phy/cluster/manual/gui_plugin.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Manual clustering GUI plugin.""" +"""Manual clustering GUI component.""" # ----------------------------------------------------------------------------- @@ -18,7 +18,6 @@ from .wizard import Wizard from phy.gui.actions import Actions from phy.io.array import select_spikes -from phy.utils.plugin import IPlugin logger = logging.getLogger(__name__) @@ -122,11 +121,11 @@ def _attach_wizard(wizard, clustering, cluster_meta): # ----------------------------------------------------------------------------- -# Clustering GUI plugin +# Clustering GUI component # ----------------------------------------------------------------------------- -class ManualClustering(IPlugin): - """Plugin that brings manual clustering facilities to a GUI: +class ManualClustering(object): + """Component that brings manual clustering facilities to a GUI: * Clustering instance: merge, split, undo, redo * ClusterMeta instance: change cluster metadata (e.g. group) @@ -137,8 +136,6 @@ class ManualClustering(IPlugin): Bring the `select` event to the GUI. This is raised when clusters are selected by the user or by the wizard. - Other plugins can connect to that event. - Parameters ---------- @@ -256,7 +253,7 @@ def _create_actions(self): self._add_action(self.undo) self._add_action(self.redo) - def attach_to_gui(self, gui): + def attach(self, gui): self.gui = gui @self.wizard.connect diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_plugin.py index 7f41d9b9e..8e5b45656 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_plugin.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Test GUI plugin.""" +"""Test GUI component.""" #------------------------------------------------------------------------------ # Imports @@ -16,6 +16,7 @@ _attach_wizard, _attach_wizard_to_clustering, _attach_wizard_to_cluster_meta, + ManualClustering, ) from phy.gui.tests.conftest import gui # noqa @@ -28,14 +29,14 @@ def manual_clustering(gui, cluster_ids, cluster_groups): spike_clusters = np.array(cluster_ids) - mc = gui.attach('ManualClustering', - spike_clusters=spike_clusters, - cluster_groups=cluster_groups, - shortcuts={'undo': 'ctrl+z'}, - ) - + mc = ManualClustering(spike_clusters=spike_clusters, + cluster_groups=cluster_groups, + shortcuts={'undo': 'ctrl+z'}, + ) _s = [] + mc.attach(gui) + # Connect to the `select` event. @mc.gui.connect_ def on_select(cluster_ids, spike_ids): @@ -209,7 +210,7 @@ def test_attach_wizard_3(wizard, cluster_ids, cluster_groups): #------------------------------------------------------------------------------ -# Test GUI plugins +# Test GUI components #------------------------------------------------------------------------------ def test_wizard_start_1(manual_clustering): diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 77dda7469..ecb4abc40 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -10,11 +10,8 @@ from collections import defaultdict import logging -from six import string_types - from .qt import QtCore, QtGui from phy.utils.event import EventEmitter -from phy.utils.plugin import get_plugin logger = logging.getLogger(__name__) @@ -87,14 +84,6 @@ def __init__(self, self._status_bar = QtGui.QStatusBar() self.setStatusBar(self._status_bar) - def attach(self, plugin, *args, **kwargs): - """Attach a plugin to the GUI.""" - if isinstance(plugin, string_types): - # Instantiate the plugin if the name is given. - plugin = get_plugin(plugin)(*args, **kwargs) - if hasattr(plugin, 'attach_to_gui'): - return plugin.attach_to_gui(self) - # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 66f0c1603..e8458b611 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -9,7 +9,6 @@ from ..qt import Qt from ..gui import GUI from phy.utils._color import _random_color -from phy.utils.plugin import IPlugin #------------------------------------------------------------------------------ @@ -67,17 +66,19 @@ def on_close_widget(): gui.close() -def test_gui_plugin(qtbot, gui): +def test_gui_component(qtbot, gui): - class TestPlugin(IPlugin): + class TestComponent(object): def __init__(self, arg): self._arg = arg - def attach_to_gui(self, gui): + def attach(self, gui): gui._attached = self._arg return 'attached' - assert gui.attach('testplugin', 3) == 'attached' + tc = TestComponent(3) + + assert tc.attach(gui) == 'attached' assert gui._attached == 3 diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index f2a792043..e56efcb44 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -41,7 +41,6 @@ def __init__(cls, name, bases, attrs): class IPlugin(with_metaclass(IPluginRegistry)): """A class deriving from IPlugin can implement the following methods: - * `attach_to_gui(gui)`: called when the plugin is attached to a GUI. * `attach_to_cli(cli)`: called when the CLI is created. """ From 71f668cfbde1fd6ad02eee4a0a60af4910415ee9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 17:05:51 +0200 Subject: [PATCH 0344/1059] Rename gui_plugin to gui_component --- phy/cluster/manual/__init__.py | 2 +- .../manual/{gui_plugin.py => gui_component.py} | 0 phy/cluster/manual/tests/conftest.py | 2 +- .../{test_gui_plugin.py => test_gui_component.py} | 12 ++++++------ 4 files changed, 8 insertions(+), 8 deletions(-) rename phy/cluster/manual/{gui_plugin.py => gui_component.py} (100%) rename phy/cluster/manual/tests/{test_gui_plugin.py => test_gui_component.py} (96%) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 8cc68c734..46249250d 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,4 +5,4 @@ from .clustering import Clustering from .wizard import Wizard -from .gui_plugin import ManualClustering +from .gui_component import ManualClustering diff --git a/phy/cluster/manual/gui_plugin.py b/phy/cluster/manual/gui_component.py similarity index 100% rename from phy/cluster/manual/gui_plugin.py rename to phy/cluster/manual/gui_component.py diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index c4e3f0d5f..e08d344fa 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -9,7 +9,7 @@ from pytest import yield_fixture from ..wizard import Wizard -from ..gui_plugin import _wizard_group +from ..gui_component import _wizard_group #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/tests/test_gui_plugin.py b/phy/cluster/manual/tests/test_gui_component.py similarity index 96% rename from phy/cluster/manual/tests/test_gui_plugin.py rename to phy/cluster/manual/tests/test_gui_component.py index 8e5b45656..dd3b80b62 100644 --- a/phy/cluster/manual/tests/test_gui_plugin.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -12,12 +12,12 @@ from ..clustering import Clustering from .._utils import create_cluster_meta -from ..gui_plugin import (_wizard_group, - _attach_wizard, - _attach_wizard_to_clustering, - _attach_wizard_to_cluster_meta, - ManualClustering, - ) +from ..gui_component import (_wizard_group, + _attach_wizard, + _attach_wizard_to_clustering, + _attach_wizard_to_cluster_meta, + ManualClustering, + ) from phy.gui.tests.conftest import gui # noqa From 4293413ae081778e3d93ea48dc92f7988f142ed1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 17:09:30 +0200 Subject: [PATCH 0345/1059] WIP: increase coverage in gui_component --- phy/cluster/manual/gui_component.py | 9 ++++++--- phy/cluster/manual/tests/test_gui_component.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 8cecce272..2ba9ec571 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -277,9 +277,12 @@ def on_start(): # Wizard-related actions # ------------------------------------------------------------------------- - def select(self, cluster_ids): - if isinstance(cluster_ids, integer_types): - cluster_ids = [cluster_ids] + def select(self, *cluster_ids): + # HACK: allow for select(1, 2, 3) in addition to select([1, 2, 3]) + # This makes it more convenient to select multiple clusters with + # the snippet: ":c 1 2 3". + if cluster_ids and isinstance(cluster_ids[0], (tuple, list)): + cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:]) self.wizard.select(cluster_ids) # Clustering actions diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index dd3b80b62..b9392478c 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -237,7 +237,7 @@ def test_wizard_start_2(manual_clustering): def on_wizard_start(): _check.append('wizard') - mc.wizard.select([1]) + mc.select([1]) assert _check == ['wizard'] @@ -277,7 +277,7 @@ def test_manual_clustering_edge_cases(manual_clustering): def test_manual_clustering_merge(manual_clustering): mc, assert_selection = manual_clustering - mc.actions.select([30, 20]) + mc.actions.select(30, 20) # NOTE: we pass multiple ints instead of a list mc.actions.merge() assert_selection(31, 2) From d5db4ded945b267e7eb08b4ad22a1a4e14795309 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 17:13:48 +0200 Subject: [PATCH 0346/1059] WIP: increase coverage in gui_component --- phy/cluster/manual/gui_component.py | 1 - .../manual/tests/test_gui_component.py | 24 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 2ba9ec571..76c3184b4 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -10,7 +10,6 @@ import logging import numpy as np -from six import integer_types from ._history import GlobalHistory from ._utils import create_cluster_meta diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index b9392478c..b63722868 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -281,13 +281,29 @@ def test_manual_clustering_merge(manual_clustering): mc.actions.merge() assert_selection(31, 2) + mc.actions.undo() + assert_selection(30, 20) + + mc.actions.redo() + assert_selection(31, 2) + def test_manual_clustering_split(manual_clustering): mc, assert_selection = manual_clustering + mc.actions.merge([1, 2, 10]) + assert_selection(31, 20) + mc.actions.select([1, 2]) mc.actions.split([1, 2]) - assert_selection(31, 20) + assert_selection(32, 20) + + mc.actions.undo() + # TODO + # assert_selection(1, 2) + + mc.actions.redo() + # assert_selection(32, 20) def test_manual_clustering_move(manual_clustering, quality, similarity): @@ -304,3 +320,9 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.actions.move([20], 'noise') assert_selection(2) + + mc.actions.undo() + assert_selection(20) + + mc.actions.redo() + assert_selection(2) From 6bb669ac287a103987d74d454ab117abf1255832 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 18:01:58 +0200 Subject: [PATCH 0347/1059] WIP: increase coverage --- phy/cluster/manual/gui_component.py | 3 ++- phy/cluster/manual/tests/test_gui_component.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 76c3184b4..8ed20491f 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -318,4 +318,5 @@ def save(self): spike_clusters = self.clustering.spike_clusters groups = {c: self.cluster_meta.get('group', c) for c in self.clustering.cluster_ids} - self.gui.emit('save_requested', spike_clusters, groups) + if self.gui: + self.gui.emit('save_requested', spike_clusters, groups) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index b63722868..cc83d32b9 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -291,19 +291,23 @@ def test_manual_clustering_merge(manual_clustering): def test_manual_clustering_split(manual_clustering): mc, assert_selection = manual_clustering - mc.actions.merge([1, 2, 10]) - assert_selection(31, 20) - mc.actions.select([1, 2]) mc.actions.split([1, 2]) - assert_selection(32, 20) + assert_selection(31, 20) mc.actions.undo() - # TODO - # assert_selection(1, 2) + assert_selection(1, 2) mc.actions.redo() - # assert_selection(32, 20) + assert_selection(31, 20) + + +def test_manual_clustering_split_2(qapp): + spike_clusters = np.array([0, 0, 1]) + + mc = ManualClustering(spike_clusters=spike_clusters) + mc.actions.split([0, 1]) + assert mc.wizard.selection == [2, 1] def test_manual_clustering_move(manual_clustering, quality, similarity): From 2d9478047f30623d7eca3657269878a91c82bd76 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 18:49:59 +0200 Subject: [PATCH 0348/1059] WIP: refactor qt --- phy/gui/qt.py | 151 ++++++++++----------------------------- phy/gui/tests/test_qt.py | 45 ++++++++++-- 2 files changed, 77 insertions(+), 119 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 6b342f46e..f711a51e9 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -6,12 +6,11 @@ # Imports # ----------------------------------------------------------------------------- +from contextlib import contextmanager +from functools import wraps import logging -import os import sys -from ..utils._misc import _is_interactive - logger = logging.getLogger(__name__) @@ -19,24 +18,12 @@ # PyQt import # ----------------------------------------------------------------------------- -_PYQT = False -try: - from PyQt4 import QtCore, QtGui, QtWebKit # noqa - Qt = QtCore.Qt - _PYQT = True -except ImportError: # pragma: no cover - try: - from PyQt5 import QtCore, QtGui, QtWebKit # noqa - _PYQT = True - except ImportError: - pass - - -def _check_qt(): # pragma: no cover - if not _PYQT: - logger.warn("PyQt is not available.") - return False - return True +from PyQt4.QtCore import Qt, QByteArray, QMetaObject, QSize # noqa +from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa + QMainWindow, QDockWidget, QWidget, + QMessageBox, QApplication, + ) +from PyQt4.QtWebKit import QWebView # noqa # ----------------------------------------------------------------------------- @@ -44,13 +31,13 @@ def _check_qt(): # pragma: no cover # ----------------------------------------------------------------------------- def _button_enum_from_name(name): - return getattr(QtGui.QMessageBox, name.capitalize()) + return getattr(QMessageBox, name.capitalize()) def _button_name_from_enum(enum): - names = dir(QtGui.QMessageBox) + names = dir(QMessageBox) for name in names: - if getattr(QtGui.QMessageBox, name) == enum: + if getattr(QMessageBox, name) == enum: return name.lower() @@ -59,7 +46,7 @@ def _prompt(message, buttons=('yes', 'no'), title='Question'): arg_buttons = 0 for (_, button) in buttons: arg_buttons |= button - box = QtGui.QMessageBox() + box = QMessageBox() box.setWindowTitle(title) box.setText(message) box.setStandardButtons(arg_buttons) @@ -72,109 +59,49 @@ def _show_box(box): # pragma: no cover # ----------------------------------------------------------------------------- -# Event loop integration with IPython +# Qt app # ----------------------------------------------------------------------------- -_APP = None -_APP_RUNNING = False - - -def _try_enable_ipython_qt(): # pragma: no cover - """Try to enable IPython Qt event loop integration. - - Returns True in the following cases: - - * python -i test.py - * ipython -i test.py - * ipython and %run test.py +def require_qt(func): + """Specify that a function requires a Qt application. - Returns False in the following cases: - - * python test.py - * ipython test.py + Use this decorator to specify that a function needs a running + Qt application before it can run. An error is raised if that is not + the case. """ - try: - from IPython import get_ipython - ip = get_ipython() - except ImportError: - return False - if not _is_interactive(): - return False - if ip: - ip.enable_gui('qt') - global _APP_RUNNING - _APP_RUNNING = True - return True - return False - - -def enable_qt(): # pragma: no cover - if not _check_qt(): - return - try: - from IPython import get_ipython - ip = get_ipython() - ip.enable_gui('qt') - global _APP_RUNNING - _APP_RUNNING = True - logger.info("Qt event loop activated.") - except: - logger.warn("Qt event loop not activated.") + @wraps(func) + def wrapped(*args, **kwargs): + if not QApplication.instance(): + raise RuntimeError("A Qt application must be created.") + return func(*args, **kwargs) + return wrapped -# ----------------------------------------------------------------------------- -# Qt app -# ----------------------------------------------------------------------------- +# Global variable with the current Qt application. +QT_APP = None -def start_qt_app(): # pragma: no cover - """Start a Qt application if necessary. - If a new Qt application is created, this function returns it. - If no new application is created, the function returns None. +def create_app(): + """Create a Qt application.""" + global QT_APP + QT_APP = QApplication.instance() + if QT_APP is None: # pragma: no cover + QT_APP = QApplication(sys.argv) + return QT_APP - """ - # Only start a Qt application if there is no - # IPython event loop integration. - if not _check_qt(): - return - global _APP - if _try_enable_ipython_qt(): - return - try: - from vispy import app - app.use_app("pyqt4") - except ImportError: - pass - if QtGui.QApplication.instance(): - _APP = QtGui.QApplication.instance() - return - if _APP: - return - _APP = QtGui.QApplication(sys.argv) - return _APP - - -def run_qt_app(): # pragma: no cover - """Start the Qt application's event loop.""" - global _APP_RUNNING - if not _check_qt(): - return - if _APP is not None and not _APP_RUNNING: - _APP_RUNNING = True - _APP.exec_() - if not _is_interactive(): - _APP_RUNNING = False + +@require_qt +def run_app(): # pragma: no cover + """Run the Qt application.""" + global QT_APP + return QT_APP.exit(QT_APP.exec_()) # ----------------------------------------------------------------------------- # Testing utilities # ----------------------------------------------------------------------------- -_MAX_ITER = 100 -_DELAY = max(0, float(os.environ.get('PHY_EVENT_LOOP_DELAY', .1))) - - def _debug_trace(): # pragma: no cover """Set a tracepoint in the Python debugger that works with Qt.""" from PyQt4.QtCore import pyqtRemoveInputHook diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index ae8deac3e..251122c2d 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -6,10 +6,14 @@ # Imports #------------------------------------------------------------------------------ -from ..qt import (QtCore, QtGui, QtWebKit, +from pytest import raises + +from ..qt import (QMessageBox, Qt, QWebView, _button_name_from_enum, _button_enum_from_name, _prompt, + require_qt, + create_app, ) @@ -17,9 +21,36 @@ # Tests #------------------------------------------------------------------------------ -def test_wrap(qtbot): +def test_require_qt_with_app(): + + @require_qt + def f(): + pass + + with raises(RuntimeError): + f() + + +def test_require_qt_without_app(qapp): + + @require_qt + def f(): + pass + + # This should not raise an error. + f() + + +def test_qt_app(qtbot): + create_app() + view = QWebView() + qtbot.addWidget(view) + view.close() + + +def test_web_view(qtbot): - view = QtWebKit.QWebView() + view = QWebView() def _assert(text): html = view.page().mainFrame().toHtml() @@ -36,7 +67,7 @@ def _assert(text): _assert('world') view.close() - view = QtWebKit.QWebView() + view = QWebView() view.resize(100, 100) view.show() qtbot.addWidget(view) @@ -48,11 +79,11 @@ def _assert(text): def test_prompt(qtbot): - assert _button_name_from_enum(QtGui.QMessageBox.Save) == 'save' - assert _button_enum_from_name('save') == QtGui.QMessageBox.Save + assert _button_name_from_enum(QMessageBox.Save) == 'save' + assert _button_enum_from_name('save') == QMessageBox.Save box = _prompt("How are you doing?", buttons=['save', 'cancel', 'close'], ) - qtbot.mouseClick(box.buttons()[0], QtCore.Qt.LeftButton) + qtbot.mouseClick(box.buttons()[0], Qt.LeftButton) assert 'save' in str(box.clickedButton().text()).lower() From 7c4b9acd6b7be6b4846a4213cb081fca1a2b51f3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 18:54:48 +0200 Subject: [PATCH 0349/1059] WIP: refactor qt --- phy/gui/__init__.py | 2 +- phy/gui/gui.py | 47 +++++++++++++++++++++------------------ phy/gui/tests/test_gui.py | 14 +++++++----- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index 5f4908f9e..4008a593a 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -3,5 +3,5 @@ """GUI routines.""" -from .qt import start_qt_app, run_qt_app, enable_qt +from .qt import require_qt, create_app, run_app from .gui import GUI diff --git a/phy/gui/gui.py b/phy/gui/gui.py index ecb4abc40..9155bb894 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -10,7 +10,8 @@ from collections import defaultdict import logging -from .qt import QtCore, QtGui +from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, + Qt, QSize, QMetaObject) from phy.utils.event import EventEmitter logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def _title(widget): return str(widget.windowTitle()).lower() -class DockWidget(QtGui.QDockWidget): +class DockWidget(QDockWidget): """A QDockWidget that can emit events.""" def __init__(self, *args, **kwargs): super(DockWidget, self).__init__(*args, **kwargs) @@ -42,7 +43,7 @@ def closeEvent(self, e): super(DockWidget, self).closeEvent(e) -class GUI(QtGui.QMainWindow): +class GUI(QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. `GUI` derives from `QMainWindow`. @@ -64,6 +65,8 @@ def __init__(self, size=None, title=None, ): + if not QApplication.instance(): + raise RuntimeError("A Qt application must be created.") super(GUI, self).__init__() if title is None: title = 'phy' @@ -71,17 +74,17 @@ def __init__(self, if position is not None: self.move(position[0], position[1]) if size is not None: - self.resize(QtCore.QSize(size[0], size[1])) + self.resize(QSize(size[0], size[1])) self.setObjectName(title) - QtCore.QMetaObject.connectSlotsByName(self) - self.setDockOptions(QtGui.QMainWindow.AllowTabbedDocks | - QtGui.QMainWindow.AllowNestedDocks | - QtGui.QMainWindow.AnimatedDocks + QMetaObject.connectSlotsByName(self) + self.setDockOptions(QMainWindow.AllowTabbedDocks | + QMainWindow.AllowNestedDocks | + QMainWindow.AnimatedDocks ) # We can derive from EventEmitter because of a conflict with connect. self._event = EventEmitter() - self._status_bar = QtGui.QStatusBar() + self._status_bar = QStatusBar() self.setStatusBar(self._status_bar) # Events @@ -139,24 +142,24 @@ def add_view(self, dockwidget.setWidget(view) # Set gui widget options. - options = QtGui.QDockWidget.DockWidgetMovable + options = QDockWidget.DockWidgetMovable if closable: - options = options | QtGui.QDockWidget.DockWidgetClosable + options = options | QDockWidget.DockWidgetClosable if floatable: - options = options | QtGui.QDockWidget.DockWidgetFloatable + options = options | QDockWidget.DockWidgetFloatable dockwidget.setFeatures(options) - dockwidget.setAllowedAreas(QtCore.Qt.LeftDockWidgetArea | - QtCore.Qt.RightDockWidgetArea | - QtCore.Qt.TopDockWidgetArea | - QtCore.Qt.BottomDockWidgetArea + dockwidget.setAllowedAreas(Qt.LeftDockWidgetArea | + Qt.RightDockWidgetArea | + Qt.TopDockWidgetArea | + Qt.BottomDockWidgetArea ) q_position = { - 'left': QtCore.Qt.LeftDockWidgetArea, - 'right': QtCore.Qt.RightDockWidgetArea, - 'top': QtCore.Qt.TopDockWidgetArea, - 'bottom': QtCore.Qt.BottomDockWidgetArea, + 'left': Qt.LeftDockWidgetArea, + 'right': Qt.RightDockWidgetArea, + 'top': Qt.TopDockWidgetArea, + 'bottom': Qt.BottomDockWidgetArea, }[position or 'right'] self.addDockWidget(q_position, dockwidget) if floating is not None: @@ -167,9 +170,9 @@ def add_view(self, def list_views(self, title='', is_visible=True): """List all views which title start with a given string.""" title = title.lower() - children = self.findChildren(QtGui.QWidget) + children = self.findChildren(QWidget) return [child for child in children - if isinstance(child, QtGui.QDockWidget) and + if isinstance(child, QDockWidget) and _title(child).startswith(title) and (child.isVisible() if is_visible else True) and child.width() >= 10 and diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index e8458b611..6439f705e 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +from pytest import raises + from ..qt import Qt from ..gui import GUI from phy.utils._color import _random_color @@ -32,6 +34,11 @@ def on_draw(e): # pragma: no cover # Test gui #------------------------------------------------------------------------------ +def test_gui_noapp(): + with raises(RuntimeError): + GUI() + + def test_gui_1(qtbot): gui = GUI(position=(200, 100), size=(100, 100)) @@ -49,7 +56,6 @@ def on_show(): gui.add_view(_create_canvas(), 'view2') view.setFloating(False) gui.show() - # qtbot.waitForWindowShown(gui) assert len(gui.list_views('view')) == 2 @@ -66,7 +72,7 @@ def on_close_widget(): gui.close() -def test_gui_component(qtbot, gui): +def test_gui_component(gui): class TestComponent(object): def __init__(self, arg): @@ -82,9 +88,7 @@ def attach(self, gui): assert gui._attached == 3 -def test_gui_status_message(qtbot): - gui = GUI() - qtbot.addWidget(gui) +def test_gui_status_message(gui): assert gui.status_message == '' gui.status_message = ':hello world!' assert gui.status_message == ':hello world!' From d6cacaf8f30a46e216dc9be1da013514fa614fd3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 18:55:08 +0200 Subject: [PATCH 0350/1059] WIP: refactor qt --- phy/utils/_misc.py | 26 ++------------------------ phy/utils/tests/test_misc.py | 4 ++-- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 3423c639b..ffd427e32 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -11,13 +11,11 @@ import json import os.path as op import os -import sys import subprocess from textwrap import dedent import numpy as np from six import string_types, exec_ -from six.moves import builtins from ._types import _is_integer @@ -33,9 +31,9 @@ def _encode_qbytearray(arr): def _decode_qbytearray(data_b64): - from phy.gui.qt import QtCore + from phy.gui.qt import QByteArray encoded = base64.b64decode(data_b64) - out = QtCore.QByteArray.fromBase64(encoded) + out = QByteArray.fromBase64(encoded) return out @@ -129,26 +127,6 @@ def _write_text(path, contents): f.write(contents) -def _is_interactive(): # pragma: no cover - """Determine whether the user has requested interactive mode.""" - # The Python interpreter sets sys.flags correctly, so use them! - if sys.flags.interactive: - return True - - # IPython does not set sys.flags when -i is specified, so first - # check it if it is already imported. - if '__IPYTHON__' not in dir(builtins): - return False - - # Then we check the application singleton and determine based on - # a variable it sets. - try: - from IPython.config.application import Application as App - return App.initialized() and App.instance().interact - except (ImportError, AttributeError): - return False - - def _git_version(): curdir = os.getcwd() filedir, _ = op.split(__file__) diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index 25a56bdcd..8a8c0a2f3 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -27,8 +27,8 @@ def test_qbytearray(tempdir): - from phy.gui.qt import QtCore - arr = QtCore.QByteArray() + from phy.gui.qt import QByteArray + arr = QByteArray() arr.append('1') arr.append('2') arr.append('3') From 54b18c52c084a6cbdca0c7db137f4683a4df0dd0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 18:59:22 +0200 Subject: [PATCH 0351/1059] Tests pass --- phy/gui/actions.py | 12 ++++++------ phy/gui/qt.py | 1 - phy/gui/tests/test_gui.py | 7 ++++--- phy/gui/tests/test_qt.py | 6 ++++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index bf76ac33d..defb4ec75 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -12,7 +12,7 @@ from six import string_types, PY3 -from .qt import QtGui +from .qt import QKeySequence, QAction from phy.utils import Bunch logger = logging.getLogger(__name__) @@ -65,7 +65,7 @@ def _get_shortcut_string(shortcut): return ', '.join([_get_shortcut_string(s) for s in shortcut]) if isinstance(shortcut, string_types): return shortcut.lower() - assert isinstance(shortcut, QtGui.QKeySequence) + assert isinstance(shortcut, QKeySequence) s = shortcut.toString() or '' return str(s).lower() @@ -77,9 +77,9 @@ def _get_qkeysequence(shortcut): if isinstance(shortcut, (tuple, list)): return [_get_qkeysequence(s) for s in shortcut] assert isinstance(shortcut, string_types) - if hasattr(QtGui.QKeySequence, shortcut): - return QtGui.QKeySequence(getattr(QtGui.QKeySequence, shortcut)) - sequence = QtGui.QKeySequence.fromString(shortcut) + if hasattr(QKeySequence, shortcut): + return QKeySequence(getattr(QKeySequence, shortcut)) + sequence = QKeySequence.fromString(shortcut) assert not sequence.isEmpty() return sequence @@ -109,7 +109,7 @@ def _alias(name): def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. - action = QtGui.QAction(name, gui) + action = QAction(name, gui) def wrapped(checked, *args, **kwargs): # pragma: no cover return callback(*args, **kwargs) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index f711a51e9..df5687ac1 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -6,7 +6,6 @@ # Imports # ----------------------------------------------------------------------------- -from contextlib import contextmanager from functools import wraps import logging import sys diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 6439f705e..8fa71e75b 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -8,7 +8,7 @@ from pytest import raises -from ..qt import Qt +from ..qt import Qt, QApplication from ..gui import GUI from phy.utils._color import _random_color @@ -35,8 +35,9 @@ def on_draw(e): # pragma: no cover #------------------------------------------------------------------------------ def test_gui_noapp(): - with raises(RuntimeError): - GUI() + if not QApplication.instance(): + with raises(RuntimeError): + GUI() def test_gui_1(qtbot): diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index 251122c2d..c1e366466 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -14,6 +14,7 @@ _prompt, require_qt, create_app, + QApplication, ) @@ -27,8 +28,9 @@ def test_require_qt_with_app(): def f(): pass - with raises(RuntimeError): - f() + if not QApplication.instance(): + with raises(RuntimeError): + f() def test_require_qt_without_app(qapp): From 5646770c5d07e1ffa6609863c03af6f2e7805176 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 19:01:27 +0200 Subject: [PATCH 0352/1059] Increase coverage --- phy/gui/gui.py | 2 +- phy/gui/qt.py | 2 +- phy/gui/tests/test_gui.py | 2 +- phy/gui/tests/test_qt.py | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 9155bb894..f25b9edfb 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -65,7 +65,7 @@ def __init__(self, size=None, title=None, ): - if not QApplication.instance(): + if not QApplication.instance(): # pragma: no cover raise RuntimeError("A Qt application must be created.") super(GUI, self).__init__() if title is None: diff --git a/phy/gui/qt.py b/phy/gui/qt.py index df5687ac1..173a0f08f 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -71,7 +71,7 @@ def require_qt(func): """ @wraps(func) def wrapped(*args, **kwargs): - if not QApplication.instance(): + if not QApplication.instance(): # pragma: no cover raise RuntimeError("A Qt application must be created.") return func(*args, **kwargs) return wrapped diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 8fa71e75b..85e6ff06e 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -36,7 +36,7 @@ def on_draw(e): # pragma: no cover def test_gui_noapp(): if not QApplication.instance(): - with raises(RuntimeError): + with raises(RuntimeError): # pragma: no cover GUI() diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index c1e366466..673bad59b 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -26,10 +26,10 @@ def test_require_qt_with_app(): @require_qt def f(): - pass + pass # pragma: no cover if not QApplication.instance(): - with raises(RuntimeError): + with raises(RuntimeError): # pragma: no cover f() From fc75a33cad7879fe772931974cc1ccd2ddaa2247 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 19:14:58 +0200 Subject: [PATCH 0353/1059] WIP: no more attach in actions and snippets --- phy/gui/actions.py | 116 ++++++++++++++++------------------ phy/gui/tests/conftest.py | 8 +-- phy/gui/tests/test_actions.py | 13 ---- 3 files changed, 57 insertions(+), 80 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index defb4ec75..cedfa7230 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -12,7 +12,8 @@ from six import string_types, PY3 -from .qt import QKeySequence, QAction +from .qt import QKeySequence, QAction, require_qt +from .gui import GUI from phy.utils import Bunch logger = logging.getLogger(__name__) @@ -107,6 +108,7 @@ def _alias(name): return alias +@require_qt def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. action = QAction(name, gui) @@ -133,17 +135,11 @@ class Actions(object): * Display all shortcuts """ - def __init__(self): - self._gui = None - self._actions = {} + def __init__(self, gui): + self._actions_dict = {} self._aliases = {} - - def get_action_dict(self): - return self._actions.copy() - - def attach(self, gui, enable_snippets=True): - """Attach a GUI.""" - self._gui = gui + assert isinstance(gui, GUI) + self.gui = gui # Default exit action. @self.add(shortcut='Quit') @@ -151,9 +147,10 @@ def exit(): gui.close() # Create and attach snippets. - if enable_snippets: - self.snippets = Snippets() - self.snippets.attach(gui, self) + self.snippets = Snippets(gui, self) + + def backup(self): + return list(self._actions_dict.values()) def add(self, callback=None, name=None, shortcut=None, alias=None): """Add an action with a keyboard shortcut.""" @@ -169,16 +166,15 @@ def add(self, callback=None, name=None, shortcut=None, alias=None): name = name.replace('&', '') # Skip existing action. - if name in self._actions: + if name in self._actions_dict: return # Create and register the action. - action = _create_qaction(self._gui, name, callback, shortcut) + action = _create_qaction(self.gui, name, callback, shortcut) action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback) - if self._gui: - self._gui.addAction(action) - self._actions[name] = action_obj + self.gui.addAction(action) + self._actions_dict[name] = action_obj # Register the alias -> name mapping. self._aliases[alias] = name @@ -192,7 +188,7 @@ def run(self, name, *args): # Resolve the alias if it is an alias. name = self._aliases.get(name, name) # Get the action. - action = self._actions.get(name, None) + action = self._actions_dict.get(name, None) if not action: raise ValueError("Action `{}` doesn't exist.".format(name)) if not name.startswith('_'): @@ -201,14 +197,13 @@ def run(self, name, *args): def remove(self, name): """Remove an action.""" - if self._gui: - self._gui.removeAction(self._actions[name].qaction) - del self._actions[name] + self.gui.removeAction(self._actions_dict[name].qaction) + del self._actions_dict[name] delattr(self, name) def remove_all(self): """Remove all actions.""" - names = sorted(self._actions.keys()) + names = sorted(self._actions_dict.keys()) for name in names: self.remove(name) @@ -216,12 +211,11 @@ def remove_all(self): def shortcuts(self): """A dictionary of action shortcuts.""" return {name: action.shortcut - for name, action in self._actions.items()} + for name, action in self._actions_dict.items()} def show_shortcuts(self): """Print all shortcuts.""" - _show_shortcuts(self.shortcuts, - self._gui.windowTitle() if self._gui else None) + _show_shortcuts(self.shortcuts, self.gui.windowTitle()) # ----------------------------------------------------------------------------- @@ -261,16 +255,16 @@ class Snippets(object): _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" " ,.;?!_-+~=*/\(){}[]") - def __init__(self): - self._gui = None - self._cmd = '' # only used when there is no gui attached + def __init__(self, gui, actions): + assert isinstance(gui, GUI) + self.gui = gui + + assert isinstance(actions, Actions) + self.actions = actions - def attach(self, gui, actions): - self._gui = gui - self._actions = actions # We will keep a backup of all actions so that we can switch # safely to the set of shortcut actions when snippet mode is on. - self._actions_backup = {} + self._actions_backup = [] # Register snippet mode shortcut. @actions.add(shortcut=':') @@ -284,7 +278,7 @@ def command(self): A cursor is appended at the end. """ - msg = self._gui.status_message if self._gui else self._cmd + msg = self.gui.status_message n = len(msg) n_cur = len(self.cursor) return msg[:n - n_cur] @@ -292,10 +286,7 @@ def command(self): @command.setter def command(self, value): value += self.cursor - if not self._gui: - self._cmd = value - else: - self._gui.status_message = value + self.gui.status_message = value def _backspace(self): """Erase the last character in the snippet command.""" @@ -328,19 +319,19 @@ def callback(): self.command += char return callback - self._actions.add(name='_snippet_{}'.format(i), - shortcut=char, - callback=_make_func(char)) + self.actions.add(name='_snippet_{}'.format(i), + shortcut=char, + callback=_make_func(char)) - self._actions.add(name='_snippet_backspace', - shortcut='backspace', - callback=self._backspace) - self._actions.add(name='_snippet_activate', - shortcut=('enter', 'return'), - callback=self._enter) - self._actions.add(name='_snippet_disable', - shortcut='escape', - callback=self.mode_off) + self.actions.add(name='_snippet_backspace', + shortcut='backspace', + callback=self._backspace) + self.actions.add(name='_snippet_activate', + shortcut=('enter', 'return'), + callback=self._enter) + self.actions.add(name='_snippet_disable', + shortcut='escape', + callback=self.mode_off) def run(self, snippet): """Executes a snippet command. @@ -355,7 +346,7 @@ def run(self, snippet): logger.info("Processing snippet `%s`.", snippet) try: - self._actions.run(name, *snippet_args[1:]) + self.actions.run(name, *snippet_args[1:]) except Exception as e: logger.warn("Error when executing snippet: \"%s\".", str(e)) @@ -364,23 +355,22 @@ def is_mode_on(self): def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") - self._actions_backup = self._actions.get_action_dict() + self._actions_backup = self.actions.backup() # Remove all existing actions. - self._actions.remove_all() + self.actions.remove_all() # Add snippet keystroke actions. self._create_snippet_actions() self.command = ':' def mode_off(self): - if self._gui: - self._gui.status_message = '' + self.gui.status_message = '' # Remove all existing actions. - self._actions.remove_all() + self.actions.remove_all() logger.info("Snippet mode disabled.") # Reestablishes the shortcuts. - for action_obj in self._actions_backup.values(): - self._actions.add(callback=action_obj.callback, - name=action_obj.name, - shortcut=action_obj.shortcut, - alias=action_obj.alias, - ) + for action_obj in self._actions_backup: + self.actions.add(callback=action_obj.callback, + name=action_obj.name, + shortcut=action_obj.shortcut, + alias=action_obj.alias, + ) diff --git a/phy/gui/tests/conftest.py b/phy/gui/tests/conftest.py index fd491a72b..d2cad174b 100644 --- a/phy/gui/tests/conftest.py +++ b/phy/gui/tests/conftest.py @@ -24,10 +24,10 @@ def gui(qapp): @yield_fixture -def actions(): - yield Actions() +def actions(gui): + yield Actions(gui) @yield_fixture -def snippets(): - yield Snippets() +def snippets(gui, actions): + yield Snippets(gui, actions) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index a4da51350..172ba40ea 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -86,8 +86,6 @@ def show_my_shortcuts(): #------------------------------------------------------------------------------ def test_actions_gui(qtbot, gui, actions): - actions.attach(gui) - qtbot.addWidget(gui) gui.show() qtbot.waitForWindowShown(gui) @@ -105,7 +103,6 @@ def press(): def test_snippets_gui(qtbot, gui, actions): - qtbot.addWidget(gui) gui.show() qtbot.waitForWindowShown(gui) @@ -117,7 +114,6 @@ def test(*args): _actions.append(args) # Attach the GUI and register the actions. - actions.attach(gui) snippets = actions.snippets # Simulate the following keystrokes `:t2 ^H^H1 3-5 ab,c ` @@ -185,9 +181,6 @@ def test(arg): assert len(str(arg)) == 1 _actions.append(arg) - # Attach the GUI and register the actions. - snippets.attach(None, actions) - with captured_logging() as buf: snippets.run(':t1') assert 'error' in buf.getvalue().lower() @@ -224,9 +217,6 @@ def test_2(*args): def test_3(*args): _actions.append((3, args)) - # Attach the GUI and register the actions. - snippets.attach(None, actions) - assert snippets.command == '' # Action 1. @@ -264,9 +254,6 @@ def test_snippets_actions_2(actions, snippets): def test(arg): _actions.append(arg) - # Attach the GUI and register the actions. - snippets.attach(None, actions) - actions.test(1) assert _actions == [1] From 4134c0b77a54a3ee31d5a8d9b894cd9d967955cb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 19:26:45 +0200 Subject: [PATCH 0354/1059] Fix tests --- phy/cluster/manual/gui_component.py | 14 +++++--------- phy/cluster/manual/tests/test_gui_component.py | 6 ++++-- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 8ed20491f..26801be63 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -220,9 +220,6 @@ def on_select(cluster_ids): _attach_wizard(self.wizard, self.clustering, self.cluster_meta) - # Create the actions. - self._create_actions() - def _add_action(self, callback, name=None, alias=None): name = name or callback.__name__ shortcut = self.shortcuts.get(name, None) @@ -230,8 +227,8 @@ def _add_action(self, callback, name=None, alias=None): name=name, shortcut=shortcut, alias=alias) - def _create_actions(self): - self.actions = Actions() + def _create_actions(self, gui): + self.actions = Actions(gui) # Selection. self._add_action(self.select, alias='c') @@ -268,8 +265,8 @@ def on_select(cluster_ids): def on_start(): gui.emit('wizard_start') - # Attach the GUI and register the actions. - self.actions.attach(gui) + # Create the actions. + self._create_actions(gui) return self @@ -318,5 +315,4 @@ def save(self): spike_clusters = self.clustering.spike_clusters groups = {c: self.cluster_meta.get('group', c) for c in self.clustering.cluster_ids} - if self.gui: - self.gui.emit('save_requested', spike_clusters, groups) + self.gui.emit('save_requested', spike_clusters, groups) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index cc83d32b9..9b8a4667d 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -302,11 +302,13 @@ def test_manual_clustering_split(manual_clustering): assert_selection(31, 20) -def test_manual_clustering_split_2(qapp): +def test_manual_clustering_split_2(gui): # noqa spike_clusters = np.array([0, 0, 1]) mc = ManualClustering(spike_clusters=spike_clusters) - mc.actions.split([0, 1]) + mc.attach(gui) + + mc.actions.split([0]) assert mc.wizard.selection == [2, 1] From d01136e32b4cf2ff62cbf6dbc0062f21a25611ae Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 19:34:46 +0200 Subject: [PATCH 0355/1059] In CLI, only show traceback in debug mode --- phy/utils/cli.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index aeb9cc87c..1b947f8b5 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -9,6 +9,7 @@ #------------------------------------------------------------------------------ import logging +import sys import click @@ -25,6 +26,15 @@ add_default_handler('DEBUG' if DEBUG else 'INFO') +# Only show traceback in debug mode (--debug). +def exceptionHandler(exception_type, exception, traceback): + logger.error("%s: %s", exception_type.__name__, exception) + + +if not DEBUG: + sys.excepthook = exceptionHandler + + @click.group() @click.version_option(version=phy.__version_git__) @click.help_option('-h', '--help') From a3c4e03549517b4c5e9a4c7eed6309e77baf6bc6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 19:40:57 +0200 Subject: [PATCH 0356/1059] Minor fixes --- phy/__init__.py | 4 ++-- phy/cluster/manual/gui_component.py | 3 ++- phy/cluster/manual/tests/test_gui_component.py | 4 ++-- phy/gui/actions.py | 2 +- phy/utils/cli.py | 2 +- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index cf65f316b..e29171ce9 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -35,7 +35,7 @@ logger.addHandler(logging.NullHandler()) -_logger_fmt = '%(asctime)s [%(levelname)s] %(caller)s %(message)s' +_logger_fmt = '%(asctime)s [%(levelname)s] %(caller)s %(message)s' _logger_date_fmt = '%H:%M:%S' @@ -44,7 +44,7 @@ def format(self, record): # Only keep the first character in the level name. record.levelname = record.levelname[0] filename = op.splitext(op.basename(record.pathname))[0] - record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(16) + record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(20) return super(_Formatter, self).format(record) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 26801be63..ff0cf7d22 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -168,7 +168,8 @@ class ManualClustering(object): 'redo': 'Redo', } - def __init__(self, spike_clusters=None, + def __init__(self, + spike_clusters, cluster_groups=None, n_spikes_max_per_cluster=100, shortcuts=None, diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 9b8a4667d..8c1e27a33 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -29,7 +29,7 @@ def manual_clustering(gui, cluster_ids, cluster_groups): spike_clusters = np.array(cluster_ids) - mc = ManualClustering(spike_clusters=spike_clusters, + mc = ManualClustering(spike_clusters, cluster_groups=cluster_groups, shortcuts={'undo': 'ctrl+z'}, ) @@ -305,7 +305,7 @@ def test_manual_clustering_split(manual_clustering): def test_manual_clustering_split_2(gui): # noqa spike_clusters = np.array([0, 0, 1]) - mc = ManualClustering(spike_clusters=spike_clusters) + mc = ManualClustering(spike_clusters) mc.attach(gui) mc.actions.split([0]) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index cedfa7230..e02ec9ac9 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -292,7 +292,7 @@ def _backspace(self): """Erase the last character in the snippet command.""" if self.command == ':': return - logger.debug("Snippet keystroke `Backspace`.") + logger.log(5, "Snippet keystroke `Backspace`.") self.command = self.command[:-1] def _enter(self): diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 1b947f8b5..80e8b9910 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -27,7 +27,7 @@ # Only show traceback in debug mode (--debug). -def exceptionHandler(exception_type, exception, traceback): +def exceptionHandler(exception_type, exception, traceback): # pragma: no cover logger.error("%s: %s", exception_type.__name__, exception) From dd16682126d17f600dfb2a67a219dabd58ef3aa8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 21:16:49 +0200 Subject: [PATCH 0357/1059] WIP: default shortcuts in actions --- phy/gui/actions.py | 4 +++- phy/gui/tests/test_actions.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index e02ec9ac9..3e21ba2b7 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -135,9 +135,10 @@ class Actions(object): * Display all shortcuts """ - def __init__(self, gui): + def __init__(self, gui, default_shortcuts=None): self._actions_dict = {} self._aliases = {} + self._default_shortcuts = default_shortcuts or {} assert isinstance(gui, GUI) self.gui = gui @@ -164,6 +165,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None): name = name or callback.__name__ alias = alias or _alias(name) name = name.replace('&', '') + shortcut = shortcut or self._default_shortcuts.get(name, None) # Skip existing action. if name in self._actions_dict: diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 172ba40ea..3ad1c0f76 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -12,6 +12,7 @@ _get_shortcut_string, _get_qkeysequence, _parse_snippet, + Actions, ) from phy.utils.testing import captured_output, captured_logging @@ -20,7 +21,7 @@ # Test actions #------------------------------------------------------------------------------ -def test_shortcuts(qtbot): +def test_shortcuts(qapp): def _assert_shortcut(name, key=None): shortcut = _get_qkeysequence(name) s = _get_shortcut_string(shortcut) @@ -37,7 +38,7 @@ def _assert_shortcut(name, key=None): _assert_shortcut(['ctrl+a', 'shift+b']) -def test_show_shortcuts(qtbot): +def test_show_shortcuts(qapp): # NOTE: a Qt application needs to be running so that we can use the # KeySequence. shortcuts = { @@ -51,6 +52,12 @@ def test_show_shortcuts(qtbot): assert 'ctrl+z' in stdout.getvalue() +def test_actions_default_shortcuts(gui): + actions = Actions(gui, default_shortcuts={'my_action': 'a'}) + actions.add(lambda: None, name='my_action') + assert actions.shortcuts['my_action'] == 'a' + + def test_actions_simple(actions): _res = [] From ecad5122a092bc2209c591c0f34c7dfda979e88b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 17 Oct 2015 21:22:57 +0200 Subject: [PATCH 0358/1059] Refactor actions --- phy/cluster/manual/gui_component.py | 38 ++++++++++++----------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index ff0cf7d22..e3b988424 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -179,8 +179,7 @@ def __init__(self, # Load default shortcuts, and override any user shortcuts. self.shortcuts = self.default_shortcuts.copy() - if shortcuts: - self.shortcuts.update(shortcuts) + self.shortcuts.update(shortcuts or {}) # Create Clustering and ClusterMeta. self.clustering = Clustering(spike_clusters) @@ -221,34 +220,27 @@ def on_select(cluster_ids): _attach_wizard(self.wizard, self.clustering, self.cluster_meta) - def _add_action(self, callback, name=None, alias=None): - name = name or callback.__name__ - shortcut = self.shortcuts.get(name, None) - self.actions.add(callback=callback, - name=name, - shortcut=shortcut, alias=alias) - def _create_actions(self, gui): - self.actions = Actions(gui) + self.actions = Actions(gui, default_shortcuts=self.shortcuts) # Selection. - self._add_action(self.select, alias='c') + self.actions.add(self.select, alias='c') # Wizard. - self._add_action(self.wizard.restart, name='reset_wizard') - self._add_action(self.wizard.previous) - self._add_action(self.wizard.next) - self._add_action(self.wizard.next_by_quality) - self._add_action(self.wizard.next_by_similarity) - self._add_action(self.wizard.pin) - self._add_action(self.wizard.unpin) + self.actions.add(self.wizard.restart, name='reset_wizard') + self.actions.add(self.wizard.previous) + self.actions.add(self.wizard.next) + self.actions.add(self.wizard.next_by_quality) + self.actions.add(self.wizard.next_by_similarity) + self.actions.add(self.wizard.pin) + self.actions.add(self.wizard.unpin) # Clustering. - self._add_action(self.merge) - self._add_action(self.split) - self._add_action(self.move) - self._add_action(self.undo) - self._add_action(self.redo) + self.actions.add(self.merge) + self.actions.add(self.split) + self.actions.add(self.move) + self.actions.add(self.undo) + self.actions.add(self.redo) def attach(self, gui): self.gui = gui From 2c31c43bfe80ce56c4514c05b8c62fd8afe9737d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 18:42:23 +0200 Subject: [PATCH 0359/1059] WIP: start refactoring phy.plot --- .coveragerc | 1 - phy/plot/__init__.py | 11 - phy/plot/_mpl_utils.py | 20 - phy/plot/_panzoom.py | 788 ------------------------------- phy/plot/_vispy_utils.py | 555 ---------------------- phy/plot/base.py | 233 +++++++++ phy/plot/ccg.py | 271 ----------- phy/plot/features.py | 659 -------------------------- phy/plot/tests/test_base.py | 0 phy/plot/tests/test_ccg.py | 57 --- phy/plot/tests/test_features.py | 96 ---- phy/plot/tests/test_traces.py | 84 ---- phy/plot/tests/test_utils.py | 57 +-- phy/plot/tests/test_waveforms.py | 88 ---- phy/plot/traces.py | 325 ------------- phy/plot/utils.py | 78 +++ phy/plot/waveforms.py | 512 -------------------- 17 files changed, 317 insertions(+), 3518 deletions(-) delete mode 100644 phy/plot/_mpl_utils.py delete mode 100644 phy/plot/_panzoom.py delete mode 100644 phy/plot/_vispy_utils.py create mode 100644 phy/plot/base.py delete mode 100644 phy/plot/ccg.py delete mode 100644 phy/plot/features.py create mode 100644 phy/plot/tests/test_base.py delete mode 100644 phy/plot/tests/test_ccg.py delete mode 100644 phy/plot/tests/test_features.py delete mode 100644 phy/plot/tests/test_traces.py delete mode 100644 phy/plot/tests/test_waveforms.py delete mode 100644 phy/plot/traces.py create mode 100644 phy/plot/utils.py delete mode 100644 phy/plot/waveforms.py diff --git a/.coveragerc b/.coveragerc index 3934caa38..942e645f6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,7 +3,6 @@ branch = True source = phy omit = */phy/ext/* - */phy/plot/* */phy/utils/tempdir.py */default_settings.py diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 03d945c18..e69de29bb 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Interactive and static visualization of data.""" - -from ._panzoom import PanZoom, PanZoomGrid -from ._vispy_utils import BaseSpikeCanvas, BaseSpikeVisual -from .waveforms import WaveformView, WaveformVisual, plot_waveforms -from .features import FeatureView, FeatureVisual, plot_features -from .traces import TraceView, TraceView, plot_traces -from .ccg import CorrelogramView, CorrelogramView, plot_correlograms diff --git a/phy/plot/_mpl_utils.py b/phy/plot/_mpl_utils.py deleted file mode 100644 index ff2b38a23..000000000 --- a/phy/plot/_mpl_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- - -"""matplotlib utilities.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - - -#------------------------------------------------------------------------------ -# matplotlib utilities -#------------------------------------------------------------------------------ - -def _bottom_left_frame(ax): - """Only keep the bottom and left ticks in a matplotlib Axes.""" - ax.spines['right'].set_visible(False) - ax.spines['top'].set_visible(False) - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') diff --git a/phy/plot/_panzoom.py b/phy/plot/_panzoom.py deleted file mode 100644 index 6dc56fa1b..000000000 --- a/phy/plot/_panzoom.py +++ /dev/null @@ -1,788 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Pan & zoom transform.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import math - -import numpy as np - -from ..utils._types import _as_array - - -#------------------------------------------------------------------------------ -# PanZoom class -#------------------------------------------------------------------------------ - -class PanZoom(object): - """Pan & zoom transform. - - The panzoom transform allow to translate and scale an object in the window - space coordinate (2D). This means that whatever point you grab on the - screen, it should remains under the mouse pointer. Zoom is realized using - the mouse scroll and is always centered on the mouse pointer. - - You can also control programmatically the transform using: - - * aspect: control the aspect ratio of the whole scene - * pan : translate the scene to the given 2D coordinates - * zoom : set the zoom level (centered at current pan coordinates) - * zmin : minimum zoom level - * zmax : maximum zoom level - - """ - - _default_zoom_coeff = 1.5 - _default_wheel_coeff = .1 - _arrows = ('Left', 'Right', 'Up', 'Down') - _pm = ('+', '-') - - def __init__(self, aspect=1.0, pan=(0.0, 0.0), zoom=(1.0, 1.0), - zmin=1e-5, zmax=1e5, - xmin=None, xmax=None, - ymin=None, ymax=None, - ): - """ - Initialize the transform. - - Parameters - ---------- - - aspect : float (default is None) - Indicate what is the aspect ratio of the object displayed. This is - necessary to convert pixel drag move in object space coordinates. - - pan : float, float (default is 0, 0) - Initial translation - - zoom : float, float (default is 1) - Initial zoom level - - zmin : float (default is 0.01) - Minimum zoom level - - zmax : float (default is 1000) - Maximum zoom level - """ - - self._aspect = aspect - self._zmin = zmin - self._zmax = zmax - self._xmin = xmin - self._xmax = xmax - self._ymin = ymin - self._ymax = ymax - - self._zoom_to_pointer = True - - # Canvas this transform is attached to - self._canvas = None - self._canvas_aspect = np.ones(2) - self._width = 1 - self._height = 1 - - self._create_pan_and_zoom(_as_array(pan), _as_array(zoom)) - - # Programs using this transform - self._programs = [] - - def _create_pan_and_zoom(self, pan, zoom): - self._pan = np.array(pan) - self._zoom = np.array(zoom) - self._zoom_coeff = self._default_zoom_coeff - self._wheel_coeff = self._default_wheel_coeff - - # Various properties - # ------------------------------------------------------------------------- - - @property - def zoom_to_pointer(self): - """Whether to zoom toward the pointer position.""" - return self._zoom_to_pointer - - @zoom_to_pointer.setter - def zoom_to_pointer(self, value): - self._zoom_to_pointer = value - - @property - def is_attached(self): - """Whether the transform is attached to a canvas.""" - return self._canvas is not None - - @property - def aspect(self): - """Aspect (width/height).""" - return self._aspect - - @aspect.setter - def aspect(self, value): - """Aspect (width/height).""" - self._aspect = value - - # xmin/xmax - # ------------------------------------------------------------------------- - - @property - def xmin(self): - """Minimum x allowed for pan.""" - return self._xmin - - @xmin.setter - def xmin(self, value): - if self._xmax is not None: - self._xmin = np.minimum(value, self._xmax) - else: - self._xmin = value - - @property - def xmax(self): - """Maximum x allowed for pan.""" - return self._xmax - - @xmax.setter - def xmax(self, value): - if self._xmin is not None: - self._xmax = np.maximum(value, self._xmin) - else: - self._xmax = value - - # ymin/ymax - # ------------------------------------------------------------------------- - - @property - def ymin(self): - """Minimum y allowed for pan.""" - return self._ymin - - @ymin.setter - def ymin(self, value): - if self._ymax is not None: - self._ymin = min(value, self._ymax) - else: - self._ymin = value - - @property - def ymax(self): - """Maximum y allowed for pan.""" - return self._ymax - - @ymax.setter - def ymax(self, value): - if self._ymin is not None: - self._ymax = max(value, self._ymin) - else: - self._ymax = value - - # zmin/zmax - # ------------------------------------------------------------------------- - - @property - def zmin(self): - """Minimum zoom level.""" - return self._zmin - - @zmin.setter - def zmin(self, value): - """Minimum zoom level.""" - self._zmin = min(value, self._zmax) - - @property - def zmax(self): - """Maximal zoom level.""" - return self._zmax - - @zmax.setter - def zmax(self, value): - """Maximal zoom level.""" - self._zmax = max(value, self._zmin) - - # Internal methods - # ------------------------------------------------------------------------- - - def _apply_pan_zoom(self): - zoom = self._zoom_aspect() - for program in self._programs: - program["u_pan"] = self._pan - program["u_zoom"] = zoom - - def _zoom_aspect(self, zoom=None): - if zoom is None: - zoom = self._zoom - zoom = _as_array(zoom) - if self._aspect is not None: - aspect = self._canvas_aspect * self._aspect - else: - aspect = 1. - return zoom * aspect - - def _normalize(self, x_y, restrict_to_box=True): - x_y = np.asarray(x_y, dtype=np.float32) - size = np.array([self._width, self._height], dtype=np.float32) - pos = x_y / (size / 2.) - 1 - return pos - - def _constrain_pan(self): - """Constrain bounding box.""" - if self.xmin is not None and self._xmax is not None: - p0 = self.xmin + 1. / self._zoom[0] - p1 = self.xmax - 1. / self._zoom[0] - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[0] = np.clip(self._pan[0], p0, p1) - - if self.ymin is not None and self._ymax is not None: - p0 = self.ymin + 1. / self._zoom[1] - p1 = self.ymax - 1. / self._zoom[1] - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[1] = np.clip(self._pan[1], p0, p1) - - def _constrain_zoom(self): - """Constrain bounding box.""" - if self.xmin is not None: - self._zoom[0] = max(self._zoom[0], - 1. / (self._pan[0] - self.xmin)) - if self.xmax is not None: - self._zoom[0] = max(self._zoom[0], - 1. / (self.xmax - self._pan[0])) - - if self.ymin is not None: - self._zoom[1] = max(self._zoom[1], - 1. / (self._pan[1] - self.ymin)) - if self.ymax is not None: - self._zoom[1] = max(self._zoom[1], - 1. / (self.ymax - self._pan[1])) - - # Pan and zoom - # ------------------------------------------------------------------------- - - @property - def pan(self): - """Pan translation.""" - return self._pan - - @pan.setter - def pan(self, value): - """Pan translation.""" - assert len(value) == 2 - self._pan[:] = value - self._constrain_pan() - self._apply_pan_zoom() - - @property - def zoom(self): - """Zoom level.""" - return self._zoom - - @zoom.setter - def zoom(self, value): - """Zoom level.""" - if isinstance(value, (int, float)): - value = (value, value) - assert len(value) == 2 - self._zoom = np.clip(value, self._zmin, self._zmax) - if not self.is_attached: - return - - # Constrain bounding box. - self._constrain_pan() - self._constrain_zoom() - - self._apply_pan_zoom() - - def _do_pan(self, d): - - dx, dy = d - - pan_x, pan_y = self.pan - zoom_x, zoom_y = self._zoom_aspect(self._zoom) - - self.pan = (pan_x + dx / zoom_x, - pan_y + dy / zoom_y) - - self._canvas.update() - - def _do_zoom(self, d, p, c=1.): - dx, dy = d - x0, y0 = p - - pan_x, pan_y = self._pan - zoom_x, zoom_y = self._zoom - zoom_x_new, zoom_y_new = (zoom_x * math.exp(c * self._zoom_coeff * dx), - zoom_y * math.exp(c * self._zoom_coeff * dy)) - - zoom_x_new = max(min(zoom_x_new, self._zmax), self._zmin) - zoom_y_new = max(min(zoom_y_new, self._zmax), self._zmin) - - self.zoom = zoom_x_new, zoom_y_new - - if self._zoom_to_pointer: - zoom_x, zoom_y = self._zoom_aspect((zoom_x, - zoom_y)) - zoom_x_new, zoom_y_new = self._zoom_aspect((zoom_x_new, - zoom_y_new)) - - self.pan = (pan_x - x0 * (1. / zoom_x - 1. / zoom_x_new), - pan_y + y0 * (1. / zoom_y - 1. / zoom_y_new)) - - self._canvas.update() - - # Event callbacks - # ------------------------------------------------------------------------- - - keyboard_shortcuts = { - 'pan': ('left click and drag', 'arrows'), - 'zoom': ('right click and drag', '+', '-'), - 'reset': 'r', - } - - def on_resize(self, event): - """Resize event.""" - - self._width = float(event.size[0]) - self._height = float(event.size[1]) - aspect = max(1., self._width / max(self._height, 1.)) - if aspect > 1.0: - self._canvas_aspect = np.array([1.0 / aspect, 1.0]) - else: - self._canvas_aspect = np.array([1.0, aspect / 1.0]) - - # Update zoom level - self.zoom = self._zoom - - def on_mouse_move(self, event): - """Pan and zoom with the mouse.""" - if event.modifiers: - return - if event.is_dragging: - x0, y0 = self._normalize(event.press_event.pos) - x1, y1 = self._normalize(event.last_event.pos, False) - x, y = self._normalize(event.pos, False) - dx, dy = x - x1, -(y - y1) - if event.button == 1: - self._do_pan((dx, dy)) - elif event.button == 2: - c = np.sqrt(self._width) * .03 - self._do_zoom((dx, dy), (x0, y0), c=c) - - def on_mouse_wheel(self, event): - """Zoom with the mouse wheel.""" - if event.modifiers: - return - dx = np.sign(event.delta[1]) * self._wheel_coeff - # Zoom toward the mouse pointer. - x0, y0 = self._normalize(event.pos) - self._do_zoom((dx, dx), (x0, y0)) - - def _zoom_keyboard(self, key): - k = .05 - if key == '-': - k = -k - self._do_zoom((k, k), (0, 0)) - - def _pan_keyboard(self, key): - k = .1 / self.zoom - if key == 'Left': - self.pan += (+k[0], +0) - elif key == 'Right': - self.pan += (-k[0], +0) - elif key == 'Down': - self.pan += (+0, +k[1]) - elif key == 'Up': - self.pan += (+0, -k[1]) - self._canvas.update() - - def _reset_keyboard(self): - self.pan = (0., 0.) - self.zoom = 1. - self._canvas.update() - - def on_key_press(self, event): - """Key press event.""" - - # Zooming with the keyboard. - key = event.key - if event.modifiers: - return - - # Pan. - if key in self._arrows: - self._pan_keyboard(key) - - # Zoom. - if key in self._pm: - self._zoom_keyboard(key) - - # Reset with 'R'. - if key == 'R': - self._reset_keyboard() - - # Canvas methods - # ------------------------------------------------------------------------- - - def add(self, programs): - """Add programs to this tranform.""" - - if not isinstance(programs, (list, tuple)): - programs = [programs] - - for program in programs: - self._programs.append(program) - - self._apply_pan_zoom() - - def attach(self, canvas): - """Attach this tranform to a canvas.""" - self._canvas = canvas - self._width = float(canvas.size[0]) - self._height = float(canvas.size[1]) - - aspect = self._width / max(1, self._height) - if aspect > 1.0: - self._canvas_aspect = np.array([1.0 / aspect, 1.0]) - else: - self._canvas_aspect = np.array([1.0, aspect / 1.0]) - - canvas.connect(self.on_resize) - canvas.connect(self.on_mouse_wheel) - canvas.connect(self.on_mouse_move) - canvas.connect(self.on_key_press) - - -class PanZoomGrid(PanZoom): - """Pan & zoom transform for a grid view. - - This is used in a grid view with independent per-subplot pan & zoom. - - The currently-active subplot depends on where the cursor was when - the mouse was clicked. - - """ - - def __init__(self, *args, **kwargs): - self._index = (0, 0) # current index of the box being pan/zoom-ed - self._n_rows = 1 - self._create_pan_and_zoom(pan=(0., 0.), zoom=(1., 1.)) - self._set_pan_zoom_coeffs() - super(PanZoomGrid, self).__init__(*args, **kwargs) - - # Grid properties - # ------------------------------------------------------------------------- - - @property - def n_rows(self): - """Number of rows.""" - assert self._n_rows is not None - return self._n_rows - - @n_rows.setter - def n_rows(self, value): - self._n_rows = int(value) - assert self._n_rows >= 1 - if self._n_rows > 16: - raise RuntimeError("There cannot be more than 16x16 subplots. " - "The limitation comes from the maximum " - "uniform array size used by panzoom. " - "But you can try to increase the number 256 in " - "'plot/glsl/grid.glsl'.") - self._set_pan_zoom_coeffs() - - # Pan and zoom - # ------------------------------------------------------------------------- - - def _create_pan_and_zoom(self, pan, zoom): - pan = _as_array(pan) - zoom = _as_array(zoom) - n = 16 # maximum number of rows - self._pan_matrix = np.empty((n, n, 2)) - self._pan_matrix[...] = pan[None, None, :] - - self._zoom_matrix = np.empty((n, n, 2)) - self._zoom_matrix[...] = zoom[None, None, :] - - def _set_pan_zoom_coeffs(self): - # The zoom coefficient for mouse zoom should be proportional - # to the subplot size. - c = 3. / np.sqrt(self._n_rows) - self._zoom_coeff = self._default_zoom_coeff * c - self._wheel_coeff = self._default_wheel_coeff - - def _set_current_box(self, pos): - self._index = self._get_box(pos) - - @property - def _box(self): - i, j = self._index - return int(i * self._n_rows + j) - - @property - def _pan(self): - i, j = self._index - return self._pan_matrix[i, j, :] - - @_pan.setter - def _pan(self, value): - i, j = self._index - self._pan_matrix[i, j, :] = value - - @property - def _zoom(self): - i, j = self._index - return self._zoom_matrix[i, j, :] - - @_zoom.setter - def _zoom(self, value): - i, j = self._index - self._zoom_matrix[i, j, :] = value - - @property - def zoom_matrix(self): - """Zoom in every subplot.""" - return self._zoom_matrix - - @property - def pan_matrix(self): - """Pan in every subplot.""" - return self._pan_matrix - - def _apply_pan_zoom(self): - pan = self._pan - zoom = self._zoom_aspect(self._zoom) - value = (pan[0], pan[1], zoom[0], zoom[1]) - for program in self._programs: - program["u_pan_zoom[{0:d}]".format(self._box)] = value - - def _map_box(self, position, inverse=False): - position = _as_array(position) - if position.ndim == 1: - position = position[None, :] - n_rows = self._n_rows - rc_x, rc_y = self._index - - rc_x += 0.5 - rc_y += 0.5 - - x = -1.0 + rc_y * (2.0 / n_rows) - y = +1.0 - rc_x * (2.0 / n_rows) - - width = 0.95 / (1.0 * n_rows) - height = 0.95 / (1.0 * n_rows) - - if not inverse: - return (x + width * position[:, 0], - y + height * position[:, 1]) - else: - return np.c_[((position[:, 0] - x) / width, - (position[:, 1] - y) / height)] - - def _map_pan_zoom(self, position, inverse=False): - position = _as_array(position) - if position.ndim == 1: - position = position[None, :] - n_rows = self._n_rows - pan = self._pan - zoom = self._zoom_aspect(self._zoom) - if not inverse: - return zoom * (position + n_rows * pan) - else: - return (position / zoom - n_rows * pan) - - # xmin/xmax - # ------------------------------------------------------------------------- - - @property - def xmin(self): - return self._local(self._xmin) - - @xmin.setter - def xmin(self, value): - if self._xmax is not None: - self._xmin = np.minimum(value, self._xmax) - else: - self._xmin = value - - @property - def xmax(self): - return self._local(self._xmax) - - @xmax.setter - def xmax(self, value): - if self._xmin is not None: - self._xmax = np.maximum(value, self._xmin) - else: - self._xmax = value - - # ymin/ymax - # ------------------------------------------------------------------------- - - @property - def ymin(self): - return self._local(self._ymin) - - @ymin.setter - def ymin(self, value): - if self._ymax is not None: - self._ymin = min(value, self._ymax) - else: - self._ymin = value - - @property - def ymax(self): - return self._local(self._ymax) - - @ymax.setter - def ymax(self, value): - if self._ymin is not None: - self._ymax = max(value, self._ymin) - else: - self._ymax = value - - # Internal methods - # ------------------------------------------------------------------------- - - def _local(self, value): - if isinstance(value, np.ndarray): - i, j = self._index - value = value[i, j] - if value == np.nan: - value = None - return value - - def _constrain_pan(self): - """Constrain bounding box.""" - if self.xmin is not None and self._xmax is not None: - p0 = (self.xmin + 1. / self._zoom[0]) / self._n_rows - p1 = (self.xmax - 1. / self._zoom[0]) / self._n_rows - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[0] = np.clip(self._pan[0], p0, p1) - - if self.ymin is not None and self._ymax is not None: - p0 = (self.ymin + 1. / self._zoom[1]) / self._n_rows - p1 = (self.ymax - 1. / self._zoom[1]) / self._n_rows - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[1] = np.clip(self._pan[1], p0, p1) - - def _get_box(self, x_y): - x0, y0 = x_y - - x0 /= self._width - y0 /= self._height - - x0 *= self._n_rows - y0 *= self._n_rows - - return (int(math.floor(y0)), int(math.floor(x0))) - - def _normalize(self, x_y, restrict_to_box=True): - x0, y0 = x_y - - x0 /= self._width - y0 /= self._height - - x0 *= self._n_rows - y0 *= self._n_rows - - if restrict_to_box: - x0 = x0 % 1 - y0 = y0 % 1 - - x0 = -(1 - 2 * x0) - y0 = -(1 - 2 * y0) - - x0 /= self._n_rows - y0 /= self._n_rows - - return x0, y0 - - def _initialize_pan_zoom(self): - # Initialize and set the uniform array. - # NOTE: 256 is the maximum size used for this uniform array. - # This corresponds to a hard-limit of 16x16 subplots. - self._u_pan_zoom = np.zeros((256, 4), - dtype=np.float32) - self._u_pan_zoom[:, 2:] = 1. - for program in self._programs: - program["u_pan_zoom"] = self._u_pan_zoom - - def _reset(self): - pan = (0., 0.) - zoom = (1., 1.) - - # Keep the current box index. - index = self._index - - # Update pan and zoom of all subplots. - for i in range(self._n_rows): - for j in range(self._n_rows): - self._index = (i, j) - - # Find out which axes to update. - if pan is not None: - p = list(pan) - self.pan = p - - if zoom is not None: - z = list(zoom) - self.zoom = z - - # Set back the box index. - self._index = index - - # Event callbacks - # ------------------------------------------------------------------------- - - keyboard_shortcuts = { - 'subplot_pan': ('left click and drag', 'arrows'), - 'subplot_zoom': ('right click and drag', '+ or -'), - 'global_reset': 'r', - } - - def on_mouse_move(self, event): - """Mouse move event.""" - # Set box index as a function of the press position. - if event.is_dragging: - self._set_current_box(event.press_event.pos) - super(PanZoomGrid, self).on_mouse_move(event) - - def on_mouse_press(self, event): - """Mouse press event.""" - # Set box index as a function of the press position. - self._set_current_box(event.pos) - super(PanZoomGrid, self).on_mouse_press(event) - - def on_mouse_double_click(self, event): - """Double click event.""" - # Set box index as a function of the press position. - self._set_current_box(event.pos) - super(PanZoomGrid, self).on_mouse_double_click(event) - - def on_mouse_wheel(self, event): - """Mouse wheel event.""" - # Set box index as a function of the press position. - self._set_current_box(event.pos) - super(PanZoomGrid, self).on_mouse_wheel(event) - - def on_key_press(self, event): - """Key press event.""" - super(PanZoomGrid, self).on_key_press(event) - - key = event.key - - # Reset with 'R'. - if key == 'R': - self._reset() - self._canvas.update() - - # Canvas methods - # ------------------------------------------------------------------------- - - def add(self, programs): - """Add programs to this tranform.""" - if not isinstance(programs, (list, tuple)): - programs = [programs] - for program in programs: - self._programs.append(program) - self._initialize_pan_zoom() - self._apply_pan_zoom() diff --git a/phy/plot/_vispy_utils.py b/phy/plot/_vispy_utils.py deleted file mode 100644 index 4e193189c..000000000 --- a/phy/plot/_vispy_utils.py +++ /dev/null @@ -1,555 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting/VisPy utilities.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from functools import wraps -import logging -import os.path as op - -import numpy as np - -from vispy import app, gloo, config -from vispy.util.event import Event - -from phy.utils._types import _as_array, _as_list -from phy.io.array import _unique, _in_polygon -from ._panzoom import PanZoom - -logger = logging.getLogger(__name__) - - -#------------------------------------------------------------------------------ -# Misc -#------------------------------------------------------------------------------ - -def _load_shader(filename): - """Load a shader file.""" - path = op.join(op.dirname(op.realpath(__file__)), 'glsl', filename) - with open(path, 'r') as f: - return f.read() - - -def _tesselate_histogram(hist): - assert hist.ndim == 1 - nsamples = len(hist) - dx = 2. / nsamples - - x0 = -1 + dx * np.arange(nsamples) - - x = np.zeros(5 * nsamples + 1) - y = -1.0 * np.ones(5 * nsamples + 1) - - x[0:-1:5] = x0 - x[1::5] = x0 - x[2::5] = x0 + dx - x[3::5] = x0 - x[4::5] = x0 + dx - x[-1] = 1 - - y[1::5] = y[2::5] = -1 + 2. * hist - - return np.c_[x, y] - - -def _enable_depth_mask(): - gloo.set_state(clear_color='black', - depth_test=True, - depth_range=(0., 1.), - # depth_mask='true', - depth_func='lequal', - blend=True, - blend_func=('src_alpha', 'one_minus_src_alpha')) - gloo.set_clear_depth(1.0) - - -def _wrap_vispy(f): - """Decorator for a function returning a VisPy canvas. - - Add `show=True` parameter. - - """ - @wraps(f) - def wrapped(*args, **kwargs): - show = kwargs.pop('show', True) - canvas = f(*args, **kwargs) - if show: - canvas.show() - return canvas - return wrapped - - -#------------------------------------------------------------------------------ -# Base spike visual -#------------------------------------------------------------------------------ - -class _BakeVisual(object): - _shader_name = '' - _gl_draw_mode = '' - - def __init__(self, **kwargs): - super(_BakeVisual, self).__init__(**kwargs) - self._to_bake = [] - self._empty = True - - curdir = op.dirname(op.realpath(__file__)) - config['include_path'] = [op.join(curdir, 'glsl')] - - vertex = _load_shader(self._shader_name + '.vert') - fragment = _load_shader(self._shader_name + '.frag') - self.program = gloo.Program(vertex, fragment) - - @property - def empty(self): - """Specify whether the visual is currently empty or not.""" - return self._empty - - @empty.setter - def empty(self, value): - """Specify whether the visual is currently empty or not.""" - self._empty = value - - def set_to_bake(self, *bakes): - """Mark data items to be prepared for GPU.""" - for bake in bakes: - if bake not in self._to_bake: - self._to_bake.append(bake) - - def _bake(self): - """Prepare and upload the data on the GPU. - - Return whether something has been baked or not. - - """ - if self._empty: - return - n_bake = len(self._to_bake) - # Bake what needs to be baked. - # WARNING: the bake functions are called in alphabetical order. - # Tweak the names if there are dependencies between the functions. - for bake in sorted(self._to_bake): - # Name of the private baking method. - name = '_bake_{0:s}'.format(bake) - if hasattr(self, name): - getattr(self, name)() - self._to_bake = [] - return n_bake > 0 - - def draw(self): - """Draw the waveforms.""" - # Bake what needs to be baked at this point. - self._bake() - if not self._empty: - self.program.draw(self._gl_draw_mode) - - def update(self): - self.draw() - - -class BaseSpikeVisual(_BakeVisual): - """Base class for a VisPy visual showing spike data. - - There is a notion of displayed spikes and displayed clusters. - - """ - - _transparency = True - - def __init__(self, **kwargs): - super(BaseSpikeVisual, self).__init__(**kwargs) - self.n_spikes = None - self._spike_clusters = None - self._spike_ids = None - self._cluster_ids = None - self._cluster_order = None - self._cluster_colors = None - self._update_clusters_automatically = True - - if self._transparency: - gloo.set_state(clear_color='black', blend=True, - blend_func=('src_alpha', 'one_minus_src_alpha')) - - # Data properties - # ------------------------------------------------------------------------- - - def _set_or_assert_n_spikes(self, arr): - """If n_spikes is None, set it using the array's shape. Otherwise, - check that the array has n_spikes rows.""" - # TODO: improve this - if self.n_spikes is None: - self.n_spikes = arr.shape[0] - assert arr.shape[0] == self.n_spikes - - def _update_clusters(self): - self._cluster_ids = _unique(self._spike_clusters) - - @property - def spike_clusters(self): - """The clusters assigned to the displayed spikes.""" - return self._spike_clusters - - @spike_clusters.setter - def spike_clusters(self, value): - """Set all spike clusters.""" - value = _as_array(value) - self._spike_clusters = value - if self._update_clusters_automatically: - self._update_clusters() - self.set_to_bake('spikes_clusters') - - @property - def cluster_order(self): - """List of selected clusters in display order.""" - if self._cluster_order is None: - return self._cluster_ids - else: - return self._cluster_order - - @cluster_order.setter - def cluster_order(self, value): - value = _as_array(value) - assert sorted(value.tolist()) == sorted(self._cluster_ids) - self._cluster_order = value - - @property - def masks(self): - """Masks of the displayed spikes.""" - return self._masks - - @masks.setter - def masks(self, value): - assert isinstance(value, np.ndarray) - value = _as_array(value) - if value.ndim == 1: - value = value[None, :] - self._set_or_assert_n_spikes(value) - # TODO: support sparse structures - assert value.ndim == 2 - assert value.shape == (self.n_spikes, self.n_channels) - self._masks = value - self.set_to_bake('spikes') - - @property - def spike_ids(self): - """Spike ids to display.""" - if self._spike_ids is None: - self._spike_ids = np.arange(self.n_spikes) - return self._spike_ids - - @spike_ids.setter - def spike_ids(self, value): - value = _as_array(value) - self._set_or_assert_n_spikes(value) - self._spike_ids = value - self.set_to_bake('spikes') - - @property - def cluster_ids(self): - """Cluster ids of the displayed spikes.""" - return self._cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - """Clusters of the displayed spikes.""" - self._cluster_ids = _as_array(value) - - @property - def n_clusters(self): - """Number of displayed clusters.""" - if self._cluster_ids is None: - return None - else: - return len(self._cluster_ids) - - @property - def cluster_colors(self): - """Colors of the displayed clusters. - - The first color is the color of the smallest cluster. - - """ - return self._cluster_colors - - @cluster_colors.setter - def cluster_colors(self, value): - self._cluster_colors = _as_array(value) - assert len(self._cluster_colors) >= self.n_clusters - self.set_to_bake('cluster_color') - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_cluster_color(self): - if self.n_clusters == 0: - u_cluster_color = np.zeros((0, 0, 3)) - else: - u_cluster_color = self.cluster_colors.reshape((1, - self.n_clusters, - 3)) - assert u_cluster_color.ndim == 3 - assert u_cluster_color.shape[2] == 3 - u_cluster_color = (u_cluster_color * 255).astype(np.uint8) - self.program['u_cluster_color'] = gloo.Texture2D(u_cluster_color) - - -#------------------------------------------------------------------------------ -# Axes and boxes visual -#------------------------------------------------------------------------------ - -class BoxVisual(_BakeVisual): - """Box frames in a square grid of subplots.""" - _shader_name = 'box' - _gl_draw_mode = 'lines' - - def __init__(self, **kwargs): - super(BoxVisual, self).__init__(**kwargs) - self._n_rows = None - - @property - def n_rows(self): - """Number of rows in the grid.""" - return self._n_rows - - @n_rows.setter - def n_rows(self, value): - assert value >= 0 - self._n_rows = value - self._empty = not(self._n_rows > 0) - self.set_to_bake('n_rows') - - @property - def n_boxes(self): - """Number of boxes in the grid.""" - return self._n_rows * self._n_rows - - def _bake_n_rows(self): - if not self._n_rows: - return - arr = np.array([[-1, -1], - [-1, +1], - [-1, +1], - [+1, +1], - [+1, +1], - [+1, -1], - [+1, -1], - [-1, -1]]) * .975 - arr = np.tile(arr, (self.n_boxes, 1)) - position = np.empty((8 * self.n_boxes, 3), dtype=np.float32) - position[:, :2] = arr - position[:, 2] = np.repeat(np.arange(self.n_boxes), 8) - self.program['a_position'] = position - self.program['n_rows'] = self._n_rows - - -class AxisVisual(BoxVisual): - """Subplot axes in a subplot grid.""" - _shader_name = 'ax' - - def __init__(self, **kwargs): - super(AxisVisual, self).__init__(**kwargs) - self._xs = [] - self._ys = [] - self.program['u_color'] = (.2, .2, .2, 1.) - - def _bake_n_rows(self): - self.program['n_rows'] = self._n_rows - - @property - def xs(self): - """A list of x coordinates.""" - return self._xs - - @xs.setter - def xs(self, value): - self._xs = _as_list(value) - self.set_to_bake('positions') - - @property - def ys(self): - """A list of y coordinates.""" - return self._ys - - @ys.setter - def ys(self, value): - self._ys = _as_list(value) - self.set_to_bake('positions') - - @property - def color(self): - return tuple(self.program['u_color']) - - @color.setter - def color(self, value): - self.program['u_color'] = tuple(value) - - def _bake_positions(self): - if not self._n_rows: - return - nx = len(self._xs) - ny = len(self._ys) - n = nx + ny - position = np.empty((2 * n * self.n_boxes, 4), dtype=np.float32) - c = 1. - arr = [[x, -c, x, +c] for x in self._xs] - arr += [[-c, y, +c, y] for y in self._ys] - arr = np.hstack(arr).astype(np.float32) - arr = arr.reshape((-1, 2)) - # Positions. - position[:, :2] = np.tile(arr, (self.n_boxes, 1)) - # Index. - position[:, 2] = np.repeat(np.arange(self.n_boxes), 2 * n) - # Axes. - position[:, 3] = np.tile(([0] * (2 * nx)) + ([1] * (2 * ny)), - self.n_boxes) - self.program['a_position'] = position - - -class LassoVisual(_BakeVisual): - """Lasso.""" - _shader_name = 'lasso' - _gl_draw_mode = 'line_loop' - - def __init__(self, **kwargs): - super(LassoVisual, self).__init__(**kwargs) - self._points = [] - self._n_rows = None - self.program['u_box'] = 0 - - @property - def n_rows(self): - """Number of rows in the grid.""" - return self._n_rows - - @n_rows.setter - def n_rows(self, value): - assert value >= 0 - self._n_rows = value - self.set_to_bake('n_rows') - - @property - def points(self): - """Control points.""" - return self._points - - def _update_points(self): - self._empty = len(self._points) <= 1 - self.set_to_bake('points') - - @points.setter - def points(self, value): - value = list(value) - self._points = value - self._update_points() - - def add(self, xy): - """Add a new point.""" - self._points.append((xy)) - self._update_points() - logger.debug("Add lasso point.") - - def clear(self): - """Remove all points.""" - self._points = [] - self._update_points() - logger.debug("Clear lasso.") - - def in_lasso(self, points): - """Find points within the lasso. - - Parameters - ---------- - points : array - A `(n_points, 2)` array with coordinates in `[-1, 1]`. - - """ - if self.n_points <= 1: - return - polygon = self._points - # Close the polygon. - polygon.append(polygon[0]) - return _in_polygon(points, polygon) - - @property - def n_points(self): - return len(self._points) - - @property - def box(self): - """The row and column where the lasso is to be shown.""" - u_box = int(self.program['u_box'][0]) - return (u_box // self._n_rows, u_box % self._n_rows) - - @box.setter - def box(self, value): - assert len(value) == 2 - i, j = value - assert 0 <= i < self._n_rows - assert 0 <= j < self._n_rows - u_box = i * self._n_rows + j - self.program['u_box'] = u_box - - @property - def n_boxes(self): - """Number of boxes in the grid.""" - return self._n_rows * self._n_rows - - def _bake_n_rows(self): - if not self._n_rows: - return - self.program['n_rows'] = self._n_rows - - def _bake_points(self): - if self.n_points <= 1: - return - self.program['a_position'] = np.array(self._points, dtype=np.float32) - - -#------------------------------------------------------------------------------ -# Base spike canvas -#------------------------------------------------------------------------------ - -class BaseSpikeCanvas(app.Canvas): - """Base class for a VisPy canvas with spike data. - - Display a main `BaseSpikeVisual` with pan zoom. - - """ - - _visual_class = None - _pz = None - _events = () - keyboard_shortcuts = {} - - def __init__(self, **kwargs): - super(BaseSpikeCanvas, self).__init__(**kwargs) - self._create_visuals() - self._create_pan_zoom() - self._add_events() - self.keyboard_shortcuts.update(self._pz.keyboard_shortcuts) - - def _create_visuals(self): - self.visual = self._visual_class() - - def _create_pan_zoom(self): - self._pz = PanZoom() - self._pz.add(self.visual.program) - self._pz.attach(self) - - def _add_events(self): - self.events.add(**{event: Event for event in self._events}) - - def emit(self, name, **kwargs): - return getattr(self.events, name)(**kwargs) - - def on_draw(self, event): - """Draw the main visual.""" - self.context.clear() - self.visual.draw() - - def on_resize(self, event): - """Resize the OpenGL context.""" - self.context.set_viewport(0, 0, event.size[0], event.size[1]) diff --git a/phy/plot/base.py b/phy/plot/base.py new file mode 100644 index 000000000..4fa33e79d --- /dev/null +++ b/phy/plot/base.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- + +"""Plotting/VisPy utilities.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging +import os.path as op + +import numpy as np + +from vispy import gloo, config + +from phy.utils._types import _as_array +from phy.io.array import _unique + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Base spike visual +#------------------------------------------------------------------------------ + +class _BakeVisual(object): + _shader_name = '' + _gl_draw_mode = '' + + def __init__(self, **kwargs): + super(_BakeVisual, self).__init__(**kwargs) + self._to_bake = [] + self._empty = True + + curdir = op.dirname(op.realpath(__file__)) + config['include_path'] = [op.join(curdir, 'glsl')] + + vertex = _load_shader(self._shader_name + '.vert') + fragment = _load_shader(self._shader_name + '.frag') + self.program = gloo.Program(vertex, fragment) + + @property + def empty(self): + """Specify whether the visual is currently empty or not.""" + return self._empty + + @empty.setter + def empty(self, value): + """Specify whether the visual is currently empty or not.""" + self._empty = value + + def set_to_bake(self, *bakes): + """Mark data items to be prepared for GPU.""" + for bake in bakes: + if bake not in self._to_bake: + self._to_bake.append(bake) + + def _bake(self): + """Prepare and upload the data on the GPU. + + Return whether something has been baked or not. + + """ + if self._empty: + return + n_bake = len(self._to_bake) + # Bake what needs to be baked. + # WARNING: the bake functions are called in alphabetical order. + # Tweak the names if there are dependencies between the functions. + for bake in sorted(self._to_bake): + # Name of the private baking method. + name = '_bake_{0:s}'.format(bake) + if hasattr(self, name): + getattr(self, name)() + self._to_bake = [] + return n_bake > 0 + + def draw(self): + """Draw the waveforms.""" + # Bake what needs to be baked at this point. + self._bake() + if not self._empty: + self.program.draw(self._gl_draw_mode) + + def update(self): + self.draw() + + +class BaseSpikeVisual(_BakeVisual): + """Base class for a VisPy visual showing spike data. + + There is a notion of displayed spikes and displayed clusters. + + """ + + _transparency = True + + def __init__(self, **kwargs): + super(BaseSpikeVisual, self).__init__(**kwargs) + self.n_spikes = None + self._spike_clusters = None + self._spike_ids = None + self._cluster_ids = None + self._cluster_order = None + self._cluster_colors = None + self._update_clusters_automatically = True + + if self._transparency: + gloo.set_state(clear_color='black', blend=True, + blend_func=('src_alpha', 'one_minus_src_alpha')) + + # Data properties + # ------------------------------------------------------------------------- + + def _set_or_assert_n_spikes(self, arr): + """If n_spikes is None, set it using the array's shape. Otherwise, + check that the array has n_spikes rows.""" + # TODO: improve this + if self.n_spikes is None: + self.n_spikes = arr.shape[0] + assert arr.shape[0] == self.n_spikes + + def _update_clusters(self): + self._cluster_ids = _unique(self._spike_clusters) + + @property + def spike_clusters(self): + """The clusters assigned to the displayed spikes.""" + return self._spike_clusters + + @spike_clusters.setter + def spike_clusters(self, value): + """Set all spike clusters.""" + value = _as_array(value) + self._spike_clusters = value + if self._update_clusters_automatically: + self._update_clusters() + self.set_to_bake('spikes_clusters') + + @property + def cluster_order(self): + """List of selected clusters in display order.""" + if self._cluster_order is None: + return self._cluster_ids + else: + return self._cluster_order + + @cluster_order.setter + def cluster_order(self, value): + value = _as_array(value) + assert sorted(value.tolist()) == sorted(self._cluster_ids) + self._cluster_order = value + + @property + def masks(self): + """Masks of the displayed spikes.""" + return self._masks + + @masks.setter + def masks(self, value): + assert isinstance(value, np.ndarray) + value = _as_array(value) + if value.ndim == 1: + value = value[None, :] + self._set_or_assert_n_spikes(value) + # TODO: support sparse structures + assert value.ndim == 2 + assert value.shape == (self.n_spikes, self.n_channels) + self._masks = value + self.set_to_bake('spikes') + + @property + def spike_ids(self): + """Spike ids to display.""" + if self._spike_ids is None: + self._spike_ids = np.arange(self.n_spikes) + return self._spike_ids + + @spike_ids.setter + def spike_ids(self, value): + value = _as_array(value) + self._set_or_assert_n_spikes(value) + self._spike_ids = value + self.set_to_bake('spikes') + + @property + def cluster_ids(self): + """Cluster ids of the displayed spikes.""" + return self._cluster_ids + + @cluster_ids.setter + def cluster_ids(self, value): + """Clusters of the displayed spikes.""" + self._cluster_ids = _as_array(value) + + @property + def n_clusters(self): + """Number of displayed clusters.""" + if self._cluster_ids is None: + return None + else: + return len(self._cluster_ids) + + @property + def cluster_colors(self): + """Colors of the displayed clusters. + + The first color is the color of the smallest cluster. + + """ + return self._cluster_colors + + @cluster_colors.setter + def cluster_colors(self, value): + self._cluster_colors = _as_array(value) + assert len(self._cluster_colors) >= self.n_clusters + self.set_to_bake('cluster_color') + + # Data baking + # ------------------------------------------------------------------------- + + def _bake_cluster_color(self): + if self.n_clusters == 0: + u_cluster_color = np.zeros((0, 0, 3)) + else: + u_cluster_color = self.cluster_colors.reshape((1, + self.n_clusters, + 3)) + assert u_cluster_color.ndim == 3 + assert u_cluster_color.shape[2] == 3 + u_cluster_color = (u_cluster_color * 255).astype(np.uint8) + self.program['u_cluster_color'] = gloo.Texture2D(u_cluster_color) diff --git a/phy/plot/ccg.py b/phy/plot/ccg.py deleted file mode 100644 index 81b7bf93b..000000000 --- a/phy/plot/ccg.py +++ /dev/null @@ -1,271 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting CCGs.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from vispy import gloo - -from ._mpl_utils import _bottom_left_frame -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - BoxVisual, - AxisVisual, - _tesselate_histogram, - _wrap_vispy) -from ._panzoom import PanZoomGrid -from ..utils._types import _as_array, _as_list -from ..utils._color import _selected_clusters_colors - - -#------------------------------------------------------------------------------ -# CCG visual -#------------------------------------------------------------------------------ - -class CorrelogramVisual(BaseSpikeVisual): - """Display a grid of auto- and cross-correlograms.""" - - _shader_name = 'correlograms' - _gl_draw_mode = 'triangle_strip' - - def __init__(self, **kwargs): - super(CorrelogramVisual, self).__init__(**kwargs) - self._correlograms = None - self._cluster_ids = None - self.n_bins = None - - # Data properties - # ------------------------------------------------------------------------- - - @property - def correlograms(self): - """Displayed correlograms. - - This is a `(n_clusters, n_clusters, n_bins)` array. - - """ - return self._correlograms - - @correlograms.setter - def correlograms(self, value): - value = _as_array(value) - # WARNING: need to set cluster_ids first - assert value.ndim == 3 - if self._cluster_ids is None: - self._cluster_ids = np.arange(value.shape[0]) - assert value.shape[:2] == (self.n_clusters, self.n_clusters) - self.n_bins = value.shape[2] - self._correlograms = value - self._empty = self.n_clusters == 0 or self.n_bins == 0 - self.set_to_bake('correlograms', 'color') - - @property - def cluster_ids(self): - """Displayed cluster ids.""" - return self._cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - self._cluster_ids = np.asarray(value, dtype=np.int32) - - @property - def n_boxes(self): - """Number of boxes in the grid view.""" - return self.n_clusters * self.n_clusters - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_correlograms(self): - n_points = self.n_boxes * (5 * self.n_bins + 1) - - # index increases from top to bottom, left to right - # same as matrix indices (i, j) starting at 0 - positions = [] - boxes = [] - - for i in range(self.n_clusters): - for j in range(self.n_clusters): - index = self.n_clusters * i + j - - hist = self._correlograms[i, j, :] - pos = _tesselate_histogram(hist) - n_points_hist = pos.shape[0] - - positions.append(pos) - boxes.append(index * np.ones(n_points_hist, dtype=np.float32)) - - positions = np.vstack(positions).astype(np.float32) - boxes = np.hstack(boxes) - - assert positions.shape == (n_points, 2) - assert boxes.shape == (n_points,) - - self.program['a_position'] = positions.copy() - self.program['a_box'] = boxes - self.program['n_rows'] = self.n_clusters - - -class CorrelogramView(BaseSpikeCanvas): - """A VisPy canvas displaying correlograms.""" - - _visual_class = CorrelogramVisual - _lines = [] - - def _create_visuals(self): - super(CorrelogramView, self)._create_visuals() - self.boxes = BoxVisual() - self.axes = AxisVisual() - - def _create_pan_zoom(self): - self._pz = PanZoomGrid() - self._pz.add(self.visual.program) - self._pz.add(self.axes.program) - self._pz.attach(self) - self._pz.aspect = None - self._pz.zmin = 1. - self._pz.xmin = -1. - self._pz.xmax = +1. - self._pz.ymin = -1. - self._pz.ymax = +1. - self._pz.zoom_to_pointer = False - - def set_data(self, - correlograms=None, - colors=None, - lines=None): - - if correlograms is not None: - correlograms = np.asarray(correlograms) - else: - correlograms = self.visual.correlograms - assert correlograms.ndim == 3 - n_clusters = len(correlograms) - assert correlograms.shape[:2] == (n_clusters, n_clusters) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - self.cluster_ids = np.arange(n_clusters) - self.visual.correlograms = correlograms - - if len(colors): - self.visual.cluster_colors = colors - - if lines is not None: - self.lines = lines - - self.update() - - @property - def cluster_ids(self): - """Displayed cluster ids.""" - return self.visual.cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - self.visual.cluster_ids = value - self.boxes.n_rows = self.visual.n_clusters - if self.visual.n_clusters >= 1: - self._pz.n_rows = self.visual.n_clusters - self.axes.n_rows = self.visual.n_clusters - if self._lines: - self.lines = self.lines - - @property - def correlograms(self): - return self.visual.correlograms - - @correlograms.setter - def correlograms(self, value): - self.visual.correlograms = value - # Update the lines which depend on the number of bins. - self.lines = self.lines - - @property - def lines(self): - """List of x coordinates where to put vertical lines. - - This is unit of samples. - - """ - return self._lines - - @lines.setter - def lines(self, value): - self._lines = _as_list(value) - c = 2. / (float(max(1, self.visual.n_bins or 0))) - self.axes.xs = np.array(self._lines) * c - self.axes.color = (.5, .5, .5, 1.) - - @property - def lines_color(self): - return self.axes.color - - @lines_color.setter - def lines_color(self, value): - self.axes.color = value - - def on_draw(self, event): - """Draw the correlograms visual.""" - gloo.clear() - self.visual.draw() - self.boxes.draw() - if self._lines: - self.axes.draw() - - -#------------------------------------------------------------------------------ -# CCG plotting -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_correlograms(correlograms, **kwargs): - """Plot an array of correlograms. - - Parameters - ---------- - - correlograms : array - A `(n_clusters, n_clusters, n_bins)` array. - lines : ndarray - Array of x coordinates where to put vertical lines (in number of - samples). - colors : array-like (optional) - A list of colors as RGB tuples. - - """ - c = CorrelogramView(keys='interactive') - c.set_data(correlograms, **kwargs) - return c - - -def _plot_ccg_mpl(ccg, baseline=None, bin=1., color=None, ax=None): - """Plot a CCG with matplotlib and return an Axes instance.""" - import matplotlib.pyplot as plt - if ax is None: - ax = plt.subplot(111) - assert ccg.ndim == 1 - n = ccg.shape[0] - assert n % 2 == 1 - bin = float(bin) - x_min = -n // 2 * bin - bin / 2 - x_max = (n // 2 - 1) * bin + bin / 2 - width = bin * 1.05 - left = np.linspace(x_min, x_max, n) - ax.bar(left, ccg, facecolor=color, width=width, linewidth=0) - if baseline is not None: - ax.axhline(baseline, color='k', linewidth=2, linestyle='-') - ax.axvline(color='k', linewidth=2, linestyle='--') - - ax.set_xlim(x_min, x_max + bin / 2) - ax.set_ylim(0) - - # Only keep the bottom and left ticks. - _bottom_left_frame(ax) - - return ax diff --git a/phy/plot/features.py b/phy/plot/features.py deleted file mode 100644 index 03709cb33..000000000 --- a/phy/plot/features.py +++ /dev/null @@ -1,659 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting features.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from six import string_types -from vispy import gloo - -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - BoxVisual, - AxisVisual, - LassoVisual, - _enable_depth_mask, - _wrap_vispy, - ) -from ._panzoom import PanZoomGrid -from phy.utils._types import _as_array, _is_integer -from phy.io.array import _index_of, _unique -from phy.utils._color import _selected_clusters_colors - - -#------------------------------------------------------------------------------ -# Features visual -#------------------------------------------------------------------------------ - -def _get_feature_dim(dim, features=None, extra_features=None): - if isinstance(dim, (tuple, list)): - channel, feature = dim - return features[:, channel, feature] - elif isinstance(dim, string_types) and dim in extra_features: - x, m, M = extra_features[dim] - m, M = (float(m), float(M)) - x = _as_array(x, np.float32) - # Normalize extra feature. - d = float(max(1., M - m)) - x = (-1. + 2 * (x - m) / d) * .8 - return x - - -class BaseFeatureVisual(BaseSpikeVisual): - """Display a grid of multidimensional features.""" - - _shader_name = None - _gl_draw_mode = 'points' - - def __init__(self, **kwargs): - super(BaseFeatureVisual, self).__init__(**kwargs) - - self._features = None - # Mapping {feature_name: array} where the array must have n_spikes - # element. - self._extra_features = {} - self._x_dim = np.empty((0, 0), dtype=object) - self._y_dim = np.empty((0, 0), dtype=object) - self.n_channels, self.n_features = None, None - self.n_rows = None - - _enable_depth_mask() - - def add_extra_feature(self, name, array, array_min, array_max): - assert isinstance(array, np.ndarray) - assert array.ndim == 1 - if self.n_spikes: - if array.shape != (self.n_spikes,): - msg = ("Unable to add the extra feature " - "`{}`: ".format(name) + - "there should be {} ".format(self.n_spikes) + - "elements in the specified vector, not " - "{}.".format(len(array))) - raise ValueError(msg) - - self._extra_features[name] = (array, array_min, array_max) - - @property - def extra_features(self): - return self._extra_features - - @property - def features(self): - """Displayed features. - - This is a `(n_spikes, n_features)` array. - - """ - return self._features - - @features.setter - def features(self, value): - self._set_features_to_bake(value) - - def _set_features_to_bake(self, value): - # WARNING: when setting new data, features need to be set first. - # n_spikes will be set as a function of features. - value = _as_array(value) - # TODO: support sparse structures - assert value.ndim == 3 - self.n_spikes, self.n_channels, self.n_features = value.shape - self._features = value - self._empty = self.n_spikes == 0 - self.set_to_bake('spikes',) - - def _check_dimension(self, dim): - if _is_integer(dim): - dim = (dim, 0) - if isinstance(dim, tuple): - assert len(dim) == 2 - channel, feature = dim - assert _is_integer(channel) - assert _is_integer(feature) - assert 0 <= channel < self.n_channels - assert 0 <= feature < self.n_features - elif isinstance(dim, string_types): - assert dim in self._extra_features - elif dim: - raise ValueError('{0} should be (channel, feature) '.format(dim) + - 'or one of the extra features.') - - def project(self, box, features=None, extra_features=None): - """Project data to a subplot's two-dimensional subspace. - - Parameters - ---------- - box : 2-tuple - The `(row, col)` of the box. - features : array - extra_features : dict - - Notes - ----- - - The coordinate system is always the world coordinate system, i.e. - `[-1, 1]`. - - """ - i, j = box - dim_x = self._x_dim[i, j] - dim_y = self._y_dim[i, j] - - fet_x = _get_feature_dim(dim_x, - features=features, - extra_features=extra_features, - ) - fet_y = _get_feature_dim(dim_y, - features=features, - extra_features=extra_features, - ) - return np.c_[fet_x, fet_y] - - @property - def x_dim(self): - """Dimensions in the x axis of all subplots. - - This is a matrix of items which can be: - - * tuple `(channel_id, feature_idx)` - * an extra feature name (string) - - """ - return self._x_dim - - @x_dim.setter - def x_dim(self, value): - self._x_dim = value - self._update_dimensions() - - @property - def y_dim(self): - """Dimensions in the y axis of all subplots. - - This is a matrix of items which can be: - - * tuple `(channel_id, feature_idx)` - * an extra feature name (string) - - """ - return self._y_dim - - @y_dim.setter - def y_dim(self, value): - self._y_dim = value - self._update_dimensions() - - def _update_dimensions(self): - """Update the GPU data afte the dimensions have changed.""" - self._check_dimension_matrix(self._x_dim) - self._check_dimension_matrix(self._y_dim) - self.set_to_bake('spikes',) - - def _check_dimension_matrix(self, value): - if not isinstance(value, np.ndarray): - value = np.array(value, dtype=object) - assert value.ndim == 2 - assert value.shape[0] == value.shape[1] - assert value.dtype == object - self.n_rows = len(value) - for dim in value.flat: - self._check_dimension(dim) - - def set_dimension(self, axis, box, dim): - matrix = self._x_dim if axis == 'x' else self._y_dim - matrix[box] = dim - self._update_dimensions() - - @property - def n_boxes(self): - """Number of boxes in the grid.""" - return self.n_rows * self.n_rows - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_spikes(self): - n_points = self.n_boxes * self.n_spikes - - # index increases from top to bottom, left to right - # same as matrix indices (i, j) starting at 0 - positions = [] - boxes = [] - - for i in range(self.n_rows): - for j in range(self.n_rows): - pos = self.project((i, j), - features=self._features, - extra_features=self._extra_features, - ) - positions.append(pos) - index = self.n_rows * i + j - boxes.append(index * np.ones(self.n_spikes, dtype=np.float32)) - - positions = np.vstack(positions).astype(np.float32) - boxes = np.hstack(boxes) - - assert positions.shape == (n_points, 2) - assert boxes.shape == (n_points,) - - self.program['a_position'] = positions.copy() - self.program['a_box'] = boxes - self.program['n_rows'] = self.n_rows - - -class BackgroundFeatureVisual(BaseFeatureVisual): - """Display a grid of multidimensional features in the background.""" - - _shader_name = 'features_bg' - _transparency = False - - -class FeatureVisual(BaseFeatureVisual): - """Display a grid of multidimensional features.""" - - _shader_name = 'features' - - def __init__(self, **kwargs): - super(FeatureVisual, self).__init__(**kwargs) - self.program['u_size'] = 3. - - # Data properties - # ------------------------------------------------------------------------- - - def _set_features_to_bake(self, value): - super(FeatureVisual, self)._set_features_to_bake(value) - self.set_to_bake('spikes', 'spikes_clusters', 'color') - - def _get_mask_dim(self, dim): - if isinstance(dim, (tuple, list)): - channel, feature = dim - return self._masks[:, channel] - else: - return np.ones(self.n_spikes) - - def _update_dimensions(self): - super(FeatureVisual, self)._update_dimensions() - self.set_to_bake('spikes_clusters', 'color') - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_spikes(self): - n_points = self.n_boxes * self.n_spikes - - # index increases from top to bottom, left to right - # same as matrix indices (i, j) starting at 0 - positions = [] - masks = [] - boxes = [] - - for i in range(self.n_rows): - for j in range(self.n_rows): - pos = self.project((i, j), - features=self._features, - extra_features=self._extra_features, - ) - positions.append(pos) - - # The mask depends on both the `x` and `y` coordinates. - mask = np.maximum(self._get_mask_dim(self._x_dim[i, j]), - self._get_mask_dim(self._y_dim[i, j])) - masks.append(mask.astype(np.float32)) - - index = self.n_rows * i + j - boxes.append(index * np.ones(self.n_spikes, dtype=np.float32)) - - positions = np.vstack(positions).astype(np.float32) - masks = np.hstack(masks) - boxes = np.hstack(boxes) - - assert positions.shape == (n_points, 2) - assert masks.shape == (n_points,) - assert boxes.shape == (n_points,) - - self.program['a_position'] = positions.copy() - self.program['a_mask'] = masks - self.program['a_box'] = boxes - - self.program['n_clusters'] = self.n_clusters - self.program['n_rows'] = self.n_rows - - def _bake_spikes_clusters(self): - # Get the spike cluster indices (between 0 and n_clusters-1). - spike_clusters_idx = self.spike_clusters - # We take the cluster order into account here. - spike_clusters_idx = _index_of(spike_clusters_idx, self.cluster_order) - a_cluster = np.tile(spike_clusters_idx, - self.n_boxes).astype(np.float32) - self.program['a_cluster'] = a_cluster - self.program['n_clusters'] = self.n_clusters - - @property - def marker_size(self): - """Marker size in pixels.""" - return float(self.program['u_size']) - - @marker_size.setter - def marker_size(self, value): - value = np.clip(value, .1, 100) - self.program['u_size'] = float(value) - self.update() - - -def _iter_dimensions(dimensions): - if isinstance(dimensions, dict): - for box, dim in dimensions.items(): - yield (box, dim) - elif isinstance(dimensions, np.ndarray): - n_rows, n_cols = dimensions.shape - for i in range(n_rows): - for j in range(n_rows): - yield (i, j), dimensions[i, j] - elif isinstance(dimensions, list): - for i in range(len(dimensions)): - l = dimensions[i] - for j in range(len(l)): - dim = l[j] - yield (i, j), dim - - -class FeatureView(BaseSpikeCanvas): - """A VisPy canvas displaying features.""" - _visual_class = FeatureVisual - _events = ('enlarge',) - - def _create_visuals(self): - self.boxes = BoxVisual() - self.axes = AxisVisual() - self.background = BackgroundFeatureVisual() - self.lasso = LassoVisual() - super(FeatureView, self)._create_visuals() - - def _create_pan_zoom(self): - self._pz = PanZoomGrid() - self._pz.add(self.visual.program) - self._pz.add(self.background.program) - self._pz.add(self.lasso.program) - self._pz.add(self.axes.program) - self._pz.aspect = None - self._pz.attach(self) - - def init_grid(self, n_rows): - """Initialize the view with a given number of rows. - - Note - ---- - - This function *must* be called before setting the attributes. - - """ - assert n_rows >= 0 - - x_dim = np.empty((n_rows, n_rows), dtype=object) - y_dim = np.empty((n_rows, n_rows), dtype=object) - x_dim.fill('time') - y_dim.fill((0, 0)) - - self.visual.n_rows = n_rows - # NOTE: update the private variable because we don't want dimension - # checking at this point nor do we want to prepare the GPU data. - self.visual._x_dim = x_dim - self.visual._y_dim = y_dim - - self.background.n_rows = n_rows - self.background._x_dim = x_dim - self.background._y_dim = y_dim - - self.boxes.n_rows = n_rows - self.lasso.n_rows = n_rows - self.axes.n_rows = n_rows - self.axes.xs = [0] - self.axes.ys = [0] - self._pz.n_rows = n_rows - - xmin = np.empty((n_rows, n_rows)) - xmax = np.empty((n_rows, n_rows)) - ymin = np.empty((n_rows, n_rows)) - ymax = np.empty((n_rows, n_rows)) - for arr in (xmin, xmax, ymin, ymax): - arr.fill(np.nan) - self._pz._xmin = xmin - self._pz._xmax = xmax - self._pz._ymin = ymin - self._pz._ymax = ymax - - for i in range(n_rows): - for j in range(n_rows): - self._pz._xmin[i, j] = -1. - self._pz._xmax[i, j] = +1. - - @property - def x_dim(self): - return self.visual.x_dim - - @property - def y_dim(self): - return self.visual.y_dim - - def _set_dimension(self, axis, box, dim): - self.background.set_dimension(axis, box, dim) - self.visual.set_dimension(axis, box, dim) - min = self._pz._xmin if axis == 'x' else self._pz._ymin - max = self._pz._xmax if axis == 'x' else self._pz._ymax - if dim == 'time': - # NOTE: the private variables are the matrices. - min[box] = -1. - max[box] = +1. - else: - min[box] = None - max[box] = None - - def set_dimensions(self, axis, dimensions): - for box, dim in _iter_dimensions(dimensions): - self._set_dimension(axis, box, dim) - self. _update_dimensions() - - def smart_dimension(self, - axis, - box, - dim, - ): - """Smartify a dimension selection by ensuring x != y.""" - if not isinstance(dim, tuple): - return dim - n_features = self.visual.n_features - # Find current dimensions. - mat = self.x_dim if axis == 'x' else self.y_dim - mat_other = self.x_dim if axis == 'y' else self.y_dim - prev_dim = mat[box] - prev_dim_other = mat_other[box] - # Select smart new dimension. - if not isinstance(prev_dim, string_types): - channel, feature = dim - prev_channel, prev_feature = prev_dim - # Scroll the feature if the channel is the same. - if prev_channel == channel: - feature = (prev_feature + 1) % n_features - # Scroll the feature if it is the same than in the other axis. - if (prev_dim_other != 'time' and - prev_dim_other == (channel, feature)): - feature = (feature + 1) % n_features - dim = (channel, feature) - return dim - - def _update_dimensions(self): - self.background._update_dimensions() - self.visual._update_dimensions() - - def set_data(self, - features=None, - n_rows=1, - x_dimensions=None, - y_dimensions=None, - masks=None, - spike_clusters=None, - extra_features=None, - background_features=None, - colors=None, - ): - if features is not None: - assert isinstance(features, np.ndarray) - if features.ndim == 2: - features = features[..., None] - assert features.ndim == 3 - else: - features = self.visual.features - n_spikes, n_channels, n_features = features.shape - - if spike_clusters is None: - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - cluster_ids = _unique(spike_clusters) - n_clusters = len(cluster_ids) - - if masks is None: - masks = np.ones(features.shape[:2], dtype=np.float32) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - self.visual.features = features - - if background_features is not None: - assert features.shape[1:] == background_features.shape[1:] - self.background.features = background_features.astype(np.float32) - else: - self.background.n_channels = self.visual.n_channels - self.background.n_features = self.visual.n_features - - if masks is not None: - self.visual.masks = masks - - self.visual.spike_clusters = spike_clusters - assert spike_clusters.shape == (n_spikes,) - - if len(colors): - self.visual.cluster_colors = colors - - # Dimensions. - self.init_grid(n_rows) - if not extra_features: - extra_features = {'time': (np.linspace(0., 1., n_spikes), 0., 1.)} - for name, (array, m, M) in (extra_features or {}).items(): - self.add_extra_feature(name, array, m, M) - self.set_dimensions('x', x_dimensions or {(0, 0): 'time'}) - self.set_dimensions('y', y_dimensions or {(0, 0): (0, 0)}) - - self.update() - - def add_extra_feature(self, name, array, array_min, array_max, - array_bg=None): - self.visual.add_extra_feature(name, array, array_min, array_max) - # Note: the background array has a different number of spikes. - if array_bg is None: - array_bg = array - self.background.add_extra_feature(name, array_bg, - array_min, array_max) - - @property - def marker_size(self): - """Marker size.""" - return self.visual.marker_size - - @marker_size.setter - def marker_size(self, value): - self.visual.marker_size = value - self.update() - - def on_draw(self, event): - """Draw the features in a grid view.""" - gloo.clear(color=True, depth=True) - self.axes.draw() - self.background.draw() - self.visual.draw() - self.lasso.draw() - self.boxes.draw() - - keyboard_shortcuts = { - 'marker_size_increase': 'alt+', - 'marker_size_decrease': 'alt-', - 'add_lasso_point': 'shift+left click', - 'clear_lasso': 'shift+right click', - } - - def on_mouse_press(self, e): - control = e.modifiers == ('Control',) - shift = e.modifiers == ('Shift',) - if shift: - if e.button == 1: - # Lasso. - n_rows = self.lasso.n_rows - - box = self._pz._get_box(e.pos) - self.lasso.box = box - - position = self._pz._normalize(e.pos) - x, y = position - x *= n_rows - y *= -n_rows - pos = (x, y) - # pos = self._pz._map_box((x, y), inverse=True) - pos = self._pz._map_pan_zoom(pos, inverse=True) - self.lasso.add(pos.ravel()) - elif e.button == 2: - self.lasso.clear() - self.update() - elif control: - # Enlarge. - box = self._pz._get_box(e.pos) - self.emit('enlarge', - box=box, - x_dim=self.x_dim[box], - y_dim=self.y_dim[box], - ) - - def on_key_press(self, event): - """Handle key press events.""" - coeff = .25 - if 'Alt' in event.modifiers: - if event.key == '+': - self.marker_size += coeff - if event.key == '-': - self.marker_size -= coeff - - -#------------------------------------------------------------------------------ -# Plotting functions -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_features(features, **kwargs): - """Plot features. - - Parameters - ---------- - - features : ndarray - The features to plot. A `(n_spikes, n_channels, n_features)` array. - spike_clusters : ndarray (optional) - A `(n_spikes,)` int array with the spike clusters. - masks : ndarray (optional) - A `(n_spikes, n_channels)` float array with the spike masks. - n_rows : int - Number of rows (= number of columns) in the grid view. - x_dimensions : list - List of dimensions for the x axis. - y_dimensions : list - List of dimensions for the yœ axis. - extra_features : dict - A dictionary `{feature_name: array}` where `array` has - `n_spikes` elements. - background_features : ndarray - The background features. A `(n_spikes, n_channels, n_features)` array. - - """ - c = FeatureView(keys='interactive') - c.set_data(features, **kwargs) - return c diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py new file mode 100644 index 000000000..e69de29bb diff --git a/phy/plot/tests/test_ccg.py b/phy/plot/tests/test_ccg.py deleted file mode 100644 index 18c793b6c..000000000 --- a/phy/plot/tests/test_ccg.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test CCG plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..ccg import _plot_ccg_mpl, CorrelogramView, plot_correlograms -from ...utils._color import _random_color -from ...io.mock import artificial_correlograms -from ...utils.testing import show_test - - -#------------------------------------------------------------------------------ -# Tests matplotlib -#------------------------------------------------------------------------------ - -def test_plot_ccg(): - n_bins = 51 - ccg = np.random.randint(size=n_bins, low=10, high=50) - _plot_ccg_mpl(ccg, baseline=20, color='g') - - -def test_plot_correlograms(): - n_bins = 51 - ccg = np.random.uniform(size=(3, 3, n_bins)) - c = plot_correlograms(ccg, lines=[-10, 0, 20], show=False) - show_test(c) - - -#------------------------------------------------------------------------------ -# Tests VisPy -#------------------------------------------------------------------------------ - -def _test_correlograms(n_clusters=None): - n_samples = 51 - - correlograms = artificial_correlograms(n_clusters, n_samples) - - c = CorrelogramView(keys='interactive') - c.cluster_ids = np.arange(n_clusters) - c.visual.correlograms = correlograms - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - c.lines = [-5, 0, 5] - show_test(c) - - -def test_correlograms_empty(): - _test_correlograms(n_clusters=0) - - -def test_correlograms_full(): - _test_correlograms(n_clusters=3) diff --git a/phy/plot/tests/test_features.py b/phy/plot/tests/test_features.py deleted file mode 100644 index ff24c5de7..000000000 --- a/phy/plot/tests/test_features.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test feature plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..features import FeatureView, plot_features -from ...utils._color import _random_color -from ...io.mock import (artificial_features, - artificial_masks, - artificial_spike_clusters, - artificial_spike_samples) -from ...utils.testing import show_test - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def _test_features(n_spikes=None, n_clusters=None): - n_channels = 32 - n_features = 3 - - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_samples = artificial_spike_samples(n_spikes).astype(np.float32) - - c = FeatureView(keys='interactive') - c.init_grid(2) - c.visual.features = features - c.background.features = features[::2] * 2. - # Useful to test depth. - # masks[n_spikes//2:, ...] = 0 - c.visual.masks = masks - M = spike_samples.max() if len(spike_samples) else 1 - c.add_extra_feature('time', spike_samples, 0, M, - array_bg=spike_samples[::2]) - c.add_extra_feature('test', - np.sin(np.linspace(-10., 10., n_spikes)), - -1., 1., - array_bg=spike_samples[::2]) - c.set_dimensions('x', [['time', (1, 0)], - [(2, 1), 'time']]) - c.set_dimensions('y', [[(0, 0), (1, 1)], - [(1, 0), 'test']]) - c.visual.spike_clusters = spike_clusters - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - - show_test(c) - - -def test_features_empty(): - _test_features(n_spikes=0, n_clusters=0) - - -def test_features_full(): - _test_features(n_spikes=100, n_clusters=3) - - -def test_plot_features(): - n_spikes = 1000 - n_channels = 32 - n_features = 1 - n_clusters = 2 - - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - # Unclustered spikes. - spike_clusters[::3] = -1 - - c = plot_features(features[:, :1, :], - show=False, - x_dimensions=[['time']], - y_dimensions=[[(0, 0)]], - ) - show_test(c) - - c = plot_features(features, - show=False) - show_test(c) - - c = plot_features(features, show=False) - show_test(c) - - c = plot_features(features, masks=masks, show=False) - show_test(c) - - c = plot_features(features, spike_clusters=spike_clusters, show=False) - show_test(c) diff --git a/phy/plot/tests/test_traces.py b/phy/plot/tests/test_traces.py deleted file mode 100644 index b8a8df834..000000000 --- a/phy/plot/tests/test_traces.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test CCG plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..traces import TraceView, plot_traces -from ...utils._color import _random_color -from ...io.mock import (artificial_traces, - artificial_masks, - artificial_spike_clusters, - ) -from ...utils.testing import show_test - - -#------------------------------------------------------------------------------ -# Tests VisPy -#------------------------------------------------------------------------------ - -def _test_traces(n_samples=None): - n_channels = 20 - n_spikes = 50 - n_clusters = 3 - - traces = artificial_traces(n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_samples = np.linspace(50, n_samples - 50, n_spikes).astype(np.uint64) - - c = TraceView(keys='interactive') - c.visual.traces = traces - c.visual.n_samples_per_spike = 20 - c.visual.spike_samples = spike_samples - c.visual.spike_clusters = spike_clusters - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - c.visual.masks = masks - c.visual.sample_rate = 20000. - c.visual.offset = 0 - - show_test(c) - - -def test_traces_empty(): - _test_traces(n_samples=0) - - -def test_traces_full(): - _test_traces(n_samples=2000) - - -def test_plot_traces(): - n_samples = 10000 - n_channels = 20 - n_spikes = 50 - n_clusters = 3 - - traces = artificial_traces(n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_samples = np.linspace(50, n_samples - 50, n_spikes).astype(np.uint64) - - c = plot_traces(traces, show=False) - show_test(c) - - c = plot_traces(traces, - spike_samples=spike_samples, - masks=masks, show=False) - show_test(c) - - c = plot_traces(traces, - spike_samples=spike_samples, - spike_clusters=spike_clusters, show=False) - show_test(c) - - c = plot_traces(traces, - spike_samples=spike_samples, - spike_clusters=spike_clusters, - show=False) - show_test(c) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 1cd174476..b396e8f6d 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -1,63 +1,18 @@ # -*- coding: utf-8 -*- -"""Test utils plotting.""" +"""Test plotting/VisPy utilities.""" + #------------------------------------------------------------------------------ # Imports #------------------------------------------------------------------------------ -from vispy import app - -from ...utils.testing import show_test -from .._vispy_utils import LassoVisual -from .._panzoom import PanZoom, PanZoomGrid +from ..utils import _load_shader #------------------------------------------------------------------------------ -# Tests VisPy +# Test utilities #------------------------------------------------------------------------------ -class TestCanvas(app.Canvas): - _pz = None - - def __init__(self, visual, grid=False, **kwargs): - super(TestCanvas, self).__init__(keys='interactive', **kwargs) - self.visual = visual - self._grid = grid - self._create_pan_zoom() - - def _create_pan_zoom(self): - if self._grid: - self._pz = PanZoomGrid() - self._pz.n_rows = self.visual.n_rows - else: - self._pz = PanZoom() - self._pz.add(self.visual.program) - self._pz.attach(self) - - def on_draw(self, event): - """Draw the main visual.""" - self.context.clear() - self.visual.draw() - - def on_resize(self, event): - """Resize the OpenGL context.""" - self.context.set_viewport(0, 0, event.size[0], event.size[1]) - - -def _show_visual(visual, grid=False): - view = TestCanvas(visual, grid=grid) - show_test(view) - return view - - -def test_lasso(): - lasso = LassoVisual() - lasso.n_rows = 4 - lasso.box = (1, 3) - lasso.points = [[+.8, +.8], - [-.8, +.8], - [-.8, -.8], - ] - view = _show_visual(lasso, grid=True) - view.visual.add([+.8, -.8]) +def test_load_shader(): + assert 'main()' in _load_shader('ax.vert') diff --git a/phy/plot/tests/test_waveforms.py b/phy/plot/tests/test_waveforms.py deleted file mode 100644 index 1a2df6ed7..000000000 --- a/phy/plot/tests/test_waveforms.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test waveform plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..waveforms import WaveformView, plot_waveforms -from ...utils._color import _random_color -from ...io.mock import (artificial_waveforms, artificial_masks, - artificial_spike_clusters) -from ...electrode.mea import staggered_positions -from ...utils.testing import show_test - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - - -def _test_waveforms(n_spikes=None, n_clusters=None): - n_channels = 32 - n_samples = 40 - - channel_positions = staggered_positions(n_channels) - - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - c = WaveformView(keys='interactive') - c.visual.waveforms = waveforms - # Test depth. - # masks[n_spikes//2:, ...] = 0 - # Test position of masks. - # masks[:, :n_channels // 2] = 0 - # masks[:, n_channels // 2:] = 1 - c.visual.masks = masks - c.visual.spike_clusters = spike_clusters - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - c.visual.channel_positions = channel_positions - c.visual.channel_order = np.arange(1, n_channels + 1) - - @c.connect - def on_channel_click(e): - print(e.channel_id, e.key) - - show_test(c) - - -def test_waveforms_empty(): - _test_waveforms(n_spikes=0, n_clusters=0) - - -def test_waveforms_full(): - _test_waveforms(n_spikes=100, n_clusters=3) - - -def test_plot_waveforms(): - n_spikes = 100 - n_clusters = 2 - n_channels = 32 - n_samples = 40 - - channel_positions = staggered_positions(n_channels) - - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - c = plot_waveforms(waveforms, show=False) - show_test(c) - - c = plot_waveforms(waveforms, masks=masks, show=False) - show_test(c) - - c = plot_waveforms(waveforms, spike_clusters=spike_clusters, show=False) - show_test(c) - - c = plot_waveforms(waveforms, - spike_clusters=spike_clusters, - channel_positions=channel_positions, - show=False) - show_test(c) diff --git a/phy/plot/traces.py b/phy/plot/traces.py deleted file mode 100644 index 80b689529..000000000 --- a/phy/plot/traces.py +++ /dev/null @@ -1,325 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting traces.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from vispy import gloo - -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - _wrap_vispy, - ) -from phy.utils._color import _selected_clusters_colors -from phy.utils._types import _as_array -from phy.io.array import _index_of, _unique - - -#------------------------------------------------------------------------------ -# CCG visual -#------------------------------------------------------------------------------ - -class TraceVisual(BaseSpikeVisual): - """Display multi-channel extracellular traces with spikes. - - The visual displays a small portion of the traces at once. There is an - optional offset. - - """ - - _shader_name = 'traces' - _gl_draw_mode = 'line_strip' - - def __init__(self, **kwargs): - super(TraceVisual, self).__init__(**kwargs) - self._traces = None - self._spike_samples = None - self._n_samples_per_spike = None - self._sample_rate = None - self._offset = None - - self.program['u_scale'] = 1. - - # Data properties - # ------------------------------------------------------------------------- - - @property - def traces(self): - """Displayed traces. - - This is a `(n_samples, n_channels)` array. - - """ - return self._traces - - @traces.setter - def traces(self, value): - value = _as_array(value) - assert value.ndim == 2 - self.n_samples, self.n_channels = value.shape - self._traces = value - self._empty = self.n_samples == 0 - self._channel_colors = .5 * np.ones((self.n_channels, 3), - dtype=np.float32) - self.set_to_bake('traces', 'channel_color') - - @property - def channel_colors(self): - """Colors of the displayed channels.""" - return self._channel_colors - - @channel_colors.setter - def channel_colors(self, value): - self._channel_colors = _as_array(value) - assert len(self._channel_colors) == self.n_channels - self.set_to_bake('channel_color') - - @property - def spike_samples(self): - """Time samples of the displayed spikes.""" - return self._spike_samples - - @spike_samples.setter - def spike_samples(self, value): - assert isinstance(value, np.ndarray) - self._set_or_assert_n_spikes(value) - self._spike_samples = value - self.set_to_bake('spikes') - - @property - def n_samples_per_spike(self): - """Number of time samples per displayed spikes.""" - return self._n_samples_per_spike - - @n_samples_per_spike.setter - def n_samples_per_spike(self, value): - self._n_samples_per_spike = int(value) - - @property - def sample_rate(self): - """Sample rate of the recording.""" - return self._sample_rate - - @sample_rate.setter - def sample_rate(self, value): - self._sample_rate = int(value) - - @property - def offset(self): - """Offset of the displayed traces (in time samples).""" - return self._offset - - @offset.setter - def offset(self, value): - self._offset = int(value) - - @property - def channel_scale(self): - """Vertical scaling of the traces.""" - return np.asscalar(self.program['u_scale']) - - @channel_scale.setter - def channel_scale(self, value): - self.program['u_scale'] = value - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_traces(self): - ns, nc = self.n_samples, self.n_channels - - a_index = np.empty((nc * ns, 2), dtype=np.float32) - a_index[:, 0] = np.repeat(np.arange(nc), ns) - a_index[:, 1] = np.tile(np.arange(ns), nc) - - self.program['a_position'] = self._traces.T.ravel().astype(np.float32) - self.program['a_index'] = a_index - self.program['n_channels'] = nc - self.program['n_samples'] = ns - - def _bake_channel_color(self): - u_channel_color = self._channel_colors.reshape((1, - self.n_channels, - -1)) - u_channel_color = (u_channel_color * 255).astype(np.uint8) - self.program['u_channel_color'] = gloo.Texture2D(u_channel_color) - - def _bake_spikes(self): - # Handle the case where there are no spikes. - if self.n_spikes == 0: - a_spike = np.zeros((self.n_channels * self.n_samples, 2), - dtype=np.float32) - a_spike[:, 0] = -1. - self.program['a_spike'] = a_spike - self.program['n_clusters'] = 0 - return - - spike_clusters_idx = self.spike_clusters - spike_clusters_idx = _index_of(spike_clusters_idx, self.cluster_ids) - assert spike_clusters_idx.shape == (self.n_spikes,) - - samples = self._spike_samples - assert samples.shape == (self.n_spikes,) - - # -1 = there's no spike at this vertex - a_clusters = np.empty((self.n_channels, self.n_samples), - dtype=np.float32) - a_clusters.fill(-1.) - a_masks = np.zeros((self.n_channels, self.n_samples), - dtype=np.float32) - masks = self._masks # (n_spikes, n_channels) - - # Add all spikes, one by one. - k = self._n_samples_per_spike // 2 - for i, s in enumerate(samples): - m = masks[i, :] # masks across all channels - channels = (m > 0.) - c = spike_clusters_idx[i] # cluster idx - i = max(s - k, 0) - j = min(s + k, self.n_samples) - a_clusters[channels, i:j] = c - a_masks[channels, i:j] = m[channels, None] - - a_spike = np.empty((self.n_channels * self.n_samples, 2), - dtype=np.float32) - a_spike[:, 0] = a_clusters.ravel() - a_spike[:, 1] = a_masks.ravel() - assert a_spike.dtype == np.float32 - self.program['a_spike'] = a_spike - self.program['n_clusters'] = self.n_clusters - - def _bake_spikes_clusters(self): - self._bake_spikes() - - -class TraceView(BaseSpikeCanvas): - """A VisPy canvas displaying traces.""" - _visual_class = TraceVisual - - def _create_pan_zoom(self): - super(TraceView, self)._create_pan_zoom() - self._pz.aspect = None - self._pz.zmin = .5 - self._pz.xmin = -1. - self._pz.xmax = +1. - self._pz.ymin = -2. - self._pz.ymax = +2. - - def set_data(self, - traces=None, - spike_samples=None, - spike_clusters=None, - n_samples_per_spike=50, - masks=None, - colors=None, - ): - if traces is not None: - assert isinstance(traces, np.ndarray) - assert traces.ndim == 2 - else: - traces = self.visual.traces - # Detrend the traces. - traces = traces - traces.mean(axis=0) - s = traces.std() - if s > 0: - traces /= s - n_samples, n_channels = traces.shape - - if spike_samples is not None: - n_spikes = len(spike_samples) - else: - n_spikes = 0 - - if spike_clusters is None: - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - cluster_ids = _unique(spike_clusters) - n_clusters = len(cluster_ids) - - if masks is None: - masks = np.ones((n_spikes, n_channels), dtype=np.float32) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - self.visual.traces = traces.astype(np.float32) - - if masks is not None: - self.visual.masks = masks - - if n_samples_per_spike is not None: - self.visual.n_samples_per_spike = n_samples_per_spike - - if spike_samples is not None: - assert spike_samples.shape == (n_spikes,) - self.visual.spike_samples = spike_samples - - if spike_clusters is not None: - assert spike_clusters.shape == (n_spikes,) - self.visual.spike_clusters = spike_clusters - - if len(colors): - self.visual.cluster_colors = colors - - self.update() - - @property - def channel_scale(self): - """Vertical scale of the traces.""" - return self.visual.channel_scale - - @channel_scale.setter - def channel_scale(self, value): - self.visual.channel_scale = value - self.update() - - keyboard_shortcuts = { - 'channel_scale_increase': 'ctrl+', - 'channel_scale_decrease': 'ctrl-', - } - - def on_key_press(self, event): - """Handle key press events.""" - key = event.key - ctrl = 'Control' in event.modifiers - - # Box scale. - if ctrl and key in ('+', '-'): - coeff = 1.1 - u = self.channel_scale - if key == '-': - self.channel_scale = u / coeff - elif key == '+': - self.channel_scale = u * coeff - - -#------------------------------------------------------------------------------ -# Plotting functions -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_traces(traces, **kwargs): - """Plot traces. - - Parameters - ---------- - - traces : ndarray - The traces to plot. A `(n_samples, n_channels)` array. - spike_samples : ndarray (optional) - A `(n_spikes,)` int array with the spike times in number of samples. - spike_clusters : ndarray (optional) - A `(n_spikes,)` int array with the spike clusters. - masks : ndarray (optional) - A `(n_spikes, n_channels)` float array with the spike masks. - n_samples_per_spike : int - Waveform size in number of samples. - - """ - c = TraceView(keys='interactive') - c.set_data(traces, **kwargs) - return c diff --git a/phy/plot/utils.py b/phy/plot/utils.py new file mode 100644 index 000000000..5d9ae4a56 --- /dev/null +++ b/phy/plot/utils.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- + +"""Plotting/VisPy utilities.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from functools import wraps +import logging +import os.path as op + +import numpy as np + +from vispy import gloo + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Misc +#------------------------------------------------------------------------------ + +def _load_shader(filename): + """Load a shader file.""" + path = op.join(op.dirname(op.realpath(__file__)), 'glsl', filename) + with open(path, 'r') as f: + return f.read() + + +def _tesselate_histogram(hist): + assert hist.ndim == 1 + nsamples = len(hist) + dx = 2. / nsamples + + x0 = -1 + dx * np.arange(nsamples) + + x = np.zeros(5 * nsamples + 1) + y = -1.0 * np.ones(5 * nsamples + 1) + + x[0:-1:5] = x0 + x[1::5] = x0 + x[2::5] = x0 + dx + x[3::5] = x0 + x[4::5] = x0 + dx + x[-1] = 1 + + y[1::5] = y[2::5] = -1 + 2. * hist + + return np.c_[x, y] + + +def _enable_depth_mask(): + gloo.set_state(clear_color='black', + depth_test=True, + depth_range=(0., 1.), + # depth_mask='true', + depth_func='lequal', + blend=True, + blend_func=('src_alpha', 'one_minus_src_alpha')) + gloo.set_clear_depth(1.0) + + +def _wrap_vispy(f): + """Decorator for a function returning a VisPy canvas. + + Add `show=True` parameter. + + """ + @wraps(f) + def wrapped(*args, **kwargs): + show = kwargs.pop('show', True) + canvas = f(*args, **kwargs) + if show: + canvas.show() + return canvas + return wrapped diff --git a/phy/plot/waveforms.py b/phy/plot/waveforms.py deleted file mode 100644 index 5e72eec78..000000000 --- a/phy/plot/waveforms.py +++ /dev/null @@ -1,512 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting waveforms.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from vispy import gloo -from vispy.gloo import Texture2D - -from ._panzoom import PanZoom -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - _enable_depth_mask, - _wrap_vispy, - ) -from phy.utils._types import _as_array -from phy.utils._color import _selected_clusters_colors -from phy.io.array import _index_of, _normalize, _unique -from phy.electrode.mea import linear_positions - - -#------------------------------------------------------------------------------ -# Waveform visual -#------------------------------------------------------------------------------ - -class WaveformVisual(BaseSpikeVisual): - """Display waveforms with probe geometry.""" - - _shader_name = 'waveforms' - _gl_draw_mode = 'line_strip' - - def __init__(self, **kwargs): - super(WaveformVisual, self).__init__(**kwargs) - - self._waveforms = None - self.n_channels, self.n_samples = None, None - self._channel_order = None - self._channel_positions = None - - self.program['u_data_scale'] = (.05, .05) - self.program['u_channel_scale'] = (1., 1.) - self.program['u_overlap'] = 0. - self.program['u_alpha'] = 0.5 - _enable_depth_mask() - - # Data properties - # ------------------------------------------------------------------------- - - @property - def waveforms(self): - """Displayed waveforms. - - This is a `(n_spikes, n_samples, n_channels)` array. - - """ - return self._waveforms - - @waveforms.setter - def waveforms(self, value): - # WARNING: when setting new data, waveforms need to be set first. - # n_spikes will be set as a function of waveforms. - value = _as_array(value) - # TODO: support sparse structures - assert value.ndim == 3 - self.n_spikes, self.n_samples, self.n_channels = value.shape - self._waveforms = value - self._empty = self.n_spikes == 0 - self.set_to_bake('spikes', 'spikes_clusters', 'color') - - @property - def channel_positions(self): - """Positions of the channels. - - This is a `(n_channels, 2)` array. - - """ - return self._channel_positions - - @channel_positions.setter - def channel_positions(self, value): - value = _as_array(value) - self._channel_positions = value - self.set_to_bake('channel_positions') - - @property - def channel_order(self): - return self._channel_order - - @channel_order.setter - def channel_order(self, value): - self._channel_order = value - - @property - def alpha(self): - """Alpha transparency (between 0 and 1).""" - return self.program['u_alpha'] - - @alpha.setter - def alpha(self, value): - self.program['u_alpha'] = value - - @property - def box_scale(self): - """Scale of the waveforms. - - This is a pair of scalars. - - """ - return tuple(self.program['u_data_scale']) - - @box_scale.setter - def box_scale(self, value): - assert len(value) == 2 - self.program['u_data_scale'] = value - - @property - def probe_scale(self): - """Scale of the probe. - - This is a pair of scalars. - - """ - return tuple(self.program['u_channel_scale']) - - @probe_scale.setter - def probe_scale(self, value): - assert len(value) == 2 - self.program['u_channel_scale'] = value - - @property - def overlap(self): - """Whether to overlap waveforms.""" - return True if self.program['u_overlap'][0] > .5 else False - - @overlap.setter - def overlap(self, value): - assert value in (True, False) - self.program['u_overlap'] = 1. if value else 0. - - def channel_hover(self, position): - """Return the channel id closest to the mouse pointer. - - Parameters - ---------- - - position : tuple - The normalized coordinates of the mouse pointer, in world - coordinates (in `[-1, 1]`). - - """ - mouse_pos = position / self.probe_scale - # Normalize channel positions. - positions = self.channel_positions.astype(np.float32) - positions = _normalize(positions, keep_ratio=True) - positions = .1 + .8 * positions - positions = 2 * positions - 1 - # Find closest channel. - d = np.sum((positions - mouse_pos[None, :]) ** 2, axis=1) - idx = np.argmin(d) - # if self.channel_order is not None: - # channel_id = self.channel_order[idx] - # WARNING: by convention this is the relative channel index. - return idx - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_channel_positions(self): - # WARNING: channel_positions must be in [0,1] because we have a - # texture. - positions = self.channel_positions.astype(np.float32) - positions = _normalize(positions, keep_ratio=True) - positions = positions.reshape((1, self.n_channels, -1)) - # Rescale a bit and recenter. - positions = .1 + .8 * positions - u_channel_pos = np.dstack((positions, - np.zeros((1, self.n_channels, 1)))) - u_channel_pos = (u_channel_pos * 255).astype(np.uint8) - # TODO: more efficient to update the data from an existing texture - self.program['u_channel_pos'] = Texture2D(u_channel_pos, - wrapping='clamp_to_edge') - - def _bake_spikes(self): - - # Bake masks. - # WARNING: swap channel/time axes in the waveforms array. - waveforms = np.swapaxes(self._waveforms, 1, 2) - assert waveforms.shape == (self.n_spikes, - self.n_channels, - self.n_samples, - ) - masks = np.repeat(self._masks.ravel(), self.n_samples) - data = np.c_[waveforms.ravel(), masks.ravel()].astype(np.float32) - # TODO: more efficient to update the data from an existing VBO - self.program['a_data'] = data - - # TODO: SparseCSR, this should just be 'channel' - self._channels_per_spike = np.tile(np.arange(self.n_channels). - astype(np.float32), - self.n_spikes) - - # TODO: SparseCSR, this should be np.diff(spikes_ptr) - self._n_channels_per_spike = self.n_channels * np.ones(self.n_spikes, - dtype=np.int32) - - self._n_waveforms = np.sum(self._n_channels_per_spike) - - # TODO: precompute this with a maximum number of waveforms? - a_time = np.tile(np.linspace(-1., 1., self.n_samples), - self._n_waveforms).astype(np.float32) - - self.program['a_time'] = a_time - self.program['n_channels'] = self.n_channels - - def _bake_spikes_clusters(self): - # WARNING: needs to be called *after* _bake_spikes(). - if not hasattr(self, '_n_channels_per_spike'): - raise RuntimeError("'_bake_spikes()' needs to be called before " - "'bake_spikes_clusters().") - # Get the spike cluster indices (between 0 and n_clusters-1). - spike_clusters_idx = self.spike_clusters - # We take the cluster order into account here. - spike_clusters_idx = _index_of(spike_clusters_idx, self.cluster_order) - # Generate the box attribute. - assert len(spike_clusters_idx) == len(self._n_channels_per_spike) - a_cluster = np.repeat(spike_clusters_idx, - self._n_channels_per_spike * self.n_samples) - a_channel = np.repeat(self._channels_per_spike, self.n_samples) - a_box = np.c_[a_cluster, a_channel].astype(np.float32) - # TODO: more efficient to update the data from an existing VBO - self.program['a_box'] = a_box - self.program['n_clusters'] = self.n_clusters - - -class WaveformView(BaseSpikeCanvas): - """A VisPy canvas displaying waveforms.""" - _visual_class = WaveformVisual - _arrows = ('Left', 'Right', 'Up', 'Down') - _pm = ('+', '-') - _events = ('channel_click',) - _key_pressed = None - _show_mean = False - - def _create_visuals(self): - super(WaveformView, self)._create_visuals() - self.mean = WaveformVisual() - self.mean.alpha = 1. - - def _create_pan_zoom(self): - self._pz = PanZoom() - self._pz.add(self.visual.program) - self._pz.add(self.mean.program) - self._pz.attach(self) - - def set_data(self, - waveforms=None, - masks=None, - spike_clusters=None, - channel_positions=None, - channel_order=None, - colors=None, - **kwargs - ): - - if waveforms is not None: - assert isinstance(waveforms, np.ndarray) - if waveforms.ndim == 2: - waveforms = waveforms[None, ...] - assert waveforms.ndim == 3 - else: - waveforms = self.visual.waveforms - n_spikes, n_samples, n_channels = waveforms.shape - - if spike_clusters is None: - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - spike_clusters = _as_array(spike_clusters) - cluster_ids = _unique(spike_clusters) - n_clusters = len(cluster_ids) - - if masks is None: - masks = np.ones((n_spikes, n_channels), dtype=np.float32) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - if channel_order is None: - channel_order = self.visual.channel_order - if channel_order is None: - channel_order = np.arange(n_channels) - - if channel_positions is None: - channel_positions = self.visual.channel_positions - if channel_positions is None: - channel_positions = linear_positions(n_channels) - - self.visual.waveforms = waveforms.astype(np.float32) - - if masks is not None: - self.visual.masks = masks - - self.visual.spike_clusters = spike_clusters - assert spike_clusters.shape == (n_spikes,) - - if len(colors): - self.visual.cluster_colors = colors - - self.visual.channel_positions = channel_positions - self.visual.channel_order = channel_order - - # Extra parameters. - for k, v in kwargs.items(): - setattr(self, k, v) - - self.update() - - @property - def box_scale(self): - """Scale of the waveforms. - - This is a pair of scalars. - - """ - return self.visual.box_scale - - @box_scale.setter - def box_scale(self, value): - self.visual.box_scale = value - self.mean.box_scale = value - self.update() - - @property - def probe_scale(self): - """Scale of the probe. - - This is a pair of scalars. - - """ - return self.visual.probe_scale - - @probe_scale.setter - def probe_scale(self, value): - self.visual.probe_scale = value - self.mean.probe_scale = value - self.update() - - @property - def alpha(self): - """Opacity.""" - return self.visual.alpha - - @alpha.setter - def alpha(self, value): - self.visual.alpha = value - self.mean.alpha = value - self.update() - - @property - def overlap(self): - """Whether to overlap waveforms.""" - return self.visual.overlap - - @overlap.setter - def overlap(self, value): - self.visual.overlap = value - self.mean.overlap = value - self.update() - - @property - def show_mean(self): - """Whether to show_mean waveforms.""" - return self._show_mean - - @show_mean.setter - def show_mean(self, value): - self._show_mean = value - self.update() - - keyboard_shortcuts = { - 'waveform_scale_increase': ('ctrl+', - 'ctrl+up', - 'shift+wheel up', - ), - 'waveform_scale_decrease': ('ctrl-', - 'ctrl+down', - 'shift+wheel down', - ), - 'waveform_width_increase': ('ctrl+right', 'ctrl+wheel up'), - 'waveform_width_decrease': ('ctrl+left', 'ctrl+wheel down'), - 'probe_width_increase': ('shift+right', 'ctrl+alt+wheel up'), - 'probe_width_decrease': ('shift+left', 'ctrl+alt+wheel down'), - 'probe_height_increase': ('shift+up', 'shift+alt+wheel up'), - 'probe_height_decrease': ('shift+down', 'shift+alt+wheel down'), - 'select_channel': ('ctrl+left click', 'ctrl+right click', 'num+click'), - } - - def on_key_press(self, event): - """Handle key press events.""" - key = event.key - - self._key_pressed = key - - ctrl = 'Control' in event.modifiers - shift = 'Shift' in event.modifiers - - # Box scale. - if ctrl and key in self._arrows + self._pm: - coeff = 1.1 - u, v = self.box_scale - if key == 'Left': - self.box_scale = (u / coeff, v) - elif key == 'Right': - self.box_scale = (u * coeff, v) - elif key in ('Down', '-'): - self.box_scale = (u, v / coeff) - elif key in ('Up', '+'): - self.box_scale = (u, v * coeff) - - # Probe scale. - if shift and key in self._arrows: - coeff = 1.1 - u, v = self.probe_scale - if key == 'Left': - self.probe_scale = (u / coeff, v) - elif key == 'Right': - self.probe_scale = (u * coeff, v) - elif key == 'Down': - self.probe_scale = (u, v / coeff) - elif key == 'Up': - self.probe_scale = (u, v * coeff) - - def on_key_release(self, event): - self._key_pressed = None - - def on_mouse_wheel(self, event): - """Handle mouse wheel events.""" - ctrl = 'Control' in event.modifiers - shift = 'Shift' in event.modifiers - alt = 'Alt' in event.modifiers - coeff = 1. + .1 * event.delta[1] - - # Box scale. - if ctrl and not alt: - u, v = self.box_scale - self.box_scale = (u * coeff, v) - if shift and not alt: - u, v = self.box_scale - self.box_scale = (u, v * coeff) - - # Probe scale. - if ctrl and alt: - u, v = self.probe_scale - self.probe_scale = (u * coeff, v) - if shift and alt: - u, v = self.probe_scale - self.probe_scale = (u, v * coeff) - - def on_mouse_press(self, e): - key = self._key_pressed - if 'Control' in e.modifiers or key in map(str, range(10)): - box_idx = int(key.name) if key in map(str, range(10)) else None - # Normalise mouse position. - position = self._pz._normalize(e.pos) - position[1] = -position[1] - zoom = self._pz._zoom_aspect() - pan = self._pz.pan - mouse_pos = ((position / zoom) - pan) - # Find the channel id. - channel_idx = self.visual.channel_hover(mouse_pos) - self.emit("channel_click", - channel_idx=channel_idx, - ax='x' if e.button == 1 else 'y', - box_idx=box_idx, - ) - - def on_draw(self, event): - """Draw the visual.""" - gloo.clear(color=True, depth=True) - if self._show_mean: - self.mean.draw() - else: - self.visual.draw() - - -#------------------------------------------------------------------------------ -# Plotting functions -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_waveforms(waveforms, **kwargs): - """Plot waveforms. - - Parameters - ---------- - - waveforms : ndarray - The waveforms to plot. A `(n_spikes, n_samples, n_channels)` array. - spike_clusters : ndarray (optional) - A `(n_spikes,)` int array with the spike clusters. - masks : ndarray (optional) - A `(n_spikes, n_channels)` float array with the spike masks. - channel_positions : ndarray - A `(n_channels, 2)` array with the channel positions. - - """ - c = WaveformView(keys='interactive') - c.set_data(waveforms, **kwargs) - return c From 193469f469c844afc3b472bcc13c7a19f0498869 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 19:13:43 +0200 Subject: [PATCH 0360/1059] WIP: test plot utils --- phy/plot/tests/conftest.py | 22 ++++++++++++++++++++++ phy/plot/tests/test_utils.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 phy/plot/tests/conftest.py diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py new file mode 100644 index 000000000..728d9fc98 --- /dev/null +++ b/phy/plot/tests/conftest.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +"""Test VisPy.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from vispy.app import Canvas, use_app, run +from pytest import yield_fixture + + +#------------------------------------------------------------------------------ +# Utilities and fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def canvas(qapp): + use_app('pyqt4') + c = Canvas(keys='interactive') + yield c + c.close() diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index b396e8f6d..062cb6d28 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -7,7 +7,14 @@ # Imports #------------------------------------------------------------------------------ -from ..utils import _load_shader +import numpy as np +from numpy.testing import assert_array_equal as ae +from vispy import gloo + +from ..utils import (_load_shader, + _tesselate_histogram, + _enable_depth_mask, + ) #------------------------------------------------------------------------------ @@ -16,3 +23,23 @@ def test_load_shader(): assert 'main()' in _load_shader('ax.vert') + + +def test_tesselate_histogram(): + n = 5 + hist = np.arange(n) + thist = _tesselate_histogram(hist) + assert thist.shape == (5 * n + 1, 2) + ae(thist[0], [-1, -1]) + ae(thist[-1], [1, -1]) + + +def test_enable_depth_mask(qtbot, canvas): + + @canvas.connect + def on_draw(e): + _enable_depth_mask() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() From 4924355c808f3161e05e9f168322e4825eaf47f2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 19:16:37 +0200 Subject: [PATCH 0361/1059] WIP --- phy/plot/tests/conftest.py | 2 +- phy/plot/tests/test_utils.py | 1 - phy/plot/utils.py | 17 ----------------- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index 728d9fc98..9b89ad9b7 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from vispy.app import Canvas, use_app, run +from vispy.app import Canvas, use_app from pytest import yield_fixture diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 062cb6d28..df5d091f8 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -9,7 +9,6 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from vispy import gloo from ..utils import (_load_shader, _tesselate_histogram, diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 5d9ae4a56..765b11010 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -7,7 +7,6 @@ # Imports #------------------------------------------------------------------------------ -from functools import wraps import logging import os.path as op @@ -60,19 +59,3 @@ def _enable_depth_mask(): blend=True, blend_func=('src_alpha', 'one_minus_src_alpha')) gloo.set_clear_depth(1.0) - - -def _wrap_vispy(f): - """Decorator for a function returning a VisPy canvas. - - Add `show=True` parameter. - - """ - @wraps(f) - def wrapped(*args, **kwargs): - show = kwargs.pop('show', True) - canvas = f(*args, **kwargs) - if show: - canvas.show() - return canvas - return wrapped From 453ad93b5c5fae16ae200b083c0a1289cc20ab50 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 19:21:17 +0200 Subject: [PATCH 0362/1059] Improve _load_shader() --- phy/plot/tests/test_utils.py | 8 ++++++++ phy/plot/utils.py | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index df5d091f8..14df595cc 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -7,8 +7,12 @@ # Imports #------------------------------------------------------------------------------ +import os +import os.path as op + import numpy as np from numpy.testing import assert_array_equal as ae +from vispy import config from ..utils import (_load_shader, _tesselate_histogram, @@ -22,6 +26,10 @@ def test_load_shader(): assert 'main()' in _load_shader('ax.vert') + assert config['include_path'] + assert op.exists(config['include_path'][0]) + assert op.isdir(config['include_path'][0]) + assert os.listdir(config['include_path'][0]) def test_tesselate_histogram(): diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 765b11010..b1325bf0d 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -12,7 +12,7 @@ import numpy as np -from vispy import gloo +from vispy import gloo, config logger = logging.getLogger(__name__) @@ -23,7 +23,11 @@ def _load_shader(filename): """Load a shader file.""" - path = op.join(op.dirname(op.realpath(__file__)), 'glsl', filename) + curdir = op.dirname(op.realpath(__file__)) + glsl_path = op.join(curdir, 'glsl') + if not config['include_path']: + config['include_path'] = [glsl_path] + path = op.join(glsl_path, filename) with open(path, 'r') as f: return f.read() From 9853a25878cfe74e18e845ebbea0541d5f3be674 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 19:28:34 +0200 Subject: [PATCH 0363/1059] Add _create_program() --- phy/plot/tests/test_utils.py | 7 +++++++ phy/plot/utils.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 14df595cc..0633a51c6 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -15,6 +15,7 @@ from vispy import config from ..utils import (_load_shader, + _create_program, _tesselate_histogram, _enable_depth_mask, ) @@ -32,6 +33,12 @@ def test_load_shader(): assert os.listdir(config['include_path'][0]) +def test_create_program(): + p = _create_program('box') + assert p.shaders[0] + assert p.shaders[1] + + def test_tesselate_histogram(): n = 5 hist = np.arange(n) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index b1325bf0d..ff89db799 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -32,6 +32,13 @@ def _load_shader(filename): return f.read() +def _create_program(name): + vertex = _load_shader(name + '.vert') + fragment = _load_shader(name + '.frag') + program = gloo.Program(vertex, fragment) + return program + + def _tesselate_histogram(hist): assert hist.ndim == 1 nsamples = len(hist) From a6d07ef3df1c8f871f624e3d75aeabd3cb84657c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 20:12:11 +0200 Subject: [PATCH 0364/1059] Export more Qt variables --- phy/gui/qt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 173a0f08f..e7d0cbc9c 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -17,7 +17,8 @@ # PyQt import # ----------------------------------------------------------------------------- -from PyQt4.QtCore import Qt, QByteArray, QMetaObject, QSize # noqa +from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa + pyqtSignal, QSize) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, QMessageBox, QApplication, From 6e38b0e8d90407269e30c26f263590143785b67b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 18 Oct 2015 20:42:18 +0200 Subject: [PATCH 0365/1059] WIP: base plot --- phy/plot/base.py | 257 ++++++----------------------------- phy/plot/tests/conftest.py | 21 ++- phy/plot/tests/test_base.py | 36 +++++ phy/plot/tests/test_utils.py | 1 - phy/utils/testing.py | 7 - 5 files changed, 99 insertions(+), 223 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 4fa33e79d..23a9642b5 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Plotting/VisPy utilities.""" +"""Base VisPy classes.""" #------------------------------------------------------------------------------ @@ -8,14 +8,8 @@ #------------------------------------------------------------------------------ import logging -import os.path as op -import numpy as np - -from vispy import gloo, config - -from phy.utils._types import _as_array -from phy.io.array import _unique +from .utils import _create_program logger = logging.getLogger(__name__) @@ -24,210 +18,45 @@ # Base spike visual #------------------------------------------------------------------------------ -class _BakeVisual(object): - _shader_name = '' - _gl_draw_mode = '' - - def __init__(self, **kwargs): - super(_BakeVisual, self).__init__(**kwargs) - self._to_bake = [] - self._empty = True - - curdir = op.dirname(op.realpath(__file__)) - config['include_path'] = [op.join(curdir, 'glsl')] - - vertex = _load_shader(self._shader_name + '.vert') - fragment = _load_shader(self._shader_name + '.frag') - self.program = gloo.Program(vertex, fragment) - - @property - def empty(self): - """Specify whether the visual is currently empty or not.""" - return self._empty - - @empty.setter - def empty(self, value): - """Specify whether the visual is currently empty or not.""" - self._empty = value - - def set_to_bake(self, *bakes): - """Mark data items to be prepared for GPU.""" - for bake in bakes: - if bake not in self._to_bake: - self._to_bake.append(bake) - - def _bake(self): - """Prepare and upload the data on the GPU. - - Return whether something has been baked or not. - - """ - if self._empty: - return - n_bake = len(self._to_bake) - # Bake what needs to be baked. - # WARNING: the bake functions are called in alphabetical order. - # Tweak the names if there are dependencies between the functions. - for bake in sorted(self._to_bake): - # Name of the private baking method. - name = '_bake_{0:s}'.format(bake) - if hasattr(self, name): - getattr(self, name)() - self._to_bake = [] - return n_bake > 0 - - def draw(self): +class BaseVisual(object): + _gl_primitive_type = None + _shader_name = None + + def __init__(self): + assert self._gl_primitive_type + assert self._shader_name + + self._data = {'a_position': None} + self._to_upload = [] # list of arrays/params to upload + + self.program = _create_program(self._shader_name) + + def is_empty(self): + """Return whether the visual is empty.""" + return self._data['a_position'] is not None + + def set_data(self): + pass + + def set_transforms(self): + pass + + def attach(self, canvas): + canvas.connect(self.on_draw) + + @canvas.connect + def on_resize(event): + """Resize the OpenGL context.""" + canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) + + def on_draw(self, e): """Draw the waveforms.""" - # Bake what needs to be baked at this point. - self._bake() - if not self._empty: - self.program.draw(self._gl_draw_mode) - - def update(self): - self.draw() - - -class BaseSpikeVisual(_BakeVisual): - """Base class for a VisPy visual showing spike data. - - There is a notion of displayed spikes and displayed clusters. - - """ - - _transparency = True - - def __init__(self, **kwargs): - super(BaseSpikeVisual, self).__init__(**kwargs) - self.n_spikes = None - self._spike_clusters = None - self._spike_ids = None - self._cluster_ids = None - self._cluster_order = None - self._cluster_colors = None - self._update_clusters_automatically = True - - if self._transparency: - gloo.set_state(clear_color='black', blend=True, - blend_func=('src_alpha', 'one_minus_src_alpha')) - - # Data properties - # ------------------------------------------------------------------------- - - def _set_or_assert_n_spikes(self, arr): - """If n_spikes is None, set it using the array's shape. Otherwise, - check that the array has n_spikes rows.""" - # TODO: improve this - if self.n_spikes is None: - self.n_spikes = arr.shape[0] - assert arr.shape[0] == self.n_spikes - - def _update_clusters(self): - self._cluster_ids = _unique(self._spike_clusters) - - @property - def spike_clusters(self): - """The clusters assigned to the displayed spikes.""" - return self._spike_clusters - - @spike_clusters.setter - def spike_clusters(self, value): - """Set all spike clusters.""" - value = _as_array(value) - self._spike_clusters = value - if self._update_clusters_automatically: - self._update_clusters() - self.set_to_bake('spikes_clusters') - - @property - def cluster_order(self): - """List of selected clusters in display order.""" - if self._cluster_order is None: - return self._cluster_ids - else: - return self._cluster_order - - @cluster_order.setter - def cluster_order(self, value): - value = _as_array(value) - assert sorted(value.tolist()) == sorted(self._cluster_ids) - self._cluster_order = value - - @property - def masks(self): - """Masks of the displayed spikes.""" - return self._masks - - @masks.setter - def masks(self, value): - assert isinstance(value, np.ndarray) - value = _as_array(value) - if value.ndim == 1: - value = value[None, :] - self._set_or_assert_n_spikes(value) - # TODO: support sparse structures - assert value.ndim == 2 - assert value.shape == (self.n_spikes, self.n_channels) - self._masks = value - self.set_to_bake('spikes') - - @property - def spike_ids(self): - """Spike ids to display.""" - if self._spike_ids is None: - self._spike_ids = np.arange(self.n_spikes) - return self._spike_ids - - @spike_ids.setter - def spike_ids(self, value): - value = _as_array(value) - self._set_or_assert_n_spikes(value) - self._spike_ids = value - self.set_to_bake('spikes') - - @property - def cluster_ids(self): - """Cluster ids of the displayed spikes.""" - return self._cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - """Clusters of the displayed spikes.""" - self._cluster_ids = _as_array(value) - - @property - def n_clusters(self): - """Number of displayed clusters.""" - if self._cluster_ids is None: - return None - else: - return len(self._cluster_ids) - - @property - def cluster_colors(self): - """Colors of the displayed clusters. - - The first color is the color of the smallest cluster. - - """ - return self._cluster_colors - - @cluster_colors.setter - def cluster_colors(self, value): - self._cluster_colors = _as_array(value) - assert len(self._cluster_colors) >= self.n_clusters - self.set_to_bake('cluster_color') - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_cluster_color(self): - if self.n_clusters == 0: - u_cluster_color = np.zeros((0, 0, 3)) - else: - u_cluster_color = self.cluster_colors.reshape((1, - self.n_clusters, - 3)) - assert u_cluster_color.ndim == 3 - assert u_cluster_color.shape[2] == 3 - u_cluster_color = (u_cluster_color * 255).astype(np.uint8) - self.program['u_cluster_color'] = gloo.Texture2D(u_cluster_color) + # Upload to the GPU what needs to be uploaded. + for name in self._to_upload: + value = self._data[name] + logger.debug("Upload `%s`: %s.", name, str(value)) + self.program[name] = value + # Reset the list of objects to upload. + self._to_upload = [] + if not self.is_empty(): + self.program.draw(self._gl_primitive_type) diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index 9b89ad9b7..63599b54a 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -6,14 +6,33 @@ # Imports #------------------------------------------------------------------------------ +# import sys + from vispy.app import Canvas, use_app -from pytest import yield_fixture +from pytest import yield_fixture # , mark + +# from phy.gui.qt import QObject, pyqtSignal #------------------------------------------------------------------------------ # Utilities and fixtures #------------------------------------------------------------------------------ +# class ExceptionHandler(QObject): +# errorSignal = pyqtSignal() +# silentSignal = pyqtSignal() + +# def handler(self, exctype, value, traceback): +# self.errorSignal.emit() +# sys._excepthook(exctype, value, traceback) + + +# exceptionHandler = ExceptionHandler() +# sys._excepthook = sys.excepthook +# sys.excepthook = exceptionHandler.handler + + +# @mark.qt_no_exception_capture @yield_fixture def canvas(qapp): use_app('pyqt4') diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index e69de29bb..cc882d79d 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +"""Test base.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from ..base import BaseVisual + + +#------------------------------------------------------------------------------ +# Test base +#------------------------------------------------------------------------------ + +def test_base_visual(qtbot, canvas): + class TestVisual(BaseVisual): + _shader_name = 'box' + _gl_primitive_type = 'lines' + + def set_data(self): + self._data['a_position'] = [[-1, 0, 0], [1, 0, 0]] + self._data['n_rows'] = 1 + self._to_upload = ['a_position', 'n_rows'] + + def is_empty(self): + return False + + v = TestVisual() + v.set_data() + + v.attach(canvas) + canvas.show() + + # qtbot.stop() diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 0633a51c6..88b00c9f9 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -56,4 +56,3 @@ def on_draw(e): canvas.show() qtbot.waitForWindowShown(canvas.native) - # qtbot.stop() diff --git a/phy/utils/testing.py b/phy/utils/testing.py index 020b4d275..375b1d51a 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -189,13 +189,6 @@ def show_test(canvas): canvas.app.process_events() -# TODO -# def test_1(guibot): -# c = Canvas() -# guibot.add(c) -# guibot.wait(c) - - def show_colored_canvas(color): """Show a transient VisPy canvas with a uniform background color.""" from vispy import app, gloo From dd29695548a2d7d6c2db80fe82e47098d113ad9a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 19 Oct 2015 11:14:39 +0200 Subject: [PATCH 0366/1059] WIP: base --- phy/plot/base.py | 4 ++-- phy/plot/tests/test_base.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 23a9642b5..76d761e85 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -42,13 +42,13 @@ def set_transforms(self): pass def attach(self, canvas): - canvas.connect(self.on_draw) - @canvas.connect def on_resize(event): """Resize the OpenGL context.""" canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) + canvas.events['draw'].connect(self.on_draw, position='last') + def on_draw(self, e): """Draw the waveforms.""" # Upload to the GPU what needs to be uploaded. diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index cc882d79d..9ee7ecff5 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -7,6 +7,8 @@ # Imports #------------------------------------------------------------------------------ +from vispy import gloo + from ..base import BaseVisual @@ -27,6 +29,11 @@ def set_data(self): def is_empty(self): return False + def on_draw(e): + gloo.clear() + + canvas.events['draw'].connect(on_draw, position='last') + v = TestVisual() v.set_data() From 31c27c406524b8b0370b3c387c7829f1e41ed4b5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 19 Oct 2015 18:23:00 +0200 Subject: [PATCH 0367/1059] WIP: BaseCanvas --- phy/plot/base.py | 58 +++++++++++++++++++++++++++---------- phy/plot/tests/conftest.py | 25 +++------------- phy/plot/tests/test_base.py | 15 +++++----- setup.cfg | 2 +- 4 files changed, 56 insertions(+), 44 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 76d761e85..abf333e90 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -9,6 +9,9 @@ import logging +from vispy import gloo +from vispy.app import Canvas + from .utils import _create_program logger = logging.getLogger(__name__) @@ -18,6 +21,21 @@ # Base spike visual #------------------------------------------------------------------------------ +class BaseCanvas(Canvas): + def __init__(self, *args, **kwargs): + super(BaseCanvas, self).__init__(*args, **kwargs) + self._visuals = [] + + def add_visual(self, visual): + self._visuals.append(visual) + visual.attach(self) + + def on_draw(self, e): + gloo.clear() + for visual in self._visuals: + visual.draw() + + class BaseVisual(object): _gl_primitive_type = None _shader_name = None @@ -26,37 +44,47 @@ def __init__(self): assert self._gl_primitive_type assert self._shader_name - self._data = {'a_position': None} - self._to_upload = [] # list of arrays/params to upload + self.size = 1, 1 + self._canvas = None + self._do_show = False self.program = _create_program(self._shader_name) - def is_empty(self): - """Return whether the visual is empty.""" - return self._data['a_position'] is not None + def show(self): + self._do_show = True + + def hide(self): + self._do_show = False def set_data(self): + """Set the data for the visual.""" pass def set_transforms(self): + """Set the list of transforms for the visual.""" pass def attach(self, canvas): + """Attach some events.""" + self._canvas = canvas + @canvas.connect def on_resize(event): """Resize the OpenGL context.""" + self.size = event.size canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) - canvas.events['draw'].connect(self.on_draw, position='last') + canvas.connect(self.on_mouse_move) - def on_draw(self, e): + def on_mouse_move(self, e): + pass + + def draw(self): """Draw the waveforms.""" - # Upload to the GPU what needs to be uploaded. - for name in self._to_upload: - value = self._data[name] - logger.debug("Upload `%s`: %s.", name, str(value)) - self.program[name] = value - # Reset the list of objects to upload. - self._to_upload = [] - if not self.is_empty(): + if not self._do_show: self.program.draw(self._gl_primitive_type) + + def update(self): + """Trigger a draw event in the canvas from the visual.""" + if self._canvas: + self._canvas.update() diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index 63599b54a..306173b3e 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -6,36 +6,19 @@ # Imports #------------------------------------------------------------------------------ -# import sys +from vispy.app import use_app +from pytest import yield_fixture -from vispy.app import Canvas, use_app -from pytest import yield_fixture # , mark - -# from phy.gui.qt import QObject, pyqtSignal +from ..base import BaseCanvas #------------------------------------------------------------------------------ # Utilities and fixtures #------------------------------------------------------------------------------ -# class ExceptionHandler(QObject): -# errorSignal = pyqtSignal() -# silentSignal = pyqtSignal() - -# def handler(self, exctype, value, traceback): -# self.errorSignal.emit() -# sys._excepthook(exctype, value, traceback) - - -# exceptionHandler = ExceptionHandler() -# sys._excepthook = sys.excepthook -# sys.excepthook = exceptionHandler.handler - - -# @mark.qt_no_exception_capture @yield_fixture def canvas(qapp): use_app('pyqt4') - c = Canvas(keys='interactive') + c = BaseCanvas(keys='interactive') yield c c.close() diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 9ee7ecff5..f46d95010 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -17,17 +17,19 @@ #------------------------------------------------------------------------------ def test_base_visual(qtbot, canvas): + class TestVisual(BaseVisual): _shader_name = 'box' _gl_primitive_type = 'lines' def set_data(self): - self._data['a_position'] = [[-1, 0, 0], [1, 0, 0]] - self._data['n_rows'] = 1 - self._to_upload = ['a_position', 'n_rows'] + self.program['a_position'] = [[-1, 0, 0], [1, 0, 0]] + self.program['n_rows'] = 1 - def is_empty(self): - return False + def on_mouse_move(self, e): + y = 1 - 2 * e.pos[1] / float(self.size[1]) + self.program['a_position'] = [[-1, y, 0], [1, y, 0]] + self.update() def on_draw(e): gloo.clear() @@ -36,8 +38,7 @@ def on_draw(e): v = TestVisual() v.set_data() + canvas.add_visual(v) - v.attach(canvas) canvas.show() - # qtbot.stop() diff --git a/setup.cfg b/setup.cfg index c138ac336..b8de731d4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,7 +2,7 @@ universal = 1 [pytest] -norecursedirs = plot experimental _* +norecursedirs = experimental _* [flake8] ignore=E265 From bef9cc5a9a859623c8139adae776f652f53c3047 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 19 Oct 2015 19:07:12 +0200 Subject: [PATCH 0368/1059] WIP --- phy/plot/tests/test_base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index f46d95010..58d16bb86 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -31,11 +31,6 @@ def on_mouse_move(self, e): self.program['a_position'] = [[-1, y, 0], [1, y, 0]] self.update() - def on_draw(e): - gloo.clear() - - canvas.events['draw'].connect(on_draw, position='last') - v = TestVisual() v.set_data() canvas.add_visual(v) From 42b8f0935d0b31b9e9c798774110084175aabd00 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 19 Oct 2015 19:16:20 +0200 Subject: [PATCH 0369/1059] Increase coverage --- .coveragerc | 1 + phy/plot/base.py | 7 +++++-- phy/plot/tests/test_base.py | 11 +++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.coveragerc b/.coveragerc index 942e645f6..c705eacb3 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,3 +11,4 @@ exclude_lines = pragma: no cover raise AssertionError raise NotImplementedError + pass diff --git a/phy/plot/base.py b/phy/plot/base.py index abf333e90..11ed937bb 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -74,14 +74,17 @@ def on_resize(event): self.size = event.size canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) - canvas.connect(self.on_mouse_move) + @canvas.connect + def on_mouse_move(event): + if self._do_show: + self.on_mouse_move(event) def on_mouse_move(self, e): pass def draw(self): """Draw the waveforms.""" - if not self._do_show: + if self._do_show: self.program.draw(self._gl_primitive_type) def update(self): diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 58d16bb86..36f9d1fc5 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -25,15 +25,14 @@ class TestVisual(BaseVisual): def set_data(self): self.program['a_position'] = [[-1, 0, 0], [1, 0, 0]] self.program['n_rows'] = 1 - - def on_mouse_move(self, e): - y = 1 - 2 * e.pos[1] / float(self.size[1]) - self.program['a_position'] = [[-1, y, 0], [1, y, 0]] - self.update() + self.show() v = TestVisual() v.set_data() canvas.add_visual(v) canvas.show() - # qtbot.stop() + v.hide() + v.show() + qtbot.waitForWindowShown(canvas.native) + v.update() From 432cd758c219d381325f7eb6520475c136cd66be Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 19 Oct 2015 19:20:34 +0200 Subject: [PATCH 0370/1059] Lint --- phy/plot/tests/test_base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 36f9d1fc5..83f9f3ec2 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -7,8 +7,6 @@ # Imports #------------------------------------------------------------------------------ -from vispy import gloo - from ..base import BaseVisual From 0d931157f7b40f7645c1ec00db7a2bc392d03450 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 06:47:06 +0200 Subject: [PATCH 0371/1059] Add mouse move test --- phy/plot/tests/test_base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 83f9f3ec2..3acee363b 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -33,4 +33,8 @@ def set_data(self): v.hide() v.show() qtbot.waitForWindowShown(canvas.native) + + # Simulate a mouse move. + canvas.events.mouse_move(delta=(1., 0.)) + v.update() From 5c9798225ce2a68ca6d6e85e985def1a20d6e1c9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 07:22:35 +0200 Subject: [PATCH 0372/1059] WIP: transforms --- phy/plot/tests/test_transform.py | 53 +++++++++++++++++++ phy/plot/transform.py | 89 ++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 phy/plot/tests/test_transform.py create mode 100644 phy/plot/transform.py diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py new file mode 100644 index 000000000..f4c7340ef --- /dev/null +++ b/phy/plot/tests/test_transform.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +"""Test transform.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from itertools import product + +import numpy as np +from numpy.testing import assert_array_equal as ae +from pytest import yield_fixture + +from ..transform import Translate, Scale, Range, Clip, GPU + + +#------------------------------------------------------------------------------ +# Test transform +#------------------------------------------------------------------------------ + +@yield_fixture(params=product([0, 1, 2], ['i', 'f'])) +def array(request): + m, t = request.param + if t == 'i': + a, b = 3, 4 + elif t == 'f': + a, b = 3., 4. + arr = [a, b] + if m == 1: + arr = [arr] + elif m == 2: + arr = np.array(arr) + elif m == 3: + arr = np.array([arr]) + elif m == 4: + arr = np.array([arr, arr, arr]) + yield arr + + +def _check(transform, array, expected): + transformed = transform.apply(array) + array = np.atleast_2d(array) + if isinstance(array, np.ndarray): + assert transformed.shape == array.shape + assert transformed.dtype == array.dtype + ae(transformed, expected) + + +def test_translate(array): + t = Translate(1, 2) + _check(t, array, [[4, 6]]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py new file mode 100644 index 000000000..705668dc3 --- /dev/null +++ b/phy/plot/transform.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- + +"""Transforms.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +import logging + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Transforms +#------------------------------------------------------------------------------ + +def _wrap_apply(f): + def wrapped(arr): + arr = np.atleast_2d(arr) + assert arr.ndim == 2 + assert arr.shape[1] == 2 + out = f(arr) + assert out.shape == arr.shape + return out + return wrapped + + +class BaseTransform(object): + def __init__(self): + self.apply = _wrap_apply(self.apply) + + def apply(self, arr): + raise NotImplementedError() + + +class Translate(BaseTransform): + def __init__(self, tx, ty): + BaseTransform.__init__(self) + self.tx, self.ty = tx, ty + + def apply(self, arr): + return arr + np.array([[self.tx, self.ty]]) + + +class Scale(BaseTransform): + def __init__(self, sx, sy): + BaseTransform.__init__(self) + self.sx, self.sy = sx, sy + + def apply(self, arr): + return arr * np.array([[self.sx, self.sy]]) + + +class Range(BaseTransform): + def __init__(self, xmin, ymin, xmax, ymax, mode='hard'): + BaseTransform.__init__(self) + self.xmin, ymin = xmin, ymin + self.xmax, ymax = xmax, ymax + self.mode = mode + + def apply(self, arr): + if self.mode == 'hard': + xym = arr.min(axis=0) + xyM = arr.max(axis=0) + + xymin = np.array([[self.xmin, self.ymin]]) + xymax = np.array([[self.xmax, self.ymax]]) + + xymin + (xymax - xymin) * (arr - xym) / (xyM - xym) + + raise NotImplementedError() + + +class Clip(BaseTransform): + def __init__(self, xmin, ymin, xmax, ymax): + BaseTransform.__init__(self) + self.xmin, ymin = xmin, ymin + self.xmax, ymax = xmax, ymax + + def apply(self, arr): + pass + + +class GPU(BaseTransform): + pass From 637d82ef12544a073c26803f292923e97ac1da13 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 08:05:50 +0200 Subject: [PATCH 0373/1059] WIP: test transforms --- phy/plot/tests/test_transform.py | 73 +++++++++++++++++++++----------- phy/plot/transform.py | 31 ++++++++++---- 2 files changed, 71 insertions(+), 33 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index f4c7340ef..f5fbf1546 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -11,43 +11,66 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import yield_fixture from ..transform import Translate, Scale, Range, Clip, GPU #------------------------------------------------------------------------------ -# Test transform +# Fixtures #------------------------------------------------------------------------------ -@yield_fixture(params=product([0, 1, 2], ['i', 'f'])) -def array(request): - m, t = request.param - if t == 'i': - a, b = 3, 4 - elif t == 'f': - a, b = 3., 4. - arr = [a, b] - if m == 1: - arr = [arr] - elif m == 2: - arr = np.array(arr) - elif m == 3: - arr = np.array([arr]) - elif m == 4: - arr = np.array([arr, arr, arr]) - yield arr - - def _check(transform, array, expected): transformed = transform.apply(array) + if array is None or not len(array): + assert transformed == array + return array = np.atleast_2d(array) if isinstance(array, np.ndarray): assert transformed.shape == array.shape - assert transformed.dtype == array.dtype - ae(transformed, expected) + assert transformed.dtype == np.float32 + assert np.all(transformed == expected) + + +#------------------------------------------------------------------------------ +# Test transform +#------------------------------------------------------------------------------ + +def test_types(): + t = Translate(1, 2) + _check(t, [], []) + + for ab in [[3, 4], [3., 4.]]: + for arr in [ab, [ab], np.array(ab), np.array([ab]), + np.array([ab, ab, ab])]: + _check(t, arr, [[4, 6]]) -def test_translate(array): +def test_translate(): t = Translate(1, 2) - _check(t, array, [[4, 6]]) + _check(t, [3, 4], [[4, 6]]) + + +def test_scale(): + t = Scale(-1, 2) + _check(t, [3, 4], [[-3, 8]]) + + +def test_range(): + t = Range(0, 1, 2, 3) + + # One element => move to the center of the window. + _check(t, [-1, -1], [[1, 2]]) + _check(t, [3, 4], [[1, 2]]) + _check(t, [0, 1], [[1, 2]]) + + # Extend the range symmetrically. + _check(t, [[-1, 0], [3, 4]], [[0, 1], [2, 3]]) + + +def test_clip(): + t = Clip(0, 1, 2, 3) + + _check(t, [-1, -1], [[0, 1]]) + _check(t, [3, 4], [[2, 3]]) + + _check(t, [[-1, 0], [3, 4]], [[0, 1], [2, 3]]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 705668dc3..1794cc84c 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -20,10 +20,14 @@ def _wrap_apply(f): def wrapped(arr): + if arr is None or not len(arr): + return arr arr = np.atleast_2d(arr) + arr = arr.astype(np.float32) assert arr.ndim == 2 assert arr.shape[1] == 2 out = f(arr) + out = out.astype(np.float32) assert out.shape == arr.shape return out return wrapped @@ -58,8 +62,13 @@ def apply(self, arr): class Range(BaseTransform): def __init__(self, xmin, ymin, xmax, ymax, mode='hard'): BaseTransform.__init__(self) - self.xmin, ymin = xmin, ymin - self.xmax, ymax = xmax, ymax + self.xmin, self.ymin = xmin, ymin + self.xmax, self.ymax = xmax, ymax + + self.xymin = np.array([[self.xmin, self.ymin]]) + self.xymax = np.array([[self.xmax, self.ymax]]) + self.xymax_minus_xymin = self.xymax - self.xymin + self.mode = mode def apply(self, arr): @@ -67,10 +76,14 @@ def apply(self, arr): xym = arr.min(axis=0) xyM = arr.max(axis=0) - xymin = np.array([[self.xmin, self.ymin]]) - xymax = np.array([[self.xmax, self.ymax]]) + # Handle min=max degenerate cases. + for i in range(arr.shape[1]): + if np.allclose(xym[i], xyM[i]): + arr[:, i] += .5 + xyM[i] += 1 - xymin + (xymax - xymin) * (arr - xym) / (xyM - xym) + return self.xymin + self.xymax_minus_xymin * \ + (arr - xym) / (xyM - xym) raise NotImplementedError() @@ -78,11 +91,13 @@ def apply(self, arr): class Clip(BaseTransform): def __init__(self, xmin, ymin, xmax, ymax): BaseTransform.__init__(self) - self.xmin, ymin = xmin, ymin - self.xmax, ymax = xmax, ymax + self.xmin, self.ymin = xmin, ymin + self.xmax, self.ymax = xmax, ymax + self.xymin = np.array([self.xmin, self.ymin]) + self.xymax = np.array([self.xmax, self.ymax]) def apply(self, arr): - pass + return np.clip(arr, self.xymin, self.xymax) class GPU(BaseTransform): From 49b6a5f99de58088e7affa489e710c3a2b980f2e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 08:06:12 +0200 Subject: [PATCH 0374/1059] Lint --- phy/plot/tests/test_transform.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index f5fbf1546..b11bd1353 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -7,12 +7,9 @@ # Imports #------------------------------------------------------------------------------ -from itertools import product - import numpy as np -from numpy.testing import assert_array_equal as ae -from ..transform import Translate, Scale, Range, Clip, GPU +from ..transform import Translate, Scale, Range, Clip #------------------------------------------------------------------------------ From 50bf7d86b56c7b0da270293ab7fc12061d2ce57e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 08:21:23 +0200 Subject: [PATCH 0375/1059] WIP: transforms --- phy/plot/tests/test_transform.py | 10 +++--- phy/plot/transform.py | 56 ++++++++++++++++++++++---------- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index b11bd1353..c1ffb1bc8 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -33,7 +33,7 @@ def _check(transform, array, expected): #------------------------------------------------------------------------------ def test_types(): - t = Translate(1, 2) + t = Translate([1, 2]) _check(t, [], []) for ab in [[3, 4], [3., 4.]]: @@ -43,17 +43,17 @@ def test_types(): def test_translate(): - t = Translate(1, 2) + t = Translate([1, 2]) _check(t, [3, 4], [[4, 6]]) def test_scale(): - t = Scale(-1, 2) + t = Scale([-1, 2]) _check(t, [3, 4], [[-3, 8]]) def test_range(): - t = Range(0, 1, 2, 3) + t = Range([0, 1], [2, 3]) # One element => move to the center of the window. _check(t, [-1, -1], [[1, 2]]) @@ -65,7 +65,7 @@ def test_range(): def test_clip(): - t = Clip(0, 1, 2, 3) + t = Clip([0, 1], [2, 3]) _check(t, [-1, -1], [[0, 1]]) _check(t, [3, 4], [[2, 3]]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 1794cc84c..f733394d5 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -7,7 +7,10 @@ # Imports #------------------------------------------------------------------------------ +from textwrap import dedent + import numpy as np +from six import string_types import logging @@ -42,32 +45,38 @@ def apply(self, arr): class Translate(BaseTransform): - def __init__(self, tx, ty): + def __init__(self, txy): BaseTransform.__init__(self) - self.tx, self.ty = tx, ty + self.txy = np.asarray(txy) def apply(self, arr): - return arr + np.array([[self.tx, self.ty]]) + return arr + self.txy + + def glsl(self, var): + return """{} + {}""".format(var, self.txy) class Scale(BaseTransform): - def __init__(self, sx, sy): + def __init__(self, sxy): BaseTransform.__init__(self) - self.sx, self.sy = sx, sy + self.sxy = np.asarray(sxy) def apply(self, arr): - return arr * np.array([[self.sx, self.sy]]) + return arr * self.sxy + + def glsl(self, var): + return """{} * {}""".format(var, self.sxy) class Range(BaseTransform): - def __init__(self, xmin, ymin, xmax, ymax, mode='hard'): + def __init__(self, xymin, xymax, mode='hard'): BaseTransform.__init__(self) - self.xmin, self.ymin = xmin, ymin - self.xmax, self.ymax = xmax, ymax + self.xymin = np.asarray(xymin) + self.xymax = np.asarray(xymax) - self.xymin = np.array([[self.xmin, self.ymin]]) - self.xymax = np.array([[self.xmax, self.ymax]]) - self.xymax_minus_xymin = self.xymax - self.xymin + # Only if the variables are numbers, not strings. + if not isinstance(xymin, string_types): + self.xymax_minus_xymin = self.xymax - self.xymin self.mode = mode @@ -87,18 +96,31 @@ def apply(self, arr): raise NotImplementedError() + def glsl(self, var): + return TODO + class Clip(BaseTransform): - def __init__(self, xmin, ymin, xmax, ymax): + def __init__(self, xymin, xymax): BaseTransform.__init__(self) - self.xmin, self.ymin = xmin, ymin - self.xmax, self.ymax = xmax, ymax - self.xymin = np.array([self.xmin, self.ymin]) - self.xymax = np.array([self.xmax, self.ymax]) + self.xymin = np.asarray(xymin) + self.xymax = np.asarray(xymax) def apply(self, arr): return np.clip(arr, self.xymin, self.xymax) + def glsl(self, var): + return dedent(""" + if (({var}.x < {xymin}.x) | + ({var}.y < {xymin}.y) | + ({var}.x > {xymax}.x) | + ({var}.y > {xymax}.y)) { + discard; + } + """).format(xymin=self.xymin, + xymax=self.xymax, + ) + class GPU(BaseTransform): pass From 01d23119c1b460e996a7f4b9e0226703af74e2c5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 12:15:56 +0200 Subject: [PATCH 0376/1059] WIP: transforms --- phy/plot/tests/test_transform.py | 15 ++++---- phy/plot/transform.py | 61 ++++++++++++-------------------- 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index c1ffb1bc8..2d0794521 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -53,19 +53,18 @@ def test_scale(): def test_range(): - t = Range([0, 1], [2, 3]) + t = Range([0, 0, 1, 1], [-1, -1, 1, 1]) - # One element => move to the center of the window. - _check(t, [-1, -1], [[1, 2]]) - _check(t, [3, 4], [[1, 2]]) - _check(t, [0, 1], [[1, 2]]) + _check(t, [-1, -1], [[-3, -3]]) + _check(t, [0, 0], [[-1, -1]]) + _check(t, [0.5, 0.5], [[0, 0]]) + _check(t, [1, 1], [[1, 1]]) - # Extend the range symmetrically. - _check(t, [[-1, 0], [3, 4]], [[0, 1], [2, 3]]) + _check(t, [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]]) def test_clip(): - t = Clip([0, 1], [2, 3]) + t = Clip([0, 1, 2, 3]) _check(t, [-1, -1], [[0, 1]]) _check(t, [3, 4], [[2, 3]]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index f733394d5..a82357236 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -10,7 +10,6 @@ from textwrap import dedent import numpy as np -from six import string_types import logging @@ -31,7 +30,8 @@ def wrapped(arr): assert arr.shape[1] == 2 out = f(arr) out = out.astype(np.float32) - assert out.shape == arr.shape + assert out.ndim == 2 + assert out.shape[0] == arr.shape[0] return out return wrapped @@ -69,58 +69,41 @@ def glsl(self, var): class Range(BaseTransform): - def __init__(self, xymin, xymax, mode='hard'): + def __init__(self, from_range, to_range): BaseTransform.__init__(self) - self.xymin = np.asarray(xymin) - self.xymax = np.asarray(xymax) - # Only if the variables are numbers, not strings. - if not isinstance(xymin, string_types): - self.xymax_minus_xymin = self.xymax - self.xymin - - self.mode = mode + self.f0 = np.array(from_range[:2]) + self.f1 = np.array(from_range[2:]) + self.t0 = np.array(to_range[:2]) + self.t1 = np.array(to_range[2:]) def apply(self, arr): - if self.mode == 'hard': - xym = arr.min(axis=0) - xyM = arr.max(axis=0) - - # Handle min=max degenerate cases. - for i in range(arr.shape[1]): - if np.allclose(xym[i], xyM[i]): - arr[:, i] += .5 - xyM[i] += 1 - - return self.xymin + self.xymax_minus_xymin * \ - (arr - xym) / (xyM - xym) - - raise NotImplementedError() + f0, f1, t0, t1 = self.f0, self.f1, self.t0, self.t1 + return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) def glsl(self, var): - return TODO + return dedent(""" + {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}) + """).format(f0=self.f0, f1=self.f1, t0=self.t0, t1=self.t1).strip() class Clip(BaseTransform): - def __init__(self, xymin, xymax): + def __init__(self, bounds): BaseTransform.__init__(self) - self.xymin = np.asarray(xymin) - self.xymax = np.asarray(xymax) + self.xymin = np.asarray(bounds[0:2]) + self.xymax = np.asarray(bounds[2:]) def apply(self, arr): return np.clip(arr, self.xymin, self.xymax) def glsl(self, var): return dedent(""" - if (({var}.x < {xymin}.x) | - ({var}.y < {xymin}.y) | - ({var}.x > {xymax}.x) | - ({var}.y > {xymax}.y)) { - discard; - } + if (({var}.x < {xymin}.x) | + ({var}.y < {xymin}.y) | + ({var}.x > {xymax}.x) | + ({var}.y > {xymax}.y)) { + discard; + } """).format(xymin=self.xymin, xymax=self.xymax, - ) - - -class GPU(BaseTransform): - pass + ).strip() From facc792aa143e36aad03eade169354dde6f01d7e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 12:35:41 +0200 Subject: [PATCH 0377/1059] WIP: test GLSL transform code generation --- phy/plot/tests/test_transform.py | 42 ++++++++++++++++++++++++--- phy/plot/transform.py | 49 +++++++++++++++++++++++--------- 2 files changed, 74 insertions(+), 17 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 2d0794521..4fec2a6c1 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -7,6 +7,8 @@ # Imports #------------------------------------------------------------------------------ +from textwrap import dedent + import numpy as np from ..transform import Translate, Scale, Range, Clip @@ -42,17 +44,17 @@ def test_types(): _check(t, arr, [[4, 6]]) -def test_translate(): +def test_translate_numpy(): t = Translate([1, 2]) _check(t, [3, 4], [[4, 6]]) -def test_scale(): +def test_scale_numpy(): t = Scale([-1, 2]) _check(t, [3, 4], [[-3, 8]]) -def test_range(): +def test_range_numpy(): t = Range([0, 0, 1, 1], [-1, -1, 1, 1]) _check(t, [-1, -1], [[-3, -3]]) @@ -63,10 +65,42 @@ def test_range(): _check(t, [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]]) -def test_clip(): +def test_clip_numpy(): t = Clip([0, 1, 2, 3]) _check(t, [-1, -1], [[0, 1]]) _check(t, [3, 4], [[2, 3]]) _check(t, [[-1, 0], [3, 4]], [[0, 1], [2, 3]]) + + +#------------------------------------------------------------------------------ +# Test GLSL transforms +#------------------------------------------------------------------------------ + +def test_translate_glsl(): + t = Translate('u_translate') + assert t.glsl('x') == 'x + u_translate' + + +def test_scale_glsl(): + t = Scale('u_scale') + assert t.glsl('x') == 'x * u_scale' + + +def test_range_glsl(): + t = Range(['u_from.xy', 'u_from.zw'], ['u_to.xy', 'u_to.zw']) + assert t.glsl('x') == ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' + '(u_from.zw - u_from.xy)') + + +def test_clip_glsl(): + t = Clip(['xymin', 'xymax']) + assert t.glsl('x') == dedent(""" + if ((x.x < xymin.x) | + (x.y < xymin.y) | + (x.x > xymax.x) | + (x.y > xymax.y)) { + discard; + } + """).strip() diff --git a/phy/plot/transform.py b/phy/plot/transform.py index a82357236..b7c145ad4 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -36,9 +36,18 @@ def wrapped(arr): return wrapped +def _wrap_glsl(f): + def wrapped(var): + out = f(var) + out = dedent(out).strip() + return out + return wrapped + + class BaseTransform(object): def __init__(self): self.apply = _wrap_apply(self.apply) + self.glsl = _wrap_glsl(self.glsl) def apply(self, arr): raise NotImplementedError() @@ -72,38 +81,52 @@ class Range(BaseTransform): def __init__(self, from_range, to_range): BaseTransform.__init__(self) - self.f0 = np.array(from_range[:2]) - self.f1 = np.array(from_range[2:]) - self.t0 = np.array(to_range[:2]) - self.t1 = np.array(to_range[2:]) + self.from_range = from_range + self.to_range = to_range + + self.f0 = np.asarray(from_range[:2]) + self.f1 = np.asarray(from_range[2:]) + self.t0 = np.asarray(to_range[:2]) + self.t1 = np.asarray(to_range[2:]) def apply(self, arr): f0, f1, t0, t1 = self.f0, self.f1, self.t0, self.t1 return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) def glsl(self, var): - return dedent(""" + return """ {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}) - """).format(f0=self.f0, f1=self.f1, t0=self.t0, t1=self.t1).strip() + """.format(var=var, + f0=self.from_range[0], f1=self.from_range[1], + t0=self.to_range[0], t1=self.to_range[1], + ) class Clip(BaseTransform): def __init__(self, bounds): BaseTransform.__init__(self) - self.xymin = np.asarray(bounds[0:2]) + self.bounds = bounds + + self.xymin = np.asarray(bounds[:2]) self.xymax = np.asarray(bounds[2:]) def apply(self, arr): return np.clip(arr, self.xymin, self.xymax) def glsl(self, var): - return dedent(""" + return """ if (({var}.x < {xymin}.x) | ({var}.y < {xymin}.y) | ({var}.x > {xymax}.x) | - ({var}.y > {xymax}.y)) { + ({var}.y > {xymax}.y)) {{ discard; - } - """).format(xymin=self.xymin, - xymax=self.xymax, - ).strip() + }} + """.format(xymin=self.bounds[0], + xymax=self.bounds[1], + var=var, + ) + + +class Subplot(Range): + # TODO + pass From 72d5730964267a9a3af1b669f4fd1adcb36e2fcc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 13:05:18 +0200 Subject: [PATCH 0378/1059] Remove no cover pragma --- phy/gui/tests/test_qt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index 673bad59b..5ebef620f 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -26,7 +26,7 @@ def test_require_qt_with_app(): @require_qt def f(): - pass # pragma: no cover + pass if not QApplication.instance(): with raises(RuntimeError): # pragma: no cover From 3232a5fe4996947ee0cd2a05675d2d39ed2f20bb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 13:58:51 +0200 Subject: [PATCH 0379/1059] WIP: refactoring transforms --- phy/plot/tests/test_transform.py | 65 ++++++++-------- phy/plot/transform.py | 127 ++++++++++++++++++------------- 2 files changed, 110 insertions(+), 82 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 4fec2a6c1..524438191 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -18,16 +18,19 @@ # Fixtures #------------------------------------------------------------------------------ -def _check(transform, array, expected): - transformed = transform.apply(array) +def _check(transform, array, expected, **kwargs): + transformed = transform.apply(array, **kwargs) if array is None or not len(array): assert transformed == array return array = np.atleast_2d(array) if isinstance(array, np.ndarray): - assert transformed.shape == array.shape + assert transformed.shape[1] == array.shape[1] assert transformed.dtype == np.float32 - assert np.all(transformed == expected) + if not len(transformed): + assert not len(expected) + else: + assert np.all(transformed == expected) #------------------------------------------------------------------------------ @@ -35,43 +38,44 @@ def _check(transform, array, expected): #------------------------------------------------------------------------------ def test_types(): - t = Translate([1, 2]) - _check(t, [], []) + t = Translate() + _check(t, [], [], translate=[1, 2]) for ab in [[3, 4], [3., 4.]]: for arr in [ab, [ab], np.array(ab), np.array([ab]), np.array([ab, ab, ab])]: - _check(t, arr, [[4, 6]]) + _check(t, arr, [[4, 6]], translate=[1, 2]) def test_translate_numpy(): - t = Translate([1, 2]) - _check(t, [3, 4], [[4, 6]]) + _check(Translate(), [3, 4], [[4, 6]], translate=[1, 2]) def test_scale_numpy(): - t = Scale([-1, 2]) - _check(t, [3, 4], [[-3, 8]]) + _check(Scale(), [3, 4], [[-3, 8]], scale=[-1, 2]) def test_range_numpy(): - t = Range([0, 0, 1, 1], [-1, -1, 1, 1]) + kwargs = dict(from_range=[0, 0, 1, 1], to_range=[-1, -1, 1, 1]) - _check(t, [-1, -1], [[-3, -3]]) - _check(t, [0, 0], [[-1, -1]]) - _check(t, [0.5, 0.5], [[0, 0]]) - _check(t, [1, 1], [[1, 1]]) + _check(Range(), [-1, -1], [[-3, -3]], **kwargs) + _check(Range(), [0, 0], [[-1, -1]], **kwargs) + _check(Range(), [0.5, 0.5], [[0, 0]], **kwargs) + _check(Range(), [1, 1], [[1, 1]], **kwargs) - _check(t, [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]]) + _check(Range(), [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]], **kwargs) def test_clip_numpy(): - t = Clip([0, 1, 2, 3]) + kwargs = dict(bounds=[0, 1, 2, 3]) - _check(t, [-1, -1], [[0, 1]]) - _check(t, [3, 4], [[2, 3]]) + _check(Clip(), [0, 1], [0, 1], **kwargs) + _check(Clip(), [1, 2], [1, 2], **kwargs) + _check(Clip(), [2, 3], [2, 3], **kwargs) - _check(t, [[-1, 0], [3, 4]], [[0, 1], [2, 3]]) + _check(Clip(), [-1, -1], [], **kwargs) + _check(Clip(), [3, 4], [], **kwargs) + _check(Clip(), [[-1, 0], [3, 4]], [], **kwargs) #------------------------------------------------------------------------------ @@ -79,24 +83,24 @@ def test_clip_numpy(): #------------------------------------------------------------------------------ def test_translate_glsl(): - t = Translate('u_translate') - assert t.glsl('x') == 'x + u_translate' + assert 'x = x + u_translate' in Translate().glsl('x', + translate='u_translate') def test_scale_glsl(): - t = Scale('u_scale') - assert t.glsl('x') == 'x * u_scale' + assert 'x = x * u_scale' in Scale().glsl('x', scale='u_scale') def test_range_glsl(): - t = Range(['u_from.xy', 'u_from.zw'], ['u_to.xy', 'u_to.zw']) - assert t.glsl('x') == ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' - '(u_from.zw - u_from.xy)') + expected = ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' + '(u_from.zw - u_from.xy)') + assert expected in Range().glsl('x', + from_range=['u_from.xy', 'u_from.zw'], + to_range=['u_to.xy', 'u_to.zw']) def test_clip_glsl(): - t = Clip(['xymin', 'xymax']) - assert t.glsl('x') == dedent(""" + expected = dedent(""" if ((x.x < xymin.x) | (x.y < xymin.y) | (x.x > xymax.x) | @@ -104,3 +108,4 @@ def test_clip_glsl(): discard; } """).strip() + assert expected in Clip().glsl('x', bounds=['xymin', 'xymax']) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index b7c145ad4..5af99e643 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -21,24 +21,24 @@ #------------------------------------------------------------------------------ def _wrap_apply(f): - def wrapped(arr): + def wrapped(arr, **kwargs): if arr is None or not len(arr): return arr arr = np.atleast_2d(arr) arr = arr.astype(np.float32) assert arr.ndim == 2 - assert arr.shape[1] == 2 - out = f(arr) + out = f(arr, **kwargs) out = out.astype(np.float32) + out = np.atleast_2d(out) assert out.ndim == 2 - assert out.shape[0] == arr.shape[0] + assert out.shape[1] == arr.shape[1] return out return wrapped def _wrap_glsl(f): - def wrapped(var): - out = f(var) + def wrapped(var, **kwargs): + out = f(var, **kwargs) out = dedent(out).strip() return out return wrapped @@ -52,68 +52,57 @@ def __init__(self): def apply(self, arr): raise NotImplementedError() + def glsl(self, var): + raise NotImplementedError() -class Translate(BaseTransform): - def __init__(self, txy): - BaseTransform.__init__(self) - self.txy = np.asarray(txy) - def apply(self, arr): - return arr + self.txy +class Translate(BaseTransform): + def apply(self, arr, translate=None): + return arr + np.asarray(translate) - def glsl(self, var): - return """{} + {}""".format(var, self.txy) + def glsl(self, var, translate=None): + return """{var} = {var} + {translate};""".format(var=var, + translate=translate) class Scale(BaseTransform): - def __init__(self, sxy): - BaseTransform.__init__(self) - self.sxy = np.asarray(sxy) - - def apply(self, arr): - return arr * self.sxy + def apply(self, arr, scale=None): + return arr * np.asarray(scale) - def glsl(self, var): - return """{} * {}""".format(var, self.sxy) + def glsl(self, var, scale=None): + return """{var} = {var} * {scale};""".format(var=var, scale=scale) class Range(BaseTransform): - def __init__(self, from_range, to_range): - BaseTransform.__init__(self) + def apply(self, arr, from_range=None, to_range=None): - self.from_range = from_range - self.to_range = to_range + f0 = np.asarray(from_range[:2]) + f1 = np.asarray(from_range[2:]) + t0 = np.asarray(to_range[:2]) + t1 = np.asarray(to_range[2:]) - self.f0 = np.asarray(from_range[:2]) - self.f1 = np.asarray(from_range[2:]) - self.t0 = np.asarray(to_range[:2]) - self.t1 = np.asarray(to_range[2:]) - - def apply(self, arr): - f0, f1, t0, t1 = self.f0, self.f1, self.t0, self.t1 return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) - def glsl(self, var): + def glsl(self, var, from_range=None, to_range=None): return """ - {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}) + {var} = {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}); """.format(var=var, - f0=self.from_range[0], f1=self.from_range[1], - t0=self.to_range[0], t1=self.to_range[1], + f0=from_range[0], f1=from_range[1], + t0=to_range[0], t1=to_range[1], ) class Clip(BaseTransform): - def __init__(self, bounds): - BaseTransform.__init__(self) - self.bounds = bounds - - self.xymin = np.asarray(bounds[:2]) - self.xymax = np.asarray(bounds[2:]) - - def apply(self, arr): - return np.clip(arr, self.xymin, self.xymax) - - def glsl(self, var): + def apply(self, arr, bounds=None): + xymin = np.asarray(bounds[:2]) + xymax = np.asarray(bounds[2:]) + index = ((arr[:, 0] >= xymin[0]) & + (arr[:, 1] >= xymin[1]) & + (arr[:, 0] <= xymax[0]) & + (arr[:, 1] <= xymax[1])) + return arr[index, ...] + + def glsl(self, var, bounds=None): return """ if (({var}.x < {xymin}.x) | ({var}.y < {xymin}.y) | @@ -121,12 +110,46 @@ def glsl(self, var): ({var}.y > {xymax}.y)) {{ discard; }} - """.format(xymin=self.bounds[0], - xymax=self.bounds[1], + """.format(xymin=bounds[0], + xymax=bounds[1], var=var, ) class Subplot(Range): - # TODO - pass + def apply(self, arr, shape=None, index=None): + i, j = index + n_rows, n_cols = shape + + i += 0.5 + j += 0.5 + + x = -1.0 + j * (2.0 / n_cols) + y = +1.0 - i * (2.0 / n_rows) + + width = 1.0 / (1.0 * n_cols) + height = 1.0 / (1.0 * n_rows) + + from_range = [-1, -1, 1, 1] + to_range = [x, y, x + width, y + height] + + return super(Subplot, self).apply(from_range, to_range) + + def glsl(self, var, shape=None, index=None): + n_rows, n_cols = shape + + width = 1.0 / (1.0 * n_cols) + height = 1.0 / (1.0 * n_rows) + + glsl = """ + float x = -1.0 + ({index}.y + .5) * (2.0 / {shape}.y); + float y = +1.0 - ({index}.x + .5) * (2.0 / {shape}.x); + + float width = 1. / (1.0 * n_rows); + float height = 1. / (1.0 * n_rows); + + {var} = vec2(x + {width} * {var}.x, + y + {height} * {var}.y); + """ + return glsl.format(index=index, shape=shape, var=var, + width=width, height=height) From 7d8b8674ae3ce82e1020dbc7208436b8e8e52984 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 14:24:02 +0200 Subject: [PATCH 0380/1059] WIP: subplot transform --- phy/plot/tests/test_transform.py | 22 ++++++++++++++++-- phy/plot/transform.py | 38 ++++++++++++++++---------------- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 524438191..89702b50e 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -11,7 +11,7 @@ import numpy as np -from ..transform import Translate, Scale, Range, Clip +from ..transform import Translate, Scale, Range, Clip, Subplot #------------------------------------------------------------------------------ @@ -30,7 +30,7 @@ def _check(transform, array, expected, **kwargs): if not len(transformed): assert not len(expected) else: - assert np.all(transformed == expected) + assert np.allclose(transformed, expected) #------------------------------------------------------------------------------ @@ -78,6 +78,18 @@ def test_clip_numpy(): _check(Clip(), [[-1, 0], [3, 4]], [], **kwargs) +def test_subplot_numpy(): + shape = (2, 3) + + _check(Subplot(), [-1, -1], [-1, +0], index=(0, 0), shape=shape) + _check(Subplot(), [+0, +0], [-2. / 3., .5], index=(0, 0), shape=shape) + + _check(Subplot(), [-1, -1], [-1, -1], index=(1, 0), shape=shape) + _check(Subplot(), [+1, +1], [-1. / 3, 0], index=(1, 0), shape=shape) + + _check(Subplot(), [0, 1], [0, 0], index=(1, 1), shape=shape) + + #------------------------------------------------------------------------------ # Test GLSL transforms #------------------------------------------------------------------------------ @@ -109,3 +121,9 @@ def test_clip_glsl(): } """).strip() assert expected in Clip().glsl('x', bounds=['xymin', 'xymax']) + + +def test_subplot_glsl(): + glsl = Subplot().glsl('x', shape='u_shape', index='a_index') + print(glsl) + assert 'x = ' in glsl diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 5af99e643..bba790d58 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -117,39 +117,39 @@ def glsl(self, var, bounds=None): class Subplot(Range): + """Assume that the from range is [-1, -1, 1, 1].""" def apply(self, arr, shape=None, index=None): i, j = index n_rows, n_cols = shape - - i += 0.5 - j += 0.5 + assert 0 <= i <= n_rows - 1 + assert 0 <= j <= n_cols - 1 x = -1.0 + j * (2.0 / n_cols) y = +1.0 - i * (2.0 / n_rows) - width = 1.0 / (1.0 * n_cols) - height = 1.0 / (1.0 * n_rows) + width = 2.0 / n_cols + height = 2.0 / n_rows + + # The origin (x, y) corresponds to the lower-left corner of the + # target box. + y -= height from_range = [-1, -1, 1, 1] to_range = [x, y, x + width, y + height] - return super(Subplot, self).apply(from_range, to_range) + return super(Subplot, self).apply(arr, + from_range=from_range, + to_range=to_range) def glsl(self, var, shape=None, index=None): - n_rows, n_cols = shape - - width = 1.0 / (1.0 * n_cols) - height = 1.0 / (1.0 * n_rows) - glsl = """ - float x = -1.0 + ({index}.y + .5) * (2.0 / {shape}.y); - float y = +1.0 - ({index}.x + .5) * (2.0 / {shape}.x); + float x = -1.0 + {index}.y * 2.0 / {shape}.y; + float y = +1.0 - {index}.x * 2.0 / {shape}.x; - float width = 1. / (1.0 * n_rows); - float height = 1. / (1.0 * n_rows); + float width = 2. / {shape}.y; + float height = 2. / {shape}.x; - {var} = vec2(x + {width} * {var}.x, - y + {height} * {var}.y); + {var} = vec2(x + width * {var}.x, + y + height * {var}.y); """ - return glsl.format(index=index, shape=shape, var=var, - width=width, height=height) + return glsl.format(index=index, shape=shape, var=var) From 40b7ff206dd059bd4ccdf205491700785a21de09 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 14:39:24 +0200 Subject: [PATCH 0381/1059] WIP --- phy/plot/tests/test_transform.py | 7 +++---- phy/plot/transform.py | 15 ++++++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 89702b50e..da1bcd74a 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -48,7 +48,7 @@ def test_types(): def test_translate_numpy(): - _check(Translate(), [3, 4], [[4, 6]], translate=[1, 2]) + _check(Translate(translate=[1, 2]), [3, 4], [[4, 6]]) def test_scale_numpy(): @@ -95,8 +95,8 @@ def test_subplot_numpy(): #------------------------------------------------------------------------------ def test_translate_glsl(): - assert 'x = x + u_translate' in Translate().glsl('x', - translate='u_translate') + t = Translate(translate='u_translate').glsl('x') + assert 'x = x + u_translate' in t def test_scale_glsl(): @@ -125,5 +125,4 @@ def test_clip_glsl(): def test_subplot_glsl(): glsl = Subplot().glsl('x', shape='u_shape', index='a_index') - print(glsl) assert 'x = ' in glsl diff --git a/phy/plot/transform.py b/phy/plot/transform.py index bba790d58..653c7c894 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -20,10 +20,12 @@ # Transforms #------------------------------------------------------------------------------ -def _wrap_apply(f): +def _wrap_apply(f, **kwargs_init): def wrapped(arr, **kwargs): if arr is None or not len(arr): return arr + # Method kwargs first, then we update with the constructor kwargs. + kwargs.update(kwargs_init) arr = np.atleast_2d(arr) arr = arr.astype(np.float32) assert arr.ndim == 2 @@ -36,8 +38,10 @@ def wrapped(arr, **kwargs): return wrapped -def _wrap_glsl(f): +def _wrap_glsl(f, **kwargs_init): def wrapped(var, **kwargs): + # Method kwargs first, then we update with the constructor kwargs. + kwargs.update(kwargs_init) out = f(var, **kwargs) out = dedent(out).strip() return out @@ -45,9 +49,10 @@ def wrapped(var, **kwargs): class BaseTransform(object): - def __init__(self): - self.apply = _wrap_apply(self.apply) - self.glsl = _wrap_glsl(self.glsl) + def __init__(self, **kwargs): + # Pass the constructor kwargs to the methods. + self.apply = _wrap_apply(self.apply, **kwargs) + self.glsl = _wrap_glsl(self.glsl, **kwargs) def apply(self, arr): raise NotImplementedError() From 2ce3155e109d24a064a6bb584d733098cbff2880 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 15:51:36 +0200 Subject: [PATCH 0382/1059] Add TransformChain --- phy/plot/tests/test_transform.py | 86 ++++++++++++++++++++++++++++---- phy/plot/transform.py | 78 ++++++++++++++++++++++++++--- 2 files changed, 149 insertions(+), 15 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index da1bcd74a..ea8626613 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -10,8 +10,12 @@ from textwrap import dedent import numpy as np +from numpy.testing import assert_equal as ae +from pytest import yield_fixture -from ..transform import Translate, Scale, Range, Clip, Subplot +from ..transform import (Translate, Scale, Range, Clip, Subplot, GPU, + TransformChain, + ) #------------------------------------------------------------------------------ @@ -47,15 +51,15 @@ def test_types(): _check(t, arr, [[4, 6]], translate=[1, 2]) -def test_translate_numpy(): +def test_translate_cpu(): _check(Translate(translate=[1, 2]), [3, 4], [[4, 6]]) -def test_scale_numpy(): +def test_scale_cpu(): _check(Scale(), [3, 4], [[-3, 8]], scale=[-1, 2]) -def test_range_numpy(): +def test_range_cpu(): kwargs = dict(from_range=[0, 0, 1, 1], to_range=[-1, -1, 1, 1]) _check(Range(), [-1, -1], [[-3, -3]], **kwargs) @@ -66,9 +70,11 @@ def test_range_numpy(): _check(Range(), [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]], **kwargs) -def test_clip_numpy(): +def test_clip_cpu(): kwargs = dict(bounds=[0, 1, 2, 3]) + _check(Clip(), [0, 0], [0, 0]) # Default bounds. + _check(Clip(), [0, 1], [0, 1], **kwargs) _check(Clip(), [1, 2], [1, 2], **kwargs) _check(Clip(), [2, 3], [2, 3], **kwargs) @@ -78,7 +84,7 @@ def test_clip_numpy(): _check(Clip(), [[-1, 0], [3, 4]], [], **kwargs) -def test_subplot_numpy(): +def test_subplot_cpu(): shape = (2, 3) _check(Subplot(), [-1, -1], [-1, +0], index=(0, 0), shape=shape) @@ -104,11 +110,13 @@ def test_scale_glsl(): def test_range_glsl(): + + assert Range(from_range=[-1, -1, 1, 1]).glsl('x') + expected = ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' '(u_from.zw - u_from.xy)') - assert expected in Range().glsl('x', - from_range=['u_from.xy', 'u_from.zw'], - to_range=['u_to.xy', 'u_to.zw']) + r = Range(to_range=['u_to.xy', 'u_to.zw']) + assert expected in r.glsl('x', from_range=['u_from.xy', 'u_from.zw']) def test_clip_glsl(): @@ -126,3 +134,63 @@ def test_clip_glsl(): def test_subplot_glsl(): glsl = Subplot().glsl('x', shape='u_shape', index='a_index') assert 'x = ' in glsl + + +#------------------------------------------------------------------------------ +# Test transform chain +#------------------------------------------------------------------------------ + +@yield_fixture +def array(): + yield np.array([[-1, 0], [1, 2]]) + + +def test_transform_chain_empty(array): + t = TransformChain([]) + + assert t.cpu_transforms == [] + assert t.gpu_transforms == [] + assert t.get('GPU') is None + + ae(t.apply(array), array) + assert t.glsl('position') == '' + + +def test_transform_chain_one(array): + translate = Translate(translate=[1, 2]) + t = TransformChain([translate]) + + assert t.cpu_transforms == [translate] + assert t.gpu_transforms == [] + + ae(t.apply(array), [[0, 2], [2, 4]]) + + +def test_transform_chain_two(array): + translate = Translate(translate=[1, 2]) + scale = Scale(scale=[.5, .5]) + t = TransformChain([translate, scale]) + + assert t.cpu_transforms == [translate, scale] + assert t.gpu_transforms == [] + + assert isinstance(t.get('Translate'), Translate) + assert t.get('Unknown') is None + + ae(t.apply(array), [[0, 1], [1, 2]]) + + +def test_transform_chain_complete(array): + t = TransformChain([Scale(scale=.5), + Scale(scale=2.)]) + t.add([Range(from_range=[-3, -3, 1, 1]), + GPU(), + Clip(), + Subplot(shape='u_shape', index='a_box'), + ]) + + assert len(t.cpu_transforms) == 3 + assert len(t.gpu_transforms) == 2 + + ae(t.apply(array), [[0, .5], [1, 1.5]]) + assert 'position = ' in t.glsl('position') diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 653c7c894..f2170a6f0 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -80,6 +80,8 @@ def glsl(self, var, scale=None): class Range(BaseTransform): def apply(self, arr, from_range=None, to_range=None): + if to_range is None: + to_range = [-1, -1, 1, 1] f0 = np.asarray(from_range[:2]) f1 = np.asarray(from_range[2:]) @@ -89,6 +91,9 @@ def apply(self, arr, from_range=None, to_range=None): return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) def glsl(self, var, from_range=None, to_range=None): + if to_range is None: + to_range = [-1, -1, 1, 1] + return """ {var} = {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}); """.format(var=var, @@ -99,6 +104,9 @@ def glsl(self, var, from_range=None, to_range=None): class Clip(BaseTransform): def apply(self, arr, bounds=None): + if bounds is None: + bounds = [-1, -1, 1, 1] + xymin = np.asarray(bounds[:2]) xymax = np.asarray(bounds[2:]) index = ((arr[:, 0] >= xymin[0]) & @@ -108,6 +116,9 @@ def apply(self, arr, bounds=None): return arr[index, ...] def glsl(self, var, bounds=None): + if bounds is None: + bounds = 'vec2(-1, -1)', 'vec2(1, 1)' + return """ if (({var}.x < {xymin}.x) | ({var}.y < {xymin}.y) | @@ -148,13 +159,68 @@ def apply(self, arr, shape=None, index=None): def glsl(self, var, shape=None, index=None): glsl = """ - float x = -1.0 + {index}.y * 2.0 / {shape}.y; - float y = +1.0 - {index}.x * 2.0 / {shape}.x; + float subplot_x = -1.0 + {index}.y * 2.0 / {shape}.y; + float subplot_y = +1.0 - {index}.x * 2.0 / {shape}.x; - float width = 2. / {shape}.y; - float height = 2. / {shape}.x; + float subplot_width = 2. / {shape}.y; + float subplot_height = 2. / {shape}.x; - {var} = vec2(x + width * {var}.x, - y + height * {var}.y); + {var} = vec2(subplot_x + subplot_width * {var}.x, + subplot_y + subplot_height * {var}.y); """ return glsl.format(index=index, shape=shape, var=var) + + +#------------------------------------------------------------------------------ +# Transform chains +#------------------------------------------------------------------------------ + +class GPU(object): + """Used to specify that the next transforms in the chain happen on + the GPU.""" + pass + + +class TransformChain(object): + """A linear sequence of transforms that happen on the CPU and GPU.""" + def __init__(self, transforms): + self.transforms = transforms + + def _index_of_gpu(self): + classes = [t.__class__.__name__ for t in self.transforms] + return classes.index('GPU') if 'GPU' in classes else None + + @property + def cpu_transforms(self): + """All transforms until `GPU()`.""" + i = self._index_of_gpu() + return self.transforms[:i] if i is not None else self.transforms + + @property + def gpu_transforms(self): + """All transforms after `GPU()`.""" + i = self._index_of_gpu() + return self.transforms[i + 1:] if i is not None else [] + + def add(self, transforms): + """Add some transforms.""" + self.transforms.extend(transforms) + + def get(self, class_name): + """Get a transform in the chain from its name.""" + for transform in self.transforms: + if transform.__class__.__name__ == class_name: + return transform + + def apply(self, arr): + """Apply all CPU transforms on an array.""" + for t in self.cpu_transforms: + arr = t.apply(arr) + return arr + + def glsl(self, var): + """Generate the GLSL code for the GPU transform chain.""" + glsl = "" + for t in self.gpu_transforms: + glsl += t.glsl(var) + '\n' + return glsl From 6cbc26fb6f114e3fdfe3ab035f16346fd7f585f5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 17:02:16 +0200 Subject: [PATCH 0383/1059] WIP: TransformChain.insert_glsl() --- phy/plot/tests/test_transform.py | 19 ++++++++- phy/plot/transform.py | 69 +++++++++++++++++++++++++++++--- 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index ea8626613..64997c87f 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -153,7 +153,6 @@ def test_transform_chain_empty(array): assert t.get('GPU') is None ae(t.apply(array), array) - assert t.glsl('position') == '' def test_transform_chain_one(array): @@ -193,4 +192,20 @@ def test_transform_chain_complete(array): assert len(t.gpu_transforms) == 2 ae(t.apply(array), [[0, .5], [1, 1.5]]) - assert 'position = ' in t.glsl('position') + + vs = dedent(""" + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + } + """).strip() + + fs = dedent(""" + void main() { + gl_FragColor = vec4(1., 1., 1., 1.); + } + """).strip() + vs, fs = t.insert_glsl(vs, fs) + assert 'a_box' in vs + assert 'v_a_position = a_position;' in vs + assert 'discard' in fs diff --git a/phy/plot/transform.py b/phy/plot/transform.py index f2170a6f0..a3ffa96e5 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -8,6 +8,7 @@ #------------------------------------------------------------------------------ from textwrap import dedent +import re import numpy as np @@ -17,7 +18,7 @@ #------------------------------------------------------------------------------ -# Transforms +# Utils #------------------------------------------------------------------------------ def _wrap_apply(f, **kwargs_init): @@ -48,6 +49,14 @@ def wrapped(var, **kwargs): return wrapped +def indent(text): + return '\n'.join(' ' + l.strip() for l in text.splitlines()) + + +#------------------------------------------------------------------------------ +# Transforms +#------------------------------------------------------------------------------ + class BaseTransform(object): def __init__(self, **kwargs): # Pass the constructor kwargs to the methods. @@ -66,6 +75,7 @@ def apply(self, arr, translate=None): return arr + np.asarray(translate) def glsl(self, var, translate=None): + assert var return """{var} = {var} + {translate};""".format(var=var, translate=translate) @@ -75,6 +85,7 @@ def apply(self, arr, scale=None): return arr * np.asarray(scale) def glsl(self, var, scale=None): + assert var return """{var} = {var} * {scale};""".format(var=var, scale=scale) @@ -91,6 +102,7 @@ def apply(self, arr, from_range=None, to_range=None): return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) def glsl(self, var, from_range=None, to_range=None): + assert var if to_range is None: to_range = [-1, -1, 1, 1] @@ -116,6 +128,7 @@ def apply(self, arr, bounds=None): return arr[index, ...] def glsl(self, var, bounds=None): + assert var if bounds is None: bounds = 'vec2(-1, -1)', 'vec2(1, 1)' @@ -158,6 +171,7 @@ def apply(self, arr, shape=None, index=None): to_range=to_range) def glsl(self, var, shape=None, index=None): + assert var glsl = """ float subplot_x = -1.0 + {index}.y * 2.0 / {shape}.y; float subplot_y = +1.0 - {index}.x * 2.0 / {shape}.x; @@ -218,9 +232,52 @@ def apply(self, arr): arr = t.apply(arr) return arr - def glsl(self, var): - """Generate the GLSL code for the GPU transform chain.""" - glsl = "" + def insert_glsl(self, vertex, fragment): + """Generate the GLSL code of the transform chain.""" + + # Find the place where to insert the GLSL snippet. + # This is "gl_Position = transform(data_var_name);" where + # data_var_name is typically an attribute. + vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') + r = vs_regex.search(vertex) + assert r, ("The vertex shader must contain the transform placeholder.") + logger.debug("Found transform placeholder in vertex code: `%s`", + r.group(0)) + + # Find the GLSL variable with the data (should be a `vec2`). + var = r.group(1) + assert var and var in vertex + + # Generate the snippet to insert in the shaders. + vs_insert = "" for t in self.gpu_transforms: - glsl += t.glsl(var) + '\n' - return glsl + if isinstance(t, Clip): + continue + vs_insert += t.glsl(var) + '\n' + vs_insert += 'gl_Position = {};\n'.format(var) + + # Clipping. + clip = self.get('Clip') + if clip: + # Varying name. + fvar = 'v_{}'.format(var) + glsl_clip = clip.glsl(fvar) + + # Prepare the fragment regex. + fs_regex = re.compile(r'(void main\(\)\s*\{)') + fs_insert = '\\1\n{}'.format(glsl_clip) + + # Add the varying declaration for clipping. + varying_decl = 'varying vec2 {};\n'.format(fvar) + vertex = varying_decl + vertex + fragment = varying_decl + fragment + + # Make the replacement in the fragment shader for clipping. + fragment = fs_regex.sub(indent(fs_insert), fragment) + # Set the varying value in the vertex shader. + vs_insert += '{} = {};\n'.format(fvar, var) + + # Insert the GLSL snippet of the transform chain in the vertex shader. + vertex = vs_regex.sub(indent(vs_insert), vertex) + + return vertex, fragment From c7da298f84b13501080bd0fc2c3f465326fbc968 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 20 Oct 2015 17:19:46 +0200 Subject: [PATCH 0384/1059] WIP: canvas and transforms --- phy/plot/base.py | 42 +++++++++++++++++++++---------------- phy/plot/tests/test_base.py | 2 +- phy/plot/transform.py | 4 ++-- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 11ed937bb..b7a3a127f 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -12,7 +12,8 @@ from vispy import gloo from vispy.app import Canvas -from .utils import _create_program +# from .transform import TransformChain +from .utils import _load_shader logger = logging.getLogger(__name__) @@ -21,19 +22,14 @@ # Base spike visual #------------------------------------------------------------------------------ -class BaseCanvas(Canvas): - def __init__(self, *args, **kwargs): - super(BaseCanvas, self).__init__(*args, **kwargs) - self._visuals = [] +def _build_program(name, transform_chain): + vertex = _load_shader(name + '.vert') + fragment = _load_shader(name + '.frag') - def add_visual(self, visual): - self._visuals.append(visual) - visual.attach(self) + vertex, fragment = transform_chain.insert_glsl(vertex, fragment) - def on_draw(self, e): - gloo.clear() - for visual in self._visuals: - visual.draw() + program = gloo.Program(vertex, fragment) + return program class BaseVisual(object): @@ -47,8 +43,7 @@ def __init__(self): self.size = 1, 1 self._canvas = None self._do_show = False - - self.program = _create_program(self._shader_name) + self.transforms = [] def show(self): self._do_show = True @@ -60,10 +55,6 @@ def set_data(self): """Set the data for the visual.""" pass - def set_transforms(self): - """Set the list of transforms for the visual.""" - pass - def attach(self, canvas): """Attach some events.""" self._canvas = canvas @@ -91,3 +82,18 @@ def update(self): """Trigger a draw event in the canvas from the visual.""" if self._canvas: self._canvas.update() + + +class BaseCanvas(Canvas): + def __init__(self, *args, **kwargs): + super(BaseCanvas, self).__init__(*args, **kwargs) + self._visuals = [] + + def add_visual(self, visual): + self._visuals.append(visual) + visual.attach(self) + + def on_draw(self, e): + gloo.clear() + for visual in self._visuals: + visual.draw() diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 3acee363b..c083e8713 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -14,7 +14,7 @@ # Test base #------------------------------------------------------------------------------ -def test_base_visual(qtbot, canvas): +def _test_base_visual(qtbot, canvas): class TestVisual(BaseVisual): _shader_name = 'box' diff --git a/phy/plot/transform.py b/phy/plot/transform.py index a3ffa96e5..b0c1301cb 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -197,8 +197,8 @@ class GPU(object): class TransformChain(object): """A linear sequence of transforms that happen on the CPU and GPU.""" - def __init__(self, transforms): - self.transforms = transforms + def __init__(self, transforms=None): + self.transforms = transforms or [] def _index_of_gpu(self): classes = [t.__class__.__name__ for t in self.transforms] From 6482420d805345e5dd02299dcbd46a9a02b79e0a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 17:38:38 +0200 Subject: [PATCH 0385/1059] WIP: interact --- phy/plot/base.py | 115 +++++++++++++++++++++++++++++++----- phy/plot/glsl/test.frag | 3 + phy/plot/glsl/test.vert | 4 ++ phy/plot/tests/test_base.py | 13 ++-- phy/plot/transform.py | 7 ++- 5 files changed, 119 insertions(+), 23 deletions(-) create mode 100644 phy/plot/glsl/test.frag create mode 100644 phy/plot/glsl/test.vert diff --git a/phy/plot/base.py b/phy/plot/base.py index b7a3a127f..062f6a7cb 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -12,8 +12,9 @@ from vispy import gloo from vispy.app import Canvas -# from .transform import TransformChain +from .transform import TransformChain from .utils import _load_shader +from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -22,27 +23,29 @@ # Base spike visual #------------------------------------------------------------------------------ -def _build_program(name, transform_chain): +def _build_program(name, transform_chain=None): vertex = _load_shader(name + '.vert') fragment = _load_shader(name + '.frag') - vertex, fragment = transform_chain.insert_glsl(vertex, fragment) + if transform_chain: + vertex, fragment = transform_chain.insert_glsl(vertex, fragment) program = gloo.Program(vertex, fragment) return program class BaseVisual(object): - _gl_primitive_type = None - _shader_name = None + gl_primitive_type = None + shader_name = None def __init__(self): - assert self._gl_primitive_type - assert self._shader_name + assert self.gl_primitive_type + assert self.shader_name self.size = 1, 1 self._canvas = None self._do_show = False + self.program = None self.transforms = [] def show(self): @@ -55,10 +58,27 @@ def set_data(self): """Set the data for the visual.""" pass - def attach(self, canvas): + def attach(self, canvas, interact=None): """Attach some events.""" + logger.debug("Attach `%s` with interact `%s` to canvas.", + self.__class__.__name__, interact or '') self._canvas = canvas + # Used when the canvas requests all attached visuals + # for the given interact. + @canvas.connect_ + def on_get_visual_for_interact(interact_req): + if interact_req == interact: + return self + + # NOTE: this is connect_ and not connect because we're using + # phy's event system, not VisPy's. The reason is that the order + # of the callbacks is not kept by VisPy, whereas we need the order + # to draw visuals in the order they are attached. + @canvas.connect_ + def on_draw(): + self.draw() + @canvas.connect def on_resize(event): """Resize the OpenGL context.""" @@ -73,10 +93,21 @@ def on_mouse_move(event): def on_mouse_move(self, e): pass + def build_program(self, transforms=None): + transforms = transforms or [] + assert self.program is None, "The program has already been built." + + # Build the transform chain using the visuals transforms first, + # and the interact's transforms then. + transform_chain = TransformChain(self.transforms + transforms) + + logger.debug("Build the program of `%s`.", self.__class__.__name__) + self.program = _build_program(self.shader_name, transform_chain) + def draw(self): """Draw the waveforms.""" - if self._do_show: - self.program.draw(self._gl_primitive_type) + if self._do_show and self.program: + self.program.draw(self.gl_primitive_type) def update(self): """Trigger a draw event in the canvas from the visual.""" @@ -87,13 +118,65 @@ def update(self): class BaseCanvas(Canvas): def __init__(self, *args, **kwargs): super(BaseCanvas, self).__init__(*args, **kwargs) - self._visuals = [] + self._events = EventEmitter() + + def connect_(self, *args, **kwargs): + return self._events.connect(*args, **kwargs) - def add_visual(self, visual): - self._visuals.append(visual) - visual.attach(self) + def emit_(self, *args, **kwargs): + return self._events.emit(*args, **kwargs) def on_draw(self, e): gloo.clear() - for visual in self._visuals: - visual.draw() + self._events.emit('draw') + + +class BaseInteract(object): + """Implement interactions for a set of attached visuals in a canvas. + + Derived classes must: + + * Define a unique `self.name` + * Define a list of transforms + + """ + name = None + transforms = None + + def __init__(self): + self._canvas = None + + def attach(self, canvas): + """Attach the interact to a canvas.""" + self._canvas = canvas + + @canvas.connect_ + def on_draw(): + # The programs are built only once per visual. + self.build_programs() + + canvas.connect(self.on_mouse_move) + canvas.connect(self.on_key_press) + + def iter_attached_visuals(self): + """Yield all visuals attached to that interact in the canvas.""" + for visual in self._canvas.emit('get_visual_for_interact', self.name): + if visual: + yield visual + + def build_programs(self): + """Build the programs of all attached visuals. + + The transform chain of the interact must have been built before. + + """ + for visual in self.iter_attached_visuals(): + if not visual.program: + assert self.transforms + visual.build_program(self.transforms) + + def on_mouse_move(self, event): + pass + + def on_key_press(self, event): + pass diff --git a/phy/plot/glsl/test.frag b/phy/plot/glsl/test.frag new file mode 100644 index 000000000..194449a1d --- /dev/null +++ b/phy/plot/glsl/test.frag @@ -0,0 +1,3 @@ +void main() { + gl_FragColor = vec4(1, 1, 1, 1); +} diff --git a/phy/plot/glsl/test.vert b/phy/plot/glsl/test.vert new file mode 100644 index 000000000..5e46f738b --- /dev/null +++ b/phy/plot/glsl/test.vert @@ -0,0 +1,4 @@ +attribute vec2 a_position; +void main() { + gl_Position = vec4(a_position.xy, 0, 1); +} diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index c083e8713..34f63e5cd 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -14,25 +14,26 @@ # Test base #------------------------------------------------------------------------------ -def _test_base_visual(qtbot, canvas): +def test_base_visual(qtbot, canvas): class TestVisual(BaseVisual): - _shader_name = 'box' - _gl_primitive_type = 'lines' + shader_name = 'test' + gl_primitive_type = 'lines' def set_data(self): - self.program['a_position'] = [[-1, 0, 0], [1, 0, 0]] - self.program['n_rows'] = 1 + self.build_program() + self.program['a_position'] = [[-1, 0], [1, 0]] self.show() v = TestVisual() v.set_data() - canvas.add_visual(v) + v.attach(canvas) canvas.show() v.hide() v.show() qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() # Simulate a mouse move. canvas.events.mouse_move(delta=(1., 0.)) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index b0c1301cb..2cc1ef601 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -240,7 +240,12 @@ def insert_glsl(self, vertex, fragment): # data_var_name is typically an attribute. vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') r = vs_regex.search(vertex) - assert r, ("The vertex shader must contain the transform placeholder.") + if not r: + logger.debug("The vertex shader doesn't contain the transform " + "placeholder: skipping the transform chain " + "GLSL insertion.") + return vertex, fragment + assert r logger.debug("Found transform placeholder in vertex code: `%s`", r.group(0)) From d99ee537ef55508d180058557f7ba4bf152b2438 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 17:59:42 +0200 Subject: [PATCH 0386/1059] WIP: update BaseVisual --- phy/plot/base.py | 18 ++++++++++++++---- phy/plot/tests/test_base.py | 17 ++++++++++++++--- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 062f6a7cb..124f5bbae 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -44,7 +44,9 @@ def __init__(self): self.size = 1, 1 self._canvas = None - self._do_show = False + # Not taken into account when the program has not been built. + self._do_show = True + self.data = {} # Data to set on the program when possible. self.program = None self.transforms = [] @@ -105,8 +107,15 @@ def build_program(self, transforms=None): self.program = _build_program(self.shader_name, transform_chain) def draw(self): - """Draw the waveforms.""" + """Draw the visual.""" + # Skip the drawing if the program hasn't been built yet. + # The program is built by the attached interact. if self._do_show and self.program: + # Upload the data if necessary. + for name, value in self.data.items(): + self.program[name] = value + self.data.clear() + # Finally, draw the program. self.program.draw(self.gl_primitive_type) def update(self): @@ -152,7 +161,8 @@ def attach(self, canvas): @canvas.connect_ def on_draw(): - # The programs are built only once per visual. + # Build the programs of all attached visuals. + # Programs that are already built are skipped. self.build_programs() canvas.connect(self.on_mouse_move) @@ -160,7 +170,7 @@ def on_draw(): def iter_attached_visuals(self): """Yield all visuals attached to that interact in the canvas.""" - for visual in self._canvas.emit('get_visual_for_interact', self.name): + for visual in self._canvas.emit_('get_visual_for_interact', self.name): if visual: yield visual diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 34f63e5cd..3ef4b295d 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -21,13 +21,13 @@ class TestVisual(BaseVisual): gl_primitive_type = 'lines' def set_data(self): - self.build_program() - self.program['a_position'] = [[-1, 0], [1, 0]] - self.show() + self.data['a_position'] = [[-1, 0], [1, 0]] v = TestVisual() v.set_data() + # We need to build the program explicitly when there is no interact. v.attach(canvas) + v.build_program() canvas.show() v.hide() @@ -39,3 +39,14 @@ def set_data(self): canvas.events.mouse_move(delta=(1., 0.)) v.update() + + +def test_base_interact(qtbot, canvas): + + class TestVisual(BaseVisual): + shader_name = 'test' + gl_primitive_type = 'lines' + + def set_data(self): + self.program['a_position'] = [[-1, 0], [1, 0]] + self.show() From 6b3219030468e279665abc681af1a1689bc50889 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 18:22:20 +0200 Subject: [PATCH 0387/1059] WIP --- phy/plot/base.py | 43 ++++++++++++++++++++++++++++++------- phy/plot/tests/test_base.py | 23 +++++++++++++++++--- phy/plot/transform.py | 2 ++ 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 124f5bbae..dd258a49e 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -49,6 +49,10 @@ def __init__(self): self.data = {} # Data to set on the program when possible. self.program = None self.transforms = [] + # Combine the visual's transforms and the interact transforms. + # The interact triggers the creation of the transform chain in + # self.build_program(). + self.transform_chain = None def show(self): self._do_show = True @@ -60,7 +64,7 @@ def set_data(self): """Set the data for the visual.""" pass - def attach(self, canvas, interact=None): + def attach(self, canvas, interact='base'): """Attach some events.""" logger.debug("Attach `%s` with interact `%s` to canvas.", self.__class__.__name__, interact or '') @@ -96,25 +100,48 @@ def on_mouse_move(self, e): pass def build_program(self, transforms=None): + """Create the gloo program by specifying the transforms + given by the optionally-attached interact. + + This function also uploads all variables set in `self.data` in + `self.set_data()`. + + This function is called by the interact's `build_programs()` method + during the draw event (only effective the first time necessary). + + """ transforms = transforms or [] assert self.program is None, "The program has already been built." # Build the transform chain using the visuals transforms first, # and the interact's transforms then. - transform_chain = TransformChain(self.transforms + transforms) + self.transform_chain = TransformChain(self.transforms + transforms) logger.debug("Build the program of `%s`.", self.__class__.__name__) - self.program = _build_program(self.shader_name, transform_chain) + self.program = _build_program(self.shader_name, self.transform_chain) + + # Get the name of the variable that needs to be transformed. + # This variable (typically a_position) comes from the vertex shader + # which contains the string `gl_Position = transform(the_name);`. + var = self.transform_chain.transformed_var_name + if not var: + logger.debug("No transformed variable has been found.") + + # Upload the data if necessary. + logger.debug("Upload program objects %s.", + ', '.join(self.data.keys())) + for name, value in self.data.items(): + # Normalize the value that needs to be transformed. + if name == var: + value = self.transform_chain.apply(value) + self.program[name] = value + self.data.clear() def draw(self): """Draw the visual.""" # Skip the drawing if the program hasn't been built yet. # The program is built by the attached interact. if self._do_show and self.program: - # Upload the data if necessary. - for name, value in self.data.items(): - self.program[name] = value - self.data.clear() # Finally, draw the program. self.program.draw(self.gl_primitive_type) @@ -149,7 +176,7 @@ class BaseInteract(object): * Define a list of transforms """ - name = None + name = 'base' transforms = None def __init__(self): diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 3ef4b295d..da8880551 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -7,7 +7,8 @@ # Imports #------------------------------------------------------------------------------ -from ..base import BaseVisual +from ..base import BaseVisual, BaseInteract +from ..transform import Scale #------------------------------------------------------------------------------ @@ -47,6 +48,22 @@ class TestVisual(BaseVisual): shader_name = 'test' gl_primitive_type = 'lines' + def __init__(self): + super(TestVisual, self).__init__() + self.set_data() + def set_data(self): - self.program['a_position'] = [[-1, 0], [1, 0]] - self.show() + self.data['a_position'] = [[-1, 0], [1, 0]] + self.transforms = [Scale((.5, 1))] + + v = TestVisual() + v.attach(canvas) + + interact = BaseInteract() + interact.attach(canvas) + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() + + v.update() diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 2cc1ef601..c0320f8b7 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -198,6 +198,7 @@ class GPU(object): class TransformChain(object): """A linear sequence of transforms that happen on the CPU and GPU.""" def __init__(self, transforms=None): + self.transformed_var_name = None self.transforms = transforms or [] def _index_of_gpu(self): @@ -251,6 +252,7 @@ def insert_glsl(self, vertex, fragment): # Find the GLSL variable with the data (should be a `vec2`). var = r.group(1) + self.transformed_var_name = var assert var and var in vertex # Generate the snippet to insert in the shaders. From 706137e2afb0120ba7436ee671a9df0fa4840a94 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 18:46:23 +0200 Subject: [PATCH 0388/1059] WIP: interact and transform tests --- phy/plot/base.py | 27 +++++++++----------- phy/plot/tests/test_base.py | 49 +++++++++++++++++++++++++++++++++---- phy/plot/transform.py | 2 +- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index dd258a49e..20ef15808 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -23,24 +23,19 @@ # Base spike visual #------------------------------------------------------------------------------ -def _build_program(name, transform_chain=None): - vertex = _load_shader(name + '.vert') - fragment = _load_shader(name + '.frag') - - if transform_chain: - vertex, fragment = transform_chain.insert_glsl(vertex, fragment) - - program = gloo.Program(vertex, fragment) - return program - - class BaseVisual(object): gl_primitive_type = None - shader_name = None + vertex = None + fragment = None + shader_name = None # Use this to load shaders from the glsl/ library. def __init__(self): + if self.shader_name: + self.vertex = _load_shader(self.shader_name + '.vert') + self.fragment = _load_shader(self.shader_name + '.frag') + assert self.vertex + assert self.fragment assert self.gl_primitive_type - assert self.shader_name self.size = 1, 1 self._canvas = None @@ -118,7 +113,10 @@ def build_program(self, transforms=None): self.transform_chain = TransformChain(self.transforms + transforms) logger.debug("Build the program of `%s`.", self.__class__.__name__) - self.program = _build_program(self.shader_name, self.transform_chain) + if self.transform_chain: + self.vertex, self.fragment = self.transform_chain.insert_glsl( + self.vertex, self.fragment) + self.program = gloo.Program(self.vertex, self.fragment) # Get the name of the variable that needs to be transformed. # This variable (typically a_position) comes from the vertex shader @@ -209,7 +207,6 @@ def build_programs(self): """ for visual in self.iter_attached_visuals(): if not visual.program: - assert self.transforms visual.build_program(self.transforms) def on_mouse_move(self, event): diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index da8880551..04597d15c 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -15,10 +15,41 @@ # Test base #------------------------------------------------------------------------------ +def test_visual_shader_name(qtbot, canvas): + + class TestVisual(BaseVisual): + shader_name = 'box' + gl_primitive_type = 'lines' + + def set_data(self): + self.data['a_position'] = [[-1, 0, 0], [1, 0, 0]] + self.data['n_rows'] = 1 + + v = TestVisual() + v.set_data() + # We need to build the program explicitly when there is no interact. + v.attach(canvas) + v.build_program() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() + + def test_base_visual(qtbot, canvas): class TestVisual(BaseVisual): - shader_name = 'test' + vertex = """ + attribute vec2 a_position; + void main() { + gl_Position = vec4(a_position.xy, 0, 1); + } + """ + fragment = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ gl_primitive_type = 'lines' def set_data(self): @@ -45,7 +76,17 @@ def set_data(self): def test_base_interact(qtbot, canvas): class TestVisual(BaseVisual): - shader_name = 'test' + vertex = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + } + """ + fragment = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ gl_primitive_type = 'lines' def __init__(self): @@ -54,7 +95,7 @@ def __init__(self): def set_data(self): self.data['a_position'] = [[-1, 0], [1, 0]] - self.transforms = [Scale((.5, 1))] + self.transforms = [Scale(scale=(.5, 1))] v = TestVisual() v.attach(canvas) @@ -65,5 +106,3 @@ def set_data(self): canvas.show() qtbot.waitForWindowShown(canvas.native) # qtbot.stop() - - v.update() diff --git a/phy/plot/transform.py b/phy/plot/transform.py index c0320f8b7..404844f59 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -261,7 +261,7 @@ def insert_glsl(self, vertex, fragment): if isinstance(t, Clip): continue vs_insert += t.glsl(var) + '\n' - vs_insert += 'gl_Position = {};\n'.format(var) + vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(var) # Clipping. clip = self.get('Clip') From 1c44d2cd98176d5b70af22f9b7d164d3c0060f19 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 19:02:04 +0200 Subject: [PATCH 0389/1059] Add comments --- phy/plot/base.py | 55 +++++++++++++++++++++++++++++++++---- phy/plot/tests/test_base.py | 8 ++++-- 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 20ef15808..43e63912a 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -24,6 +24,26 @@ #------------------------------------------------------------------------------ class BaseVisual(object): + """A Visual represents one object (or homogeneous set of objects). + + It is rendered with a single pass of a single gloo program with a single + type of GL primitive. + + Derived classes must implement: + + * `gl_primitive_type`: `lines`, `points`, etc. + * `vertex` and `fragment`, or `shader_name`: the GLSL code, or the name of + the GLSL files to load from the `glsl/` subdirectory. + `shader_name` + * `data`: a dictionary acting as a proxy for the gloo Program. + This is because the Program is built later, once the interact has been + attached. The interact is responsible for the creation of the program, + since it implements a part of the transform chain. + * `transforms`: a list of `Transform` instances, which can act on the CPU + or the GPU. The interact's transforms will be appended to that list + when the visual is attached to the canvas. + + """ gl_primitive_type = None vertex = None fragment = None @@ -39,11 +59,14 @@ def __init__(self): self.size = 1, 1 self._canvas = None + self.program = None # Not taken into account when the program has not been built. self._do_show = True + + # To set in `set_data()`. self.data = {} # Data to set on the program when possible. - self.program = None self.transforms = [] + # Combine the visual's transforms and the interact transforms. # The interact triggers the creation of the transform chain in # self.build_program(). @@ -56,11 +79,21 @@ def hide(self): self._do_show = False def set_data(self): - """Set the data for the visual.""" + """Set the data for the visual. + + Derived classes can add data to the `self.data` dictionary and + set transforms in the `self.transforms` list. + + """ pass def attach(self, canvas, interact='base'): - """Attach some events.""" + """Attach the visual to a canvas. + + The interact's name can be specified. The interact's transforms + will be appended to the visual's transforms. + + """ logger.debug("Attach `%s` with interact `%s` to canvas.", self.__class__.__name__, interact or '') self._canvas = canvas @@ -91,9 +124,17 @@ def on_mouse_move(event): if self._do_show: self.on_mouse_move(event) + @canvas.connect + def on_key_press(event): + if self._do_show: + self.on_key_press(event) + def on_mouse_move(self, e): pass + def on_key_press(self, e): + pass + def build_program(self, transforms=None): """Create the gloo program by specifying the transforms given by the optionally-attached interact. @@ -150,6 +191,7 @@ def update(self): class BaseCanvas(Canvas): + """A blank VisPy canvas with a custom event system that keeps the order.""" def __init__(self, *args, **kwargs): super(BaseCanvas, self).__init__(*args, **kwargs) self._events = EventEmitter() @@ -170,8 +212,8 @@ class BaseInteract(object): Derived classes must: - * Define a unique `self.name` - * Define a list of transforms + * Define a unique `name` + * Define a list of `transforms` """ name = 'base' @@ -202,7 +244,8 @@ def iter_attached_visuals(self): def build_programs(self): """Build the programs of all attached visuals. - The transform chain of the interact must have been built before. + The list of transforms of the interact should have been set before + calling this function. """ for visual in self.iter_attached_visuals(): diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 04597d15c..3f1e83e85 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -16,7 +16,7 @@ #------------------------------------------------------------------------------ def test_visual_shader_name(qtbot, canvas): - + """Test a BaseVisual with a shader name.""" class TestVisual(BaseVisual): shader_name = 'box' gl_primitive_type = 'lines' @@ -37,6 +37,7 @@ def set_data(self): def test_base_visual(qtbot, canvas): + """Test a BaseVisual with custom shaders.""" class TestVisual(BaseVisual): vertex = """ @@ -69,12 +70,13 @@ def set_data(self): # Simulate a mouse move. canvas.events.mouse_move(delta=(1., 0.)) + canvas.events.key_press(text='a') v.update() def test_base_interact(qtbot, canvas): - + """Test a BaseVisual with a CPU transform and a blank interact.""" class TestVisual(BaseVisual): vertex = """ attribute vec2 a_position; @@ -97,9 +99,11 @@ def set_data(self): self.data['a_position'] = [[-1, 0], [1, 0]] self.transforms = [Scale(scale=(.5, 1))] + # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() v.attach(canvas) + # Base interact (no transform). interact = BaseInteract() interact.attach(canvas) From c93fdc62771d49786404236fe266ffd0d1897c48 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 20:27:50 +0200 Subject: [PATCH 0390/1059] WIP: tests --- phy/plot/base.py | 3 ++ phy/plot/tests/test_base.py | 57 +++++++++++++++++++++++++++++++- phy/plot/tests/test_transform.py | 13 +++++++- phy/plot/transform.py | 45 +++++++++++++++++++++---- 4 files changed, 110 insertions(+), 8 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 43e63912a..f8c39dc5a 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -155,8 +155,11 @@ def build_program(self, transforms=None): logger.debug("Build the program of `%s`.", self.__class__.__name__) if self.transform_chain: + # Insert the interact's GLSL into the shaders. self.vertex, self.fragment = self.transform_chain.insert_glsl( self.vertex, self.fragment) + logger.debug("Vertex shader: \n%s", self.vertex) + logger.debug("Fragment shader: \n%s", self.fragment) self.program = gloo.Program(self.vertex, self.fragment) # Get the name of the variable that needs to be transformed. diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 3f1e83e85..7888c1405 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -7,8 +7,10 @@ # Imports #------------------------------------------------------------------------------ +import numpy as np + from ..base import BaseVisual, BaseInteract -from ..transform import Scale +from ..transform import Translate, Scale, Range, Clip, Subplot, GPU #------------------------------------------------------------------------------ @@ -110,3 +112,56 @@ def set_data(self): canvas.show() qtbot.waitForWindowShown(canvas.native) # qtbot.stop() + + +def test_interact(qtbot, canvas): + """Test a BaseVisual with multiple CPU and GPU transforms and a + non-blank interact.""" + + class TestVisual(BaseVisual): + vertex = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + gl_PointSize = 2.0; + } + """ + fragment = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + gl_primitive_type = 'points' + + def __init__(self): + super(TestVisual, self).__init__() + self.set_data() + + def set_data(self): + self.data['a_position'] = np.random.uniform(0, 20, (100, 2)) + self.transforms = [Scale(scale=(.1, .1)), + Translate(translate=(-1, -1)), + GPU(), + Range(from_range=(-1, -1, 1, 1), + to_range=(-.9, -.9, .9, .9), + ), + ] + + class TestInteract(BaseInteract): + name = 'test' + + def __init__(self): + super(TestInteract, self).__init__() + self.transforms = [Subplot(shape=(2, 3), index=(1, 2))] + + # We attach the visual to the canvas. By default, a BaseInteract is used. + v = TestVisual() + v.attach(canvas, 'test') + + # Base interact (no transform). + interact = TestInteract() + interact.attach(canvas) + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 64997c87f..ffa772550 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -13,7 +13,8 @@ from numpy.testing import assert_equal as ae from pytest import yield_fixture -from ..transform import (Translate, Scale, Range, Clip, Subplot, GPU, +from ..transform import (_glslify_range, + Translate, Scale, Range, Clip, Subplot, GPU, TransformChain, ) @@ -37,6 +38,16 @@ def _check(transform, array, expected, **kwargs): assert np.allclose(transformed, expected) +#------------------------------------------------------------------------------ +# Test utils +#------------------------------------------------------------------------------ + +def test_glslify_range(): + assert _glslify_range(('a', 'b')) == ('a', 'b') + assert _glslify_range((1, 2, 3, 4)) == ('vec2(1.000, 2.000)', + 'vec2(3.000, 4.000)') + + #------------------------------------------------------------------------------ # Test transform #------------------------------------------------------------------------------ diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 404844f59..542edc8a1 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -11,6 +11,7 @@ import re import numpy as np +from six import string_types import logging @@ -53,6 +54,25 @@ def indent(text): return '\n'.join(' ' + l.strip() for l in text.splitlines()) +def _glslify_pair(p): + if isinstance(p, string_types): + return p + elif len(p) == 2: + s = 'vec2({:.3f}, {:.3f})' + return s.format(*p) + raise ValueError() # pragma: no cover + + +def _glslify_range(r): + if len(r) == 2: + assert isinstance(r[0], string_types) + assert isinstance(r[1], string_types) + return r + elif len(r) == 4: + return _glslify_pair(r[:2]), _glslify_pair(r[2:]) + raise ValueError() # pragma: no cover + + #------------------------------------------------------------------------------ # Transforms #------------------------------------------------------------------------------ @@ -106,6 +126,9 @@ def glsl(self, var, from_range=None, to_range=None): if to_range is None: to_range = [-1, -1, 1, 1] + from_range = _glslify_range(from_range) + to_range = _glslify_range(to_range) + return """ {var} = {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}); """.format(var=var, @@ -147,7 +170,8 @@ def glsl(self, var, bounds=None): class Subplot(Range): """Assume that the from range is [-1, -1, 1, 1].""" - def apply(self, arr, shape=None, index=None): + + def get_range(self, shape=None, index=None): i, j = index n_rows, n_cols = shape assert 0 <= i <= n_rows - 1 @@ -166,13 +190,17 @@ def apply(self, arr, shape=None, index=None): from_range = [-1, -1, 1, 1] to_range = [x, y, x + width, y + height] + return from_range, to_range + + def apply(self, arr, shape=None, index=None): + from_range, to_range = self.get_range(shape=shape, index=index) return super(Subplot, self).apply(arr, from_range=from_range, to_range=to_range) def glsl(self, var, shape=None, index=None): assert var - glsl = """ + snippet = """ float subplot_x = -1.0 + {index}.y * 2.0 / {shape}.y; float subplot_y = +1.0 - {index}.x * 2.0 / {shape}.x; @@ -182,7 +210,11 @@ def glsl(self, var, shape=None, index=None): {var} = vec2(subplot_x + subplot_width * {var}.x, subplot_y + subplot_height * {var}.y); """ - return glsl.format(index=index, shape=shape, var=var) + index = _glslify_pair(index) + shape = _glslify_pair(shape) + + snippet = snippet.format(index=index, shape=shape, var=var) + return snippet #------------------------------------------------------------------------------ @@ -256,12 +288,13 @@ def insert_glsl(self, vertex, fragment): assert var and var in vertex # Generate the snippet to insert in the shaders. - vs_insert = "" + temp_var = '_transformed_data' + vs_insert = "vec2 {} = {};\n".format(temp_var, var) for t in self.gpu_transforms: if isinstance(t, Clip): continue - vs_insert += t.glsl(var) + '\n' - vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(var) + vs_insert += t.glsl(temp_var) + '\n' + vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(temp_var) # Clipping. clip = self.get('Clip') From e9d6e33a7ec0ddbe3210335d39b25d4ac8061783 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 22:05:45 +0200 Subject: [PATCH 0391/1059] Fix bugs --- phy/plot/tests/test_base.py | 9 +++-- phy/plot/tests/test_transform.py | 12 ++++--- phy/plot/transform.py | 59 +++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 27 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 7888c1405..f875513a6 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -138,12 +138,12 @@ def __init__(self): self.set_data() def set_data(self): - self.data['a_position'] = np.random.uniform(0, 20, (100, 2)) + self.data['a_position'] = np.random.uniform(0, 20, (100000, 2)) self.transforms = [Scale(scale=(.1, .1)), Translate(translate=(-1, -1)), GPU(), Range(from_range=(-1, -1, 1, 1), - to_range=(-.9, -.9, .9, .9), + to_range=(-1.5, -1.5, 1.5, 1.5), ), ] @@ -152,7 +152,10 @@ class TestInteract(BaseInteract): def __init__(self): super(TestInteract, self).__init__() - self.transforms = [Subplot(shape=(2, 3), index=(1, 2))] + bounds = Subplot().get_range(shape=(2, 3), index=(1, 2))[1] + self.transforms = [Subplot(shape=(2, 3), index=(1, 2)), + Clip(bounds=bounds), + ] # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index ffa772550..3d64ed9e1 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -132,9 +132,9 @@ def test_range_glsl(): def test_clip_glsl(): expected = dedent(""" - if ((x.x < xymin.x) | - (x.y < xymin.y) | - (x.x > xymax.x) | + if ((x.x < xymin.x) || + (x.y < xymin.y) || + (x.x > xymax.x) || (x.y > xymax.y)) { discard; } @@ -218,5 +218,9 @@ def test_transform_chain_complete(array): """).strip() vs, fs = t.insert_glsl(vs, fs) assert 'a_box' in vs - assert 'v_a_position = a_position;' in vs + assert 'v_' in vs + assert 'v_' in fs assert 'discard' in fs + + # Increase coverage. + t.insert_glsl(vs.replace('transform', ''), fs) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 542edc8a1..dccea5d81 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -50,6 +50,14 @@ def wrapped(var, **kwargs): return wrapped +def _wrap_prepost(f, **kwargs_init): + def wrapped(*args, **kwargs): + # Method kwargs first, then we update with the constructor kwargs. + kwargs.update(kwargs_init) + return f(*args, **kwargs) + return wrapped + + def indent(text): return '\n'.join(' ' + l.strip() for l in text.splitlines()) @@ -82,6 +90,14 @@ def __init__(self, **kwargs): # Pass the constructor kwargs to the methods. self.apply = _wrap_apply(self.apply, **kwargs) self.glsl = _wrap_glsl(self.glsl, **kwargs) + self.pre_transforms = _wrap_prepost(self.pre_transforms, **kwargs) + self.post_transforms = _wrap_prepost(self.post_transforms, **kwargs) + + def pre_transforms(self): + return [] + + def post_transforms(self): + return [] def apply(self, arr): raise NotImplementedError() @@ -154,11 +170,12 @@ def glsl(self, var, bounds=None): assert var if bounds is None: bounds = 'vec2(-1, -1)', 'vec2(1, 1)' + bounds = _glslify_range(bounds) return """ - if (({var}.x < {xymin}.x) | - ({var}.y < {xymin}.y) | - ({var}.x > {xymax}.x) | + if (({var}.x < {xymin}.x) || + ({var}.y < {xymin}.y) || + ({var}.x > {xymax}.x) || ({var}.y > {xymax}.y)) {{ discard; }} @@ -177,15 +194,11 @@ def get_range(self, shape=None, index=None): assert 0 <= i <= n_rows - 1 assert 0 <= j <= n_cols - 1 - x = -1.0 + j * (2.0 / n_cols) - y = +1.0 - i * (2.0 / n_rows) - width = 2.0 / n_cols height = 2.0 / n_rows - # The origin (x, y) corresponds to the lower-left corner of the - # target box. - y -= height + x = -1.0 + j * width + y = +1.0 - (i + 1) * height from_range = [-1, -1, 1, 1] to_range = [x, y, x + width, y + height] @@ -200,18 +213,20 @@ def apply(self, arr, shape=None, index=None): def glsl(self, var, shape=None, index=None): assert var - snippet = """ - float subplot_x = -1.0 + {index}.y * 2.0 / {shape}.y; - float subplot_y = +1.0 - {index}.x * 2.0 / {shape}.x; + index = _glslify_pair(index) + shape = _glslify_pair(shape) + + snippet = """ float subplot_width = 2. / {shape}.y; float subplot_height = 2. / {shape}.x; - {var} = vec2(subplot_x + subplot_width * {var}.x, - subplot_y + subplot_height * {var}.y); - """ - index = _glslify_pair(index) - shape = _glslify_pair(shape) + float subplot_x = -1.0 + {index}.y * subplot_width; + float subplot_y = +1.0 - ({index}.x + 1) * subplot_height; + + {var} = vec2(subplot_x + subplot_width * ({var}.x + 1) * .5, + subplot_y + subplot_height * ({var}.y + 1) * .5); + """.format(index=index, shape=shape, var=var) snippet = snippet.format(index=index, shape=shape, var=var) return snippet @@ -259,6 +274,10 @@ def get(self, class_name): if transform.__class__.__name__ == class_name: return transform + def _extend(self): + """Insert pre- and post- transforms into the chain.""" + # TODO + def apply(self, arr): """Apply all CPU transforms on an array.""" for t in self.cpu_transforms: @@ -288,7 +307,7 @@ def insert_glsl(self, vertex, fragment): assert var and var in vertex # Generate the snippet to insert in the shaders. - temp_var = '_transformed_data' + temp_var = 'temp_pos_tr' vs_insert = "vec2 {} = {};\n".format(temp_var, var) for t in self.gpu_transforms: if isinstance(t, Clip): @@ -300,7 +319,7 @@ def insert_glsl(self, vertex, fragment): clip = self.get('Clip') if clip: # Varying name. - fvar = 'v_{}'.format(var) + fvar = 'v_{}'.format(temp_var) glsl_clip = clip.glsl(fvar) # Prepare the fragment regex. @@ -315,7 +334,7 @@ def insert_glsl(self, vertex, fragment): # Make the replacement in the fragment shader for clipping. fragment = fs_regex.sub(indent(fs_insert), fragment) # Set the varying value in the vertex shader. - vs_insert += '{} = {};\n'.format(fvar, var) + vs_insert += '{} = {};\n'.format(fvar, temp_var) # Insert the GLSL snippet of the transform chain in the vertex shader. vertex = vs_regex.sub(indent(vs_insert), vertex) From 92a9f70a012a9d584caf1ce2b66000ffb1df5aab Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 22:23:13 +0200 Subject: [PATCH 0392/1059] WIP: refactor transforms --- phy/plot/tests/test_base.py | 4 +-- phy/plot/tests/test_transform.py | 10 +++--- phy/plot/transform.py | 60 ++++++++++++++------------------ 3 files changed, 33 insertions(+), 41 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index f875513a6..a01e48de0 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -142,8 +142,8 @@ def set_data(self): self.transforms = [Scale(scale=(.1, .1)), Translate(translate=(-1, -1)), GPU(), - Range(from_range=(-1, -1, 1, 1), - to_range=(-1.5, -1.5, 1.5, 1.5), + Range(from_bounds=(-1, -1, 1, 1), + to_bounds=(-1.5, -1.5, 1.5, 1.5), ), ] diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 3d64ed9e1..0390c2763 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -71,7 +71,7 @@ def test_scale_cpu(): def test_range_cpu(): - kwargs = dict(from_range=[0, 0, 1, 1], to_range=[-1, -1, 1, 1]) + kwargs = dict(from_bounds=[0, 0, 1, 1], to_bounds=[-1, -1, 1, 1]) _check(Range(), [-1, -1], [[-3, -3]], **kwargs) _check(Range(), [0, 0], [[-1, -1]], **kwargs) @@ -122,12 +122,12 @@ def test_scale_glsl(): def test_range_glsl(): - assert Range(from_range=[-1, -1, 1, 1]).glsl('x') + assert Range(from_bounds=[-1, -1, 1, 1]).glsl('x') expected = ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' '(u_from.zw - u_from.xy)') - r = Range(to_range=['u_to.xy', 'u_to.zw']) - assert expected in r.glsl('x', from_range=['u_from.xy', 'u_from.zw']) + r = Range(to_bounds=['u_to.xy', 'u_to.zw']) + assert expected in r.glsl('x', from_bounds=['u_from.xy', 'u_from.zw']) def test_clip_glsl(): @@ -193,7 +193,7 @@ def test_transform_chain_two(array): def test_transform_chain_complete(array): t = TransformChain([Scale(scale=.5), Scale(scale=2.)]) - t.add([Range(from_range=[-3, -3, 1, 1]), + t.add([Range(from_bounds=[-3, -3, 1, 1]), GPU(), Clip(), Subplot(shape='u_shape', index='a_box'), diff --git a/phy/plot/transform.py b/phy/plot/transform.py index dccea5d81..a57373f33 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -63,6 +63,7 @@ def indent(text): def _glslify_pair(p): + """GLSL-ify either a string identifier (vec2) or a pair of numbers.""" if isinstance(p, string_types): return p elif len(p) == 2: @@ -72,6 +73,8 @@ def _glslify_pair(p): def _glslify_range(r): + """GLSL-ify either a pair of string identifiers (vec2) or a pair of + pairs of numbers.""" if len(r) == 2: assert isinstance(r[0], string_types) assert isinstance(r[1], string_types) @@ -126,50 +129,39 @@ def glsl(self, var, scale=None): class Range(BaseTransform): - def apply(self, arr, from_range=None, to_range=None): - if to_range is None: - to_range = [-1, -1, 1, 1] - - f0 = np.asarray(from_range[:2]) - f1 = np.asarray(from_range[2:]) - t0 = np.asarray(to_range[:2]) - t1 = np.asarray(to_range[2:]) + def apply(self, arr, from_bounds=None, to_bounds=(-1, -1, 1, 1)): + f0 = np.asarray(from_bounds[:2]) + f1 = np.asarray(from_bounds[2:]) + t0 = np.asarray(to_bounds[:2]) + t1 = np.asarray(to_bounds[2:]) return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) - def glsl(self, var, from_range=None, to_range=None): + def glsl(self, var, from_bounds=None, to_bounds=(-1, -1, 1, 1)): assert var - if to_range is None: - to_range = [-1, -1, 1, 1] - from_range = _glslify_range(from_range) - to_range = _glslify_range(to_range) + from_bounds = _glslify_range(from_bounds) + to_bounds = _glslify_range(to_bounds) return """ {var} = {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}); """.format(var=var, - f0=from_range[0], f1=from_range[1], - t0=to_range[0], t1=to_range[1], + f0=from_bounds[0], f1=from_bounds[1], + t0=to_bounds[0], t1=to_bounds[1], ) class Clip(BaseTransform): - def apply(self, arr, bounds=None): - if bounds is None: - bounds = [-1, -1, 1, 1] - - xymin = np.asarray(bounds[:2]) - xymax = np.asarray(bounds[2:]) - index = ((arr[:, 0] >= xymin[0]) & - (arr[:, 1] >= xymin[1]) & - (arr[:, 0] <= xymax[0]) & - (arr[:, 1] <= xymax[1])) + def apply(self, arr, bounds=(-1, -1, 1, 1)): + index = ((arr[:, 0] >= bounds[0]) & + (arr[:, 1] >= bounds[1]) & + (arr[:, 0] <= bounds[2]) & + (arr[:, 1] <= bounds[3])) return arr[index, ...] - def glsl(self, var, bounds=None): + def glsl(self, var, bounds=(-1, -1, 1, 1)): assert var - if bounds is None: - bounds = 'vec2(-1, -1)', 'vec2(1, 1)' + bounds = _glslify_range(bounds) return """ @@ -200,16 +192,16 @@ def get_range(self, shape=None, index=None): x = -1.0 + j * width y = +1.0 - (i + 1) * height - from_range = [-1, -1, 1, 1] - to_range = [x, y, x + width, y + height] + from_bounds = [-1, -1, 1, 1] + to_bounds = [x, y, x + width, y + height] - return from_range, to_range + return from_bounds, to_bounds def apply(self, arr, shape=None, index=None): - from_range, to_range = self.get_range(shape=shape, index=index) + from_bounds, to_bounds = self.get_range(shape=shape, index=index) return super(Subplot, self).apply(arr, - from_range=from_range, - to_range=to_range) + from_bounds=from_bounds, + to_bounds=to_bounds) def glsl(self, var, shape=None, index=None): assert var From 896ce017aaccf6606bdbd77b3cd9ebf5242bed5f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 22:40:24 +0200 Subject: [PATCH 0393/1059] WIP: refactor transforms --- phy/plot/base.py | 4 +- phy/plot/tests/test_base.py | 5 +- phy/plot/tests/test_transform.py | 24 ++++----- phy/plot/transform.py | 93 +++++++++++++------------------- 4 files changed, 54 insertions(+), 72 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index f8c39dc5a..c1dda2a7a 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -158,8 +158,8 @@ def build_program(self, transforms=None): # Insert the interact's GLSL into the shaders. self.vertex, self.fragment = self.transform_chain.insert_glsl( self.vertex, self.fragment) - logger.debug("Vertex shader: \n%s", self.vertex) - logger.debug("Fragment shader: \n%s", self.fragment) + logger.log(5, "Vertex shader: \n%s", self.vertex) + logger.log(5, "Fragment shader: \n%s", self.fragment) self.program = gloo.Program(self.vertex, self.fragment) # Get the name of the variable that needs to be transformed. diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index a01e48de0..dbeb9623d 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -10,7 +10,8 @@ import numpy as np from ..base import BaseVisual, BaseInteract -from ..transform import Translate, Scale, Range, Clip, Subplot, GPU +from ..transform import (subplot_range, Translate, Scale, Range, + Clip, Subplot, GPU) #------------------------------------------------------------------------------ @@ -152,7 +153,7 @@ class TestInteract(BaseInteract): def __init__(self): super(TestInteract, self).__init__() - bounds = Subplot().get_range(shape=(2, 3), index=(1, 2))[1] + bounds = subplot_range(shape=(2, 3), index=(1, 2)) self.transforms = [Subplot(shape=(2, 3), index=(1, 2)), Clip(bounds=bounds), ] diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 0390c2763..235bbb26c 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -13,7 +13,7 @@ from numpy.testing import assert_equal as ae from pytest import yield_fixture -from ..transform import (_glslify_range, +from ..transform import (_glslify, Translate, Scale, Range, Clip, Subplot, GPU, TransformChain, ) @@ -42,10 +42,10 @@ def _check(transform, array, expected, **kwargs): # Test utils #------------------------------------------------------------------------------ -def test_glslify_range(): - assert _glslify_range(('a', 'b')) == ('a', 'b') - assert _glslify_range((1, 2, 3, 4)) == ('vec2(1.000, 2.000)', - 'vec2(3.000, 4.000)') +def test_glslify(): + assert _glslify('a') == 'a', 'b' + assert _glslify((1, 2, 3, 4)) == 'vec4(1, 2, 3, 4)' + assert _glslify((1., 2.)) == 'vec2(1.0, 2.0)' #------------------------------------------------------------------------------ @@ -126,20 +126,20 @@ def test_range_glsl(): expected = ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' '(u_from.zw - u_from.xy)') - r = Range(to_bounds=['u_to.xy', 'u_to.zw']) - assert expected in r.glsl('x', from_bounds=['u_from.xy', 'u_from.zw']) + r = Range(to_bounds='u_to') + assert expected in r.glsl('x', from_bounds='u_from') def test_clip_glsl(): expected = dedent(""" - if ((x.x < xymin.x) || - (x.y < xymin.y) || - (x.x > xymax.x) || - (x.y > xymax.y)) { + if ((x.x < b.x) || + (x.y < b.y) || + (x.x > b.z) || + (x.y > b.w)) { discard; } """).strip() - assert expected in Clip().glsl('x', bounds=['xymin', 'xymax']) + assert expected in Clip().glsl('x', bounds='b') def test_subplot_glsl(): diff --git a/phy/plot/transform.py b/phy/plot/transform.py index a57373f33..05ab17f9d 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -62,26 +62,29 @@ def indent(text): return '\n'.join(' ' + l.strip() for l in text.splitlines()) -def _glslify_pair(p): - """GLSL-ify either a string identifier (vec2) or a pair of numbers.""" - if isinstance(p, string_types): - return p - elif len(p) == 2: - s = 'vec2({:.3f}, {:.3f})' - return s.format(*p) - raise ValueError() # pragma: no cover - - -def _glslify_range(r): - """GLSL-ify either a pair of string identifiers (vec2) or a pair of - pairs of numbers.""" - if len(r) == 2: - assert isinstance(r[0], string_types) - assert isinstance(r[1], string_types) +def _glslify(r): + """Transform a string or a n-tuple to a valid GLSL expression.""" + if isinstance(r, string_types): return r - elif len(r) == 4: - return _glslify_pair(r[:2]), _glslify_pair(r[2:]) - raise ValueError() # pragma: no cover + else: + assert 2 <= len(r) <= 4 + return 'vec{}({})'.format(len(r), ', '.join(map(str, r))) + + +def subplot_range(shape=None, index=None): + i, j = index + n_rows, n_cols = shape + + assert 0 <= i <= n_rows - 1 + assert 0 <= j <= n_cols - 1 + + width = 2.0 / n_cols + height = 2.0 / n_rows + + x = -1.0 + j * width + y = +1.0 - (i + 1) * height + + return [x, y, x + width, y + height] #------------------------------------------------------------------------------ @@ -140,15 +143,12 @@ def apply(self, arr, from_bounds=None, to_bounds=(-1, -1, 1, 1)): def glsl(self, var, from_bounds=None, to_bounds=(-1, -1, 1, 1)): assert var - from_bounds = _glslify_range(from_bounds) - to_bounds = _glslify_range(to_bounds) + from_bounds = _glslify(from_bounds) + to_bounds = _glslify(to_bounds) - return """ - {var} = {t0} + ({t1} - {t0}) * ({var} - {f0}) / ({f1} - {f0}); - """.format(var=var, - f0=from_bounds[0], f1=from_bounds[1], - t0=to_bounds[0], t1=to_bounds[1], - ) + return ("{var} = {t}.xy + ({t}.zw - {t}.xy) * " + "({var} - {f}.xy) / ({f}.zw - {f}.xy);" + "").format(var=var, f=from_bounds, t=to_bounds) class Clip(BaseTransform): @@ -162,43 +162,24 @@ def apply(self, arr, bounds=(-1, -1, 1, 1)): def glsl(self, var, bounds=(-1, -1, 1, 1)): assert var - bounds = _glslify_range(bounds) + bounds = _glslify(bounds) return """ - if (({var}.x < {xymin}.x) || - ({var}.y < {xymin}.y) || - ({var}.x > {xymax}.x) || - ({var}.y > {xymax}.y)) {{ + if (({var}.x < {bounds}.x) || + ({var}.y < {bounds}.y) || + ({var}.x > {bounds}.z) || + ({var}.y > {bounds}.w)) {{ discard; }} - """.format(xymin=bounds[0], - xymax=bounds[1], - var=var, - ) + """.format(bounds=bounds, var=var) class Subplot(Range): """Assume that the from range is [-1, -1, 1, 1].""" - def get_range(self, shape=None, index=None): - i, j = index - n_rows, n_cols = shape - assert 0 <= i <= n_rows - 1 - assert 0 <= j <= n_cols - 1 - - width = 2.0 / n_cols - height = 2.0 / n_rows - - x = -1.0 + j * width - y = +1.0 - (i + 1) * height - - from_bounds = [-1, -1, 1, 1] - to_bounds = [x, y, x + width, y + height] - - return from_bounds, to_bounds - def apply(self, arr, shape=None, index=None): - from_bounds, to_bounds = self.get_range(shape=shape, index=index) + from_bounds = (-1, -1, 1, 1) + to_bounds = subplot_range(shape=shape, index=index) return super(Subplot, self).apply(arr, from_bounds=from_bounds, to_bounds=to_bounds) @@ -206,8 +187,8 @@ def apply(self, arr, shape=None, index=None): def glsl(self, var, shape=None, index=None): assert var - index = _glslify_pair(index) - shape = _glslify_pair(shape) + index = _glslify(index) + shape = _glslify(shape) snippet = """ float subplot_width = 2. / {shape}.y; From 598cef93711b3cd63ed2fabcf15376472f22e198 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 21 Oct 2015 22:51:51 +0200 Subject: [PATCH 0394/1059] Pre and post transforms --- phy/plot/tests/test_transform.py | 15 ++++++++++++++- phy/plot/transform.py | 21 +++++++++++++-------- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 235bbb26c..190cba160 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -13,7 +13,7 @@ from numpy.testing import assert_equal as ae from pytest import yield_fixture -from ..transform import (_glslify, +from ..transform import (_glslify, BaseTransform, Translate, Scale, Range, Clip, Subplot, GPU, TransformChain, ) @@ -166,6 +166,19 @@ def test_transform_chain_empty(array): ae(t.apply(array), array) +def test_transform_chain_pre_post(array): + class MyTransform(BaseTransform): + def pre_transforms(self, key=None): + return [MyTransform(key=key - 1)] + + def post_transforms(self, key=None): + return [MyTransform(key=key + 1), MyTransform(key=key + 2)] + + t = TransformChain([Translate(), MyTransform(key=0), Scale()]) + expected = [None, -1, 0, 1, 2, None] + assert [getattr(p, 'key', None) for p in t.cpu_transforms] == expected + + def test_transform_chain_one(array): translate = Translate(translate=[1, 2]) t = TransformChain([translate]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 05ab17f9d..74bce2b47 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -93,16 +93,17 @@ def subplot_range(shape=None, index=None): class BaseTransform(object): def __init__(self, **kwargs): + self.__dict__.update(kwargs) # Pass the constructor kwargs to the methods. self.apply = _wrap_apply(self.apply, **kwargs) self.glsl = _wrap_glsl(self.glsl, **kwargs) self.pre_transforms = _wrap_prepost(self.pre_transforms, **kwargs) self.post_transforms = _wrap_prepost(self.post_transforms, **kwargs) - def pre_transforms(self): + def pre_transforms(self, **kwargs): return [] - def post_transforms(self): + def post_transforms(self, **kwargs): return [] def apply(self, arr): @@ -219,7 +220,8 @@ class TransformChain(object): """A linear sequence of transforms that happen on the CPU and GPU.""" def __init__(self, transforms=None): self.transformed_var_name = None - self.transforms = transforms or [] + self.transforms = [] + self.add(transforms) def _index_of_gpu(self): classes = [t.__class__.__name__ for t in self.transforms] @@ -239,7 +241,14 @@ def gpu_transforms(self): def add(self, transforms): """Add some transforms.""" - self.transforms.extend(transforms) + for t in transforms: + if hasattr(t, 'pre_transforms'): + for p in t.pre_transforms(): + self.transforms.append(p) + self.transforms.append(t) + if hasattr(t, 'post_transforms'): + for p in t.post_transforms(): + self.transforms.append(p) def get(self, class_name): """Get a transform in the chain from its name.""" @@ -247,10 +256,6 @@ def get(self, class_name): if transform.__class__.__name__ == class_name: return transform - def _extend(self): - """Insert pre- and post- transforms into the chain.""" - # TODO - def apply(self, arr): """Apply all CPU transforms on an array.""" for t in self.cpu_transforms: From d59e60f31216e933bb5011a1fcd6690ad5493143 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 08:56:54 +0200 Subject: [PATCH 0395/1059] Add some events in BaseInteract --- phy/plot/base.py | 12 ++++++++++++ phy/plot/tests/test_base.py | 1 + 2 files changed, 13 insertions(+) diff --git a/phy/plot/base.py b/phy/plot/base.py index c1dda2a7a..643d660ca 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -225,6 +225,10 @@ class BaseInteract(object): def __init__(self): self._canvas = None + @property + def size(self): + return self._canvas.size if self._canvas else None + def attach(self, canvas): """Attach the interact to a canvas.""" self._canvas = canvas @@ -235,7 +239,9 @@ def on_draw(): # Programs that are already built are skipped. self.build_programs() + canvas.connect(self.on_resize) canvas.connect(self.on_mouse_move) + canvas.connect(self.on_mouse_wheel) canvas.connect(self.on_key_press) def iter_attached_visuals(self): @@ -255,8 +261,14 @@ def build_programs(self): if not visual.program: visual.build_program(self.transforms) + def on_resize(self, event): + pass + def on_mouse_move(self, event): pass + def on_mouse_wheel(self, event): + pass + def on_key_press(self, event): pass diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index dbeb9623d..d5c7f8d17 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -111,6 +111,7 @@ def set_data(self): interact.attach(canvas) canvas.show() + assert interact.size[0] >= 1 qtbot.waitForWindowShown(canvas.native) # qtbot.stop() From fe7e40b70dd7c3c6693e49cb1af44d8b12b5134c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 09:07:41 +0200 Subject: [PATCH 0396/1059] WIP: PanZoom --- phy/plot/base.py | 8 +- phy/plot/panzoom.py | 399 +++++++++++++++++++++++++++++++++ phy/plot/tests/test_panzoom.py | 54 +++++ 3 files changed, 458 insertions(+), 3 deletions(-) create mode 100644 phy/plot/panzoom.py create mode 100644 phy/plot/tests/test_panzoom.py diff --git a/phy/plot/base.py b/phy/plot/base.py index 643d660ca..60fb45348 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -246,9 +246,11 @@ def on_draw(): def iter_attached_visuals(self): """Yield all visuals attached to that interact in the canvas.""" - for visual in self._canvas.emit_('get_visual_for_interact', self.name): - if visual: - yield visual + if self._canvas: + for visual in self._canvas.emit_('get_visual_for_interact', + self.name): + if visual: + yield visual def build_programs(self): """Build the programs of all attached visuals. diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py new file mode 100644 index 000000000..7f6431f2d --- /dev/null +++ b/phy/plot/panzoom.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- + +"""Pan & zoom transform.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import math + +import numpy as np + +from .base import BaseInteract +from phy.utils._types import _as_array + + +#------------------------------------------------------------------------------ +# PanZoom class +#------------------------------------------------------------------------------ + +class PanZoom(BaseInteract): + """Pan and zoom interact.""" + + _default_zoom_coeff = 1.5 + _default_wheel_coeff = .1 + _arrows = ('Left', 'Right', 'Up', 'Down') + _pm = ('+', '-') + + def __init__(self, + aspect=1.0, + pan=(0.0, 0.0), zoom=(1.0, 1.0), + zmin=1e-5, zmax=1e5, + xmin=None, xmax=None, + ymin=None, ymax=None, + ): + """ + Initialize the transform. + + Parameters + ---------- + + aspect : float (default is None) + Indicate what is the aspect ratio of the object displayed. This is + necessary to convert pixel drag move in object space coordinates. + + pan : float, float (default is 0, 0) + Initial translation + + zoom : float, float (default is 1) + Initial zoom level + + zmin : float (default is 0.01) + Minimum zoom level + + zmax : float (default is 1000) + Maximum zoom level + """ + super(PanZoom, self).__init__() + + self._aspect = aspect + + self._zmin = zmin + self._zmax = zmax + self._xmin = xmin + self._xmax = xmax + self._ymin = ymin + self._ymax = ymax + + self._pan = np.array(pan) + self._zoom = np.array(zoom) + + self._zoom_coeff = self._default_zoom_coeff + self._wheel_coeff = self._default_wheel_coeff + + self._zoom_to_pointer = True + self._canvas_aspect = np.ones(2) + + # Various properties + # ------------------------------------------------------------------------- + + def is_attached(self): + """Whether the transform is attached to a canvas.""" + return self._canvas is not None + + @property + def aspect(self): + """Aspect (width/height).""" + return self._aspect + + @aspect.setter + def aspect(self, value): + self._aspect = value + + # xmin/xmax + # ------------------------------------------------------------------------- + + @property + def xmin(self): + """Minimum x allowed for pan.""" + return self._xmin + + @xmin.setter + def xmin(self, value): + if self._xmax is not None: + self._xmin = np.minimum(value, self._xmax) + else: + self._xmin = value + + @property + def xmax(self): + """Maximum x allowed for pan.""" + return self._xmax + + @xmax.setter + def xmax(self, value): + if self._xmin is not None: + self._xmax = np.maximum(value, self._xmin) + else: + self._xmax = value + + # ymin/ymax + # ------------------------------------------------------------------------- + + @property + def ymin(self): + """Minimum y allowed for pan.""" + return self._ymin + + @ymin.setter + def ymin(self, value): + if self._ymax is not None: + self._ymin = min(value, self._ymax) + else: + self._ymin = value + + @property + def ymax(self): + """Maximum y allowed for pan.""" + return self._ymax + + @ymax.setter + def ymax(self, value): + if self._ymin is not None: + self._ymax = max(value, self._ymin) + else: + self._ymax = value + + # zmin/zmax + # ------------------------------------------------------------------------- + + @property + def zmin(self): + """Minimum zoom level.""" + return self._zmin + + @zmin.setter + def zmin(self, value): + self._zmin = min(value, self._zmax) + + @property + def zmax(self): + """Maximal zoom level.""" + return self._zmax + + @zmax.setter + def zmax(self, value): + self._zmax = max(value, self._zmin) + + # Internal methods + # ------------------------------------------------------------------------- + + def _iter_programs(self): + for visual in self.iter_attached_visuals(): + yield visual.program + + def _apply_pan_zoom(self): + zoom = self._zoom_aspect() + for program in self._iter_programs(): + program['u_pan'] = self._pan + program['u_zoom'] = zoom + + def _zoom_aspect(self, zoom=None): + if zoom is None: + zoom = self._zoom + zoom = _as_array(zoom) + if self._aspect is not None: + aspect = self._canvas_aspect * self._aspect + else: + aspect = 1. + return zoom * aspect + + def _normalize(self, x_y, restrict_to_box=True): + x_y = np.asarray(x_y, dtype=np.float32) + size = np.asarray(self.size, dtype=np.float32) + pos = x_y / (size / 2.) - 1 + return pos + + def _constrain_pan(self): + """Constrain bounding box.""" + if self.xmin is not None and self._xmax is not None: + p0 = self.xmin + 1. / self._zoom[0] + p1 = self.xmax - 1. / self._zoom[0] + p0, p1 = min(p0, p1), max(p0, p1) + self._pan[0] = np.clip(self._pan[0], p0, p1) + + if self.ymin is not None and self._ymax is not None: + p0 = self.ymin + 1. / self._zoom[1] + p1 = self.ymax - 1. / self._zoom[1] + p0, p1 = min(p0, p1), max(p0, p1) + self._pan[1] = np.clip(self._pan[1], p0, p1) + + def _constrain_zoom(self): + """Constrain bounding box.""" + if self.xmin is not None: + self._zoom[0] = max(self._zoom[0], + 1. / (self._pan[0] - self.xmin)) + if self.xmax is not None: + self._zoom[0] = max(self._zoom[0], + 1. / (self.xmax - self._pan[0])) + + if self.ymin is not None: + self._zoom[1] = max(self._zoom[1], + 1. / (self._pan[1] - self.ymin)) + if self.ymax is not None: + self._zoom[1] = max(self._zoom[1], + 1. / (self.ymax - self._pan[1])) + + # Pan and zoom + # ------------------------------------------------------------------------- + + @property + def pan(self): + """Pan translation.""" + return list(self._pan) + + @pan.setter + def pan(self, value): + """Pan translation.""" + assert len(value) == 2 + self._pan[:] = value + self._constrain_pan() + self._apply_pan_zoom() + + @property + def zoom(self): + """Zoom level.""" + return list(self._zoom) + + @zoom.setter + def zoom(self, value): + """Zoom level.""" + if isinstance(value, (int, float)): + value = (value, value) + assert len(value) == 2 + self._zoom = np.clip(value, self._zmin, self._zmax) + if not self.is_attached: + return + + # Constrain bounding box. + self._constrain_pan() + self._constrain_zoom() + + self._apply_pan_zoom() + + def pan_delta(self, d): + + dx, dy = d + + pan_x, pan_y = self.pan + zoom_x, zoom_y = self._zoom_aspect(self._zoom) + + self.pan = (pan_x + dx / zoom_x, + pan_y + dy / zoom_y) + + self._canvas.update() + + def zoom_delta(self, d, p, c=1.): + dx, dy = d + x0, y0 = p + + pan_x, pan_y = self._pan + zoom_x, zoom_y = self._zoom + zoom_x_new, zoom_y_new = (zoom_x * math.exp(c * self._zoom_coeff * dx), + zoom_y * math.exp(c * self._zoom_coeff * dy)) + + zoom_x_new = max(min(zoom_x_new, self._zmax), self._zmin) + zoom_y_new = max(min(zoom_y_new, self._zmax), self._zmin) + + self.zoom = zoom_x_new, zoom_y_new + + if self._zoom_to_pointer: + zoom_x, zoom_y = self._zoom_aspect((zoom_x, + zoom_y)) + zoom_x_new, zoom_y_new = self._zoom_aspect((zoom_x_new, + zoom_y_new)) + + self.pan = (pan_x - x0 * (1. / zoom_x - 1. / zoom_x_new), + pan_y + y0 * (1. / zoom_y - 1. / zoom_y_new)) + + self._canvas.update() + + # Event callbacks + # ------------------------------------------------------------------------- + + keyboard_shortcuts = { + 'pan': ('left click and drag', 'arrows'), + 'zoom': ('right click and drag', '+', '-'), + 'reset': 'r', + } + + def _set_canvas_aspect(self): + w, h = self.size + aspect = max(1., w / max(float(h), 1.)) + if aspect > 1.0: + self._canvas_aspect = np.array([1.0 / aspect, 1.0]) + else: + self._canvas_aspect = np.array([1.0, aspect / 1.0]) + + def on_resize(self, event): + """Resize event.""" + self._set_canvas_aspect() + # Update zoom level + self.zoom = self._zoom + + def on_mouse_move(self, event): + """Pan and zoom with the mouse.""" + if event.modifiers: + return + if event.is_dragging: + x0, y0 = self._normalize(event.press_event.pos) + x1, y1 = self._normalize(event.last_event.pos, False) + x, y = self._normalize(event.pos, False) + dx, dy = x - x1, -(y - y1) + if event.button == 1: + self.pan_delta((dx, dy)) + elif event.button == 2: + c = np.sqrt(self.size[0]) * .03 + self.zoom_delta((dx, dy), (x0, y0), c=c) + + def on_mouse_wheel(self, event): + """Zoom with the mouse wheel.""" + if event.modifiers: + return + dx = np.sign(event.delta[1]) * self._wheel_coeff + # Zoom toward the mouse pointer. + x0, y0 = self._normalize(event.pos) + self.zoom_delta((dx, dx), (x0, y0)) + + def _zoom_keyboard(self, key): + k = .05 + if key == '-': + k = -k + self.zoom_delta((k, k), (0, 0)) + + def _pan_keyboard(self, key): + k = .1 / self.zoom + if key == 'Left': + self.pan += (+k[0], +0) + elif key == 'Right': + self.pan += (-k[0], +0) + elif key == 'Down': + self.pan += (+0, +k[1]) + elif key == 'Up': + self.pan += (+0, -k[1]) + self._canvas.update() + + def _reset_keyboard(self): + self.pan = (0., 0.) + self.zoom = 1. + self._canvas.update() + + def on_key_press(self, event): + """Key press event.""" + + # Zooming with the keyboard. + key = event.key + if event.modifiers: + return + + # Pan. + if key in self._arrows: + self._pan_keyboard(key) + + # Zoom. + if key in self._pm: + self._zoom_keyboard(key) + + # Reset with 'R'. + if key == 'R': + self._reset_keyboard() + + # Canvas methods + # ------------------------------------------------------------------------- + + def attach(self, canvas): + """Attach this tranform to a canvas.""" + super(PanZoom, self).attach(canvas) + self._set_canvas_aspect() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py new file mode 100644 index 000000000..c1e57fd80 --- /dev/null +++ b/phy/plot/tests/test_panzoom.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +"""Test panzoom.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +from ..panzoom import PanZoom + + +#------------------------------------------------------------------------------ +# Test panzoom +#------------------------------------------------------------------------------ + +def test_panzoom_basic_attrs(): + panzoom = PanZoom() + + assert not panzoom.is_attached() + + # Aspect. + assert panzoom.aspect == 1. + panzoom.aspect = 2. + assert panzoom.aspect == 2. + + # Constraints. + for name in ('xmin', 'xmax', 'ymin', 'ymax'): + assert getattr(panzoom, name) is None + setattr(panzoom, name, 1.) + assert getattr(panzoom, name) == 1. + + for name, v in (('zmin', 1e-5), ('zmax', 1e5)): + assert getattr(panzoom, name) == v + setattr(panzoom, name, v * 2) + assert getattr(panzoom, name) == v * 2 + + assert list(panzoom.iter_attached_visuals()) == [] + + +def test_panzoom_basic_pan_zoom(): + panzoom = PanZoom() + + # Pan. + assert panzoom.pan == [0., 0.] + panzoom.pan = (1., -1.) + assert panzoom.pan == [1., -1.] + + # Zoom. + assert panzoom.zoom == [1., 1.] + panzoom.zoom = (2., .5) + assert panzoom.zoom == [2., .5] From 16aae015de5ee715660909f04e9d75125c9e3cd5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 09:13:40 +0200 Subject: [PATCH 0397/1059] WIP: test panzoom --- phy/plot/panzoom.py | 11 +++++++---- phy/plot/tests/test_panzoom.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 7f6431f2d..802f041d2 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -226,6 +226,10 @@ def _constrain_zoom(self): self._zoom[1] = max(self._zoom[1], 1. / (self.ymax - self._pan[1])) + def update(self): + if self.is_attached(): + self._canvas.update() + # Pan and zoom # ------------------------------------------------------------------------- @@ -272,10 +276,9 @@ def pan_delta(self, d): self.pan = (pan_x + dx / zoom_x, pan_y + dy / zoom_y) + self.update() - self._canvas.update() - - def zoom_delta(self, d, p, c=1.): + def zoom_delta(self, d, p=(0., 0.), c=1.): dx, dy = d x0, y0 = p @@ -298,7 +301,7 @@ def zoom_delta(self, d, p, c=1.): self.pan = (pan_x - x0 * (1. / zoom_x - 1. / zoom_x_new), pan_y + y0 * (1. / zoom_y - 1. / zoom_y_new)) - self._canvas.update() + self.update() # Event callbacks # ------------------------------------------------------------------------- diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index c1e57fd80..3e974c253 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -52,3 +52,19 @@ def test_panzoom_basic_pan_zoom(): assert panzoom.zoom == [1., 1.] panzoom.zoom = (2., .5) assert panzoom.zoom == [2., .5] + panzoom.zoom = (1., 1.) + + # Pan delta. + panzoom.pan_delta((-1., 1.)) + assert panzoom.pan == [0., 0.] + + # Zoom delta. + panzoom.zoom_delta((1., 1.)) + assert panzoom.zoom[0] > 2 + assert panzoom.zoom[0] == panzoom.zoom[1] + panzoom.zoom = (1., 1.) + + # Zoom delta. + panzoom.zoom_delta((2., 3.), (.5, .5)) + assert panzoom.zoom[0] > 2 + assert panzoom.zoom[1] > 3 * panzoom.zoom[0] From 24459d3f9eca9bc9b821f420a30fd53ad6c88c0d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 09:24:48 +0200 Subject: [PATCH 0398/1059] WIP: test panzoom --- phy/plot/panzoom.py | 53 ++++++++-------- phy/plot/tests/test_base.py | 5 +- phy/plot/tests/test_panzoom.py | 109 +++++++++++++++++++++++---------- 3 files changed, 110 insertions(+), 57 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 802f041d2..fddea8a7b 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -172,7 +172,8 @@ def zmax(self, value): def _iter_programs(self): for visual in self.iter_attached_visuals(): - yield visual.program + if visual.program: + yield visual.program def _apply_pan_zoom(self): zoom = self._zoom_aspect() @@ -320,14 +321,39 @@ def _set_canvas_aspect(self): else: self._canvas_aspect = np.array([1.0, aspect / 1.0]) + def _zoom_keyboard(self, key): + k = .05 + if key == '-': + k = -k + self.zoom_delta((k, k), (0, 0)) + + def _pan_keyboard(self, key): + k = .1 / self.zoom + if key == 'Left': + self.pan += (+k[0], +0) + elif key == 'Right': + self.pan += (-k[0], +0) + elif key == 'Down': + self.pan += (+0, +k[1]) + elif key == 'Up': + self.pan += (+0, -k[1]) + self._canvas.update() + + def _reset_keyboard(self): + self.pan = (0., 0.) + self.zoom = 1. + self._canvas.update() + def on_resize(self, event): """Resize event.""" + super(PanZoom, self).on_resize(event) self._set_canvas_aspect() # Update zoom level self.zoom = self._zoom def on_mouse_move(self, event): """Pan and zoom with the mouse.""" + super(PanZoom, self).on_mouse_move(event) if event.modifiers: return if event.is_dragging: @@ -343,6 +369,7 @@ def on_mouse_move(self, event): def on_mouse_wheel(self, event): """Zoom with the mouse wheel.""" + super(PanZoom, self).on_mouse_wheel(event) if event.modifiers: return dx = np.sign(event.delta[1]) * self._wheel_coeff @@ -350,31 +377,9 @@ def on_mouse_wheel(self, event): x0, y0 = self._normalize(event.pos) self.zoom_delta((dx, dx), (x0, y0)) - def _zoom_keyboard(self, key): - k = .05 - if key == '-': - k = -k - self.zoom_delta((k, k), (0, 0)) - - def _pan_keyboard(self, key): - k = .1 / self.zoom - if key == 'Left': - self.pan += (+k[0], +0) - elif key == 'Right': - self.pan += (-k[0], +0) - elif key == 'Down': - self.pan += (+0, +k[1]) - elif key == 'Up': - self.pan += (+0, -k[1]) - self._canvas.update() - - def _reset_keyboard(self): - self.pan = (0., 0.) - self.zoom = 1. - self._canvas.update() - def on_key_press(self, event): """Key press event.""" + super(PanZoom, self).on_key_press(event) # Zooming with the keyboard. key = event.key diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index d5c7f8d17..b86d7d258 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -56,11 +56,14 @@ class TestVisual(BaseVisual): """ gl_primitive_type = 'lines' + def __init__(self): + super(TestVisual, self).__init__() + self.set_data() + def set_data(self): self.data['a_position'] = [[-1, 0], [1, 0]] v = TestVisual() - v.set_data() # We need to build the program explicitly when there is no interact. v.attach(canvas) v.build_program() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 3e974c253..fd81d59eb 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -7,64 +7,109 @@ # Imports #------------------------------------------------------------------------------ -import numpy as np +# import numpy as np +from pytest import yield_fixture +from ..base import BaseVisual from ..panzoom import PanZoom +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +class MyTestVisual(BaseVisual): + vertex = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + } + """ + fragment = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + gl_primitive_type = 'lines' + + def __init__(self): + super(MyTestVisual, self).__init__() + self.set_data() + + def set_data(self): + self.data['a_position'] = [[-1, 0], [1, 0]] + + +@yield_fixture +def visual(): + yield MyTestVisual() + + #------------------------------------------------------------------------------ # Test panzoom #------------------------------------------------------------------------------ -def test_panzoom_basic_attrs(): - panzoom = PanZoom() +def test_pz_basic_attrs(): + pz = PanZoom() - assert not panzoom.is_attached() + assert not pz.is_attached() # Aspect. - assert panzoom.aspect == 1. - panzoom.aspect = 2. - assert panzoom.aspect == 2. + assert pz.aspect == 1. + pz.aspect = 2. + assert pz.aspect == 2. # Constraints. for name in ('xmin', 'xmax', 'ymin', 'ymax'): - assert getattr(panzoom, name) is None - setattr(panzoom, name, 1.) - assert getattr(panzoom, name) == 1. + assert getattr(pz, name) is None + setattr(pz, name, 1.) + assert getattr(pz, name) == 1. for name, v in (('zmin', 1e-5), ('zmax', 1e5)): - assert getattr(panzoom, name) == v - setattr(panzoom, name, v * 2) - assert getattr(panzoom, name) == v * 2 + assert getattr(pz, name) == v + setattr(pz, name, v * 2) + assert getattr(pz, name) == v * 2 - assert list(panzoom.iter_attached_visuals()) == [] + assert list(pz.iter_attached_visuals()) == [] -def test_panzoom_basic_pan_zoom(): - panzoom = PanZoom() +def test_pz_basic_pan_zoom(): + pz = PanZoom() # Pan. - assert panzoom.pan == [0., 0.] - panzoom.pan = (1., -1.) - assert panzoom.pan == [1., -1.] + assert pz.pan == [0., 0.] + pz.pan = (1., -1.) + assert pz.pan == [1., -1.] # Zoom. - assert panzoom.zoom == [1., 1.] - panzoom.zoom = (2., .5) - assert panzoom.zoom == [2., .5] - panzoom.zoom = (1., 1.) + assert pz.zoom == [1., 1.] + pz.zoom = (2., .5) + assert pz.zoom == [2., .5] + pz.zoom = (1., 1.) # Pan delta. - panzoom.pan_delta((-1., 1.)) - assert panzoom.pan == [0., 0.] + pz.pan_delta((-1., 1.)) + assert pz.pan == [0., 0.] # Zoom delta. - panzoom.zoom_delta((1., 1.)) - assert panzoom.zoom[0] > 2 - assert panzoom.zoom[0] == panzoom.zoom[1] - panzoom.zoom = (1., 1.) + pz.zoom_delta((1., 1.)) + assert pz.zoom[0] > 2 + assert pz.zoom[0] == pz.zoom[1] + pz.zoom = (1., 1.) # Zoom delta. - panzoom.zoom_delta((2., 3.), (.5, .5)) - assert panzoom.zoom[0] > 2 - assert panzoom.zoom[1] > 3 * panzoom.zoom[0] + pz.zoom_delta((2., 3.), (.5, .5)) + assert pz.zoom[0] > 2 + assert pz.zoom[1] > 3 * pz.zoom[0] + + +def test_pz_attached(qtbot, canvas, visual): + + visual.attach(canvas) + + pz = PanZoom() + pz.attach(canvas) + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() From e3070280d82bad686d42420b3d96dab9211673c1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 09:31:17 +0200 Subject: [PATCH 0399/1059] WIP --- phy/plot/base.py | 6 ++---- phy/plot/panzoom.py | 1 + phy/plot/tests/test_base.py | 4 +--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 60fb45348..6f55e7353 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -87,7 +87,7 @@ def set_data(self): """ pass - def attach(self, canvas, interact='base'): + def attach(self, canvas, interact='BaseInteract'): """Attach the visual to a canvas. The interact's name can be specified. The interact's transforms @@ -215,11 +215,9 @@ class BaseInteract(object): Derived classes must: - * Define a unique `name` * Define a list of `transforms` """ - name = 'base' transforms = None def __init__(self): @@ -248,7 +246,7 @@ def iter_attached_visuals(self): """Yield all visuals attached to that interact in the canvas.""" if self._canvas: for visual in self._canvas.emit_('get_visual_for_interact', - self.name): + self.__class__.__name__): if visual: yield visual diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index fddea8a7b..7b309ff0f 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -22,6 +22,7 @@ class PanZoom(BaseInteract): """Pan and zoom interact.""" + name = 'panzoom' _default_zoom_coeff = 1.5 _default_wheel_coeff = .1 _arrows = ('Left', 'Right', 'Up', 'Down') diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index b86d7d258..d797c8dcc 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -153,8 +153,6 @@ def set_data(self): ] class TestInteract(BaseInteract): - name = 'test' - def __init__(self): super(TestInteract, self).__init__() bounds = subplot_range(shape=(2, 3), index=(1, 2)) @@ -164,7 +162,7 @@ def __init__(self): # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() - v.attach(canvas, 'test') + v.attach(canvas, 'TestInteract') # Base interact (no transform). interact = TestInteract() From ca7cb66fdf10e52d5333efe1f2eb688af5f319e1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 11:31:48 +0200 Subject: [PATCH 0400/1059] WIP: attach pan zoom --- phy/plot/base.py | 12 ++++++++++-- phy/plot/panzoom.py | 5 +++++ phy/plot/tests/test_panzoom.py | 5 ++++- phy/plot/transform.py | 1 + 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 6f55e7353..620e483cd 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -135,7 +135,7 @@ def on_mouse_move(self, e): def on_key_press(self, e): pass - def build_program(self, transforms=None): + def build_program(self, transforms=None, vertex_decl='', frag_decl=''): """Create the gloo program by specifying the transforms given by the optionally-attached interact. @@ -158,6 +158,9 @@ def build_program(self, transforms=None): # Insert the interact's GLSL into the shaders. self.vertex, self.fragment = self.transform_chain.insert_glsl( self.vertex, self.fragment) + # Insert shader declarations. + self.vertex = vertex_decl + '\n' + self.vertex + self.fragment = frag_decl + '\n' + self.fragment logger.log(5, "Vertex shader: \n%s", self.vertex) logger.log(5, "Fragment shader: \n%s", self.fragment) self.program = gloo.Program(self.vertex, self.fragment) @@ -219,6 +222,8 @@ class BaseInteract(object): """ transforms = None + vertex_decl = '' + frag_decl = '' def __init__(self): self._canvas = None @@ -259,7 +264,10 @@ def build_programs(self): """ for visual in self.iter_attached_visuals(): if not visual.program: - visual.build_program(self.transforms) + visual.build_program(self.transforms, + vertex_decl=self.vertex_decl, + frag_decl=self.frag_decl, + ) def on_resize(self, event): pass diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 7b309ff0f..e854b5ddb 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -12,6 +12,7 @@ import numpy as np from .base import BaseInteract +from .transform import Translate, Scale from phy.utils._types import _as_array @@ -77,6 +78,10 @@ def __init__(self, self._zoom_to_pointer = True self._canvas_aspect = np.ones(2) + self.transforms = [Translate(translate='u_pan'), + Scale(scale='u_zoom')] + self.vertex_decl = 'uniform vec2 u_pan;\nuniform vec2 u_zoom;\n' + # Various properties # ------------------------------------------------------------------------- diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index fd81d59eb..eda84fdf4 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -12,6 +12,7 @@ from ..base import BaseVisual from ..panzoom import PanZoom +from ..transform import GPU #------------------------------------------------------------------------------ @@ -34,6 +35,7 @@ class MyTestVisual(BaseVisual): def __init__(self): super(MyTestVisual, self).__init__() + self.transforms = [GPU()] self.set_data() def set_data(self): @@ -105,7 +107,8 @@ def test_pz_basic_pan_zoom(): def test_pz_attached(qtbot, canvas, visual): - visual.attach(canvas) + visual.attach(canvas, 'PanZoom') + visual.show() pz = PanZoom() pz.attach(canvas) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 74bce2b47..43b160059 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -115,6 +115,7 @@ def glsl(self, var): class Translate(BaseTransform): def apply(self, arr, translate=None): + assert isinstance(arr, np.ndarray) return arr + np.asarray(translate) def glsl(self, var, translate=None): From f0c83cfb25199699ba0f793b3806d8ec87f33a6e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 13:20:27 +0200 Subject: [PATCH 0401/1059] WIP: test pan zoom --- phy/plot/base.py | 19 ++++++++++++++----- phy/plot/panzoom.py | 11 +++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 620e483cd..b7165dbb3 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -165,16 +165,23 @@ def build_program(self, transforms=None, vertex_decl='', frag_decl=''): logger.log(5, "Fragment shader: \n%s", self.fragment) self.program = gloo.Program(self.vertex, self.fragment) + if not self.transform_chain.transformed_var_name: + logger.debug("No transformed variable has been found.") + # Upload the data if necessary. + self._upload_data() + + def _upload_data(self): + """Upload pending data (attributes and uniforms) before drawing.""" + if not self.data: + return + # Get the name of the variable that needs to be transformed. # This variable (typically a_position) comes from the vertex shader # which contains the string `gl_Position = transform(the_name);`. var = self.transform_chain.transformed_var_name - if not var: - logger.debug("No transformed variable has been found.") - # Upload the data if necessary. - logger.debug("Upload program objects %s.", - ', '.join(self.data.keys())) + logger.log(5, "Upload program objects %s.", + ', '.join(self.data.keys())) for name, value in self.data.items(): # Normalize the value that needs to be transformed. if name == var: @@ -187,6 +194,8 @@ def draw(self): # Skip the drawing if the program hasn't been built yet. # The program is built by the attached interact. if self._do_show and self.program: + # Upload pending data. + self._upload_data() # Finally, draw the program. self.program.draw(self.gl_primitive_type) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index e854b5ddb..16b4ab597 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -176,16 +176,11 @@ def zmax(self, value): # Internal methods # ------------------------------------------------------------------------- - def _iter_programs(self): - for visual in self.iter_attached_visuals(): - if visual.program: - yield visual.program - def _apply_pan_zoom(self): zoom = self._zoom_aspect() - for program in self._iter_programs(): - program['u_pan'] = self._pan - program['u_zoom'] = zoom + for visual in self.iter_attached_visuals(): + visual.data['u_pan'] = self._pan + visual.data['u_zoom'] = zoom def _zoom_aspect(self, zoom=None): if zoom is None: From 922195c3bf9aee25ef161639baf75f57a292bbec Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 14:12:32 +0200 Subject: [PATCH 0402/1059] WIP: test panzoom interactions --- phy/plot/panzoom.py | 22 +++++----- phy/plot/tests/test_base.py | 2 +- phy/plot/tests/test_panzoom.py | 74 ++++++++++++++++++++++++++++++---- 3 files changed, 79 insertions(+), 19 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 16b4ab597..17354f42a 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -270,14 +270,12 @@ def zoom(self, value): self._apply_pan_zoom() def pan_delta(self, d): - dx, dy = d pan_x, pan_y = self.pan zoom_x, zoom_y = self._zoom_aspect(self._zoom) - self.pan = (pan_x + dx / zoom_x, - pan_y + dy / zoom_y) + self.pan = (pan_x + dx / zoom_x, pan_y + dy / zoom_y) self.update() def zoom_delta(self, d, p=(0., 0.), c=1.): @@ -329,18 +327,22 @@ def _zoom_keyboard(self, key): self.zoom_delta((k, k), (0, 0)) def _pan_keyboard(self, key): - k = .1 / self.zoom + k = .1 / np.asarray(self.zoom) if key == 'Left': - self.pan += (+k[0], +0) + # self.pan += (+k[0], +0) + self.pan_delta((+k[0], +0)) elif key == 'Right': - self.pan += (-k[0], +0) + # self.pan += (-k[0], +0) + self.pan_delta((-k[0], +0)) elif key == 'Down': - self.pan += (+0, +k[1]) + self.pan_delta((+0, +k[1])) + # self.pan += (+0, +k[1]) elif key == 'Up': - self.pan += (+0, -k[1]) + self.pan_delta((+0, -k[1])) + # self.pan += (+0, -k[1]) self._canvas.update() - def _reset_keyboard(self): + def reset(self): self.pan = (0., 0.) self.zoom = 1. self._canvas.update() @@ -397,7 +399,7 @@ def on_key_press(self, event): # Reset with 'R'. if key == 'R': - self._reset_keyboard() + self.reset() # Canvas methods # ------------------------------------------------------------------------- diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index d797c8dcc..590c64741 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -75,7 +75,7 @@ def set_data(self): # qtbot.stop() # Simulate a mouse move. - canvas.events.mouse_move(delta=(1., 0.)) + canvas.events.mouse_move(pos=(0., 0.)) canvas.events.key_press(text='a') v.update() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index eda84fdf4..3414b8310 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -7,7 +7,9 @@ # Imports #------------------------------------------------------------------------------ -# import numpy as np +import numpy as np +from vispy.app import MouseEvent +from vispy.util import keys from pytest import yield_fixture from ..base import BaseVisual @@ -47,6 +49,20 @@ def visual(): yield MyTestVisual() +@yield_fixture +def panzoom(qtbot, canvas, visual): + visual.attach(canvas, 'PanZoom') + visual.show() + + pz = PanZoom() + pz.attach(canvas) + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + + yield pz + + #------------------------------------------------------------------------------ # Test panzoom #------------------------------------------------------------------------------ @@ -105,14 +121,56 @@ def test_pz_basic_pan_zoom(): assert pz.zoom[1] > 3 * pz.zoom[0] -def test_pz_attached(qtbot, canvas, visual): +def test_pz_pan(qtbot, canvas, panzoom): + pz = panzoom - visual.attach(canvas, 'PanZoom') - visual.show() + # Pan with mouse. + press = MouseEvent(type='mouse_press', pos=(0, 0)) + canvas.events.mouse_move(pos=(10., 0.), button=1, + last_event=press, press_event=press) + assert pz.pan[0] > 0 + assert pz.pan[1] == 0 + pz.pan = (0, 0) - pz = PanZoom() - pz.attach(canvas) + # Pan with keyboard. + canvas.events.key_press(key=keys.UP) + assert pz.pan[0] == 0 + assert pz.pan[1] < 0 + pz.pan = (0, 0) + + # Reset with R. + canvas.events.key_press(text='r') + assert pz.pan == [0, 0] + + # qtbot.stop() + + +def test_pz_zoom(qtbot, canvas, panzoom): + pz = panzoom + + # Zoom with mouse. + press = MouseEvent(type='mouse_press', pos=(50., 50.)) + canvas.events.mouse_move(pos=(0., 0.), button=2, + last_event=press, press_event=press) + assert pz.pan[0] < 0 + assert pz.pan[1] < 0 + assert pz.zoom[0] < 1 + assert pz.zoom[1] > 1 + pz.reset() + + # Zoom with mouse. + size = np.asarray(canvas.size) + canvas.events.mouse_wheel(pos=size / 2., delta=(0., 1.)) + assert pz.pan == [0, 0] + assert pz.zoom[0] > 1 + assert pz.zoom[1] > 1 + pz.reset() + + # Zoom with keyboard. + canvas.events.key_press(key=keys.Key('+')) + assert pz.pan == [0, 0] + assert pz.zoom[0] > 1 + assert pz.zoom[1] > 1 + pz.reset() - canvas.show() - qtbot.waitForWindowShown(canvas.native) # qtbot.stop() From 0341a28ed84b7a0313166c5dabc5c7f52e1688a3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 14:20:02 +0200 Subject: [PATCH 0403/1059] WIP: increase panzoom coverage --- phy/plot/tests/test_panzoom.py | 46 ++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 3414b8310..0660e537f 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -121,7 +121,7 @@ def test_pz_basic_pan_zoom(): assert pz.zoom[1] > 3 * pz.zoom[0] -def test_pz_pan(qtbot, canvas, panzoom): +def test_pz_pan_mouse(qtbot, canvas, panzoom): pz = panzoom # Pan with mouse. @@ -132,20 +132,39 @@ def test_pz_pan(qtbot, canvas, panzoom): assert pz.pan[1] == 0 pz.pan = (0, 0) + # Panning with a modifier should not pan. + press = MouseEvent(type='mouse_press', pos=(0, 0)) + canvas.events.mouse_move(pos=(10., 0.), button=1, + last_event=press, press_event=press, + modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + + +def test_pz_pan_keyboard(qtbot, canvas, panzoom): + pz = panzoom + # Pan with keyboard. canvas.events.key_press(key=keys.UP) assert pz.pan[0] == 0 assert pz.pan[1] < 0 - pz.pan = (0, 0) + + # All panning movements with keys. + canvas.events.key_press(key=keys.LEFT) + canvas.events.key_press(key=keys.DOWN) + canvas.events.key_press(key=keys.RIGHT) + assert pz.pan == [0, 0] # Reset with R. - canvas.events.key_press(text='r') + canvas.events.key_press(key=keys.RIGHT) + canvas.events.key_press(key=keys.Key('r')) assert pz.pan == [0, 0] - # qtbot.stop() + # Using modifiers should not pan. + canvas.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] -def test_pz_zoom(qtbot, canvas, panzoom): +def test_pz_zoom_mouse(qtbot, canvas, panzoom): pz = panzoom # Zoom with mouse. @@ -166,11 +185,24 @@ def test_pz_zoom(qtbot, canvas, panzoom): assert pz.zoom[1] > 1 pz.reset() + # Using modifiers with the wheel should not zoom. + canvas.events.mouse_wheel(pos=(0., 0.), delta=(0., 1.), + modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + assert pz.zoom == [1, 1] + pz.reset() + + +def test_pz_zoom_keyboard(qtbot, canvas, panzoom): + pz = panzoom + # Zoom with keyboard. canvas.events.key_press(key=keys.Key('+')) assert pz.pan == [0, 0] assert pz.zoom[0] > 1 assert pz.zoom[1] > 1 - pz.reset() - # qtbot.stop() + # Unzoom with keyboard. + canvas.events.key_press(key=keys.Key('-')) + assert pz.pan == [0, 0] + assert pz.zoom == [1, 1] From bf2a59899cc9faaed3bfb1766253107e3eb05b25 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 14:27:38 +0200 Subject: [PATCH 0404/1059] WIP: increase panzoom coverage --- phy/plot/panzoom.py | 2 +- phy/plot/tests/test_panzoom.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 17354f42a..c53a0096f 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -314,7 +314,7 @@ def zoom_delta(self, d, p=(0., 0.), c=1.): def _set_canvas_aspect(self): w, h = self.size - aspect = max(1., w / max(float(h), 1.)) + aspect = w / max(float(h), 1.) if aspect > 1.0: self._canvas_aspect = np.array([1.0 / aspect, 1.0]) else: diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 0660e537f..a1f2014dc 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -67,7 +67,7 @@ def panzoom(qtbot, canvas, visual): # Test panzoom #------------------------------------------------------------------------------ -def test_pz_basic_attrs(): +def test_panzoom_basic_attrs(): pz = PanZoom() assert not pz.is_attached() @@ -91,7 +91,7 @@ def test_pz_basic_attrs(): assert list(pz.iter_attached_visuals()) == [] -def test_pz_basic_pan_zoom(): +def test_panzoom_basic_pan_zoom(): pz = PanZoom() # Pan. @@ -121,7 +121,7 @@ def test_pz_basic_pan_zoom(): assert pz.zoom[1] > 3 * pz.zoom[0] -def test_pz_pan_mouse(qtbot, canvas, panzoom): +def test_panzoom_pan_mouse(qtbot, canvas, panzoom): pz = panzoom # Pan with mouse. @@ -140,7 +140,7 @@ def test_pz_pan_mouse(qtbot, canvas, panzoom): assert pz.pan == [0, 0] -def test_pz_pan_keyboard(qtbot, canvas, panzoom): +def test_panzoom_pan_keyboard(qtbot, canvas, panzoom): pz = panzoom # Pan with keyboard. @@ -164,7 +164,7 @@ def test_pz_pan_keyboard(qtbot, canvas, panzoom): assert pz.pan == [0, 0] -def test_pz_zoom_mouse(qtbot, canvas, panzoom): +def test_panzoom_zoom_mouse(qtbot, canvas, panzoom): pz = panzoom # Zoom with mouse. @@ -193,7 +193,7 @@ def test_pz_zoom_mouse(qtbot, canvas, panzoom): pz.reset() -def test_pz_zoom_keyboard(qtbot, canvas, panzoom): +def test_panzoom_zoom_keyboard(qtbot, canvas, panzoom): pz = panzoom # Zoom with keyboard. @@ -206,3 +206,10 @@ def test_pz_zoom_keyboard(qtbot, canvas, panzoom): canvas.events.key_press(key=keys.Key('-')) assert pz.pan == [0, 0] assert pz.zoom == [1, 1] + + +def test_panzoom_resize(qtbot, canvas, panzoom): + # Increase coverage with different aspect ratio. + canvas.native.resize(400, 600) + # canvas.events.resize(size=(100, 1000)) + assert list(panzoom._canvas_aspect) == [1., 2. / 3] From 1fa1d9591ca9f58bb2a045e0ae609dfcfe2a7929 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 14:50:51 +0200 Subject: [PATCH 0405/1059] WIP: increase panzoom coverage --- phy/plot/panzoom.py | 43 ++++++++++++-------------------- phy/plot/tests/test_panzoom.py | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 27 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index c53a0096f..34f17ede6 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -78,6 +78,7 @@ def __init__(self, self._zoom_to_pointer = True self._canvas_aspect = np.ones(2) + # We define the GLSL uniforms used for pan and zoom. self.transforms = [Translate(translate='u_pan'), Scale(scale='u_zoom')] self.vertex_decl = 'uniform vec2 u_pan;\nuniform vec2 u_zoom;\n' @@ -108,10 +109,8 @@ def xmin(self): @xmin.setter def xmin(self, value): - if self._xmax is not None: - self._xmin = np.minimum(value, self._xmax) - else: - self._xmin = value + self._xmin = (np.minimum(value, self._xmax) + if self._xmax is not None else value) @property def xmax(self): @@ -120,10 +119,8 @@ def xmax(self): @xmax.setter def xmax(self, value): - if self._xmin is not None: - self._xmax = np.maximum(value, self._xmin) - else: - self._xmax = value + self._xmax = (np.maximum(value, self._xmin) + if self._xmin is not None else value) # ymin/ymax # ------------------------------------------------------------------------- @@ -135,10 +132,8 @@ def ymin(self): @ymin.setter def ymin(self, value): - if self._ymax is not None: - self._ymin = min(value, self._ymax) - else: - self._ymin = value + self._ymin = (min(value, self._ymax) + if self._ymax is not None else value) @property def ymax(self): @@ -147,10 +142,8 @@ def ymax(self): @ymax.setter def ymax(self, value): - if self._ymin is not None: - self._ymax = max(value, self._ymin) - else: - self._ymax = value + self._ymax = (max(value, self._ymin) + if self._ymin is not None else value) # zmin/zmax # ------------------------------------------------------------------------- @@ -183,13 +176,10 @@ def _apply_pan_zoom(self): visual.data['u_zoom'] = zoom def _zoom_aspect(self, zoom=None): - if zoom is None: - zoom = self._zoom + zoom = zoom if zoom is not None else self._zoom zoom = _as_array(zoom) - if self._aspect is not None: - aspect = self._canvas_aspect * self._aspect - else: - aspect = 1. + aspect = (self._canvas_aspect * self._aspect + if self._aspect is not None else 1.) return zoom * aspect def _normalize(self, x_y, restrict_to_box=True): @@ -200,13 +190,13 @@ def _normalize(self, x_y, restrict_to_box=True): def _constrain_pan(self): """Constrain bounding box.""" - if self.xmin is not None and self._xmax is not None: + if self.xmin is not None and self.xmax is not None: p0 = self.xmin + 1. / self._zoom[0] p1 = self.xmax - 1. / self._zoom[0] p0, p1 = min(p0, p1), max(p0, p1) self._pan[0] = np.clip(self._pan[0], p0, p1) - if self.ymin is not None and self._ymax is not None: + if self.ymin is not None and self.ymax is not None: p0 = self.ymin + 1. / self._zoom[1] p1 = self.ymax - 1. / self._zoom[1] p0, p1 = min(p0, p1), max(p0, p1) @@ -260,8 +250,6 @@ def zoom(self, value): value = (value, value) assert len(value) == 2 self._zoom = np.clip(value, self._zmin, self._zmax) - if not self.is_attached: - return # Constrain bounding box. self._constrain_pan() @@ -345,7 +333,8 @@ def _pan_keyboard(self, key): def reset(self): self.pan = (0., 0.) self.zoom = 1. - self._canvas.update() + if self._canvas: + self._canvas.update() def on_resize(self, event): """Resize event.""" diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index a1f2014dc..c6c5a381d 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -121,6 +121,51 @@ def test_panzoom_basic_pan_zoom(): assert pz.zoom[1] > 3 * pz.zoom[0] +def test_panzoom_constraints_x(): + pz = PanZoom() + pz.xmin, pz.xmax = -2, 2 + + # Pan beyond the bounds. + pz.pan_delta((-2, 2)) + assert pz.pan == [-1, 2] + pz.reset() + + # Zoom beyond the bounds. + pz.zoom_delta((-1, -2)) + assert pz.pan == [0, 0] + assert pz.zoom[0] == .5 + assert pz.zoom[1] < .5 + + +def test_panzoom_constraints_y(): + pz = PanZoom() + pz.ymin, pz.ymax = -2, 2 + + # Pan beyond the bounds. + pz.pan_delta((2, -2)) + assert pz.pan == [2, -1] + pz.reset() + + # Zoom beyond the bounds. + pz.zoom_delta((-2, -1)) + assert pz.pan == [0, 0] + assert pz.zoom[0] < .5 + assert pz.zoom[1] == .5 + + +def test_panzoom_constraints_z(): + pz = PanZoom() + pz.zmin, pz.zmax = .5, 2 + + # Zoom beyond the bounds. + pz.zoom_delta((-10, -10)) + assert pz.zoom == [.5, .5] + pz.reset() + + pz.zoom_delta((10, 10)) + assert pz.zoom == [2, 2] + + def test_panzoom_pan_mouse(qtbot, canvas, panzoom): pz = panzoom From a0d59f03fa750801fac0354c62a4eab06d788151 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 15:23:08 +0200 Subject: [PATCH 0406/1059] Add comments --- phy/plot/panzoom.py | 30 +++++++++++++++++++++++++++--- phy/plot/tests/test_panzoom.py | 1 - 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 34f17ede6..68182818f 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -21,7 +21,27 @@ #------------------------------------------------------------------------------ class PanZoom(BaseInteract): - """Pan and zoom interact.""" + """Pan and zoom interact. + + To use it: + + ```python + # Create a visual. + visual = MyVisual(...) + visual.set_data(...) + + # Attach the visual to the canvas. + canvas = BaseCanvas() + visual.attach(canvas, 'PanZoom') + + # Create and attach the PanZoom interact. + pz = PanZoom() + pz.attach(canvas) + + canvas.show() + ``` + + """ name = 'panzoom' _default_zoom_coeff = 1.5 @@ -219,6 +239,7 @@ def _constrain_zoom(self): 1. / (self.ymax - self._pan[1])) def update(self): + """Update the attached canvas if it exists.""" if self.is_attached(): self._canvas.update() @@ -258,6 +279,7 @@ def zoom(self, value): self._apply_pan_zoom() def pan_delta(self, d): + """Pan the view by a given amount.""" dx, dy = d pan_x, pan_y = self.pan @@ -267,6 +289,7 @@ def pan_delta(self, d): self.update() def zoom_delta(self, d, p=(0., 0.), c=1.): + """Zoom the view by a given amount.""" dx, dy = d x0, y0 = p @@ -331,6 +354,7 @@ def _pan_keyboard(self, key): self._canvas.update() def reset(self): + """Reset the view.""" self.pan = (0., 0.) self.zoom = 1. if self._canvas: @@ -370,7 +394,7 @@ def on_mouse_wheel(self, event): self.zoom_delta((dx, dx), (x0, y0)) def on_key_press(self, event): - """Key press event.""" + """Pan and zoom with the keyboard.""" super(PanZoom, self).on_key_press(event) # Zooming with the keyboard. @@ -394,6 +418,6 @@ def on_key_press(self, event): # ------------------------------------------------------------------------- def attach(self, canvas): - """Attach this tranform to a canvas.""" + """Attach this interact to a canvas.""" super(PanZoom, self).attach(canvas) self._set_canvas_aspect() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index c6c5a381d..c9b43f6ab 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -52,7 +52,6 @@ def visual(): @yield_fixture def panzoom(qtbot, canvas, visual): visual.attach(canvas, 'PanZoom') - visual.show() pz = PanZoom() pz.attach(canvas) From 7b531523554c1d784a3788460a25ae7362b73566 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 15:33:14 +0200 Subject: [PATCH 0407/1059] Add pixels_to_ndc() and flip y in normalization --- phy/plot/panzoom.py | 17 +++++++---------- phy/plot/tests/test_transform.py | 7 ++++++- phy/plot/transform.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 68182818f..bdbc223e2 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -12,7 +12,7 @@ import numpy as np from .base import BaseInteract -from .transform import Translate, Scale +from .transform import Translate, Scale, pixels_to_ndc from phy.utils._types import _as_array @@ -202,11 +202,8 @@ def _zoom_aspect(self, zoom=None): if self._aspect is not None else 1.) return zoom * aspect - def _normalize(self, x_y, restrict_to_box=True): - x_y = np.asarray(x_y, dtype=np.float32) - size = np.asarray(self.size, dtype=np.float32) - pos = x_y / (size / 2.) - 1 - return pos + def _normalize(self, pos): + return pixels_to_ndc(pos, size=self.size) def _constrain_pan(self): """Constrain bounding box.""" @@ -310,7 +307,7 @@ def zoom_delta(self, d, p=(0., 0.), c=1.): zoom_y_new)) self.pan = (pan_x - x0 * (1. / zoom_x - 1. / zoom_x_new), - pan_y + y0 * (1. / zoom_y - 1. / zoom_y_new)) + pan_y - y0 * (1. / zoom_y - 1. / zoom_y_new)) self.update() @@ -374,9 +371,9 @@ def on_mouse_move(self, event): return if event.is_dragging: x0, y0 = self._normalize(event.press_event.pos) - x1, y1 = self._normalize(event.last_event.pos, False) - x, y = self._normalize(event.pos, False) - dx, dy = x - x1, -(y - y1) + x1, y1 = self._normalize(event.last_event.pos) + x, y = self._normalize(event.pos) + dx, dy = x - x1, y - y1 if event.button == 1: self.pan_delta((dx, dy)) elif event.button == 2: diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 190cba160..0652e3657 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -13,7 +13,8 @@ from numpy.testing import assert_equal as ae from pytest import yield_fixture -from ..transform import (_glslify, BaseTransform, +from ..transform import (_glslify, pixels_to_ndc, + BaseTransform, Translate, Scale, Range, Clip, Subplot, GPU, TransformChain, ) @@ -48,6 +49,10 @@ def test_glslify(): assert _glslify((1., 2.)) == 'vec2(1.0, 2.0)' +def test_pixels_to_ndc(): + assert list(pixels_to_ndc((0, 0), size=(10, 10))) == [-1, 1] + + #------------------------------------------------------------------------------ # Test transform #------------------------------------------------------------------------------ diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 43b160059..55e670011 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -87,6 +87,17 @@ def subplot_range(shape=None, index=None): return [x, y, x + width, y + height] +def pixels_to_ndc(pos, size=None): + """Convert from pixels to normalized device coordinates (in [-1, 1]).""" + pos = np.asarray(pos, dtype=np.float32) + size = np.asarray(size, dtype=np.float32) + pos = pos / (size / 2.) - 1 + # Flip y, because the origin in pixels is at the top left corner of the + # window. + pos[1] = -pos[1] + return pos + + #------------------------------------------------------------------------------ # Transforms #------------------------------------------------------------------------------ From 71452275fa14313948bd99285b1dad3acceb1485 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 15:39:40 +0200 Subject: [PATCH 0408/1059] Remove name --- phy/plot/panzoom.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index bdbc223e2..a82c617b2 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -43,7 +43,6 @@ class PanZoom(BaseInteract): """ - name = 'panzoom' _default_zoom_coeff = 1.5 _default_wheel_coeff = .1 _arrows = ('Left', 'Right', 'Up', 'Down') From 7b84454c764ac2086c5cff43faf44853011d7a42 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 16:06:20 +0200 Subject: [PATCH 0409/1059] WIP --- phy/plot/tests/test_base.py | 4 ++-- phy/plot/transform.py | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 590c64741..55bf6077d 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -10,7 +10,7 @@ import numpy as np from ..base import BaseVisual, BaseInteract -from ..transform import (subplot_range, Translate, Scale, Range, +from ..transform import (subplot_bounds, Translate, Scale, Range, Clip, Subplot, GPU) @@ -155,7 +155,7 @@ def set_data(self): class TestInteract(BaseInteract): def __init__(self): super(TestInteract, self).__init__() - bounds = subplot_range(shape=(2, 3), index=(1, 2)) + bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) self.transforms = [Subplot(shape=(2, 3), index=(1, 2)), Clip(bounds=bounds), ] diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 55e670011..d57fff267 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -50,7 +50,7 @@ def wrapped(var, **kwargs): return wrapped -def _wrap_prepost(f, **kwargs_init): +def _wrap(f, **kwargs_init): def wrapped(*args, **kwargs): # Method kwargs first, then we update with the constructor kwargs. kwargs.update(kwargs_init) @@ -71,7 +71,7 @@ def _glslify(r): return 'vec{}({})'.format(len(r), ', '.join(map(str, r))) -def subplot_range(shape=None, index=None): +def subplot_bounds(shape=None, index=None): i, j = index n_rows, n_cols = shape @@ -108,8 +108,8 @@ def __init__(self, **kwargs): # Pass the constructor kwargs to the methods. self.apply = _wrap_apply(self.apply, **kwargs) self.glsl = _wrap_glsl(self.glsl, **kwargs) - self.pre_transforms = _wrap_prepost(self.pre_transforms, **kwargs) - self.post_transforms = _wrap_prepost(self.post_transforms, **kwargs) + self.pre_transforms = _wrap(self.pre_transforms, **kwargs) + self.post_transforms = _wrap(self.post_transforms, **kwargs) def pre_transforms(self, **kwargs): return [] @@ -188,11 +188,18 @@ def glsl(self, var, bounds=(-1, -1, 1, 1)): class Subplot(Range): - """Assume that the from range is [-1, -1, 1, 1].""" + """Assume that the from_bounds is [-1, -1, 1, 1].""" + + def __init__(self, **kwargs): + super(Subplot, self).__init__(**kwargs) + self.get_bounds = _wrap(self.get_bounds) + + def get_bounds(self, shape=None, index=None): + return subplot_bounds(shape=shape, index=index) def apply(self, arr, shape=None, index=None): from_bounds = (-1, -1, 1, 1) - to_bounds = subplot_range(shape=shape, index=index) + to_bounds = self.get_bounds(shape=shape, index=index) return super(Subplot, self).apply(arr, from_bounds=from_bounds, to_bounds=to_bounds) From ee090bad610e587adb22372811c3bf13184bf69d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 17:03:40 +0200 Subject: [PATCH 0410/1059] WIP: the interact can set default data --- phy/plot/base.py | 12 +++++++ phy/plot/grid.py | 39 +++++++++++++++++++++ phy/plot/tests/test_grid.py | 68 +++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 phy/plot/grid.py create mode 100644 phy/plot/tests/test_grid.py diff --git a/phy/plot/base.py b/phy/plot/base.py index b7165dbb3..4580883c5 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -236,6 +236,7 @@ class BaseInteract(object): def __init__(self): self._canvas = None + self.data = {} @property def size(self): @@ -256,6 +257,10 @@ def on_draw(): canvas.connect(self.on_mouse_wheel) canvas.connect(self.on_key_press) + def is_attached(self): + """Whether the transform is attached to a canvas.""" + return self._canvas is not None + def iter_attached_visuals(self): """Yield all visuals attached to that interact in the canvas.""" if self._canvas: @@ -273,6 +278,8 @@ def build_programs(self): """ for visual in self.iter_attached_visuals(): if not visual.program: + # Use the interact's data. + visual.data.update(self.data) visual.build_program(self.transforms, vertex_decl=self.vertex_decl, frag_decl=self.frag_decl, @@ -289,3 +296,8 @@ def on_mouse_wheel(self, event): def on_key_press(self, event): pass + + def update(self): + """Update the attached canvas if it exists.""" + if self.is_attached(): + self._canvas.update() diff --git a/phy/plot/grid.py b/phy/plot/grid.py new file mode 100644 index 000000000..13797a9f4 --- /dev/null +++ b/phy/plot/grid.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +"""Grid interact for subplots.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +from .base import BaseInteract +from .transform import Scale, Subplot, Clip, pixels_to_ndc + + +#------------------------------------------------------------------------------ +# Grid class +#------------------------------------------------------------------------------ + +class Grid(BaseInteract): + """Grid interact.""" + + def __init__(self, shape, box_var=None): + """ + """ + super(Grid, self).__init__() + self.box_var = box_var or 'a_box' + self.shape = shape + assert len(shape) == 2 + assert shape[0] >= 1 + assert shape[1] >= 1 + + # Define the grid transform and clipping. + m = 1. - .05 + self.transforms = [Scale(scale=(m, m)), + Clip(bounds=[-m, -m, m, m]), + Subplot(shape=shape, index='a_box'), + ] + self.vertex_decl = 'attribute vec2 a_box;\n' diff --git a/phy/plot/tests/test_grid.py b/phy/plot/tests/test_grid.py new file mode 100644 index 000000000..938deefe7 --- /dev/null +++ b/phy/plot/tests/test_grid.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +"""Test grid.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from pytest import yield_fixture + +from ..base import BaseVisual +from ..grid import Grid +from ..transform import GPU + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +class MyTestVisual(BaseVisual): + vertex = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + } + """ + fragment = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + gl_primitive_type = 'lines' + + def __init__(self): + super(MyTestVisual, self).__init__() + self.transforms = [GPU()] + self.set_data() + + def set_data(self): + self.data['a_position'] = [[-1, 0], [1, 0]] + + +@yield_fixture +def visual(): + yield MyTestVisual() + + +@yield_fixture +def grid(qtbot, canvas, visual): + visual.attach(canvas, 'Grid') + + grid = Grid(shape=(2, 3)) + grid.attach(canvas) + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + + yield grid + + +#------------------------------------------------------------------------------ +# Test grid +#------------------------------------------------------------------------------ + +def test_grid_1(qtbot, visual, grid): + qtbot.stop() From 183bac96c470f160fdc9f45e6e39edc708711c2d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 17:07:31 +0200 Subject: [PATCH 0411/1059] WIP: test grid --- phy/plot/grid.py | 55 +++++++++++++++++++++++++++++++------ phy/plot/panzoom.py | 20 +++----------- phy/plot/tests/test_grid.py | 22 +++++++++++++-- 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/phy/plot/grid.py b/phy/plot/grid.py index 13797a9f4..30f1f1902 100644 --- a/phy/plot/grid.py +++ b/phy/plot/grid.py @@ -7,10 +7,10 @@ # Imports #------------------------------------------------------------------------------ -import numpy as np +import math from .base import BaseInteract -from .transform import Scale, Subplot, Clip, pixels_to_ndc +from .transform import Scale, Subplot, Clip #------------------------------------------------------------------------------ @@ -18,22 +18,61 @@ #------------------------------------------------------------------------------ class Grid(BaseInteract): - """Grid interact.""" + """Grid interact. + + NOTE: to be used in a grid, a visual must define `a_box`. + TODO: improve this, so that a visual doesn't have to be aware of the grid. + + """ def __init__(self, shape, box_var=None): - """ - """ super(Grid, self).__init__() + self._zoom = 1. + + # Name of the variable with the box index. self.box_var = box_var or 'a_box' + self.shape = shape assert len(shape) == 2 assert shape[0] >= 1 assert shape[1] >= 1 # Define the grid transform and clipping. - m = 1. - .05 - self.transforms = [Scale(scale=(m, m)), + m = 1. - .05 # Margin. + self.transforms = [Scale(scale='u_scale'), + Scale(scale=(m, m)), Clip(bounds=[-m, -m, m, m]), Subplot(shape=shape, index='a_box'), ] - self.vertex_decl = 'attribute vec2 a_box;\n' + self.vertex_decl = 'attribute vec2 a_box;\nuniform float u_scale;\n' + self.data['u_scale'] = self._zoom + + @property + def zoom(self): + """Zoom level.""" + return self._zoom + + @zoom.setter + def zoom(self, value): + """Zoom level.""" + self._zoom = value + for visual in self.iter_attached_visuals(): + visual.data['u_scale'] = value + + def on_key_press(self, event): + """Pan and zoom with the keyboard.""" + super(Grid, self).on_key_press(event) + if event.modifiers: + return + key = event.key + + # Zoom. + if key in ('-', '+'): + k = .05 if key == '+' else -.05 + self.zoom *= math.exp(1.5 * k) + self.update() + + # Reset with 'R'. + if key == 'R': + self.zoom = 1. + self.update() diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index a82c617b2..8a2f6e527 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -101,14 +101,12 @@ def __init__(self, self.transforms = [Translate(translate='u_pan'), Scale(scale='u_zoom')] self.vertex_decl = 'uniform vec2 u_pan;\nuniform vec2 u_zoom;\n' + self.data['u_pan'] = pan + self.data['u_zoom'] = zoom # Various properties # ------------------------------------------------------------------------- - def is_attached(self): - """Whether the transform is attached to a canvas.""" - return self._canvas is not None - @property def aspect(self): """Aspect (width/height).""" @@ -234,11 +232,6 @@ def _constrain_zoom(self): self._zoom[1] = max(self._zoom[1], 1. / (self.ymax - self._pan[1])) - def update(self): - """Update the attached canvas if it exists.""" - if self.is_attached(): - self._canvas.update() - # Pan and zoom # ------------------------------------------------------------------------- @@ -336,25 +329,20 @@ def _zoom_keyboard(self, key): def _pan_keyboard(self, key): k = .1 / np.asarray(self.zoom) if key == 'Left': - # self.pan += (+k[0], +0) self.pan_delta((+k[0], +0)) elif key == 'Right': - # self.pan += (-k[0], +0) self.pan_delta((-k[0], +0)) elif key == 'Down': self.pan_delta((+0, +k[1])) - # self.pan += (+0, +k[1]) elif key == 'Up': self.pan_delta((+0, -k[1])) - # self.pan += (+0, -k[1]) - self._canvas.update() + self.update() def reset(self): """Reset the view.""" self.pan = (0., 0.) self.zoom = 1. - if self._canvas: - self._canvas.update() + self.update() def on_resize(self, event): """Resize event.""" diff --git a/phy/plot/tests/test_grid.py b/phy/plot/tests/test_grid.py index 938deefe7..3b9bc395f 100644 --- a/phy/plot/tests/test_grid.py +++ b/phy/plot/tests/test_grid.py @@ -7,6 +7,8 @@ # Imports #------------------------------------------------------------------------------ +from itertools import product + import numpy as np from pytest import yield_fixture @@ -24,6 +26,7 @@ class MyTestVisual(BaseVisual): attribute vec2 a_position; void main() { gl_Position = transform(a_position); + gl_PointSize = 2.; } """ fragment = """ @@ -31,7 +34,7 @@ class MyTestVisual(BaseVisual): gl_FragColor = vec4(1, 1, 1, 1); } """ - gl_primitive_type = 'lines' + gl_primitive_type = 'points' def __init__(self): super(MyTestVisual, self).__init__() @@ -39,7 +42,19 @@ def __init__(self): self.set_data() def set_data(self): - self.data['a_position'] = [[-1, 0], [1, 0]] + n = 1000 + + box = [[i, j] for i, j in product(range(2), range(3))] + box = np.repeat(box, n, axis=0) + + coeff = [(1 + i + j) for i, j in product(range(2), range(3))] + coeff = np.repeat(coeff, n) + coeff = coeff[:, None] + + position = .1 * coeff * np.random.randn(2 * 3 * n, 2) + + self.data['a_position'] = position.astype(np.float32) + self.data['a_box'] = box.astype(np.float32) @yield_fixture @@ -65,4 +80,5 @@ def grid(qtbot, canvas, visual): #------------------------------------------------------------------------------ def test_grid_1(qtbot, visual, grid): - qtbot.stop() + pass + # qtbot.stop() From e72eaf3e2f0e286c52da0bbd712a4b65b9418cf6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 17:10:16 +0200 Subject: [PATCH 0412/1059] Fix bugs --- phy/plot/transform.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index d57fff267..3a23b18c7 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -305,9 +305,13 @@ def insert_glsl(self, vertex, fragment): # Generate the snippet to insert in the shaders. temp_var = 'temp_pos_tr' + # Name for the (eventual) varying. + fvar = 'v_{}'.format(temp_var) vs_insert = "vec2 {} = {};\n".format(temp_var, var) for t in self.gpu_transforms: if isinstance(t, Clip): + # Set the varying value in the vertex shader. + vs_insert += '{} = {};\n'.format(fvar, temp_var) continue vs_insert += t.glsl(temp_var) + '\n' vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(temp_var) @@ -316,7 +320,6 @@ def insert_glsl(self, vertex, fragment): clip = self.get('Clip') if clip: # Varying name. - fvar = 'v_{}'.format(temp_var) glsl_clip = clip.glsl(fvar) # Prepare the fragment regex. @@ -330,8 +333,6 @@ def insert_glsl(self, vertex, fragment): # Make the replacement in the fragment shader for clipping. fragment = fs_regex.sub(indent(fs_insert), fragment) - # Set the varying value in the vertex shader. - vs_insert += '{} = {};\n'.format(fvar, temp_var) # Insert the GLSL snippet of the transform chain in the vertex shader. vertex = vs_regex.sub(indent(vs_insert), vertex) From 9be941273d3020acb1a3ef86c114da6f670bc6b8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 17:24:26 +0200 Subject: [PATCH 0413/1059] Increase coverage --- phy/plot/base.py | 6 ++++-- phy/plot/grid.py | 10 +++++----- phy/plot/tests/test_grid.py | 25 ++++++++++++++++++++++--- 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 4580883c5..d9a018dd1 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -278,8 +278,10 @@ def build_programs(self): """ for visual in self.iter_attached_visuals(): if not visual.program: - # Use the interact's data. - visual.data.update(self.data) + # Use the interact's data by default. + for n, v in self.data.items(): + if n not in visual.data: + visual.data[n] = v visual.build_program(self.transforms, vertex_decl=self.vertex_decl, frag_decl=self.frag_decl, diff --git a/phy/plot/grid.py b/phy/plot/grid.py index 30f1f1902..d834e7b9f 100644 --- a/phy/plot/grid.py +++ b/phy/plot/grid.py @@ -21,7 +21,6 @@ class Grid(BaseInteract): """Grid interact. NOTE: to be used in a grid, a visual must define `a_box`. - TODO: improve this, so that a visual doesn't have to be aware of the grid. """ @@ -39,13 +38,14 @@ def __init__(self, shape, box_var=None): # Define the grid transform and clipping. m = 1. - .05 # Margin. - self.transforms = [Scale(scale='u_scale'), + self.transforms = [Scale(scale='u_zoom'), Scale(scale=(m, m)), Clip(bounds=[-m, -m, m, m]), Subplot(shape=shape, index='a_box'), ] - self.vertex_decl = 'attribute vec2 a_box;\nuniform float u_scale;\n' - self.data['u_scale'] = self._zoom + self.vertex_decl = 'attribute vec2 a_box;\nuniform float u_zoom;\n' + self.data['u_zoom'] = self._zoom + self.data['a_box'] = (0, 0) @property def zoom(self): @@ -57,7 +57,7 @@ def zoom(self, value): """Zoom level.""" self._zoom = value for visual in self.iter_attached_visuals(): - visual.data['u_scale'] = value + visual.data['u_zoom'] = value def on_key_press(self, event): """Pan and zoom with the keyboard.""" diff --git a/phy/plot/tests/test_grid.py b/phy/plot/tests/test_grid.py index 3b9bc395f..55ec498b0 100644 --- a/phy/plot/tests/test_grid.py +++ b/phy/plot/tests/test_grid.py @@ -10,6 +10,7 @@ from itertools import product import numpy as np +from vispy.util import keys from pytest import yield_fixture from ..base import BaseVisual @@ -79,6 +80,24 @@ def grid(qtbot, canvas, visual): # Test grid #------------------------------------------------------------------------------ -def test_grid_1(qtbot, visual, grid): - pass - # qtbot.stop() +def test_grid_1(qtbot, canvas, visual, grid): + + # Zoom with the keyboard. + canvas.events.key_press(key=keys.Key('+')) + assert grid.zoom > 1 + + # Unzoom with the keyboard. + canvas.events.key_press(key=keys.Key('-')) + assert grid.zoom == 1. + + # Set the zoom explicitly. + grid.zoom = 2 + assert grid.zoom == 2. + + # No effect with modifiers. + canvas.events.key_press(key=keys.Key('r'), modifiers=(keys.CONTROL,)) + assert grid.zoom == 2. + + # Press 'R'. + canvas.events.key_press(key=keys.Key('r')) + assert grid.zoom == 1. From 3f50fd32a842378e36f7fa751055e9be4aef3ef3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 17:35:03 +0200 Subject: [PATCH 0414/1059] Rename a_box to a_box_index --- phy/plot/grid.py | 11 ++++++----- phy/plot/tests/test_grid.py | 2 +- phy/plot/tests/test_transform.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/phy/plot/grid.py b/phy/plot/grid.py index d834e7b9f..e16f0abab 100644 --- a/phy/plot/grid.py +++ b/phy/plot/grid.py @@ -20,7 +20,7 @@ class Grid(BaseInteract): """Grid interact. - NOTE: to be used in a grid, a visual must define `a_box`. + NOTE: to be used in a grid, a visual must define `a_box_index`. """ @@ -29,7 +29,7 @@ def __init__(self, shape, box_var=None): self._zoom = 1. # Name of the variable with the box index. - self.box_var = box_var or 'a_box' + self.box_var = box_var or 'a_box_index' self.shape = shape assert len(shape) == 2 @@ -41,11 +41,12 @@ def __init__(self, shape, box_var=None): self.transforms = [Scale(scale='u_zoom'), Scale(scale=(m, m)), Clip(bounds=[-m, -m, m, m]), - Subplot(shape=shape, index='a_box'), + Subplot(shape=shape, index='a_box_index'), ] - self.vertex_decl = 'attribute vec2 a_box;\nuniform float u_zoom;\n' + self.vertex_decl = ('attribute vec2 a_box_index;\n' + 'uniform float u_zoom;\n') self.data['u_zoom'] = self._zoom - self.data['a_box'] = (0, 0) + self.data['a_box_index'] = (0, 0) @property def zoom(self): diff --git a/phy/plot/tests/test_grid.py b/phy/plot/tests/test_grid.py index 55ec498b0..8a774ab2a 100644 --- a/phy/plot/tests/test_grid.py +++ b/phy/plot/tests/test_grid.py @@ -55,7 +55,7 @@ def set_data(self): position = .1 * coeff * np.random.randn(2 * 3 * n, 2) self.data['a_position'] = position.astype(np.float32) - self.data['a_box'] = box.astype(np.float32) + self.data['a_box_index'] = box.astype(np.float32) @yield_fixture diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 0652e3657..bbbc88b29 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -214,7 +214,7 @@ def test_transform_chain_complete(array): t.add([Range(from_bounds=[-3, -3, 1, 1]), GPU(), Clip(), - Subplot(shape='u_shape', index='a_box'), + Subplot(shape='u_shape', index='a_box_index'), ]) assert len(t.cpu_transforms) == 3 @@ -235,7 +235,7 @@ def test_transform_chain_complete(array): } """).strip() vs, fs = t.insert_glsl(vs, fs) - assert 'a_box' in vs + assert 'a_box_index' in vs assert 'v_' in vs assert 'v_' in fs assert 'discard' in fs From 8686156a23fa690a97c0f303e4a23ff127b073d3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 22 Oct 2015 21:40:10 +0200 Subject: [PATCH 0415/1059] Add on_mouse_wheel --- phy/plot/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/phy/plot/base.py b/phy/plot/base.py index d9a018dd1..c49521bbd 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -119,6 +119,11 @@ def on_resize(event): self.size = event.size canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) + @canvas.connect + def on_mouse_wheel(event): + if self._do_show: + self.on_mouse_wheel(event) + @canvas.connect def on_mouse_move(event): if self._do_show: @@ -132,6 +137,9 @@ def on_key_press(event): def on_mouse_move(self, e): pass + def on_mouse_wheel(self, e): + pass + def on_key_press(self, e): pass From 094111dd198a0ffd426c0d161b092f8f104f045b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 23 Oct 2015 20:10:11 +0200 Subject: [PATCH 0416/1059] Try a new API with visuals and interact --- phy/plot/base.py | 283 +++++++++++++-------------------- phy/plot/grid.py | 30 ++-- phy/plot/panzoom.py | 27 ++-- phy/plot/tests/test_base.py | 82 +++++----- phy/plot/tests/test_grid.py | 34 ++-- phy/plot/tests/test_panzoom.py | 32 ++-- 6 files changed, 213 insertions(+), 275 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index c49521bbd..380c6110e 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -12,7 +12,7 @@ from vispy import gloo from vispy.app import Canvas -from .transform import TransformChain +from .transform import TransformChain, GPU from .utils import _load_shader from phy.utils import EventEmitter @@ -45,65 +45,50 @@ class BaseVisual(object): """ gl_primitive_type = None - vertex = None - fragment = None - shader_name = None # Use this to load shaders from the glsl/ library. + shader_name = None def __init__(self): - if self.shader_name: - self.vertex = _load_shader(self.shader_name + '.vert') - self.fragment = _load_shader(self.shader_name + '.frag') - assert self.vertex - assert self.fragment - assert self.gl_primitive_type - - self.size = 1, 1 - self._canvas = None + # This will be set by attach(). self.program = None - # Not taken into account when the program has not been built. - self._do_show = True - - # To set in `set_data()`. - self.data = {} # Data to set on the program when possible. - self.transforms = [] - # Combine the visual's transforms and the interact transforms. - # The interact triggers the creation of the transform chain in - # self.build_program(). - self.transform_chain = None + # To override + # ------------------------------------------------------------------------- - def show(self): - self._do_show = True + def get_shaders(self): + assert self.shader_name + return (_load_shader(self.shader_name + '.vert'), + _load_shader(self.shader_name + '.frag')) - def hide(self): - self._do_show = False + def get_transforms(self): + return [GPU()] def set_data(self): - """Set the data for the visual. + """Set data to the program. - Derived classes can add data to the `self.data` dictionary and - set transforms in the `self.transforms` list. + Must be called *after* attach(canvas), because the program is built + when the visual is attached to the canvas. """ - pass + raise NotImplementedError() + + # Public methods + # ------------------------------------------------------------------------- + + def apply_cpu_transforms(self, data): + return TransformChain(self.get_transforms()).apply(data) - def attach(self, canvas, interact='BaseInteract'): + def attach(self, canvas): """Attach the visual to a canvas. - The interact's name can be specified. The interact's transforms - will be appended to the visual's transforms. + After calling this method, the following properties are available: + + * self.program """ - logger.debug("Attach `%s` with interact `%s` to canvas.", - self.__class__.__name__, interact or '') - self._canvas = canvas + logger.debug("Attach `%s` to canvas.", self.__class__.__name__) - # Used when the canvas requests all attached visuals - # for the given interact. - @canvas.connect_ - def on_get_visual_for_interact(interact_req): - if interact_req == interact: - return self + self.program = canvas.interact.build_program(self) + # self.transform_chain = canvas.interact.transform_chain # NOTE: this is connect_ and not connect because we're using # phy's event system, not VisPy's. The reason is that the order @@ -111,28 +96,19 @@ def on_get_visual_for_interact(interact_req): # to draw visuals in the order they are attached. @canvas.connect_ def on_draw(): - self.draw() + self.on_draw() @canvas.connect def on_resize(event): """Resize the OpenGL context.""" - self.size = event.size canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) - @canvas.connect - def on_mouse_wheel(event): - if self._do_show: - self.on_mouse_wheel(event) - - @canvas.connect - def on_mouse_move(event): - if self._do_show: - self.on_mouse_move(event) + canvas.connect(self.on_mouse_wheel) + canvas.connect(self.on_mouse_move) + canvas.connect(self.on_key_press) - @canvas.connect - def on_key_press(event): - if self._do_show: - self.on_key_press(event) + # HACK: allow a visual to update the canvas it is attached to. + self.update = canvas.update def on_mouse_move(self, e): pass @@ -143,92 +119,14 @@ def on_mouse_wheel(self, e): def on_key_press(self, e): pass - def build_program(self, transforms=None, vertex_decl='', frag_decl=''): - """Create the gloo program by specifying the transforms - given by the optionally-attached interact. - - This function also uploads all variables set in `self.data` in - `self.set_data()`. - - This function is called by the interact's `build_programs()` method - during the draw event (only effective the first time necessary). - - """ - transforms = transforms or [] - assert self.program is None, "The program has already been built." - - # Build the transform chain using the visuals transforms first, - # and the interact's transforms then. - self.transform_chain = TransformChain(self.transforms + transforms) - - logger.debug("Build the program of `%s`.", self.__class__.__name__) - if self.transform_chain: - # Insert the interact's GLSL into the shaders. - self.vertex, self.fragment = self.transform_chain.insert_glsl( - self.vertex, self.fragment) - # Insert shader declarations. - self.vertex = vertex_decl + '\n' + self.vertex - self.fragment = frag_decl + '\n' + self.fragment - logger.log(5, "Vertex shader: \n%s", self.vertex) - logger.log(5, "Fragment shader: \n%s", self.fragment) - self.program = gloo.Program(self.vertex, self.fragment) - - if not self.transform_chain.transformed_var_name: - logger.debug("No transformed variable has been found.") - # Upload the data if necessary. - self._upload_data() - - def _upload_data(self): - """Upload pending data (attributes and uniforms) before drawing.""" - if not self.data: - return - - # Get the name of the variable that needs to be transformed. - # This variable (typically a_position) comes from the vertex shader - # which contains the string `gl_Position = transform(the_name);`. - var = self.transform_chain.transformed_var_name - - logger.log(5, "Upload program objects %s.", - ', '.join(self.data.keys())) - for name, value in self.data.items(): - # Normalize the value that needs to be transformed. - if name == var: - value = self.transform_chain.apply(value) - self.program[name] = value - self.data.clear() - - def draw(self): + def on_draw(self): """Draw the visual.""" # Skip the drawing if the program hasn't been built yet. - # The program is built by the attached interact. - if self._do_show and self.program: - # Upload pending data. - self._upload_data() - # Finally, draw the program. + # The program is built by the interact. + if self.program: + # Draw the program. self.program.draw(self.gl_primitive_type) - def update(self): - """Trigger a draw event in the canvas from the visual.""" - if self._canvas: - self._canvas.update() - - -class BaseCanvas(Canvas): - """A blank VisPy canvas with a custom event system that keeps the order.""" - def __init__(self, *args, **kwargs): - super(BaseCanvas, self).__init__(*args, **kwargs) - self._events = EventEmitter() - - def connect_(self, *args, **kwargs): - return self._events.connect(*args, **kwargs) - - def emit_(self, *args, **kwargs): - return self._events.emit(*args, **kwargs) - - def on_draw(self, e): - gloo.clear() - self._events.emit('draw') - class BaseInteract(object): """Implement interactions for a set of attached visuals in a canvas. @@ -238,13 +136,28 @@ class BaseInteract(object): * Define a list of `transforms` """ - transforms = None - vertex_decl = '' - frag_decl = '' - def __init__(self): self._canvas = None - self.data = {} + # List of attached visuals. + self.visuals = [] + + # To override + # ------------------------------------------------------------------------- + + def get_shader_declarations(self): + return '', '' + + def get_transforms(self): + return [] + + def update_program(self, program): + pass + + # Public methods + # ------------------------------------------------------------------------- + + def get_visuals(self): + return self.visuals @property def size(self): @@ -254,12 +167,6 @@ def attach(self, canvas): """Attach the interact to a canvas.""" self._canvas = canvas - @canvas.connect_ - def on_draw(): - # Build the programs of all attached visuals. - # Programs that are already built are skipped. - self.build_programs() - canvas.connect(self.on_resize) canvas.connect(self.on_mouse_move) canvas.connect(self.on_mouse_wheel) @@ -269,31 +176,37 @@ def is_attached(self): """Whether the transform is attached to a canvas.""" return self._canvas is not None - def iter_attached_visuals(self): - """Yield all visuals attached to that interact in the canvas.""" - if self._canvas: - for visual in self._canvas.emit_('get_visual_for_interact', - self.__class__.__name__): - if visual: - yield visual + def build_program(self, visual): + """Create the gloo program of a visual using the interact's + transforms. - def build_programs(self): - """Build the programs of all attached visuals. - - The list of transforms of the interact should have been set before - calling this function. + This method is called when a visual is attached to the canvas. """ - for visual in self.iter_attached_visuals(): - if not visual.program: - # Use the interact's data by default. - for n, v in self.data.items(): - if n not in visual.data: - visual.data[n] = v - visual.build_program(self.transforms, - vertex_decl=self.vertex_decl, - frag_decl=self.frag_decl, - ) + assert visual.program is None, "The program has already been built." + assert visual not in self.visuals + self.visuals.append(visual) + + # Build the transform chain using the visuals transforms first, + # then the interact's transforms. + transform_chain = TransformChain(visual.get_transforms() + + self.get_transforms()) + + logger.debug("Build the program of `%s`.", self.__class__.__name__) + # Insert the interact's GLSL into the shaders. + vertex, fragment = visual.get_shaders() + vertex, fragment = transform_chain.insert_glsl(vertex, fragment) + + # Insert shader declarations. + vertex_decl, frag_decl = self.get_shader_declarations() + vertex = vertex_decl + '\n' + vertex + fragment = frag_decl + '\n' + fragment + logger.log(5, "Vertex shader: \n%s", vertex) + logger.log(5, "Fragment shader: \n%s", fragment) + + program = gloo.Program(vertex, fragment) + self.update_program(program) + return program def on_resize(self, event): pass @@ -308,6 +221,28 @@ def on_key_press(self, event): pass def update(self): - """Update the attached canvas if it exists.""" + """Update the attached canvas and all attached programs.""" if self.is_attached(): + for visual in self.get_visuals(): + self.update_program(visual.program) self._canvas.update() + + +class BaseCanvas(Canvas): + """A blank VisPy canvas with a custom event system that keeps the order.""" + def __init__(self, *args, **kwargs): + # Set the interact. + self.interact = kwargs.pop('interact', BaseInteract()) + super(BaseCanvas, self).__init__(*args, **kwargs) + self._events = EventEmitter() + self.interact.attach(self) + + def connect_(self, *args, **kwargs): + return self._events.connect(*args, **kwargs) + + def emit_(self, *args, **kwargs): + return self._events.emit(*args, **kwargs) + + def on_draw(self, e): + gloo.clear() + self._events.emit('draw') diff --git a/phy/plot/grid.py b/phy/plot/grid.py index e16f0abab..8d6a20c06 100644 --- a/phy/plot/grid.py +++ b/phy/plot/grid.py @@ -36,17 +36,26 @@ def __init__(self, shape, box_var=None): assert shape[0] >= 1 assert shape[1] >= 1 + def get_shader_declarations(self): + return ('attribute vec2 a_box_index;\n' + 'uniform float u_zoom;\n', '') + + def get_transforms(self): # Define the grid transform and clipping. m = 1. - .05 # Margin. - self.transforms = [Scale(scale='u_zoom'), - Scale(scale=(m, m)), - Clip(bounds=[-m, -m, m, m]), - Subplot(shape=shape, index='a_box_index'), - ] - self.vertex_decl = ('attribute vec2 a_box_index;\n' - 'uniform float u_zoom;\n') - self.data['u_zoom'] = self._zoom - self.data['a_box_index'] = (0, 0) + return [Scale(scale='u_zoom'), + Scale(scale=(m, m)), + Clip(bounds=[-m, -m, m, m]), + Subplot(shape=self.shape, index='a_box_index'), + ] + + def update_program(self, program): + program['u_zoom'] = self._zoom + # Only set the default box index if necessary. + try: + program['a_box_index'] + except KeyError: + program['a_box_index'] = (0, 0) @property def zoom(self): @@ -57,8 +66,7 @@ def zoom(self): def zoom(self, value): """Zoom level.""" self._zoom = value - for visual in self.iter_attached_visuals(): - visual.data['u_zoom'] = value + self.update() def on_key_press(self, event): """Pan and zoom with the keyboard.""" diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 8a2f6e527..a06f485ac 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -97,12 +97,17 @@ def __init__(self, self._zoom_to_pointer = True self._canvas_aspect = np.ones(2) - # We define the GLSL uniforms used for pan and zoom. - self.transforms = [Translate(translate='u_pan'), - Scale(scale='u_zoom')] - self.vertex_decl = 'uniform vec2 u_pan;\nuniform vec2 u_zoom;\n' - self.data['u_pan'] = pan - self.data['u_zoom'] = zoom + def get_shader_declarations(self): + return 'uniform vec2 u_pan;\nuniform vec2 u_zoom;\n', '' + + def get_transforms(self): + return [Translate(translate='u_pan'), + Scale(scale='u_zoom')] + + def update_program(self, program): + zoom = self._zoom_aspect() + program['u_pan'] = self._pan + program['u_zoom'] = zoom # Various properties # ------------------------------------------------------------------------- @@ -186,12 +191,6 @@ def zmax(self, value): # Internal methods # ------------------------------------------------------------------------- - def _apply_pan_zoom(self): - zoom = self._zoom_aspect() - for visual in self.iter_attached_visuals(): - visual.data['u_pan'] = self._pan - visual.data['u_zoom'] = zoom - def _zoom_aspect(self, zoom=None): zoom = zoom if zoom is not None else self._zoom zoom = _as_array(zoom) @@ -246,7 +245,7 @@ def pan(self, value): assert len(value) == 2 self._pan[:] = value self._constrain_pan() - self._apply_pan_zoom() + self.update() @property def zoom(self): @@ -265,7 +264,7 @@ def zoom(self, value): self._constrain_pan() self._constrain_zoom() - self._apply_pan_zoom() + self.update() def pan_delta(self, d): """Pan the view by a given amount.""" diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 55bf6077d..164a9c1e7 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -9,7 +9,7 @@ import numpy as np -from ..base import BaseVisual, BaseInteract +from ..base import BaseCanvas, BaseVisual, BaseInteract from ..transform import (subplot_bounds, Translate, Scale, Range, Clip, Subplot, GPU) @@ -25,14 +25,14 @@ class TestVisual(BaseVisual): gl_primitive_type = 'lines' def set_data(self): - self.data['a_position'] = [[-1, 0, 0], [1, 0, 0]] - self.data['n_rows'] = 1 + self.program['a_position'] = [[-1, 0, 0], [1, 0, 0]] + self.program['n_rows'] = 1 v = TestVisual() - v.set_data() # We need to build the program explicitly when there is no interact. v.attach(canvas) - v.build_program() + # Must be called *after* attach(). + v.set_data() canvas.show() qtbot.waitForWindowShown(canvas.native) @@ -56,21 +56,18 @@ class TestVisual(BaseVisual): """ gl_primitive_type = 'lines' - def __init__(self): - super(TestVisual, self).__init__() - self.set_data() + def get_shaders(self): + return self.vertex, self.fragment def set_data(self): - self.data['a_position'] = [[-1, 0], [1, 0]] + self.program['a_position'] = [[-1, 0], [1, 0]] v = TestVisual() # We need to build the program explicitly when there is no interact. v.attach(canvas) - v.build_program() + v.set_data() canvas.show() - v.hide() - v.show() qtbot.waitForWindowShown(canvas.native) # qtbot.stop() @@ -97,24 +94,22 @@ class TestVisual(BaseVisual): """ gl_primitive_type = 'lines' - def __init__(self): - super(TestVisual, self).__init__() - self.set_data() + def get_shaders(self): + return self.vertex, self.fragment + + def get_transforms(self): + return [Scale(scale=(.5, 1))] def set_data(self): - self.data['a_position'] = [[-1, 0], [1, 0]] - self.transforms = [Scale(scale=(.5, 1))] + self.program['a_position'] = [[-1, 0], [1, 0]] # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() v.attach(canvas) - - # Base interact (no transform). - interact = BaseInteract() - interact.attach(canvas) + v.set_data() canvas.show() - assert interact.size[0] >= 1 + assert canvas.interact.size[0] >= 1 qtbot.waitForWindowShown(canvas.native) # qtbot.stop() @@ -138,36 +133,37 @@ class TestVisual(BaseVisual): """ gl_primitive_type = 'points' - def __init__(self): - super(TestVisual, self).__init__() - self.set_data() + def get_shaders(self): + return self.vertex, self.fragment + + def get_transforms(self): + return [Scale(scale=(.1, .1)), + Translate(translate=(-1, -1)), + GPU(), + Range(from_bounds=(-1, -1, 1, 1), + to_bounds=(-1.5, -1.5, 1.5, 1.5), + ), + ] def set_data(self): - self.data['a_position'] = np.random.uniform(0, 20, (100000, 2)) - self.transforms = [Scale(scale=(.1, .1)), - Translate(translate=(-1, -1)), - GPU(), - Range(from_bounds=(-1, -1, 1, 1), - to_bounds=(-1.5, -1.5, 1.5, 1.5), - ), - ] + data = np.random.uniform(0, 20, (1000, 2)).astype(np.float32) + self.program['a_position'] = self.apply_cpu_transforms(data) class TestInteract(BaseInteract): - def __init__(self): - super(TestInteract, self).__init__() + def get_transforms(self): bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) - self.transforms = [Subplot(shape=(2, 3), index=(1, 2)), - Clip(bounds=bounds), - ] + return [Subplot(shape=(2, 3), index=(1, 2)), + Clip(bounds=bounds), + ] + + canvas = BaseCanvas(keys='interactive', interact=TestInteract()) # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() - v.attach(canvas, 'TestInteract') - - # Base interact (no transform). - interact = TestInteract() - interact.attach(canvas) + v.attach(canvas) + v.set_data() canvas.show() qtbot.waitForWindowShown(canvas.native) # qtbot.stop() + canvas.close() diff --git a/phy/plot/tests/test_grid.py b/phy/plot/tests/test_grid.py index 8a774ab2a..d814a53df 100644 --- a/phy/plot/tests/test_grid.py +++ b/phy/plot/tests/test_grid.py @@ -13,9 +13,8 @@ from vispy.util import keys from pytest import yield_fixture -from ..base import BaseVisual +from ..base import BaseVisual, BaseCanvas from ..grid import Grid -from ..transform import GPU #------------------------------------------------------------------------------ @@ -37,10 +36,8 @@ class MyTestVisual(BaseVisual): """ gl_primitive_type = 'points' - def __init__(self): - super(MyTestVisual, self).__init__() - self.transforms = [GPU()] - self.set_data() + def get_shaders(self): + return self.vertex, self.fragment def set_data(self): n = 1000 @@ -54,33 +51,34 @@ def set_data(self): position = .1 * coeff * np.random.randn(2 * 3 * n, 2) - self.data['a_position'] = position.astype(np.float32) - self.data['a_box_index'] = box.astype(np.float32) + self.program['a_position'] = position.astype(np.float32) + self.program['a_box_index'] = box.astype(np.float32) @yield_fixture -def visual(): - yield MyTestVisual() +def canvas(qapp): + c = BaseCanvas(keys='interactive', interact=Grid(shape=(2, 3))) + yield c + c.close() @yield_fixture -def grid(qtbot, canvas, visual): - visual.attach(canvas, 'Grid') - - grid = Grid(shape=(2, 3)) - grid.attach(canvas) +def grid(qtbot, canvas): + visual = MyTestVisual() + visual.attach(canvas) + visual.set_data() canvas.show() qtbot.waitForWindowShown(canvas.native) - yield grid + yield canvas.interact #------------------------------------------------------------------------------ # Test grid #------------------------------------------------------------------------------ -def test_grid_1(qtbot, canvas, visual, grid): +def test_grid_1(qtbot, canvas, grid): # Zoom with the keyboard. canvas.events.key_press(key=keys.Key('+')) @@ -101,3 +99,5 @@ def test_grid_1(qtbot, canvas, visual, grid): # Press 'R'. canvas.events.key_press(key=keys.Key('r')) assert grid.zoom == 1. + + # qtbot.stop() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index c9b43f6ab..d13231fd9 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -12,9 +12,8 @@ from vispy.util import keys from pytest import yield_fixture -from ..base import BaseVisual +from ..base import BaseVisual, BaseCanvas from ..panzoom import PanZoom -from ..transform import GPU #------------------------------------------------------------------------------ @@ -35,31 +34,30 @@ class MyTestVisual(BaseVisual): """ gl_primitive_type = 'lines' - def __init__(self): - super(MyTestVisual, self).__init__() - self.transforms = [GPU()] - self.set_data() + def get_shaders(self): + return self.vertex, self.fragment def set_data(self): - self.data['a_position'] = [[-1, 0], [1, 0]] + self.program['a_position'] = [[-1, 0], [1, 0]] @yield_fixture -def visual(): - yield MyTestVisual() +def canvas(qapp): + c = BaseCanvas(keys='interactive', interact=PanZoom()) + yield c + c.close() @yield_fixture -def panzoom(qtbot, canvas, visual): - visual.attach(canvas, 'PanZoom') - - pz = PanZoom() - pz.attach(canvas) +def panzoom(qtbot, canvas): + visual = MyTestVisual() + visual.attach(canvas) + visual.set_data() canvas.show() qtbot.waitForWindowShown(canvas.native) - yield pz + yield canvas.interact #------------------------------------------------------------------------------ @@ -87,7 +85,7 @@ def test_panzoom_basic_attrs(): setattr(pz, name, v * 2) assert getattr(pz, name) == v * 2 - assert list(pz.iter_attached_visuals()) == [] + assert list(pz.get_visuals()) == [] def test_panzoom_basic_pan_zoom(): @@ -183,6 +181,8 @@ def test_panzoom_pan_mouse(qtbot, canvas, panzoom): modifiers=(keys.CONTROL,)) assert pz.pan == [0, 0] + # qtbot.stop() + def test_panzoom_pan_keyboard(qtbot, canvas, panzoom): pz = panzoom From b4d3684f060a18085f2e1f8dd6fd534c3b222490 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 23 Oct 2015 20:11:28 +0200 Subject: [PATCH 0417/1059] Increase coverage --- phy/plot/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 380c6110e..8ae9bbc5d 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -240,7 +240,7 @@ def __init__(self, *args, **kwargs): def connect_(self, *args, **kwargs): return self._events.connect(*args, **kwargs) - def emit_(self, *args, **kwargs): + def emit_(self, *args, **kwargs): # pragma: no cover return self._events.emit(*args, **kwargs) def on_draw(self, e): From 98af90b208a283a011a4c064eff98989df87b046 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 12:14:18 +0200 Subject: [PATCH 0418/1059] Add canvas_pz fixture --- phy/plot/base.py | 17 +++++++---------- phy/plot/tests/conftest.py | 9 +++++++++ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 8ae9bbc5d..e27f60a25 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -32,16 +32,13 @@ class BaseVisual(object): Derived classes must implement: * `gl_primitive_type`: `lines`, `points`, etc. - * `vertex` and `fragment`, or `shader_name`: the GLSL code, or the name of - the GLSL files to load from the `glsl/` subdirectory. - `shader_name` - * `data`: a dictionary acting as a proxy for the gloo Program. - This is because the Program is built later, once the interact has been - attached. The interact is responsible for the creation of the program, - since it implements a part of the transform chain. - * `transforms`: a list of `Transform` instances, which can act on the CPU - or the GPU. The interact's transforms will be appended to that list - when the visual is attached to the canvas. + * `get_shaders()`: return the vertex and fragment shaders, or just + `shader_name` for built-in shaders + * `get_transforms()`: return a list of `Transform` instances, which + can act on the CPU or the GPU. The interact's transforms will be + appended to that list when the visual is attached to the canvas. + * `set_data()`: has access to `self.program`. Must be called after + `attach()`. """ gl_primitive_type = None diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index 306173b3e..89d9fc31b 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -10,6 +10,7 @@ from pytest import yield_fixture from ..base import BaseCanvas +from ..panzoom import PanZoom #------------------------------------------------------------------------------ @@ -22,3 +23,11 @@ def canvas(qapp): c = BaseCanvas(keys='interactive') yield c c.close() + + +@yield_fixture +def canvas_pz(qapp): + use_app('pyqt4') + c = BaseCanvas(keys='interactive', interact=PanZoom()) + yield c + c.close() From a1fe33d87f8a82b5a0f48af505d12b01d82f3509 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 12:31:08 +0200 Subject: [PATCH 0419/1059] Start scatter visual --- phy/plot/glsl/scatter.frag | 14 +++ phy/plot/glsl/scatter.vert | 19 +++++ phy/plot/tests/test_visuals.py | 59 +++++++++++++ phy/plot/visuals.py | 151 +++++++++++++++++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 phy/plot/glsl/scatter.frag create mode 100644 phy/plot/glsl/scatter.vert create mode 100644 phy/plot/tests/test_visuals.py create mode 100644 phy/plot/visuals.py diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag new file mode 100644 index 000000000..f994e80d6 --- /dev/null +++ b/phy/plot/glsl/scatter.frag @@ -0,0 +1,14 @@ +#include "markers/%MARKER_TYPE.glsl" +#include "filled_antialias.glsl" +#include "grid.glsl" + +varying vec4 v_color; +varying float v_size; + +void main() +{ + vec2 P = gl_PointCoord.xy - vec2(0.5,0.5); + float point_size = v_size + 2. * (1.0 + 1.5*1.0); + float distance = marker_%MARKER_TYPE(P*point_size, v_size); + gl_FragColor = filled(distance, 1.0, 1.0, v_color); +} diff --git a/phy/plot/glsl/scatter.vert b/phy/plot/glsl/scatter.vert new file mode 100644 index 000000000..c241fca42 --- /dev/null +++ b/phy/plot/glsl/scatter.vert @@ -0,0 +1,19 @@ +attribute vec3 a_position; +attribute vec4 a_color; +attribute float a_size; + +varying vec4 v_color; +varying float v_size; + +void main() { + vec2 xy = a_position.xy; + gl_Position = transform(xy); + gl_Position.z = a_position.z; + + // Point size as a function of the marker size and antialiasing. + gl_PointSize = a_size + 5.0; + + // Set the varyings. + v_color = a_color; + v_size = a_size; +} diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py new file mode 100644 index 000000000..8e0b6e528 --- /dev/null +++ b/phy/plot/tests/test_visuals.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +"""Test visuals.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from pytest import mark + +from ..visuals import ScatterVisual + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +#------------------------------------------------------------------------------ +# Test visuals +#------------------------------------------------------------------------------ + +@mark.parametrize('marker_type', ScatterVisual._supported_marker_types) +def test_scatter_markers(qtbot, canvas_pz, marker_type): + + # Try all marker types. + v = ScatterVisual(marker_type=marker_type) + v.attach(canvas_pz) + + n = 100 + pos = .2 * np.random.randn(n, 2) + v.set_data(pos=pos) + + canvas_pz.show() + # qtbot.stop() + + +def test_scatter_custom(qtbot, canvas_pz): + + v = ScatterVisual() + v.attach(canvas_pz) + + n = 100 + + # Random position. + pos = .2 * np.random.randn(n, 2) + + # Random colors. + c = np.random.uniform(.4, .7, size=(n, 4)) + c[:, -1] = .5 + + # Random sizes + s = 5 + 20 * np.random.rand(n) + + v.set_data(pos=pos, colors=c, size=s) + + canvas_pz.show() + # qtbot.stop() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py new file mode 100644 index 000000000..4dc40f789 --- /dev/null +++ b/phy/plot/visuals.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +"""Common visuals.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +from .base import BaseVisual +from .transform import Range, GPU +from .utils import _enable_depth_mask + + +#------------------------------------------------------------------------------ +# Visuals +#------------------------------------------------------------------------------ + +class ScatterVisual(BaseVisual): + shader_name = 'scatter' + gl_primitive_type = 'points' + _default_marker_size = 10. + _supported_marker_types = ( + 'arrow', + 'asterisk', + 'chevron', + 'clover', + 'club', + 'cross', + 'diamond', + 'disc', + 'ellipse', + 'hbar', + 'heart', + 'infinity', + 'pin', + 'ring', + 'spade', + 'square', + 'tag', + 'triangle', + 'vbar', + ) + + def __init__(self, marker_type=None): + super(ScatterVisual, self).__init__() + # Default bounds. + self.data_bounds = [-1, -1, 1, 1] + self.n_points = None + + # Set the marker type. + self.marker_type = marker_type or 'disc' + assert self.marker_type in self._supported_marker_types + + # Enable transparency. + _enable_depth_mask() + + def get_shaders(self): + v, f = super(ScatterVisual, self).get_shaders() + # Replace the marker type in the shader. + f = f.replace('%MARKER_TYPE', self.marker_type) + return v, f + + def get_transforms(self): + return [Range(from_bounds=self.data_bounds), GPU()] + + def set_data(self, + pos=None, + depth=None, + colors=None, + marker_type=None, + size=None, + data_bounds=None, + ): + assert pos is not None + pos = np.asarray(pos) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + n = pos.shape[0] + + # Set the data bounds from the data. + if data_bounds is None: + m, M = pos.min(axis=0), pos.max(axis=0) + data_bounds = [m[0], m[1], M[0], M[1]] + assert len(data_bounds) == 4 + assert data_bounds[0] < data_bounds[2] + assert data_bounds[1] < data_bounds[3] + + # Set the transformed position. + pos_tr = self.apply_cpu_transforms(pos) + pos_tr = np.asarray(pos_tr, dtype=np.float32) + assert pos_tr.shape == (n, 2) + + # Set the depth. + if depth is None: + depth = np.zeros(n, dtype=np.float32) + depth = np.asarray(depth, dtype=np.float32) + assert depth.shape == (n,) + + pos_depth = np.empty((n, 3), dtype=np.float32) + pos_depth[:, :2] = pos_tr + pos_depth[:, 2] = depth + self.program['a_position'] = pos_depth + + # Set the marker size. + if size is None: + size = self._default_marker_size * np.ones(n, dtype=np.float32) + size = np.asarray(size, dtype=np.float32) + self.program['a_size'] = size + + # Set the group colors. + if colors is None: + colors = np.ones((n, 4), dtype=np.float32) + colors = np.asarray(colors, dtype=np.float32) + assert colors.shape == (n, 4) + self.program['a_color'] = colors + + +class PlotVisual(BaseVisual): + shader_name = 'plot' + gl_primitive_type = 'lines' + + def get_transforms(self): + pass + + def set_data(self): + pass + + +class HistogramVisual(BaseVisual): + shader_name = 'plot' + gl_primitive_type = 'triangles' + + def get_transforms(self): + pass + + def set_data(self): + pass + + +class TextVisual(BaseVisual): + shader_name = 'text' + gl_primitive_type = 'points' + + def get_transforms(self): + pass + + def set_data(self): + pass From 81487e39b71d4602ab506daced1c1416479b5b4a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 13:23:24 +0200 Subject: [PATCH 0420/1059] Add scatter empty test --- phy/plot/tests/test_visuals.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 8e0b6e528..8e7605405 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -21,6 +21,19 @@ # Test visuals #------------------------------------------------------------------------------ +def test_scatter_empty(qtbot, canvas): + + v = ScatterVisual() + v.attach(canvas) + + n = 0 + pos = np.zeros((n, 2)) + v.set_data(pos=pos) + + canvas.show() + qtbot.stop() + + @mark.parametrize('marker_type', ScatterVisual._supported_marker_types) def test_scatter_markers(qtbot, canvas_pz, marker_type): From c8176d30cee3785d2101c2a2f99df8edeb9db8fc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 14:25:46 +0200 Subject: [PATCH 0421/1059] Add plot visual --- phy/plot/glsl/plot.frag | 11 +++ phy/plot/glsl/plot.vert | 26 +++++++ phy/plot/glsl/utils.glsl | 4 + phy/plot/tests/test_visuals.py | 73 +++++++++++++++++- phy/plot/visuals.py | 133 ++++++++++++++++++++++++++++++--- 5 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 phy/plot/glsl/plot.frag create mode 100644 phy/plot/glsl/plot.vert create mode 100644 phy/plot/glsl/utils.glsl diff --git a/phy/plot/glsl/plot.frag b/phy/plot/glsl/plot.frag new file mode 100644 index 000000000..09faf3bde --- /dev/null +++ b/phy/plot/glsl/plot.frag @@ -0,0 +1,11 @@ +varying vec4 v_color; +varying float v_signal_index; + +void main() { + + // Discard pixels between signals. + if (fract(v_signal_index) > 0.) + discard; + + gl_FragColor = v_color; +} diff --git a/phy/plot/glsl/plot.vert b/phy/plot/glsl/plot.vert new file mode 100644 index 000000000..090e96ff5 --- /dev/null +++ b/phy/plot/glsl/plot.vert @@ -0,0 +1,26 @@ +#include "utils.glsl" + +attribute vec3 a_position; +attribute float a_signal_index; // 0..n_signals-1 + +uniform sampler2D u_signal_bounds; +uniform sampler2D u_signal_colors; +uniform float n_signals; + +varying vec4 v_color; +varying float v_signal_index; + +void main() { + // Will be used by the transform. + vec4 signal_bounds = fetch_texture(a_signal_index, + u_signal_bounds, + n_signals); + signal_bounds = (2 * signal_bounds - 1); // See hack in Python. + + vec2 xy = a_position.xy; + gl_Position = transform(xy); + gl_Position.z = a_position.z; + + v_color = fetch_texture(a_signal_index, u_signal_colors, n_signals); + v_signal_index = a_signal_index; +} diff --git a/phy/plot/glsl/utils.glsl b/phy/plot/glsl/utils.glsl new file mode 100644 index 000000000..5ba113336 --- /dev/null +++ b/phy/plot/glsl/utils.glsl @@ -0,0 +1,4 @@ + +vec4 fetch_texture(float index, sampler2D texture, float size) { + return texture2D(texture, vec2(index / (size - 1.), .5)); +} diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 8e7605405..69d4f6610 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,15 +10,16 @@ import numpy as np from pytest import mark -from ..visuals import ScatterVisual +from ..visuals import ScatterVisual, PlotVisual #------------------------------------------------------------------------------ # Fixtures #------------------------------------------------------------------------------ + #------------------------------------------------------------------------------ -# Test visuals +# Test scatter visual #------------------------------------------------------------------------------ def test_scatter_empty(qtbot, canvas): @@ -31,7 +32,7 @@ def test_scatter_empty(qtbot, canvas): v.set_data(pos=pos) canvas.show() - qtbot.stop() + # qtbot.stop() @mark.parametrize('marker_type', ScatterVisual._supported_marker_types) @@ -70,3 +71,69 @@ def test_scatter_custom(qtbot, canvas_pz): canvas_pz.show() # qtbot.stop() + + +#------------------------------------------------------------------------------ +# Test plot visual +#------------------------------------------------------------------------------ + +def test_plot_empty(qtbot, canvas): + + v = PlotVisual() + v.attach(canvas) + + data = np.zeros((1, 0)) + v.set_data(data=data) + + canvas.show() + # qtbot.stop() + + +def test_plot_0(qtbot, canvas_pz): + + v = PlotVisual() + v.attach(canvas_pz) + + data = np.zeros((1, 10)) + v.set_data(data=data) + + canvas_pz.show() + # qtbot.stop() + + +def test_plot_1(qtbot, canvas_pz): + + v = PlotVisual() + v.attach(canvas_pz) + + data = .2 * np.random.randn(1, 10) + v.set_data(data=data) + + canvas_pz.show() + # qtbot.stop() + + +def test_plot_2(qtbot, canvas_pz): + + v = PlotVisual() + v.attach(canvas_pz) + + n_signals = 50 + data = 20 * np.random.randn(n_signals, 10) + + # Signal bounds. + b = np.zeros((n_signals, 4)) + b[:, 0] = -1 + b[:, 1] = np.linspace(-1, 1 - 2. / n_signals, n_signals) + b[:, 2] = 1 + b[:, 3] = np.linspace(-1 + 2. / n_signals, 1., n_signals) + + # Signal colors. + c = np.random.uniform(.5, 1, size=(n_signals, 4)) + c[:, 3] = .5 + + v.set_data(data=data, data_bounds=[-10, 10], + signal_bounds=b, signal_colors=c) + + canvas_pz.show() + # qtbot.stop() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 4dc40f789..feebb0e21 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -8,12 +8,33 @@ #------------------------------------------------------------------------------ import numpy as np +from vispy.gloo import Texture2D from .base import BaseVisual from .transform import Range, GPU from .utils import _enable_depth_mask +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def _check_data_bounds(data_bounds): + assert len(data_bounds) == 4 + assert data_bounds[0] < data_bounds[2] + assert data_bounds[1] < data_bounds[3] + + +def _get_data_bounds(data_bounds, pos): + if not len(pos): + return data_bounds or [-1, -1, 1, 1] + if data_bounds is None: + m, M = pos.min(axis=0), pos.max(axis=0) + data_bounds = [m[0], m[1], M[0], M[1]] + _check_data_bounds(data_bounds) + return data_bounds + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -81,12 +102,7 @@ def set_data(self, n = pos.shape[0] # Set the data bounds from the data. - if data_bounds is None: - m, M = pos.min(axis=0), pos.max(axis=0) - data_bounds = [m[0], m[1], M[0], M[1]] - assert len(data_bounds) == 4 - assert data_bounds[0] < data_bounds[2] - assert data_bounds[1] < data_bounds[3] + self.data_bounds = _get_data_bounds(data_bounds, pos) # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) @@ -99,6 +115,7 @@ def set_data(self, depth = np.asarray(depth, dtype=np.float32) assert depth.shape == (n,) + # Set the a_position attribute. pos_depth = np.empty((n, 3), dtype=np.float32) pos_depth[:, :2] = pos_tr pos_depth[:, 2] = depth @@ -120,13 +137,109 @@ def set_data(self, class PlotVisual(BaseVisual): shader_name = 'plot' - gl_primitive_type = 'lines' + gl_primitive_type = 'line_strip' + + def __init__(self): + super(PlotVisual, self).__init__() + self.data_bounds = [-1, -1, 1, 1] + _enable_depth_mask() def get_transforms(self): - pass + return [Range(from_bounds=self.data_bounds), + GPU(), + Range(from_bounds=(-1, -1, 1, 1), + to_bounds='signal_bounds'), + ] - def set_data(self): - pass + def set_data(self, + data=None, + depth=None, + data_bounds=None, + signal_bounds=None, + signal_colors=None, + ): + assert data is not None + data = np.asarray(data) + assert data.ndim == 2 + n_signals, n_samples = data.shape + n = n_signals * n_samples + + # Generate the x coordinates. + x = np.linspace(-1., 1., n_samples) + x = np.tile(x, n_signals) + assert x.shape == (n,) + + # Generate the signal index. + signal_index = np.arange(n_signals) + signal_index = np.repeat(signal_index, n_samples) + signal_index = signal_index.astype(np.float32) + + # Generate the (n, 2) pos array. + pos = np.empty((n, 2), dtype=np.float32) + pos[:, 0] = x + pos[:, 1] = data.ravel() + + # Generate the complete data_bounds 4-tuple from the specified 2-tuple. + if data_bounds is None: + data_bounds = [data.min(), data.max()] if data.size else [-1, 1] + assert len(data_bounds) == 2 + # Ensure that the data bounds are not degenerate. + if data_bounds[0] == data_bounds[1]: + data_bounds = [data_bounds[0] - 1, data_bounds[0] + 1] + ymin, ymax = data_bounds + data_bounds = [-1, ymin, 1, ymax] + _check_data_bounds(data_bounds) + self.data_bounds = data_bounds + + # Set the transformed position. + pos_tr = self.apply_cpu_transforms(pos) + pos_tr = np.asarray(pos_tr, dtype=np.float32) + assert pos_tr.shape == (n, 2) + + # Set the depth. + if depth is None: + depth = np.zeros(n, dtype=np.float32) + depth = np.asarray(depth, dtype=np.float32) + assert depth.shape == (n,) + + # Set the a_position attribute. + pos_depth = np.empty((n, 3), dtype=np.float32) + pos_depth[:, :2] = pos_tr + pos_depth[:, 2] = depth + self.program['a_position'] = pos_depth + + # Signal index. + self.program['a_signal_index'] = signal_index + + # Signal bounds (positions). + if signal_bounds is None: + signal_bounds = np.tile([-1, -1, 1, 1], (n_signals, 1)) + assert signal_bounds.shape == (n_signals, 4) + # Convert to 3D texture. + signal_bounds = signal_bounds[np.newaxis, ...].astype(np.float32) + assert signal_bounds.shape == (1, n_signals, 4) + # NOTE: we need to cast the texture to [0, 255] (uint8). + # This is easy as soon as we assume that the signal bounds are in + # [-1, 1]. + assert np.all(signal_bounds >= -1) + assert np.all(signal_bounds <= 1) + signal_bounds = 127 * (signal_bounds + 1) + assert np.all(signal_bounds >= 0) + assert np.all(signal_bounds <= 255) + signal_bounds = signal_bounds.astype(np.uint8) + self.program['u_signal_bounds'] = Texture2D(signal_bounds) + + # Signal colors. + if signal_colors is None: + signal_colors = np.ones((n_signals, 4), dtype=np.float32) + assert signal_colors.shape == (n_signals, 4) + # Convert to 3D texture. + signal_colors = signal_colors[np.newaxis, ...].astype(np.float32) + assert signal_colors.shape == (1, n_signals, 4) + self.program['u_signal_colors'] = Texture2D(signal_colors) + + # Number of signals. + self.program['n_signals'] = n_signals class HistogramVisual(BaseVisual): From 44826197dcd0555ef0d9e8891ff3eab84eb5cf52 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 14:59:40 +0200 Subject: [PATCH 0422/1059] Update _tesselate_histogram() --- phy/plot/tests/test_utils.py | 11 ++++++----- phy/plot/utils.py | 27 ++++++++++++++++++--------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 88b00c9f9..4f059b79c 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -11,7 +11,7 @@ import os.path as op import numpy as np -from numpy.testing import assert_array_equal as ae +from numpy.testing import assert_allclose as ac from vispy import config from ..utils import (_load_shader, @@ -40,12 +40,13 @@ def test_create_program(): def test_tesselate_histogram(): - n = 5 + n = 7 hist = np.arange(n) thist = _tesselate_histogram(hist) - assert thist.shape == (5 * n + 1, 2) - ae(thist[0], [-1, -1]) - ae(thist[-1], [1, -1]) + assert thist.shape == (6 * n, 2) + ac(thist[0], [-1., 0]) + ac(thist[-3], [1., n - 1]) + ac(thist[-1], [1., 0]) def test_enable_depth_mask(qtbot, canvas): diff --git a/phy/plot/utils.py b/phy/plot/utils.py index ff89db799..992a88291 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -40,23 +40,32 @@ def _create_program(name): def _tesselate_histogram(hist): + """ + + 2/4 3 + ____ + |\ | + | \ | + | \ | + |___\| + + 0 1/5 + + """ assert hist.ndim == 1 nsamples = len(hist) dx = 2. / nsamples x0 = -1 + dx * np.arange(nsamples) - x = np.zeros(5 * nsamples + 1) - y = -1.0 * np.ones(5 * nsamples + 1) + x = np.zeros(6 * nsamples) + y = np.zeros(6 * nsamples) - x[0:-1:5] = x0 - x[1::5] = x0 - x[2::5] = x0 + dx - x[3::5] = x0 - x[4::5] = x0 + dx - x[-1] = 1 + x[0::2] = np.repeat(x0, 3) + x[1::2] = x[0::2] + dx - y[1::5] = y[2::5] = -1 + 2. * hist + # y[0::6] = y[1::6] = y[5::6] = -1 + y[2::6] = y[3::6] = y[4::6] = hist return np.c_[x, y] From 7691366cef3ce6b20581405e279c563f0b69c885 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 15:31:16 +0200 Subject: [PATCH 0423/1059] Update histogram tesselation --- phy/plot/tests/test_utils.py | 6 +++--- phy/plot/utils.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 4f059b79c..dc07bd8c0 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -44,9 +44,9 @@ def test_tesselate_histogram(): hist = np.arange(n) thist = _tesselate_histogram(hist) assert thist.shape == (6 * n, 2) - ac(thist[0], [-1., 0]) - ac(thist[-3], [1., n - 1]) - ac(thist[-1], [1., 0]) + ac(thist[0], [0, 0]) + ac(thist[-3], [n, n - 1]) + ac(thist[-1], [n, 0]) def test_enable_depth_mask(qtbot, canvas): diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 992a88291..294fe83ad 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -54,17 +54,15 @@ def _tesselate_histogram(hist): """ assert hist.ndim == 1 nsamples = len(hist) - dx = 2. / nsamples - x0 = -1 + dx * np.arange(nsamples) + x0 = np.arange(nsamples) x = np.zeros(6 * nsamples) y = np.zeros(6 * nsamples) x[0::2] = np.repeat(x0, 3) - x[1::2] = x[0::2] + dx + x[1::2] = x[0::2] + 1 - # y[0::6] = y[1::6] = y[5::6] = -1 y[2::6] = y[3::6] = y[4::6] = hist return np.c_[x, y] From 90fb3b18f415e3491d4a6d3dc51014e606e766ad Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 15:53:13 +0200 Subject: [PATCH 0424/1059] Add histogram visual --- phy/plot/glsl/histogram.frag | 5 ++ phy/plot/glsl/histogram.vert | 21 ++++++++ phy/plot/tests/test_visuals.py | 60 ++++++++++++++++++++++- phy/plot/visuals.py | 89 +++++++++++++++++++++++++++++++--- 4 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 phy/plot/glsl/histogram.frag create mode 100644 phy/plot/glsl/histogram.vert diff --git a/phy/plot/glsl/histogram.frag b/phy/plot/glsl/histogram.frag new file mode 100644 index 000000000..7c31f83aa --- /dev/null +++ b/phy/plot/glsl/histogram.frag @@ -0,0 +1,5 @@ +varying vec4 v_color; + +void main() { + gl_FragColor = v_color; +} diff --git a/phy/plot/glsl/histogram.vert b/phy/plot/glsl/histogram.vert new file mode 100644 index 000000000..130ace589 --- /dev/null +++ b/phy/plot/glsl/histogram.vert @@ -0,0 +1,21 @@ +#include "utils.glsl" + +attribute vec2 a_position; +attribute float a_hist_index; // 0..n_hists-1 + +uniform sampler2D u_hist_colors; +uniform sampler2D u_hist_bounds; +uniform float n_hists; + +varying vec4 v_color; +varying float v_hist_index; + +void main() { + vec4 hist_bounds = fetch_texture(a_hist_index, + u_hist_bounds, + n_hists) * 10.; // avoid texture clipping + gl_Position = transform(a_position); + + v_color = fetch_texture(a_hist_index, u_hist_colors, n_hists); + v_hist_index = a_hist_index; +} diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 69d4f6610..68649f710 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,7 @@ import numpy as np from pytest import mark -from ..visuals import ScatterVisual, PlotVisual +from ..visuals import ScatterVisual, PlotVisual, HistogramVisual #------------------------------------------------------------------------------ @@ -137,3 +137,61 @@ def test_plot_2(qtbot, canvas_pz): canvas_pz.show() # qtbot.stop() + + +#------------------------------------------------------------------------------ +# Test histogram visual +#------------------------------------------------------------------------------ + +def test_histogram_empty(qtbot, canvas): + + v = HistogramVisual() + v.attach(canvas) + + hist = np.zeros((1, 0)) + v.set_data(hist=hist) + + canvas.show() + # qtbot.stop() + + +def test_histogram_0(qtbot, canvas_pz): + + v = HistogramVisual() + v.attach(canvas_pz) + + hist = np.zeros((1, 10)) + v.set_data(hist=hist) + + canvas_pz.show() + # qtbot.stop() + + +def test_histogram_1(qtbot, canvas_pz): + + v = HistogramVisual() + v.attach(canvas_pz) + + hist = np.random.rand(1, 10) + v.set_data(hist=hist) + + canvas_pz.show() + # qtbot.stop() + + +def test_histogram_2(qtbot, canvas_pz): + + v = HistogramVisual() + v.attach(canvas_pz) + + n_hists = 5 + hist = np.random.rand(n_hists, 21) + + # Histogram colors. + c = np.random.uniform(.3, .6, size=(n_hists, 4)) + c[:, 3] = 1 + + v.set_data(hist=hist, hist_colors=c, hist_lims=2 * np.ones(n_hists)) + + canvas_pz.show() + # qtbot.stop() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index feebb0e21..16c53b934 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -12,7 +12,7 @@ from .base import BaseVisual from .transform import Range, GPU -from .utils import _enable_depth_mask +from .utils import _enable_depth_mask, _tesselate_histogram #------------------------------------------------------------------------------ @@ -173,6 +173,7 @@ def set_data(self, signal_index = np.arange(n_signals) signal_index = np.repeat(signal_index, n_samples) signal_index = signal_index.astype(np.float32) + self.program['a_signal_index'] = signal_index # Generate the (n, 2) pos array. pos = np.empty((n, 2), dtype=np.float32) @@ -208,9 +209,6 @@ def set_data(self, pos_depth[:, 2] = depth self.program['a_position'] = pos_depth - # Signal index. - self.program['a_signal_index'] = signal_index - # Signal bounds (positions). if signal_bounds is None: signal_bounds = np.tile([-1, -1, 1, 1], (n_signals, 1)) @@ -243,14 +241,89 @@ def set_data(self, class HistogramVisual(BaseVisual): - shader_name = 'plot' + shader_name = 'histogram' gl_primitive_type = 'triangles' + def __init__(self): + super(HistogramVisual, self).__init__() + self.n_bins = 0 + self.hist_max = 1 + def get_transforms(self): - pass + return [Range(from_bounds=[0, 0, self.n_bins, self.hist_max], + to_bounds=[0, 0, 1, 1]), + GPU(), + Range(from_bounds='hist_bounds', # (0, 0, 1, v) + to_bounds=(-1, -1, 1, 1)), + ] - def set_data(self): - pass + def set_data(self, + hist=None, + hist_lims=None, + hist_colors=None, + ): + assert hist is not None + hist = np.atleast_2d(hist) + assert hist.ndim == 2 + n_hists, n_bins = hist.shape + n = 6 * n_hists * n_bins + self.n_bins = n_bins + + # Generate hist_max. + hist_max = hist.max() if hist.size else 1. + hist_max = float(hist_max) + hist_max = hist_max if hist_max > 0 else 1. + assert hist_max > 0 + self.hist_max = hist_max + + # Concatenate all histograms. + pos = np.vstack(_tesselate_histogram(row) for row in hist) + assert pos.shape == (n, 2) + + # Set the transformed position. + pos_tr = self.apply_cpu_transforms(pos) + pos_tr = np.asarray(pos_tr, dtype=np.float32) + assert pos_tr.shape == (n, 2) + self.program['a_position'] = pos_tr + + # Generate the hist index. + hist_index = np.arange(n_hists) + # 6 * n_bins vertices per histogram. + hist_index = np.repeat(hist_index, n_bins * 6) + hist_index = hist_index.astype(np.float32) + assert hist_index.shape == (n,) + self.program['a_hist_index'] = hist_index + + # Hist colors. + if hist_colors is None: + hist_colors = np.ones((n_hists, 4), dtype=np.float32) + assert hist_colors.shape == (n_hists, 4) + # Convert to 3D texture. + hist_colors = hist_colors[np.newaxis, ...].astype(np.float32) + assert hist_colors.shape == (1, n_hists, 4) + self.program['u_hist_colors'] = Texture2D(hist_colors) + + # Hist bounds. + if hist_lims is None: + hist_lims = hist_max * np.ones(n_hists) + hist_lims = np.asarray(hist_lims, dtype=np.float32) + # NOTE: hist_lims is now relative to hist_max (what is on the GPU). + hist_lims = hist_lims / hist_max + assert hist_lims.shape == (n_hists,) + # Now, we create the 4-tuples for the bounds: [0, 0, 1, hists_lim]. + hist_bounds = np.zeros((n_hists, 4), dtype=np.float32) + hist_bounds[:, 2] = 1 + hist_bounds[:, 3] = hist_lims + # Convert to 3D texture. + hist_bounds = hist_bounds[np.newaxis, ...].astype(np.float32) + assert hist_bounds.shape == (1, n_hists, 4) + assert np.all(hist_bounds >= 0) + assert np.all(hist_bounds <= 10) + # NOTE: necessary because VisPy silently clips textures to [0, 1]. + hist_bounds /= 10. + self.program['u_hist_bounds'] = Texture2D(hist_bounds) + + self.program['n_hists'] = n_hists class TextVisual(BaseVisual): From 1086af108256efa2caca803a6894fde94c18179d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 16:35:17 +0200 Subject: [PATCH 0425/1059] Refactor visual tests --- phy/plot/glsl/histogram.vert | 3 +- phy/plot/tests/test_visuals.py | 122 +++++++++------------------------ phy/plot/transform.py | 4 +- 3 files changed, 37 insertions(+), 92 deletions(-) diff --git a/phy/plot/glsl/histogram.vert b/phy/plot/glsl/histogram.vert index 130ace589..c4055ebbd 100644 --- a/phy/plot/glsl/histogram.vert +++ b/phy/plot/glsl/histogram.vert @@ -13,7 +13,8 @@ varying float v_hist_index; void main() { vec4 hist_bounds = fetch_texture(a_hist_index, u_hist_bounds, - n_hists) * 10.; // avoid texture clipping + n_hists); + hist_bounds = hist_bounds * 10.; // NOTE: avoid texture clipping gl_Position = transform(a_position); v_color = fetch_texture(a_hist_index, u_hist_colors, n_hists); diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 68649f710..6168c54a4 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -18,43 +18,34 @@ #------------------------------------------------------------------------------ +def _test_visual(qtbot, c, v, stop=False, **kwargs): + v.attach(c) + v.set_data(**kwargs) + c.show() + if stop: # pragma: no cover + qtbot.stop() + + #------------------------------------------------------------------------------ # Test scatter visual #------------------------------------------------------------------------------ def test_scatter_empty(qtbot, canvas): - - v = ScatterVisual() - v.attach(canvas) - - n = 0 - pos = np.zeros((n, 2)) - v.set_data(pos=pos) - - canvas.show() - # qtbot.stop() + pos = np.zeros((0, 2)) + _test_visual(qtbot, canvas, ScatterVisual(), pos=pos) @mark.parametrize('marker_type', ScatterVisual._supported_marker_types) def test_scatter_markers(qtbot, canvas_pz, marker_type): - - # Try all marker types. - v = ScatterVisual(marker_type=marker_type) - v.attach(canvas_pz) - n = 100 pos = .2 * np.random.randn(n, 2) - v.set_data(pos=pos) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, + ScatterVisual(marker_type=marker_type), + pos=pos) def test_scatter_custom(qtbot, canvas_pz): - v = ScatterVisual() - v.attach(canvas_pz) - n = 100 # Random position. @@ -67,10 +58,8 @@ def test_scatter_custom(qtbot, canvas_pz): # Random sizes s = 5 + 20 * np.random.rand(n) - v.set_data(pos=pos, colors=c, size=s) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, ScatterVisual(), + pos=pos, colors=c, size=s) #------------------------------------------------------------------------------ @@ -78,46 +67,25 @@ def test_scatter_custom(qtbot, canvas_pz): #------------------------------------------------------------------------------ def test_plot_empty(qtbot, canvas): - - v = PlotVisual() - v.attach(canvas) - data = np.zeros((1, 0)) - v.set_data(data=data) - - canvas.show() - # qtbot.stop() + _test_visual(qtbot, canvas, PlotVisual(), + data=data) def test_plot_0(qtbot, canvas_pz): - - v = PlotVisual() - v.attach(canvas_pz) - data = np.zeros((1, 10)) - v.set_data(data=data) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, PlotVisual(), + data=data) def test_plot_1(qtbot, canvas_pz): - - v = PlotVisual() - v.attach(canvas_pz) - data = .2 * np.random.randn(1, 10) - v.set_data(data=data) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, PlotVisual(), + data=data) def test_plot_2(qtbot, canvas_pz): - v = PlotVisual() - v.attach(canvas_pz) - n_signals = 50 data = 20 * np.random.randn(n_signals, 10) @@ -132,11 +100,10 @@ def test_plot_2(qtbot, canvas_pz): c = np.random.uniform(.5, 1, size=(n_signals, 4)) c[:, 3] = .5 - v.set_data(data=data, data_bounds=[-10, 10], - signal_bounds=b, signal_colors=c) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, PlotVisual(), + data=data, data_bounds=[-10, 10], + signal_bounds=b, signal_colors=c, + stop=True) #------------------------------------------------------------------------------ @@ -144,46 +111,25 @@ def test_plot_2(qtbot, canvas_pz): #------------------------------------------------------------------------------ def test_histogram_empty(qtbot, canvas): - - v = HistogramVisual() - v.attach(canvas) - hist = np.zeros((1, 0)) - v.set_data(hist=hist) - - canvas.show() - # qtbot.stop() + _test_visual(qtbot, canvas, HistogramVisual(), + hist=hist) def test_histogram_0(qtbot, canvas_pz): - - v = HistogramVisual() - v.attach(canvas_pz) - hist = np.zeros((1, 10)) - v.set_data(hist=hist) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, HistogramVisual(), + hist=hist) def test_histogram_1(qtbot, canvas_pz): - - v = HistogramVisual() - v.attach(canvas_pz) - hist = np.random.rand(1, 10) - v.set_data(hist=hist) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, HistogramVisual(), + hist=hist) def test_histogram_2(qtbot, canvas_pz): - v = HistogramVisual() - v.attach(canvas_pz) - n_hists = 5 hist = np.random.rand(n_hists, 21) @@ -191,7 +137,5 @@ def test_histogram_2(qtbot, canvas_pz): c = np.random.uniform(.3, .6, size=(n_hists, 4)) c[:, 3] = 1 - v.set_data(hist=hist, hist_colors=c, hist_lims=2 * np.ones(n_hists)) - - canvas_pz.show() - # qtbot.stop() + _test_visual(qtbot, canvas_pz, HistogramVisual(), + hist=hist, hist_colors=c, hist_lims=2 * np.ones(n_hists)) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 3a23b18c7..476c312b2 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -295,8 +295,8 @@ def insert_glsl(self, vertex, fragment): "GLSL insertion.") return vertex, fragment assert r - logger.debug("Found transform placeholder in vertex code: `%s`", - r.group(0)) + logger.log(5, "Found transform placeholder in vertex code: `%s`", + r.group(0)) # Find the GLSL variable with the data (should be a `vec2`). var = r.group(1) From 1042c910615f852c46f24b921d18724a073aee52 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 16:41:58 +0200 Subject: [PATCH 0426/1059] Use global NDC variable --- phy/plot/tests/test_visuals.py | 3 +-- phy/plot/transform.py | 14 +++++++++----- phy/plot/visuals.py | 14 +++++++------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 6168c54a4..0b4f40e04 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -102,8 +102,7 @@ def test_plot_2(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, PlotVisual(), data=data, data_bounds=[-10, 10], - signal_bounds=b, signal_colors=c, - stop=True) + signal_bounds=b, signal_colors=c) #------------------------------------------------------------------------------ diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 476c312b2..8808f3e45 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -98,6 +98,10 @@ def pixels_to_ndc(pos, size=None): return pos +"""Bounds in Normalized Device Coordinates (NDC).""" +NDC = (-1.0, -1.0, +1.0, +1.0) + + #------------------------------------------------------------------------------ # Transforms #------------------------------------------------------------------------------ @@ -145,7 +149,7 @@ def glsl(self, var, scale=None): class Range(BaseTransform): - def apply(self, arr, from_bounds=None, to_bounds=(-1, -1, 1, 1)): + def apply(self, arr, from_bounds=None, to_bounds=NDC): f0 = np.asarray(from_bounds[:2]) f1 = np.asarray(from_bounds[2:]) t0 = np.asarray(to_bounds[:2]) @@ -153,7 +157,7 @@ def apply(self, arr, from_bounds=None, to_bounds=(-1, -1, 1, 1)): return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) - def glsl(self, var, from_bounds=None, to_bounds=(-1, -1, 1, 1)): + def glsl(self, var, from_bounds=None, to_bounds=NDC): assert var from_bounds = _glslify(from_bounds) @@ -165,14 +169,14 @@ def glsl(self, var, from_bounds=None, to_bounds=(-1, -1, 1, 1)): class Clip(BaseTransform): - def apply(self, arr, bounds=(-1, -1, 1, 1)): + def apply(self, arr, bounds=NDC): index = ((arr[:, 0] >= bounds[0]) & (arr[:, 1] >= bounds[1]) & (arr[:, 0] <= bounds[2]) & (arr[:, 1] <= bounds[3])) return arr[index, ...] - def glsl(self, var, bounds=(-1, -1, 1, 1)): + def glsl(self, var, bounds=NDC): assert var bounds = _glslify(bounds) @@ -198,7 +202,7 @@ def get_bounds(self, shape=None, index=None): return subplot_bounds(shape=shape, index=index) def apply(self, arr, shape=None, index=None): - from_bounds = (-1, -1, 1, 1) + from_bounds = NDC to_bounds = self.get_bounds(shape=shape, index=index) return super(Subplot, self).apply(arr, from_bounds=from_bounds, diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 16c53b934..d48829aa4 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -11,7 +11,7 @@ from vispy.gloo import Texture2D from .base import BaseVisual -from .transform import Range, GPU +from .transform import Range, GPU, NDC from .utils import _enable_depth_mask, _tesselate_histogram @@ -27,7 +27,7 @@ def _check_data_bounds(data_bounds): def _get_data_bounds(data_bounds, pos): if not len(pos): - return data_bounds or [-1, -1, 1, 1] + return data_bounds or NDC if data_bounds is None: m, M = pos.min(axis=0), pos.max(axis=0) data_bounds = [m[0], m[1], M[0], M[1]] @@ -68,7 +68,7 @@ class ScatterVisual(BaseVisual): def __init__(self, marker_type=None): super(ScatterVisual, self).__init__() # Default bounds. - self.data_bounds = [-1, -1, 1, 1] + self.data_bounds = NDC self.n_points = None # Set the marker type. @@ -141,13 +141,13 @@ class PlotVisual(BaseVisual): def __init__(self): super(PlotVisual, self).__init__() - self.data_bounds = [-1, -1, 1, 1] + self.data_bounds = NDC _enable_depth_mask() def get_transforms(self): return [Range(from_bounds=self.data_bounds), GPU(), - Range(from_bounds=(-1, -1, 1, 1), + Range(from_bounds=NDC, to_bounds='signal_bounds'), ] @@ -211,7 +211,7 @@ def set_data(self, # Signal bounds (positions). if signal_bounds is None: - signal_bounds = np.tile([-1, -1, 1, 1], (n_signals, 1)) + signal_bounds = np.tile(NDC, (n_signals, 1)) assert signal_bounds.shape == (n_signals, 4) # Convert to 3D texture. signal_bounds = signal_bounds[np.newaxis, ...].astype(np.float32) @@ -254,7 +254,7 @@ def get_transforms(self): to_bounds=[0, 0, 1, 1]), GPU(), Range(from_bounds='hist_bounds', # (0, 0, 1, v) - to_bounds=(-1, -1, 1, 1)), + to_bounds=NDC), ] def set_data(self, From 0bc59553458d14b82e7cb5ee7f30d582552d7a4a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 16:46:57 +0200 Subject: [PATCH 0427/1059] WIP: refactor visuals set_data() --- phy/plot/visuals.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index d48829aa4..5dd8c4076 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -35,6 +35,13 @@ def _get_data_bounds(data_bounds, pos): return data_bounds +def _check_pos_2D(pos): + assert pos is not None + pos = np.asarray(pos) + assert pos.ndim == 2 + return pos + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -95,11 +102,9 @@ def set_data(self, size=None, data_bounds=None, ): - assert pos is not None - pos = np.asarray(pos) - assert pos.ndim == 2 - assert pos.shape[1] == 2 + pos = _check_pos_2D(pos) n = pos.shape[0] + assert pos.shape == (n, 2) # Set the data bounds from the data. self.data_bounds = _get_data_bounds(data_bounds, pos) @@ -158,9 +163,7 @@ def set_data(self, signal_bounds=None, signal_colors=None, ): - assert data is not None - data = np.asarray(data) - assert data.ndim == 2 + pos = _check_pos_2D(data) n_signals, n_samples = data.shape n = n_signals * n_samples @@ -262,11 +265,10 @@ def set_data(self, hist_lims=None, hist_colors=None, ): - assert hist is not None - hist = np.atleast_2d(hist) - assert hist.ndim == 2 + hist = _check_pos_2D(hist) n_hists, n_bins = hist.shape n = 6 * n_hists * n_bins + # Store n_bins for get_transforms(). self.n_bins = n_bins # Generate hist_max. From 98f42f89da4ae238a465eb3e77f496aa102b5c64 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 16:51:50 +0200 Subject: [PATCH 0428/1059] WIP: refactor visuals set_data() --- phy/plot/visuals.py | 65 +++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 5dd8c4076..5144bb030 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -42,6 +42,33 @@ def _check_pos_2D(pos): return pos +def _get_pos_depth(pos_tr, depth): + n = pos_tr.shape[0] + pos_tr = np.asarray(pos_tr, dtype=np.float32) + assert pos_tr.shape == (n, 2) + + # Set the depth. + if depth is None: + depth = np.zeros(n, dtype=np.float32) + depth = np.asarray(depth, dtype=np.float32) + assert depth.shape == (n,) + + # Set the a_position attribute. + pos_depth = np.empty((n, 3), dtype=np.float32) + pos_depth[:, :2] = pos_tr + pos_depth[:, 2] = depth + + return pos_depth + + +def _get_colors(colors, n): + if colors is None: + colors = np.ones((n, 4), dtype=np.float32) + colors = np.asarray(colors, dtype=np.float32) + assert colors.shape == (n, 4) + return colors + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -111,20 +138,7 @@ def set_data(self, # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) - pos_tr = np.asarray(pos_tr, dtype=np.float32) - assert pos_tr.shape == (n, 2) - - # Set the depth. - if depth is None: - depth = np.zeros(n, dtype=np.float32) - depth = np.asarray(depth, dtype=np.float32) - assert depth.shape == (n,) - - # Set the a_position attribute. - pos_depth = np.empty((n, 3), dtype=np.float32) - pos_depth[:, :2] = pos_tr - pos_depth[:, 2] = depth - self.program['a_position'] = pos_depth + self.program['a_position'] = _get_pos_depth(pos_tr, depth) # Set the marker size. if size is None: @@ -132,12 +146,8 @@ def set_data(self, size = np.asarray(size, dtype=np.float32) self.program['a_size'] = size - # Set the group colors. - if colors is None: - colors = np.ones((n, 4), dtype=np.float32) - colors = np.asarray(colors, dtype=np.float32) - assert colors.shape == (n, 4) - self.program['a_color'] = colors + # Set the colors. + self.program['a_color'] = _get_colors(colors, n) class PlotVisual(BaseVisual): @@ -197,20 +207,7 @@ def set_data(self, # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) - pos_tr = np.asarray(pos_tr, dtype=np.float32) - assert pos_tr.shape == (n, 2) - - # Set the depth. - if depth is None: - depth = np.zeros(n, dtype=np.float32) - depth = np.asarray(depth, dtype=np.float32) - assert depth.shape == (n,) - - # Set the a_position attribute. - pos_depth = np.empty((n, 3), dtype=np.float32) - pos_depth[:, :2] = pos_tr - pos_depth[:, 2] = depth - self.program['a_position'] = pos_depth + self.program['a_position'] = _get_pos_depth(pos_tr, depth) # Signal bounds (positions). if signal_bounds is None: From 50e16e70a4ada7c1f1029f7f8c593b1975cdc10e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 16:57:58 +0200 Subject: [PATCH 0429/1059] WIP: refactor visuals set_data() --- phy/plot/visuals.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 5144bb030..a13d3a365 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -61,12 +61,16 @@ def _get_pos_depth(pos_tr, depth): return pos_depth -def _get_colors(colors, n): - if colors is None: - colors = np.ones((n, 4), dtype=np.float32) - colors = np.asarray(colors, dtype=np.float32) - assert colors.shape == (n, 4) - return colors +def _get_attr(attr, n, default): + if not hasattr(default, '__len__'): + default = [default] + if attr is None: + attr = np.tile(default, (n, 1)) + attr = np.asarray(attr, dtype=np.float32) + if attr.ndim == 1: + attr = attr[:, np.newaxis] + assert attr.shape == (n, len(default)) + return attr #------------------------------------------------------------------------------ @@ -136,18 +140,10 @@ def set_data(self, # Set the data bounds from the data. self.data_bounds = _get_data_bounds(data_bounds, pos) - # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) - - # Set the marker size. - if size is None: - size = self._default_marker_size * np.ones(n, dtype=np.float32) - size = np.asarray(size, dtype=np.float32) - self.program['a_size'] = size - - # Set the colors. - self.program['a_color'] = _get_colors(colors, n) + self.program['a_size'] = _get_attr(size, n, self._default_marker_size) + self.program['a_color'] = _get_attr(colors, n, (1, 1, 1, 1)) class PlotVisual(BaseVisual): From cd07055d5ac66c88b7da32e25bd94f150678f79e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 17:03:48 +0200 Subject: [PATCH 0430/1059] WIP: refactor visuals set_data() --- phy/plot/visuals.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index a13d3a365..1688f841e 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -26,6 +26,7 @@ def _check_data_bounds(data_bounds): def _get_data_bounds(data_bounds, pos): + """"Prepare data bounds, possibly using min/max of the data.""" if not len(pos): return data_bounds or NDC if data_bounds is None: @@ -36,6 +37,7 @@ def _get_data_bounds(data_bounds, pos): def _check_pos_2D(pos): + """Check position data before GPU uploading.""" assert pos is not None pos = np.asarray(pos) assert pos.ndim == 2 @@ -43,6 +45,7 @@ def _check_pos_2D(pos): def _get_pos_depth(pos_tr, depth): + """Prepare a (N, 3) position-depth array for GPU uploading.""" n = pos_tr.shape[0] pos_tr = np.asarray(pos_tr, dtype=np.float32) assert pos_tr.shape == (n, 2) @@ -61,7 +64,8 @@ def _get_pos_depth(pos_tr, depth): return pos_depth -def _get_attr(attr, n, default): +def _get_attr(attr, default, n): + """Prepare an attribute for GPU uploading.""" if not hasattr(default, '__len__'): default = [default] if attr is None: @@ -73,6 +77,15 @@ def _get_attr(attr, n, default): return attr +def _get_index(n_items, item_size, n): + """Prepare an index attribute for GPU uploading.""" + index = np.arange(n_items) + index = np.repeat(index, item_size) + index = index.astype(np.float32) + assert index.shape == (n,) + return index + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -142,8 +155,8 @@ def set_data(self, pos_tr = self.apply_cpu_transforms(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) - self.program['a_size'] = _get_attr(size, n, self._default_marker_size) - self.program['a_color'] = _get_attr(colors, n, (1, 1, 1, 1)) + self.program['a_size'] = _get_attr(size, self._default_marker_size, n) + self.program['a_color'] = _get_attr(colors, (1, 1, 1, 1), n) class PlotVisual(BaseVisual): @@ -179,10 +192,7 @@ def set_data(self, assert x.shape == (n,) # Generate the signal index. - signal_index = np.arange(n_signals) - signal_index = np.repeat(signal_index, n_samples) - signal_index = signal_index.astype(np.float32) - self.program['a_signal_index'] = signal_index + self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) # Generate the (n, 2) pos array. pos = np.empty((n, 2), dtype=np.float32) @@ -282,12 +292,7 @@ def set_data(self, self.program['a_position'] = pos_tr # Generate the hist index. - hist_index = np.arange(n_hists) - # 6 * n_bins vertices per histogram. - hist_index = np.repeat(hist_index, n_bins * 6) - hist_index = hist_index.astype(np.float32) - assert hist_index.shape == (n,) - self.program['a_hist_index'] = hist_index + self.program['a_hist_index'] = _get_index(n_hists, n_bins * 6, n) # Hist colors. if hist_colors is None: From 43419703a712b0e10c269ee51a55d69c12e9bcff Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 17:21:48 +0200 Subject: [PATCH 0431/1059] WIP: refactor visuals set_data() --- phy/plot/visuals.py | 86 ++++++++++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 1688f841e..90c4fd566 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -36,6 +36,20 @@ def _get_data_bounds(data_bounds, pos): return data_bounds +def _get_data_bounds_1D(data_bounds, data): + """Generate the complete data_bounds 4-tuple from the specified 2-tuple.""" + if data_bounds is None: + data_bounds = [data.min(), data.max()] if data.size else [-1, 1] + assert len(data_bounds) == 2 + # Ensure that the data bounds are not degenerate. + if data_bounds[0] == data_bounds[1]: + data_bounds = [data_bounds[0] - 1, data_bounds[0] + 1] + ymin, ymax = data_bounds + data_bounds = [-1, ymin, 1, ymax] + _check_data_bounds(data_bounds) + return data_bounds + + def _check_pos_2D(pos): """Check position data before GPU uploading.""" assert pos is not None @@ -86,6 +100,33 @@ def _get_index(n_items, item_size, n): return index +def _get_texture(arr, default, n_items, from_bounds): + """Prepare data to be uploaded as a texture, with casting to uint8. + The from_bounds must be specified. + """ + if not hasattr(default, '__len__'): + default = [default] + n_cols = len(default) + if arr is None: + arr = np.tile(default, (n_items, 1)) + assert arr.shape == (n_items, n_cols) + # Convert to 3D texture. + arr = arr[np.newaxis, ...].astype(np.float32) + assert arr.shape == (1, n_items, n_cols) + # NOTE: we need to cast the texture to [0, 255] (uint8). + # This is easy as soon as we assume that the signal bounds are in + # [-1, 1]. + assert len(from_bounds) == 2 + m, M = map(float, from_bounds) + assert np.all(arr >= m) + assert np.all(arr <= M) + arr = 255 * (arr - m) / (M - m) + assert np.all(arr >= 0) + assert np.all(arr <= 255) + arr = arr.astype(np.uint8) + return arr + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -191,55 +232,28 @@ def set_data(self, x = np.tile(x, n_signals) assert x.shape == (n,) - # Generate the signal index. - self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) - # Generate the (n, 2) pos array. pos = np.empty((n, 2), dtype=np.float32) pos[:, 0] = x pos[:, 1] = data.ravel() - # Generate the complete data_bounds 4-tuple from the specified 2-tuple. - if data_bounds is None: - data_bounds = [data.min(), data.max()] if data.size else [-1, 1] - assert len(data_bounds) == 2 - # Ensure that the data bounds are not degenerate. - if data_bounds[0] == data_bounds[1]: - data_bounds = [data_bounds[0] - 1, data_bounds[0] + 1] - ymin, ymax = data_bounds - data_bounds = [-1, ymin, 1, ymax] - _check_data_bounds(data_bounds) - self.data_bounds = data_bounds - # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) + # Generate the signal index. + self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) + + # Generate the complete data_bounds 4-tuple from the specified 2-tuple. + self.data_bounds = _get_data_bounds_1D(data_bounds, data) + # Signal bounds (positions). - if signal_bounds is None: - signal_bounds = np.tile(NDC, (n_signals, 1)) - assert signal_bounds.shape == (n_signals, 4) - # Convert to 3D texture. - signal_bounds = signal_bounds[np.newaxis, ...].astype(np.float32) - assert signal_bounds.shape == (1, n_signals, 4) - # NOTE: we need to cast the texture to [0, 255] (uint8). - # This is easy as soon as we assume that the signal bounds are in - # [-1, 1]. - assert np.all(signal_bounds >= -1) - assert np.all(signal_bounds <= 1) - signal_bounds = 127 * (signal_bounds + 1) - assert np.all(signal_bounds >= 0) - assert np.all(signal_bounds <= 255) - signal_bounds = signal_bounds.astype(np.uint8) + signal_bounds = _get_texture(signal_bounds, NDC, n_signals, [-1, 1]) self.program['u_signal_bounds'] = Texture2D(signal_bounds) # Signal colors. - if signal_colors is None: - signal_colors = np.ones((n_signals, 4), dtype=np.float32) - assert signal_colors.shape == (n_signals, 4) - # Convert to 3D texture. - signal_colors = signal_colors[np.newaxis, ...].astype(np.float32) - assert signal_colors.shape == (1, n_signals, 4) + signal_colors = _get_texture(signal_colors, (1,) * 4, + n_signals, [0, 1]) self.program['u_signal_colors'] = Texture2D(signal_colors) # Number of signals. From babba7d7c3b76ea156b60333695d825f6b4ded30 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 18:00:28 +0200 Subject: [PATCH 0432/1059] WIP: refactor visuals set_data() --- phy/plot/visuals.py | 47 ++++++++++++++++----------------------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 90c4fd566..9a42ebe28 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -78,6 +78,14 @@ def _get_pos_depth(pos_tr, depth): return pos_depth +def _get_hist_max(hist): + hist_max = hist.max() if hist.size else 1. + hist_max = float(hist_max) + hist_max = hist_max if hist_max > 0 else 1. + assert hist_max > 0 + return hist_max + + def _get_attr(attr, default, n): """Prepare an attribute for GPU uploading.""" if not hasattr(default, '__len__'): @@ -289,11 +297,7 @@ def set_data(self, self.n_bins = n_bins # Generate hist_max. - hist_max = hist.max() if hist.size else 1. - hist_max = float(hist_max) - hist_max = hist_max if hist_max > 0 else 1. - assert hist_max > 0 - self.hist_max = hist_max + self.hist_max = _get_hist_max(hist) # Concatenate all histograms. pos = np.vstack(_tesselate_histogram(row) for row in hist) @@ -309,34 +313,17 @@ def set_data(self, self.program['a_hist_index'] = _get_index(n_hists, n_bins * 6, n) # Hist colors. - if hist_colors is None: - hist_colors = np.ones((n_hists, 4), dtype=np.float32) - assert hist_colors.shape == (n_hists, 4) - # Convert to 3D texture. - hist_colors = hist_colors[np.newaxis, ...].astype(np.float32) - assert hist_colors.shape == (1, n_hists, 4) - self.program['u_hist_colors'] = Texture2D(hist_colors) + self.program['u_hist_colors'] = _get_texture(hist_colors, + (1, 1, 1, 1), + n_hists, [0, 1]) # Hist bounds. - if hist_lims is None: - hist_lims = hist_max * np.ones(n_hists) - hist_lims = np.asarray(hist_lims, dtype=np.float32) - # NOTE: hist_lims is now relative to hist_max (what is on the GPU). - hist_lims = hist_lims / hist_max - assert hist_lims.shape == (n_hists,) - # Now, we create the 4-tuples for the bounds: [0, 0, 1, hists_lim]. - hist_bounds = np.zeros((n_hists, 4), dtype=np.float32) - hist_bounds[:, 2] = 1 - hist_bounds[:, 3] = hist_lims - # Convert to 3D texture. - hist_bounds = hist_bounds[np.newaxis, ...].astype(np.float32) - assert hist_bounds.shape == (1, n_hists, 4) - assert np.all(hist_bounds >= 0) - assert np.all(hist_bounds <= 10) - # NOTE: necessary because VisPy silently clips textures to [0, 1]. - hist_bounds /= 10. + hist_bounds = np.c_[np.zeros((n_hists, 2)), + np.ones(n_hists), + hist_lims] if hist_lims is not None else None + hist_bounds = _get_texture(hist_bounds, [0, 0, 1, self.hist_max], + n_hists, [0, 10]) self.program['u_hist_bounds'] = Texture2D(hist_bounds) - self.program['n_hists'] = n_hists From 60b372a5ce6076c769d72bf047e6b8a3fc9ed041 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 18:02:35 +0200 Subject: [PATCH 0433/1059] Increase coverage --- phy/plot/visuals.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 9a42ebe28..a6a0e9287 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -112,7 +112,7 @@ def _get_texture(arr, default, n_items, from_bounds): """Prepare data to be uploaded as a texture, with casting to uint8. The from_bounds must be specified. """ - if not hasattr(default, '__len__'): + if not hasattr(default, '__len__'): # pragma: no cover default = [default] n_cols = len(default) if arr is None: @@ -299,11 +299,8 @@ def set_data(self, # Generate hist_max. self.hist_max = _get_hist_max(hist) - # Concatenate all histograms. - pos = np.vstack(_tesselate_histogram(row) for row in hist) - assert pos.shape == (n, 2) - # Set the transformed position. + pos = np.vstack(_tesselate_histogram(row) for row in hist) pos_tr = self.apply_cpu_transforms(pos) pos_tr = np.asarray(pos_tr, dtype=np.float32) assert pos_tr.shape == (n, 2) From e1ee2d21c67631969c10b17a4229aea99a22bc40 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 18:37:27 +0200 Subject: [PATCH 0434/1059] Add box and axes visuals --- phy/plot/glsl/ax.frag | 9 ------- phy/plot/glsl/ax.vert | 30 ---------------------- phy/plot/glsl/simple.frag | 5 ++++ phy/plot/glsl/simple.vert | 6 +++++ phy/plot/tests/test_utils.py | 2 +- phy/plot/tests/test_visuals.py | 38 ++++++++++++++++++++++++++- phy/plot/visuals.py | 47 ++++++++++++++++++++++++++++++++++ 7 files changed, 96 insertions(+), 41 deletions(-) delete mode 100644 phy/plot/glsl/ax.frag delete mode 100644 phy/plot/glsl/ax.vert create mode 100644 phy/plot/glsl/simple.frag create mode 100644 phy/plot/glsl/simple.vert diff --git a/phy/plot/glsl/ax.frag b/phy/plot/glsl/ax.frag deleted file mode 100644 index b3d4f845e..000000000 --- a/phy/plot/glsl/ax.frag +++ /dev/null @@ -1,9 +0,0 @@ -#include "grid.glsl" -uniform vec4 u_color; - -void main() { - // Clipping. - if (grid_clip(v_position, .975)) discard; - - gl_FragColor = u_color; -} diff --git a/phy/plot/glsl/ax.vert b/phy/plot/glsl/ax.vert deleted file mode 100644 index b87b20bd9..000000000 --- a/phy/plot/glsl/ax.vert +++ /dev/null @@ -1,30 +0,0 @@ -// xy is the vertex position in NDC -// index is the box index -// ax is 0 for x, 1 for y -attribute vec4 a_position; // xy, index, ax - -#include "grid.glsl" - -vec2 pan_zoom(vec2 position, float index, float ax) -{ - vec4 pz = fetch_pan_zoom(index); - vec2 pan = pz.xy; - vec2 zoom = pz.zw; - - if (ax < 0.5) - return vec2(zoom.x * (position.x + n_rows * pan.x), position.y); - else - return vec2(position.x, zoom.y * (position.y + n_rows * pan.y)); -} - -void main() { - - vec2 pos = a_position.xy; - vec2 position = pan_zoom(pos, a_position.z, a_position.w); - - gl_Position = vec4(to_box(position, a_position.z), - 0.0, 1.0); - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/simple.frag b/phy/plot/glsl/simple.frag new file mode 100644 index 000000000..8df552c17 --- /dev/null +++ b/phy/plot/glsl/simple.frag @@ -0,0 +1,5 @@ +uniform vec4 u_color; + +void main() { + gl_FragColor = u_color; +} diff --git a/phy/plot/glsl/simple.vert b/phy/plot/glsl/simple.vert new file mode 100644 index 000000000..2a59757e8 --- /dev/null +++ b/phy/plot/glsl/simple.vert @@ -0,0 +1,6 @@ +attribute vec2 a_position; +uniform vec4 u_color; + +void main() { + gl_Position = transform(a_position); +} diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index dc07bd8c0..647d61a3e 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -26,7 +26,7 @@ #------------------------------------------------------------------------------ def test_load_shader(): - assert 'main()' in _load_shader('ax.vert') + assert 'main()' in _load_shader('simple.vert') assert config['include_path'] assert op.exists(config['include_path'][0]) assert op.isdir(config['include_path'][0]) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 0b4f40e04..d3cd28382 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,8 @@ import numpy as np from pytest import mark -from ..visuals import ScatterVisual, PlotVisual, HistogramVisual +from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, + BoxVisual, AxesVisual,) #------------------------------------------------------------------------------ @@ -138,3 +139,38 @@ def test_histogram_2(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, HistogramVisual(), hist=hist, hist_colors=c, hist_lims=2 * np.ones(n_hists)) + + +#------------------------------------------------------------------------------ +# Test box visual +#------------------------------------------------------------------------------ + +def test_box_empty(qtbot, canvas): + _test_visual(qtbot, canvas, BoxVisual()) + + +def test_box_0(qtbot, canvas_pz): + _test_visual(qtbot, canvas_pz, BoxVisual(), + bounds=(-.5, -.5, 0., 0.), + color=(1., 0., 0., .5)) + + +#------------------------------------------------------------------------------ +# Test axes visual +#------------------------------------------------------------------------------ + +def test_axes_empty(qtbot, canvas): + _test_visual(qtbot, canvas, AxesVisual()) + + +def test_axes_0(qtbot, canvas_pz): + _test_visual(qtbot, canvas_pz, AxesVisual(), + xs=[0]) + + +def test_axes_1(qtbot, canvas_pz): + _test_visual(qtbot, canvas_pz, AxesVisual(), + xs=[-.25, -.1], + ys=[-.15], + bounds=(-.5, -.5, 0., 0.), + color=(0., 1., 0., .5)) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index a6a0e9287..e706c0c0e 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -135,6 +135,13 @@ def _get_texture(arr, default, n_items, from_bounds): return arr +def _get_color(color, default): + if color is None: + color = default + assert len(color) == 4 + return color + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -333,3 +340,43 @@ def get_transforms(self): def set_data(self): pass + + +class BoxVisual(BaseVisual): + shader_name = 'simple' + gl_primitive_type = 'lines' + _default_color = (.35, .35, .35, 1.) + + def set_data(self, bounds=NDC, color=None): + # Set the position. + x0, y0, x1, y1 = bounds + arr = np.array([[x0, y0], + [x0, y1], + [x0, y1], + [x1, y1], + [x1, y1], + [x1, y0], + [x1, y0], + [x0, y0]], dtype=np.float32) + self.program['a_position'] = self.apply_cpu_transforms(arr) + + # Set the color + self.program['u_color'] = _get_color(color, self._default_color) + + +class AxesVisual(BaseVisual): + shader_name = 'simple' + gl_primitive_type = 'lines' + _default_color = (.2, .2, .2, 1.) + + def set_data(self, xs=(), ys=(), bounds=NDC, color=None): + # Set the position. + arr = [[x, bounds[1], x, bounds[3]] for x in xs] + arr += [[bounds[0], y, bounds[2], y] for y in ys] + arr = np.hstack(arr or [[]]).astype(np.float32) + arr = arr.reshape((-1, 2)).astype(np.float32) + position = self.apply_cpu_transforms(arr) + self.program['a_position'] = position + + # Set the color + self.program['u_color'] = _get_color(color, self._default_color) From 91583abf69e0c3a28a26da75cd24e7bd13edca1e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 18:41:54 +0200 Subject: [PATCH 0435/1059] Use simple shader name in some plot tests --- phy/plot/tests/test_base.py | 16 ++-------------- phy/plot/tests/test_panzoom.py | 16 ++-------------- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 164a9c1e7..7ea33388a 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -81,27 +81,15 @@ def set_data(self): def test_base_interact(qtbot, canvas): """Test a BaseVisual with a CPU transform and a blank interact.""" class TestVisual(BaseVisual): - vertex = """ - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - } - """ - fragment = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ + shader_name = 'simple' gl_primitive_type = 'lines' - def get_shaders(self): - return self.vertex, self.fragment - def get_transforms(self): return [Scale(scale=(.5, 1))] def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index d13231fd9..fd553da3e 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -21,24 +21,12 @@ #------------------------------------------------------------------------------ class MyTestVisual(BaseVisual): - vertex = """ - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - } - """ - fragment = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ + shader_name = 'simple' gl_primitive_type = 'lines' - def get_shaders(self): - return self.vertex, self.fragment - def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] @yield_fixture From 1d6be5db3e456eabb7af93f2ae1515decd97ff34 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 24 Oct 2015 18:53:33 +0200 Subject: [PATCH 0436/1059] Remove old GLSL code --- phy/plot/glsl/__init__.py | 0 phy/plot/glsl/box.frag | 3 -- phy/plot/glsl/box.vert | 8 ---- phy/plot/glsl/color.glsl | 14 ------- phy/plot/glsl/correlograms.frag | 13 ------- phy/plot/glsl/correlograms.vert | 34 ----------------- phy/plot/glsl/depth_mask.glsl | 8 ---- phy/plot/glsl/features.frag | 17 --------- phy/plot/glsl/features.vert | 37 ------------------ phy/plot/glsl/features_bg.frag | 9 ----- phy/plot/glsl/features_bg.vert | 16 -------- phy/plot/glsl/filled_antialias.glsl | 28 -------------- phy/plot/glsl/grid.glsl | 46 ---------------------- phy/plot/glsl/lasso.frag | 8 ---- phy/plot/glsl/lasso.vert | 16 -------- phy/plot/glsl/pan_zoom.glsl | 7 ---- phy/plot/glsl/scatter.frag | 2 - phy/plot/glsl/test.frag | 3 -- phy/plot/glsl/test.vert | 4 -- phy/plot/glsl/traces.frag | 20 ---------- phy/plot/glsl/traces.vert | 59 ----------------------------- phy/plot/glsl/waveforms.frag | 8 ---- phy/plot/glsl/waveforms.vert | 53 -------------------------- phy/plot/tests/test_base.py | 6 +-- phy/plot/tests/test_utils.py | 7 ---- phy/plot/utils.py | 7 ---- 26 files changed, 3 insertions(+), 430 deletions(-) delete mode 100644 phy/plot/glsl/__init__.py delete mode 100644 phy/plot/glsl/box.frag delete mode 100644 phy/plot/glsl/box.vert delete mode 100644 phy/plot/glsl/color.glsl delete mode 100644 phy/plot/glsl/correlograms.frag delete mode 100644 phy/plot/glsl/correlograms.vert delete mode 100644 phy/plot/glsl/depth_mask.glsl delete mode 100644 phy/plot/glsl/features.frag delete mode 100644 phy/plot/glsl/features.vert delete mode 100644 phy/plot/glsl/features_bg.frag delete mode 100644 phy/plot/glsl/features_bg.vert delete mode 100644 phy/plot/glsl/filled_antialias.glsl delete mode 100644 phy/plot/glsl/grid.glsl delete mode 100644 phy/plot/glsl/lasso.frag delete mode 100644 phy/plot/glsl/lasso.vert delete mode 100644 phy/plot/glsl/pan_zoom.glsl delete mode 100644 phy/plot/glsl/test.frag delete mode 100644 phy/plot/glsl/test.vert delete mode 100644 phy/plot/glsl/traces.frag delete mode 100644 phy/plot/glsl/traces.vert delete mode 100644 phy/plot/glsl/waveforms.frag delete mode 100644 phy/plot/glsl/waveforms.vert diff --git a/phy/plot/glsl/__init__.py b/phy/plot/glsl/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/plot/glsl/box.frag b/phy/plot/glsl/box.frag deleted file mode 100644 index 4f60c17ad..000000000 --- a/phy/plot/glsl/box.frag +++ /dev/null @@ -1,3 +0,0 @@ -void main() { - gl_FragColor = vec4(.35, .35, .35, 1.); -} diff --git a/phy/plot/glsl/box.vert b/phy/plot/glsl/box.vert deleted file mode 100644 index ea72ad645..000000000 --- a/phy/plot/glsl/box.vert +++ /dev/null @@ -1,8 +0,0 @@ -#include "grid.glsl" - -attribute vec3 a_position; // x, y, index - -void main() { - gl_Position = vec4(to_box(a_position.xy, a_position.z), - 0.0, 1.0); -} diff --git a/phy/plot/glsl/color.glsl b/phy/plot/glsl/color.glsl deleted file mode 100644 index 4d0005597..000000000 --- a/phy/plot/glsl/color.glsl +++ /dev/null @@ -1,14 +0,0 @@ - -vec3 get_color(float cluster, sampler2D texture, float n_clusters) { - if (cluster < 0) - return vec3(.5, .5, .5); - return texture2D(texture, vec2(cluster / (n_clusters - 1.), .5)).xyz; -} - -vec3 color_mask(vec3 color, float mask) { - vec3 hsv = rgb_to_hsv(color); - // Change the saturation and value as a function of the mask. - hsv.y *= mask; - hsv.z *= .5 * (1. + mask); - return hsv_to_rgb(hsv); -} diff --git a/phy/plot/glsl/correlograms.frag b/phy/plot/glsl/correlograms.frag deleted file mode 100644 index 4afa3c773..000000000 --- a/phy/plot/glsl/correlograms.frag +++ /dev/null @@ -1,13 +0,0 @@ -#include "grid.glsl" - -varying vec4 v_color; -varying float v_box; - -void main() -{ - // Clipping. - if (grid_clip(v_position)) discard; - if (fract(v_box) > 0.) discard; - - gl_FragColor = v_color; -} diff --git a/phy/plot/glsl/correlograms.vert b/phy/plot/glsl/correlograms.vert deleted file mode 100644 index d00a3f407..000000000 --- a/phy/plot/glsl/correlograms.vert +++ /dev/null @@ -1,34 +0,0 @@ -#include "colormaps/color-space.glsl" -#include "color.glsl" -#include "grid.glsl" - -attribute vec2 a_position; -attribute float a_box; // (from 0 to n_rows**2-1) - -uniform sampler2D u_cluster_color; - -varying vec4 v_color; -varying float v_box; - -void main (void) -{ - // ACG/CCG color. - vec2 rc = row_col(a_box, n_rows); - if (abs(rc.x - rc.y) < .1) { - v_color.rgb = get_color(rc.x, - u_cluster_color, - n_rows); - } - else { - v_color.rgb = vec3(1., 1., 1.); - } - v_color.a = 1.; - - vec2 position = pan_zoom_grid(a_position, a_box); - vec2 box_position = to_box(position, a_box); - gl_Position = vec4(box_position, 0., 1.); - - // Used for clipping. - v_position = position; - v_box = a_box; -} diff --git a/phy/plot/glsl/depth_mask.glsl b/phy/plot/glsl/depth_mask.glsl deleted file mode 100644 index 298a67ff2..000000000 --- a/phy/plot/glsl/depth_mask.glsl +++ /dev/null @@ -1,8 +0,0 @@ -float depth_mask(float cluster, float mask, float n_clusters) { - // Depth and mask. - float depth = 0.0; - if (mask > 0.25) { - depth = -.1 - (cluster + mask) / (n_clusters + 10.); - } - return depth; -} diff --git a/phy/plot/glsl/features.frag b/phy/plot/glsl/features.frag deleted file mode 100644 index beed3e6cb..000000000 --- a/phy/plot/glsl/features.frag +++ /dev/null @@ -1,17 +0,0 @@ -#include "markers/disc.glsl" -#include "filled_antialias.glsl" -#include "grid.glsl" - -varying float v_size; -varying vec4 v_color; - -void main() -{ - // Clipping. - if (grid_clip(v_position)) discard; - - vec2 P = gl_PointCoord.xy - vec2(0.5,0.5); - float point_size = v_size + 2. * (1.0 + 1.5*1.0); - float distance = marker_disc(P*point_size, v_size); - gl_FragColor = filled(distance, 1.0, 1.0, v_color); -} diff --git a/phy/plot/glsl/features.vert b/phy/plot/glsl/features.vert deleted file mode 100644 index 5e3891fc2..000000000 --- a/phy/plot/glsl/features.vert +++ /dev/null @@ -1,37 +0,0 @@ -#include "colormaps/color-space.glsl" -#include "color.glsl" -#include "grid.glsl" -#include "depth_mask.glsl" - -attribute vec2 a_position; -attribute float a_mask; -attribute float a_cluster; // cluster idx -attribute float a_box; // (from 0 to n_rows**2-1) - -uniform float u_size; -uniform float n_clusters; -uniform sampler2D u_cluster_color; - -varying vec4 v_color; -varying float v_size; - -void main (void) -{ - v_size = u_size; - - v_color.rgb = color_mask(get_color(a_cluster, u_cluster_color, n_clusters), - a_mask); - v_color.a = .5; - - vec2 position = pan_zoom_grid(a_position, a_box); - vec2 box_position = to_box(position, a_box); - - // Depth as a function of the mask and cluster index. - float depth = depth_mask(mod(a_box, n_rows), a_mask, n_clusters); - - gl_Position = vec4(box_position, depth, 1.); - gl_PointSize = u_size + 2.0 * (1.0 + 1.5 * 1.0); - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/features_bg.frag b/phy/plot/glsl/features_bg.frag deleted file mode 100644 index 602dcd437..000000000 --- a/phy/plot/glsl/features_bg.frag +++ /dev/null @@ -1,9 +0,0 @@ -#include "grid.glsl" - -void main() -{ - // Clipping. - if (grid_clip(v_position)) discard; - - gl_FragColor = vec4(.5, .5, .5, .25); -} diff --git a/phy/plot/glsl/features_bg.vert b/phy/plot/glsl/features_bg.vert deleted file mode 100644 index 5973df500..000000000 --- a/phy/plot/glsl/features_bg.vert +++ /dev/null @@ -1,16 +0,0 @@ -#include "grid.glsl" - -attribute vec2 a_position; -attribute float a_box; // (from 0 to n_rows**2-1) - -void main (void) -{ - vec2 position = pan_zoom_grid(a_position, a_box); - vec2 box_position = to_box(position, a_box); - - gl_Position = vec4(box_position, 0., 1.); - gl_PointSize = 3.0; - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/filled_antialias.glsl b/phy/plot/glsl/filled_antialias.glsl deleted file mode 100644 index 7e95e7ee8..000000000 --- a/phy/plot/glsl/filled_antialias.glsl +++ /dev/null @@ -1,28 +0,0 @@ -vec4 filled(float distance, float linewidth, float antialias, vec4 bg_color) -{ - vec4 frag_color; - float t = linewidth/2.0 - antialias; - float signed_distance = distance; - float border_distance = abs(signed_distance) - t; - float alpha = border_distance/antialias; - alpha = exp(-alpha*alpha); - - if (border_distance < 0.0) - frag_color = bg_color; - else if (signed_distance < 0.0) - frag_color = bg_color; - else { - if (abs(signed_distance) < (linewidth/2.0 + antialias)) { - frag_color = vec4(bg_color.rgb, alpha * bg_color.a); - } - else { - discard; - } - } - return frag_color; -} - -vec4 filled(float distance, float linewidth, float antialias, vec4 fg_color, vec4 bg_color) -{ - return filled(distance, linewidth, antialias, fg_color); -} diff --git a/phy/plot/glsl/grid.glsl b/phy/plot/glsl/grid.glsl deleted file mode 100644 index 1bc154b77..000000000 --- a/phy/plot/glsl/grid.glsl +++ /dev/null @@ -1,46 +0,0 @@ - -uniform float n_rows; -uniform vec4 u_pan_zoom[256]; // maximum grid size: 16 x 16 - -varying vec2 v_position; - -vec2 row_col(float index, float n_rows) { - float row = floor(index / n_rows); - float col = mod(index, n_rows); - return vec2(row, col); -} - -vec4 fetch_pan_zoom(float index) { - vec4 pz = u_pan_zoom[int(index)]; - vec2 pan = pz.xy; - vec2 zoom = pz.zw; - return vec4(pan, zoom); -} - -vec2 to_box(vec2 position, float index) { - vec2 rc = row_col(index, n_rows) + 0.5; - - float x = -1.0 + rc.y * (2.0 / n_rows); - float y = +1.0 - rc.x * (2.0 / n_rows); - - float width = 0.95 / (1.0 * n_rows); - float height = 0.95 / (1.0 * n_rows); - - return vec2(x + width * position.x, - y + height * position.y); -} - -bool grid_clip(vec2 position, float lim) { - return ((position.x < -lim) || (position.x > +lim) || - (position.y < -lim) || (position.y > +lim)); -} - -bool grid_clip(vec2 position) { - return grid_clip(position, .95); -} - -vec2 pan_zoom_grid(vec2 position, float index) -{ - vec4 pz = fetch_pan_zoom(index); // (pan_x, pan_y, zoom_x, zoom_y) - return pz.zw * (position + n_rows * pz.xy); -} diff --git a/phy/plot/glsl/lasso.frag b/phy/plot/glsl/lasso.frag deleted file mode 100644 index a4ec5fa9e..000000000 --- a/phy/plot/glsl/lasso.frag +++ /dev/null @@ -1,8 +0,0 @@ -#include "grid.glsl" - -void main() { - // Clipping. - if (grid_clip(v_position, .975)) discard; - - gl_FragColor = vec4(.5, .5, .5, 1.); -} diff --git a/phy/plot/glsl/lasso.vert b/phy/plot/glsl/lasso.vert deleted file mode 100644 index 6b5595e52..000000000 --- a/phy/plot/glsl/lasso.vert +++ /dev/null @@ -1,16 +0,0 @@ -#include "grid.glsl" - -attribute vec2 a_position; - -uniform float u_box; - -void main() { - - vec2 pos = a_position; - vec2 position = pan_zoom_grid(pos, u_box); - - gl_Position = vec4(to_box(position, u_box), -1.0, 1.0); - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/pan_zoom.glsl b/phy/plot/glsl/pan_zoom.glsl deleted file mode 100644 index 05dd679cf..000000000 --- a/phy/plot/glsl/pan_zoom.glsl +++ /dev/null @@ -1,7 +0,0 @@ -uniform vec2 u_zoom; -uniform vec2 u_pan; - -vec2 pan_zoom(vec2 position) -{ - return u_zoom * (position + u_pan); -} diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag index f994e80d6..76703d7dc 100644 --- a/phy/plot/glsl/scatter.frag +++ b/phy/plot/glsl/scatter.frag @@ -1,6 +1,4 @@ #include "markers/%MARKER_TYPE.glsl" -#include "filled_antialias.glsl" -#include "grid.glsl" varying vec4 v_color; varying float v_size; diff --git a/phy/plot/glsl/test.frag b/phy/plot/glsl/test.frag deleted file mode 100644 index 194449a1d..000000000 --- a/phy/plot/glsl/test.frag +++ /dev/null @@ -1,3 +0,0 @@ -void main() { - gl_FragColor = vec4(1, 1, 1, 1); -} diff --git a/phy/plot/glsl/test.vert b/phy/plot/glsl/test.vert deleted file mode 100644 index 5e46f738b..000000000 --- a/phy/plot/glsl/test.vert +++ /dev/null @@ -1,4 +0,0 @@ -attribute vec2 a_position; -void main() { - gl_Position = vec4(a_position.xy, 0, 1); -} diff --git a/phy/plot/glsl/traces.frag b/phy/plot/glsl/traces.frag deleted file mode 100644 index b388fe09d..000000000 --- a/phy/plot/glsl/traces.frag +++ /dev/null @@ -1,20 +0,0 @@ - -varying vec3 v_index; // (channel, cluster, mask) -varying vec3 v_color_channel; -varying vec3 v_color_spike; - -void main() { - vec3 color; - - // Discard vertices between two channels. - if ((fract(v_index.x) > 0.)) - discard; - - // Avoid color interpolation at spike boundaries. - if ((v_index.y >= 0) && (fract(v_index.y) == 0.) && (v_index.z > 0.)) - color = v_color_spike; - else - color = v_color_channel; - - gl_FragColor = vec4(color, .85); -} diff --git a/phy/plot/glsl/traces.vert b/phy/plot/glsl/traces.vert deleted file mode 100644 index 24f8a1b11..000000000 --- a/phy/plot/glsl/traces.vert +++ /dev/null @@ -1,59 +0,0 @@ -#include "pan_zoom.glsl" -#include "colormaps/color-space.glsl" -#include "color.glsl" - -attribute float a_position; -attribute vec2 a_index; // (channel_idx, t) -attribute vec2 a_spike; // (cluster_idx, mask) - -uniform sampler2D u_channel_color; -uniform sampler2D u_cluster_color; - -uniform float n_channels; -uniform float n_clusters; -uniform float n_samples; -uniform float u_scale; - -varying vec3 v_index; // (channel, cluster, mask) -varying vec3 v_color_channel; -varying vec3 v_color_spike; - -float get_x(float x_index) { - // 'x_index' is between 0 and nsamples. - return -1. + 2. * x_index / (float(n_samples) - 1.); -} - -float get_y(float y_index, float sample) { - // 'y_index' is between 0 and n_channels. - float a = float(u_scale) / float(n_channels); - float b = -1. + 2. * (y_index + .5) / float(n_channels); - return a * sample + .9 * b; -} - -void main() { - float channel = a_index.x; - - float x = get_x(a_index.y); - float y = get_y(channel, a_position); - vec2 position = vec2(x, y); - - gl_Position = vec4(pan_zoom(position), 0.0, 1.0); - - // Spike color as a function of the cluster and mask. - v_color_spike = color_mask(get_color(a_spike.x, // cluster_id - u_cluster_color, - n_clusters), - a_spike.y // mask - ); - - // Channel color. - v_color_channel = get_color(channel, - u_channel_color, - n_channels); - - // The fragment shader needs to know: - // * the channel (to discard fragments between channels) - // * the cluster (for the color) - // * the mask (for the alpha channel) - v_index = vec3(channel, a_spike.x, a_spike.y); -} diff --git a/phy/plot/glsl/waveforms.frag b/phy/plot/glsl/waveforms.frag deleted file mode 100644 index f10c808e7..000000000 --- a/phy/plot/glsl/waveforms.frag +++ /dev/null @@ -1,8 +0,0 @@ -varying vec4 v_color; -varying vec2 v_box; - -void main() { - if ((fract(v_box.x) > 0.) || (fract(v_box.y) > 0.)) - discard; - gl_FragColor = v_color; -} diff --git a/phy/plot/glsl/waveforms.vert b/phy/plot/glsl/waveforms.vert deleted file mode 100644 index 8d619ca3b..000000000 --- a/phy/plot/glsl/waveforms.vert +++ /dev/null @@ -1,53 +0,0 @@ -#include "colormaps/color-space.glsl" -#include "color.glsl" -#include "pan_zoom.glsl" -#include "depth_mask.glsl" - -attribute vec2 a_data; // position (-1..1), mask -attribute float a_time; // -1..1 -attribute vec2 a_box; // 0..(n_clusters-1, n_channels-1) - -uniform float n_clusters; -uniform float n_channels; -uniform vec2 u_data_scale; -uniform vec2 u_channel_scale; -uniform sampler2D u_channel_pos; -uniform sampler2D u_cluster_color; -uniform float u_overlap; -uniform float u_alpha; - -varying vec4 v_color; -varying vec2 v_box; - -vec2 get_box_pos(vec2 box) { // box = (cluster, channel) - vec2 box_pos = texture2D(u_channel_pos, - vec2(box.y / (n_channels - 1.), .5)).xy; - box_pos = 2. * box_pos - 1.; - box_pos = box_pos * u_channel_scale; - // Spacing between cluster boxes. - float h = 2.5 * u_data_scale.x; - if (u_overlap < 0.5) - box_pos.x += h * (box.x - .5 * (n_clusters - 1.)) / n_clusters; - return box_pos; -} - -void main() { - vec2 pos = u_data_scale * vec2(a_time, a_data.x); // -1..1 - vec2 box_pos = get_box_pos(a_box); - v_box = a_box; - - // Depth as a function of the mask and cluster index. - float depth = depth_mask(a_box.x, a_data.y, n_clusters); - - vec2 x_coeff = vec2(1. / max(n_clusters, 1.), 1.); - if (u_overlap > 0.5) - x_coeff.x = 1.; - // The z coordinate is the depth: it depends on the mask. - gl_Position = vec4(pan_zoom(x_coeff * pos + box_pos), depth, 1.); - - // Compute the waveform color as a function of the cluster color - // and the mask. - v_color.rgb = color_mask(get_color(a_box.x, u_cluster_color, n_clusters), - a_data.y); - v_color.a = u_alpha; -} diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 7ea33388a..90a9f4e1c 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -21,12 +21,12 @@ def test_visual_shader_name(qtbot, canvas): """Test a BaseVisual with a shader name.""" class TestVisual(BaseVisual): - shader_name = 'box' + shader_name = 'simple' gl_primitive_type = 'lines' def set_data(self): - self.program['a_position'] = [[-1, 0, 0], [1, 0, 0]] - self.program['n_rows'] = 1 + self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] v = TestVisual() # We need to build the program explicitly when there is no interact. diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 647d61a3e..f9851ae6a 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -15,7 +15,6 @@ from vispy import config from ..utils import (_load_shader, - _create_program, _tesselate_histogram, _enable_depth_mask, ) @@ -33,12 +32,6 @@ def test_load_shader(): assert os.listdir(config['include_path'][0]) -def test_create_program(): - p = _create_program('box') - assert p.shaders[0] - assert p.shaders[1] - - def test_tesselate_histogram(): n = 7 hist = np.arange(n) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 294fe83ad..27f174ecf 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -32,13 +32,6 @@ def _load_shader(filename): return f.read() -def _create_program(name): - vertex = _load_shader(name + '.vert') - fragment = _load_shader(name + '.frag') - program = gloo.Program(vertex, fragment) - return program - - def _tesselate_histogram(hist): """ From 7c91432c159adcbbf0cb93decf8a83c2470ab0dc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 10:19:28 +0100 Subject: [PATCH 0437/1059] WIP: remove box placement from plot visual --- phy/plot/glsl/plot.vert | 7 ------- phy/plot/tests/test_visuals.py | 11 ++--------- phy/plot/visuals.py | 13 +++---------- 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/phy/plot/glsl/plot.vert b/phy/plot/glsl/plot.vert index 090e96ff5..3992106d5 100644 --- a/phy/plot/glsl/plot.vert +++ b/phy/plot/glsl/plot.vert @@ -3,7 +3,6 @@ attribute vec3 a_position; attribute float a_signal_index; // 0..n_signals-1 -uniform sampler2D u_signal_bounds; uniform sampler2D u_signal_colors; uniform float n_signals; @@ -11,12 +10,6 @@ varying vec4 v_color; varying float v_signal_index; void main() { - // Will be used by the transform. - vec4 signal_bounds = fetch_texture(a_signal_index, - u_signal_bounds, - n_signals); - signal_bounds = (2 * signal_bounds - 1); // See hack in Python. - vec2 xy = a_position.xy; gl_Position = transform(xy); gl_Position.z = a_position.z; diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index d3cd28382..707c5d6a4 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -90,20 +90,13 @@ def test_plot_2(qtbot, canvas_pz): n_signals = 50 data = 20 * np.random.randn(n_signals, 10) - # Signal bounds. - b = np.zeros((n_signals, 4)) - b[:, 0] = -1 - b[:, 1] = np.linspace(-1, 1 - 2. / n_signals, n_signals) - b[:, 2] = 1 - b[:, 3] = np.linspace(-1 + 2. / n_signals, 1., n_signals) - # Signal colors. c = np.random.uniform(.5, 1, size=(n_signals, 4)) c[:, 3] = .5 _test_visual(qtbot, canvas_pz, PlotVisual(), - data=data, data_bounds=[-10, 10], - signal_bounds=b, signal_colors=c) + data=data, data_bounds=[-50, 50], + signal_colors=c, stop=True) #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index e706c0c0e..4b47808c6 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -227,15 +227,12 @@ def __init__(self): def get_transforms(self): return [Range(from_bounds=self.data_bounds), GPU(), - Range(from_bounds=NDC, - to_bounds='signal_bounds'), ] def set_data(self, data=None, depth=None, data_bounds=None, - signal_bounds=None, signal_colors=None, ): pos = _check_pos_2D(data) @@ -252,6 +249,9 @@ def set_data(self, pos[:, 0] = x pos[:, 1] = data.ravel() + # Generate the complete data_bounds 4-tuple from the specified 2-tuple. + self.data_bounds = _get_data_bounds_1D(data_bounds, data) + # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) @@ -259,13 +259,6 @@ def set_data(self, # Generate the signal index. self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) - # Generate the complete data_bounds 4-tuple from the specified 2-tuple. - self.data_bounds = _get_data_bounds_1D(data_bounds, data) - - # Signal bounds (positions). - signal_bounds = _get_texture(signal_bounds, NDC, n_signals, [-1, 1]) - self.program['u_signal_bounds'] = Texture2D(signal_bounds) - # Signal colors. signal_colors = _get_texture(signal_colors, (1,) * 4, n_signals, [0, 1]) From 371d6cfdea7514b66c5b44c7903a86b012d6b422 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 10:21:26 +0100 Subject: [PATCH 0438/1059] WIP: new interact module --- phy/plot/{grid.py => interact.py} | 2 +- phy/plot/tests/{test_grid.py => test_interact.py} | 4 ++-- phy/plot/tests/test_visuals.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename phy/plot/{grid.py => interact.py} (98%) rename phy/plot/tests/{test_grid.py => test_interact.py} (98%) diff --git a/phy/plot/grid.py b/phy/plot/interact.py similarity index 98% rename from phy/plot/grid.py rename to phy/plot/interact.py index 8d6a20c06..32d2f34f7 100644 --- a/phy/plot/grid.py +++ b/phy/plot/interact.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Grid interact for subplots.""" +"""Common interacts.""" #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_grid.py b/phy/plot/tests/test_interact.py similarity index 98% rename from phy/plot/tests/test_grid.py rename to phy/plot/tests/test_interact.py index d814a53df..467bdaa0a 100644 --- a/phy/plot/tests/test_grid.py +++ b/phy/plot/tests/test_interact.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Test grid.""" +"""Test interact.""" #------------------------------------------------------------------------------ @@ -14,7 +14,7 @@ from pytest import yield_fixture from ..base import BaseVisual, BaseCanvas -from ..grid import Grid +from ..interact import Grid #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 707c5d6a4..ed7dad74d 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -96,7 +96,7 @@ def test_plot_2(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, PlotVisual(), data=data, data_bounds=[-50, 50], - signal_colors=c, stop=True) + signal_colors=c) #------------------------------------------------------------------------------ From 54d97023ff432e37b9cf206ac492149d3e30b97a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 10:39:41 +0100 Subject: [PATCH 0439/1059] Move _get_texture() --- phy/plot/utils.py | 27 +++++++++++++++++++++++++++ phy/plot/visuals.py | 29 +---------------------------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 27f174ecf..5ecbe96f8 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -70,3 +70,30 @@ def _enable_depth_mask(): blend=True, blend_func=('src_alpha', 'one_minus_src_alpha')) gloo.set_clear_depth(1.0) + + +def _get_texture(arr, default, n_items, from_bounds): + """Prepare data to be uploaded as a texture, with casting to uint8. + The from_bounds must be specified. + """ + if not hasattr(default, '__len__'): # pragma: no cover + default = [default] + n_cols = len(default) + if arr is None: + arr = np.tile(default, (n_items, 1)) + assert arr.shape == (n_items, n_cols) + # Convert to 3D texture. + arr = arr[np.newaxis, ...].astype(np.float32) + assert arr.shape == (1, n_items, n_cols) + # NOTE: we need to cast the texture to [0, 255] (uint8). + # This is easy as soon as we assume that the signal bounds are in + # [-1, 1]. + assert len(from_bounds) == 2 + m, M = map(float, from_bounds) + assert np.all(arr >= m) + assert np.all(arr <= M) + arr = 255 * (arr - m) / (M - m) + assert np.all(arr >= 0) + assert np.all(arr <= 255) + arr = arr.astype(np.uint8) + return arr diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 4b47808c6..e0da2445d 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -12,7 +12,7 @@ from .base import BaseVisual from .transform import Range, GPU, NDC -from .utils import _enable_depth_mask, _tesselate_histogram +from .utils import _enable_depth_mask, _tesselate_histogram, _get_texture #------------------------------------------------------------------------------ @@ -108,33 +108,6 @@ def _get_index(n_items, item_size, n): return index -def _get_texture(arr, default, n_items, from_bounds): - """Prepare data to be uploaded as a texture, with casting to uint8. - The from_bounds must be specified. - """ - if not hasattr(default, '__len__'): # pragma: no cover - default = [default] - n_cols = len(default) - if arr is None: - arr = np.tile(default, (n_items, 1)) - assert arr.shape == (n_items, n_cols) - # Convert to 3D texture. - arr = arr[np.newaxis, ...].astype(np.float32) - assert arr.shape == (1, n_items, n_cols) - # NOTE: we need to cast the texture to [0, 255] (uint8). - # This is easy as soon as we assume that the signal bounds are in - # [-1, 1]. - assert len(from_bounds) == 2 - m, M = map(float, from_bounds) - assert np.all(arr >= m) - assert np.all(arr <= M) - arr = 255 * (arr - m) / (M - m) - assert np.all(arr >= 0) - assert np.all(arr <= 255) - arr = arr.astype(np.uint8) - return arr - - def _get_color(color, default): if color is None: color = default From 83e776ac51b504c8a12dd55418801ac2e74120a6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:14:48 +0100 Subject: [PATCH 0440/1059] Add glsl/ path during import of phy.plot --- phy/plot/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index e69de29bb..10d217e71 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +# flake8: noqa + +"""VisPy plotting.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +from vispy import config + + +#------------------------------------------------------------------------------ +# Add the `glsl/ path` for shader include +#------------------------------------------------------------------------------ + +curdir = op.dirname(op.realpath(__file__)) +glsl_path = op.join(curdir, 'glsl') +if not config['include_path']: + config['include_path'] = [glsl_path] From dc8e7efb31bbf9165b37cb4d79414bd216de3922 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:16:05 +0100 Subject: [PATCH 0441/1059] Add get_pre_transforms() in BaseInteract --- phy/plot/base.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index e27f60a25..41c24c0ea 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -142,12 +142,20 @@ def __init__(self): # ------------------------------------------------------------------------- def get_shader_declarations(self): + """Return extra declarations for the vertex and fragment shaders.""" return '', '' + def get_pre_transforms(self): + """Return an optional GLSL snippet to insert into the vertex shader + before the transforms.""" + return '' + def get_transforms(self): + """Return the list of transforms.""" return [] def update_program(self, program): + """Update a program during an interaction event.""" pass # Public methods @@ -192,7 +200,9 @@ def build_program(self, visual): logger.debug("Build the program of `%s`.", self.__class__.__name__) # Insert the interact's GLSL into the shaders. vertex, fragment = visual.get_shaders() - vertex, fragment = transform_chain.insert_glsl(vertex, fragment) + # Get the GLSL snippet to insert before the transformations. + pre = self.get_pre_transforms() + vertex, fragment = transform_chain.insert_glsl(vertex, fragment, pre) # Insert shader declarations. vertex_decl, frag_decl = self.get_shader_declarations() From 09987bacb9200758b309283f0a41e7c793374911 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:16:28 +0100 Subject: [PATCH 0442/1059] Remove config include path in plot.utils --- phy/plot/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 5ecbe96f8..77fa74415 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -25,8 +25,6 @@ def _load_shader(filename): """Load a shader file.""" curdir = op.dirname(op.realpath(__file__)) glsl_path = op.join(curdir, 'glsl') - if not config['include_path']: - config['include_path'] = [glsl_path] path = op.join(glsl_path, filename) with open(path, 'r') as f: return f.read() From 8e7d90cd4c49ab5df34aa7476d6281390e6b1065 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:16:41 +0100 Subject: [PATCH 0443/1059] Remove config include path in plot.utils --- phy/plot/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 77fa74415..1417501eb 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -12,7 +12,7 @@ import numpy as np -from vispy import gloo, config +from vispy import gloo logger = logging.getLogger(__name__) From aabc6fed325b6ddaa00a185c12b92bd322fda61a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:17:03 +0100 Subject: [PATCH 0444/1059] Add pre_transforms argument in TransformChain.insert_glsl() --- phy/plot/transform.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 8808f3e45..df015fe2c 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -285,7 +285,7 @@ def apply(self, arr): arr = t.apply(arr) return arr - def insert_glsl(self, vertex, fragment): + def insert_glsl(self, vertex, fragment, pre_transforms=''): """Generate the GLSL code of the transform chain.""" # Find the place where to insert the GLSL snippet. @@ -311,7 +311,10 @@ def insert_glsl(self, vertex, fragment): temp_var = 'temp_pos_tr' # Name for the (eventual) varying. fvar = 'v_{}'.format(temp_var) - vs_insert = "vec2 {} = {};\n".format(temp_var, var) + vs_insert = '' + # Insert the pre-transforms. + vs_insert += pre_transforms + '\n' + vs_insert += "vec2 {} = {};\n".format(temp_var, var) for t in self.gpu_transforms: if isinstance(t, Clip): # Set the varying value in the vertex shader. From 417ba68b7a8dc5c6db6668be999aec495be71bd1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:22:59 +0100 Subject: [PATCH 0445/1059] WIP: boxed interact --- phy/plot/interact.py | 70 ++++++++++++++++++++++++++++++++- phy/plot/tests/test_interact.py | 64 ++++++++++++++++++++++-------- 2 files changed, 115 insertions(+), 19 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 32d2f34f7..950608d9b 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -9,12 +9,16 @@ import math +import numpy as np +from vispy.gloo import Texture2D + from .base import BaseInteract -from .transform import Scale, Subplot, Clip +from .transform import Scale, Range, Subplot, Clip, NDC +from .utils import _get_texture #------------------------------------------------------------------------------ -# Grid class +# Grid interact #------------------------------------------------------------------------------ class Grid(BaseInteract): @@ -85,3 +89,65 @@ def on_key_press(self, event): if key == 'R': self.zoom = 1. self.update() + + +#------------------------------------------------------------------------------ +# Boxed interact +#------------------------------------------------------------------------------ + +class Boxed(BaseInteract): + """Boxed interact. + + NOTE: to be used in a boxed, a visual must define `a_box_index`. + + """ + def __init__(self, box_bounds, box_var=None): + super(Boxed, self).__init__() + + # Name of the variable with the box index. + self.box_var = box_var or 'a_box_index' + + self.box_bounds = np.atleast_2d(box_bounds) + assert self.box_bounds.shape[1] == 4 + self.n_boxes = len(self.box_bounds) + + def get_shader_declarations(self): + return ('#include "utils.glsl"\n\n' + 'attribute float {};\n'.format(self.box_var) + + 'uniform sampler2D u_box_bounds;\n' + 'uniform float n_boxes;', '') + + def get_pre_transforms(self): + return """ + // Fetch the box bounds for the current box (`box_var`). + vec4 box_bounds = fetch_texture({}, + u_box_bounds, + n_boxes); + box_bounds = (2 * box_bounds - 1); // See hack in Python. + """.format(self.box_var) + + def get_transforms(self): + return [Range(from_bounds=NDC, + to_bounds='box_bounds'), + ] + + def update_program(self, program): + # Signal bounds (positions). + box_bounds = _get_texture(self.box_bounds, NDC, self.n_boxes, [-1, 1]) + program['u_box_bounds'] = Texture2D(box_bounds) + program['n_boxes'] = self.n_boxes + + +class Stacked(BaseInteract): + """Stacked interact. + + NOTE: to be used in a stacked, a visual must define `a_box_index`. + + """ + + # # Signal bounds. + # b = np.zeros((n_signals, 4)) + # b[:, 0] = -1 + # b[:, 1] = np.linspace(-1, 1 - 2. / n_signals, n_signals) + # b[:, 2] = 1 + # b[:, 3] = np.linspace(-1 + 2. / n_signals, 1., n_signals) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 467bdaa0a..cbec432af 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -14,7 +14,7 @@ from pytest import yield_fixture from ..base import BaseVisual, BaseCanvas -from ..interact import Grid +from ..interact import Grid, Boxed #------------------------------------------------------------------------------ @@ -42,9 +42,6 @@ def get_shaders(self): def set_data(self): n = 1000 - box = [[i, j] for i, j in product(range(2), range(3))] - box = np.repeat(box, n, axis=0) - coeff = [(1 + i + j) for i, j in product(range(2), range(3))] coeff = np.repeat(coeff, n) coeff = coeff[:, None] @@ -52,40 +49,62 @@ def set_data(self): position = .1 * coeff * np.random.randn(2 * 3 * n, 2) self.program['a_position'] = position.astype(np.float32) - self.program['a_box_index'] = box.astype(np.float32) @yield_fixture -def canvas(qapp): +def canvas_grid(qapp): c = BaseCanvas(keys='interactive', interact=Grid(shape=(2, 3))) yield c c.close() @yield_fixture -def grid(qtbot, canvas): +def canvas_boxed(qapp): + n = 6 + b = np.zeros((n, 4)) + + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 1. / 3., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 1. / 3., 1., n) + + c = BaseCanvas(keys='interactive', interact=Boxed(box_bounds=b)) + yield c + c.close() + + +def get_interact(qtbot, canvas, box_index): + c = canvas + visual = MyTestVisual() - visual.attach(canvas) + visual.attach(c) visual.set_data() - canvas.show() - qtbot.waitForWindowShown(canvas.native) + visual.program['a_box_index'] = box_index.astype(np.float32) - yield canvas.interact + c.show() + qtbot.waitForWindowShown(c.native) + + return c.interact #------------------------------------------------------------------------------ # Test grid #------------------------------------------------------------------------------ -def test_grid_1(qtbot, canvas, grid): +def test_grid_1(qtbot, canvas_grid): + c = canvas_grid + n = 1000 + + box_index = [[i, j] for i, j in product(range(2), range(3))] + box_index = np.repeat(box_index, n, axis=0) + + grid = get_interact(qtbot, canvas_grid, box_index) # Zoom with the keyboard. - canvas.events.key_press(key=keys.Key('+')) + c.events.key_press(key=keys.Key('+')) assert grid.zoom > 1 # Unzoom with the keyboard. - canvas.events.key_press(key=keys.Key('-')) + c.events.key_press(key=keys.Key('-')) assert grid.zoom == 1. # Set the zoom explicitly. @@ -93,11 +112,22 @@ def test_grid_1(qtbot, canvas, grid): assert grid.zoom == 2. # No effect with modifiers. - canvas.events.key_press(key=keys.Key('r'), modifiers=(keys.CONTROL,)) + c.events.key_press(key=keys.Key('r'), modifiers=(keys.CONTROL,)) assert grid.zoom == 2. # Press 'R'. - canvas.events.key_press(key=keys.Key('r')) + c.events.key_press(key=keys.Key('r')) assert grid.zoom == 1. - # qtbot.stop() + qtbot.stop() + + +def test_boxed_1(qtbot, canvas_boxed): + c = canvas_boxed + + n = 1000 + box_index = np.repeat(np.arange(6), n, axis=0) + + boxed = get_interact(qtbot, canvas_boxed, box_index) + + qtbot.stop() From e8f7d172f37ef1a53fc892b490ce04c17bc02c04 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:42:29 +0100 Subject: [PATCH 0446/1059] WIP: refactor interacts --- phy/plot/base.py | 99 ++++++++++++++++++++++--------------- phy/plot/tests/conftest.py | 9 ---- phy/plot/tests/test_base.py | 14 ++++-- 3 files changed, 69 insertions(+), 53 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 41c24c0ea..e0fddfa74 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -84,8 +84,7 @@ def attach(self, canvas): """ logger.debug("Attach `%s` to canvas.", self.__class__.__name__) - self.program = canvas.interact.build_program(self) - # self.transform_chain = canvas.interact.transform_chain + self.program = build_program(self, canvas.interacts) # NOTE: this is connect_ and not connect because we're using # phy's event system, not VisPy's. The reason is that the order @@ -125,6 +124,10 @@ def on_draw(self): self.program.draw(self.gl_primitive_type) +#------------------------------------------------------------------------------ +# Base interact +#------------------------------------------------------------------------------ + class BaseInteract(object): """Implement interactions for a set of attached visuals in a canvas. @@ -172,6 +175,9 @@ def attach(self, canvas): """Attach the interact to a canvas.""" self._canvas = canvas + # This might be improved. + canvas.interacts.append(self) + canvas.connect(self.on_resize) canvas.connect(self.on_mouse_move) canvas.connect(self.on_mouse_wheel) @@ -181,40 +187,6 @@ def is_attached(self): """Whether the transform is attached to a canvas.""" return self._canvas is not None - def build_program(self, visual): - """Create the gloo program of a visual using the interact's - transforms. - - This method is called when a visual is attached to the canvas. - - """ - assert visual.program is None, "The program has already been built." - assert visual not in self.visuals - self.visuals.append(visual) - - # Build the transform chain using the visuals transforms first, - # then the interact's transforms. - transform_chain = TransformChain(visual.get_transforms() + - self.get_transforms()) - - logger.debug("Build the program of `%s`.", self.__class__.__name__) - # Insert the interact's GLSL into the shaders. - vertex, fragment = visual.get_shaders() - # Get the GLSL snippet to insert before the transformations. - pre = self.get_pre_transforms() - vertex, fragment = transform_chain.insert_glsl(vertex, fragment, pre) - - # Insert shader declarations. - vertex_decl, frag_decl = self.get_shader_declarations() - vertex = vertex_decl + '\n' + vertex - fragment = frag_decl + '\n' + fragment - logger.log(5, "Vertex shader: \n%s", vertex) - logger.log(5, "Fragment shader: \n%s", fragment) - - program = gloo.Program(vertex, fragment) - self.update_program(program) - return program - def on_resize(self, event): pass @@ -235,14 +207,16 @@ def update(self): self._canvas.update() +#------------------------------------------------------------------------------ +# Base canvas +#------------------------------------------------------------------------------ + class BaseCanvas(Canvas): """A blank VisPy canvas with a custom event system that keeps the order.""" def __init__(self, *args, **kwargs): - # Set the interact. - self.interact = kwargs.pop('interact', BaseInteract()) super(BaseCanvas, self).__init__(*args, **kwargs) self._events = EventEmitter() - self.interact.attach(self) + self.interacts = [] def connect_(self, *args, **kwargs): return self._events.connect(*args, **kwargs) @@ -253,3 +227,50 @@ def emit_(self, *args, **kwargs): # pragma: no cover def on_draw(self, e): gloo.clear() self._events.emit('draw') + + +#------------------------------------------------------------------------------ +# Build program with interacts +#------------------------------------------------------------------------------ + +def build_program(visual, interacts=()): + """Create the gloo program of a visual using the interacts + transforms. + + This method is called when a visual is attached to the canvas. + + """ + assert visual.program is None, "The program has already been built." + + # Build the transform chain using the visuals transforms first, + # then the interact's transforms. + transforms = visual.get_transforms() + for interact in interacts: + transforms.extend(interact.get_transforms()) + transform_chain = TransformChain(transforms) + + logger.debug("Build the program of `%s`.", visual.__class__.__name__) + # Insert the interact's GLSL into the shaders. + vertex, fragment = visual.get_shaders() + # Get the GLSL snippet to insert before the transformations. + pre = '\n'.join(interact.get_pre_transforms() for interact in interacts) + vertex, fragment = transform_chain.insert_glsl(vertex, fragment, pre) + + # Insert shader declarations using the interacts (if any). + if interacts: + vertex_decls, frag_decls = zip(*(interact.get_shader_declarations() + for interact in interacts)) + + vertex = '\n'.join(vertex_decls) + '\n' + vertex + fragment = '\n'.join(frag_decls) + '\n' + fragment + + logger.log(5, "Vertex shader: \n%s", vertex) + logger.log(5, "Fragment shader: \n%s", fragment) + + program = gloo.Program(vertex, fragment) + + # Update the program with all interacts. + for interact in interacts: + interact.update_program(program) + + return program diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index 89d9fc31b..306173b3e 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -10,7 +10,6 @@ from pytest import yield_fixture from ..base import BaseCanvas -from ..panzoom import PanZoom #------------------------------------------------------------------------------ @@ -23,11 +22,3 @@ def canvas(qapp): c = BaseCanvas(keys='interactive') yield c c.close() - - -@yield_fixture -def canvas_pz(qapp): - use_app('pyqt4') - c = BaseCanvas(keys='interactive', interact=PanZoom()) - yield c - c.close() diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 90a9f4e1c..31a3687ca 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -79,7 +79,7 @@ def set_data(self): def test_base_interact(qtbot, canvas): - """Test a BaseVisual with a CPU transform and a blank interact.""" + """Test a BaseVisual with a CPU transform and no interact.""" class TestVisual(BaseVisual): shader_name = 'simple' gl_primitive_type = 'lines' @@ -97,14 +97,18 @@ def set_data(self): v.set_data() canvas.show() - assert canvas.interact.size[0] >= 1 + assert not canvas.interacts qtbot.waitForWindowShown(canvas.native) # qtbot.stop() def test_interact(qtbot, canvas): """Test a BaseVisual with multiple CPU and GPU transforms and a - non-blank interact.""" + non-blank interact. + + There should be points filling the entire lower (2, 3) subplot. + + """ class TestVisual(BaseVisual): vertex = """ @@ -144,7 +148,7 @@ def get_transforms(self): Clip(bounds=bounds), ] - canvas = BaseCanvas(keys='interactive', interact=TestInteract()) + TestInteract().attach(canvas) # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() @@ -152,6 +156,6 @@ def get_transforms(self): v.set_data() canvas.show() + assert len(canvas.interacts) == 1 qtbot.waitForWindowShown(canvas.native) # qtbot.stop() - canvas.close() From 249f7e1c502659ec9a02508eb328b5e92b570d3c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 11:55:50 +0100 Subject: [PATCH 0447/1059] WIP: fix bugs --- phy/plot/glsl/scatter.frag | 1 + phy/plot/tests/test_base.py | 2 +- phy/plot/tests/test_visuals.py | 9 ++++++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag index 76703d7dc..fc092a7e9 100644 --- a/phy/plot/glsl/scatter.frag +++ b/phy/plot/glsl/scatter.frag @@ -1,3 +1,4 @@ +#include "antialias/filled.glsl" #include "markers/%MARKER_TYPE.glsl" varying vec4 v_color; diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 31a3687ca..5ad17b46b 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -9,7 +9,7 @@ import numpy as np -from ..base import BaseCanvas, BaseVisual, BaseInteract +from ..base import BaseVisual, BaseInteract from ..transform import (subplot_bounds, Translate, Scale, Range, Clip, Subplot, GPU) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index ed7dad74d..48b116f82 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -8,8 +8,9 @@ #------------------------------------------------------------------------------ import numpy as np -from pytest import mark +from pytest import mark, yield_fixture +from ..panzoom import PanZoom from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, BoxVisual, AxesVisual,) @@ -18,11 +19,17 @@ # Fixtures #------------------------------------------------------------------------------ +@yield_fixture +def canvas_pz(canvas): + PanZoom().attach(canvas) + yield canvas + def _test_visual(qtbot, c, v, stop=False, **kwargs): v.attach(c) v.set_data(**kwargs) c.show() + qtbot.waitForWindowShown(c.native) if stop: # pragma: no cover qtbot.stop() From 7eba5ad20da0433d5d4ddf17032da0d316d2bd92 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 12:15:59 +0100 Subject: [PATCH 0448/1059] WIP: refactor --- phy/plot/base.py | 12 ++--- phy/plot/tests/conftest.py | 7 +++ phy/plot/tests/test_panzoom.py | 86 +++++++++++++++++----------------- phy/plot/tests/test_visuals.py | 11 +---- 4 files changed, 58 insertions(+), 58 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index e0fddfa74..2f598d3a8 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -103,6 +103,8 @@ def on_resize(event): canvas.connect(self.on_mouse_move) canvas.connect(self.on_key_press) + # NOTE: this might be improved. + canvas.visuals.append(self) # HACK: allow a visual to update the canvas it is attached to. self.update = canvas.update @@ -138,8 +140,6 @@ class BaseInteract(object): """ def __init__(self): self._canvas = None - # List of attached visuals. - self.visuals = [] # To override # ------------------------------------------------------------------------- @@ -164,9 +164,6 @@ def update_program(self, program): # Public methods # ------------------------------------------------------------------------- - def get_visuals(self): - return self.visuals - @property def size(self): return self._canvas.size if self._canvas else None @@ -175,7 +172,7 @@ def attach(self, canvas): """Attach the interact to a canvas.""" self._canvas = canvas - # This might be improved. + # NOTE: this might be improved. canvas.interacts.append(self) canvas.connect(self.on_resize) @@ -202,7 +199,7 @@ def on_key_press(self, event): def update(self): """Update the attached canvas and all attached programs.""" if self.is_attached(): - for visual in self.get_visuals(): + for visual in self._canvas.visuals: self.update_program(visual.program) self._canvas.update() @@ -217,6 +214,7 @@ def __init__(self, *args, **kwargs): super(BaseCanvas, self).__init__(*args, **kwargs) self._events = EventEmitter() self.interacts = [] + self.visuals = [] def connect_(self, *args, **kwargs): return self._events.connect(*args, **kwargs) diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index 306173b3e..a330ac4a2 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -10,6 +10,7 @@ from pytest import yield_fixture from ..base import BaseCanvas +from ..panzoom import PanZoom #------------------------------------------------------------------------------ @@ -22,3 +23,9 @@ def canvas(qapp): c = BaseCanvas(keys='interactive') yield c c.close() + + +@yield_fixture +def canvas_pz(canvas): + PanZoom().attach(canvas) + yield canvas diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index fd553da3e..7d6ab1bf7 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -8,11 +8,11 @@ #------------------------------------------------------------------------------ import numpy as np +from pytest import yield_fixture from vispy.app import MouseEvent from vispy.util import keys -from pytest import yield_fixture -from ..base import BaseVisual, BaseCanvas +from ..base import BaseVisual from ..panzoom import PanZoom @@ -30,22 +30,16 @@ def set_data(self): @yield_fixture -def canvas(qapp): - c = BaseCanvas(keys='interactive', interact=PanZoom()) - yield c - c.close() - - -@yield_fixture -def panzoom(qtbot, canvas): +def panzoom(qtbot, canvas_pz): + c = canvas_pz visual = MyTestVisual() - visual.attach(canvas) + visual.attach(c) visual.set_data() - canvas.show() - qtbot.waitForWindowShown(canvas.native) + c.show() + qtbot.waitForWindowShown(c.native) - yield canvas.interact + yield c.interacts[0] #------------------------------------------------------------------------------ @@ -151,58 +145,61 @@ def test_panzoom_constraints_z(): assert pz.zoom == [2, 2] -def test_panzoom_pan_mouse(qtbot, canvas, panzoom): +def test_panzoom_pan_mouse(qtbot, canvas_pz, panzoom): + c = canvas_pz pz = panzoom # Pan with mouse. press = MouseEvent(type='mouse_press', pos=(0, 0)) - canvas.events.mouse_move(pos=(10., 0.), button=1, - last_event=press, press_event=press) + c.events.mouse_move(pos=(10., 0.), button=1, + last_event=press, press_event=press) assert pz.pan[0] > 0 assert pz.pan[1] == 0 pz.pan = (0, 0) # Panning with a modifier should not pan. press = MouseEvent(type='mouse_press', pos=(0, 0)) - canvas.events.mouse_move(pos=(10., 0.), button=1, - last_event=press, press_event=press, - modifiers=(keys.CONTROL,)) + c.events.mouse_move(pos=(10., 0.), button=1, + last_event=press, press_event=press, + modifiers=(keys.CONTROL,)) assert pz.pan == [0, 0] # qtbot.stop() -def test_panzoom_pan_keyboard(qtbot, canvas, panzoom): +def test_panzoom_pan_keyboard(qtbot, canvas_pz, panzoom): + c = canvas_pz pz = panzoom # Pan with keyboard. - canvas.events.key_press(key=keys.UP) + c.events.key_press(key=keys.UP) assert pz.pan[0] == 0 assert pz.pan[1] < 0 # All panning movements with keys. - canvas.events.key_press(key=keys.LEFT) - canvas.events.key_press(key=keys.DOWN) - canvas.events.key_press(key=keys.RIGHT) + c.events.key_press(key=keys.LEFT) + c.events.key_press(key=keys.DOWN) + c.events.key_press(key=keys.RIGHT) assert pz.pan == [0, 0] # Reset with R. - canvas.events.key_press(key=keys.RIGHT) - canvas.events.key_press(key=keys.Key('r')) + c.events.key_press(key=keys.RIGHT) + c.events.key_press(key=keys.Key('r')) assert pz.pan == [0, 0] # Using modifiers should not pan. - canvas.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) + c.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) assert pz.pan == [0, 0] -def test_panzoom_zoom_mouse(qtbot, canvas, panzoom): +def test_panzoom_zoom_mouse(qtbot, canvas_pz, panzoom): + c = canvas_pz pz = panzoom # Zoom with mouse. press = MouseEvent(type='mouse_press', pos=(50., 50.)) - canvas.events.mouse_move(pos=(0., 0.), button=2, - last_event=press, press_event=press) + c.events.mouse_move(pos=(0., 0.), button=2, + last_event=press, press_event=press) assert pz.pan[0] < 0 assert pz.pan[1] < 0 assert pz.zoom[0] < 1 @@ -210,38 +207,43 @@ def test_panzoom_zoom_mouse(qtbot, canvas, panzoom): pz.reset() # Zoom with mouse. - size = np.asarray(canvas.size) - canvas.events.mouse_wheel(pos=size / 2., delta=(0., 1.)) + size = np.asarray(c.size) + c.events.mouse_wheel(pos=size / 2., delta=(0., 1.)) assert pz.pan == [0, 0] assert pz.zoom[0] > 1 assert pz.zoom[1] > 1 pz.reset() # Using modifiers with the wheel should not zoom. - canvas.events.mouse_wheel(pos=(0., 0.), delta=(0., 1.), - modifiers=(keys.CONTROL,)) + c.events.mouse_wheel(pos=(0., 0.), delta=(0., 1.), + modifiers=(keys.CONTROL,)) assert pz.pan == [0, 0] assert pz.zoom == [1, 1] pz.reset() -def test_panzoom_zoom_keyboard(qtbot, canvas, panzoom): +def test_panzoom_zoom_keyboard(qtbot, canvas_pz, panzoom): + c = canvas_pz pz = panzoom # Zoom with keyboard. - canvas.events.key_press(key=keys.Key('+')) + c.events.key_press(key=keys.Key('+')) assert pz.pan == [0, 0] assert pz.zoom[0] > 1 assert pz.zoom[1] > 1 # Unzoom with keyboard. - canvas.events.key_press(key=keys.Key('-')) + c.events.key_press(key=keys.Key('-')) assert pz.pan == [0, 0] assert pz.zoom == [1, 1] -def test_panzoom_resize(qtbot, canvas, panzoom): +def test_panzoom_resize(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + # Increase coverage with different aspect ratio. - canvas.native.resize(400, 600) - # canvas.events.resize(size=(100, 1000)) - assert list(panzoom._canvas_aspect) == [1., 2. / 3] + c.native.resize(400, 600) + # qtbot.stop() + # c.events.resize(size=(100, 1000)) + assert list(pz._canvas_aspect) == [1., 2. / 3] diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 48b116f82..23315d145 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -8,9 +8,8 @@ #------------------------------------------------------------------------------ import numpy as np -from pytest import mark, yield_fixture +from pytest import mark -from ..panzoom import PanZoom from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, BoxVisual, AxesVisual,) @@ -19,12 +18,6 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture -def canvas_pz(canvas): - PanZoom().attach(canvas) - yield canvas - - def _test_visual(qtbot, c, v, stop=False, **kwargs): v.attach(c) v.set_data(**kwargs) @@ -36,8 +29,8 @@ def _test_visual(qtbot, c, v, stop=False, **kwargs): #------------------------------------------------------------------------------ # Test scatter visual -#------------------------------------------------------------------------------ +#------------------------------------------------------------------------------ def test_scatter_empty(qtbot, canvas): pos = np.zeros((0, 2)) _test_visual(qtbot, canvas, ScatterVisual(), pos=pos) From f827a6b35eb2c5f718afdab6fba4af49cb8c0f29 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 12:21:07 +0100 Subject: [PATCH 0449/1059] WIP: fix interact --- phy/plot/tests/test_interact.py | 49 ++++++++++++--------------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index cbec432af..d39fa133e 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -11,9 +11,8 @@ import numpy as np from vispy.util import keys -from pytest import yield_fixture -from ..base import BaseVisual, BaseCanvas +from ..base import BaseVisual from ..interact import Grid, Boxed @@ -51,29 +50,11 @@ def set_data(self): self.program['a_position'] = position.astype(np.float32) -@yield_fixture -def canvas_grid(qapp): - c = BaseCanvas(keys='interactive', interact=Grid(shape=(2, 3))) - yield c - c.close() - - -@yield_fixture -def canvas_boxed(qapp): - n = 6 - b = np.zeros((n, 4)) - - b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 1. / 3., n) - b[:, 2] = b[:, 3] = np.linspace(-1. + 1. / 3., 1., n) - - c = BaseCanvas(keys='interactive', interact=Boxed(box_bounds=b)) - yield c - c.close() - - -def get_interact(qtbot, canvas, box_index): +def _create_visual(qtbot, canvas, interact, box_index): c = canvas + interact.attach(c) + visual = MyTestVisual() visual.attach(c) visual.set_data() @@ -83,21 +64,21 @@ def get_interact(qtbot, canvas, box_index): c.show() qtbot.waitForWindowShown(c.native) - return c.interact - #------------------------------------------------------------------------------ # Test grid #------------------------------------------------------------------------------ -def test_grid_1(qtbot, canvas_grid): - c = canvas_grid +def test_grid_1(qtbot, canvas): + + c = canvas n = 1000 box_index = [[i, j] for i, j in product(range(2), range(3))] box_index = np.repeat(box_index, n, axis=0) - grid = get_interact(qtbot, canvas_grid, box_index) + grid = Grid(shape=(2, 3)) + _create_visual(qtbot, canvas, grid, box_index) # Zoom with the keyboard. c.events.key_press(key=keys.Key('+')) @@ -122,12 +103,18 @@ def test_grid_1(qtbot, canvas_grid): qtbot.stop() -def test_boxed_1(qtbot, canvas_boxed): - c = canvas_boxed +def test_boxed_1(qtbot, canvas): + + n = 6 + b = np.zeros((n, 4)) + + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 1. / 3., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 1. / 3., 1., n) n = 1000 box_index = np.repeat(np.arange(6), n, axis=0) - boxed = get_interact(qtbot, canvas_boxed, box_index) + boxed = Boxed(box_bounds=b) + _create_visual(qtbot, canvas, boxed, box_index) qtbot.stop() From 41fe7756f402173fb40a8ef6e703be32d12c3be9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 12:48:27 +0100 Subject: [PATCH 0450/1059] Fix bugs --- phy/plot/tests/test_base.py | 10 +++++++++- phy/plot/tests/test_interact.py | 4 ++-- phy/plot/tests/test_panzoom.py | 2 -- phy/plot/tests/test_visuals.py | 18 ++++++++++++------ 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 5ad17b46b..be4d51449 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -78,7 +78,15 @@ def set_data(self): v.update() -def test_base_interact(qtbot, canvas): +def test_base_interact(): + interact = BaseInteract() + assert interact.get_shader_declarations() == ('', '') + assert interact.get_pre_transforms() == '' + assert interact.get_transforms() == [] + interact.update_program(None) + + +def test_no_interact(qtbot, canvas): """Test a BaseVisual with a CPU transform and no interact.""" class TestVisual(BaseVisual): shader_name = 'simple' diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index d39fa133e..adf36dfbf 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -100,7 +100,7 @@ def test_grid_1(qtbot, canvas): c.events.key_press(key=keys.Key('r')) assert grid.zoom == 1. - qtbot.stop() + # qtbot.stop() def test_boxed_1(qtbot, canvas): @@ -117,4 +117,4 @@ def test_boxed_1(qtbot, canvas): boxed = Boxed(box_bounds=b) _create_visual(qtbot, canvas, boxed, box_index) - qtbot.stop() + # qtbot.stop() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 7d6ab1bf7..5c464f15d 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -67,8 +67,6 @@ def test_panzoom_basic_attrs(): setattr(pz, name, v * 2) assert getattr(pz, name) == v * 2 - assert list(pz.get_visuals()) == [] - def test_panzoom_basic_pan_zoom(): pz = PanZoom() diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 23315d145..c77e8daf3 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -8,7 +8,6 @@ #------------------------------------------------------------------------------ import numpy as np -from pytest import mark from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, BoxVisual, AxesVisual,) @@ -36,13 +35,20 @@ def test_scatter_empty(qtbot, canvas): _test_visual(qtbot, canvas, ScatterVisual(), pos=pos) -@mark.parametrize('marker_type', ScatterVisual._supported_marker_types) -def test_scatter_markers(qtbot, canvas_pz, marker_type): +def test_scatter_markers(qtbot, canvas_pz): + c = canvas_pz + n = 100 pos = .2 * np.random.randn(n, 2) - _test_visual(qtbot, canvas_pz, - ScatterVisual(marker_type=marker_type), - pos=pos) + + v = ScatterVisual(marker_type='vbar') + v.attach(c) + v.set_data(pos=pos) + + c.show() + qtbot.waitForWindowShown(c.native) + + # qtbot.stop() def test_scatter_custom(qtbot, canvas_pz): From 431c93c2e51c3868a416ae611943020d095d384e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 12:55:11 +0100 Subject: [PATCH 0451/1059] Updated interact --- phy/plot/interact.py | 6 +++--- phy/plot/tests/test_interact.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 950608d9b..8a2aba8be 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -42,19 +42,19 @@ def __init__(self, shape, box_var=None): def get_shader_declarations(self): return ('attribute vec2 a_box_index;\n' - 'uniform float u_zoom;\n', '') + 'uniform float u_grid_zoom;\n', '') def get_transforms(self): # Define the grid transform and clipping. m = 1. - .05 # Margin. - return [Scale(scale='u_zoom'), + return [Scale(scale='u_grid_zoom'), Scale(scale=(m, m)), Clip(bounds=[-m, -m, m, m]), Subplot(shape=self.shape, index='a_box_index'), ] def update_program(self, program): - program['u_zoom'] = self._zoom + program['u_grid_zoom'] = self._zoom # Only set the default box index if necessary. try: program['a_box_index'] diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index adf36dfbf..8a43dbc4b 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -14,6 +14,7 @@ from ..base import BaseVisual from ..interact import Grid, Boxed +from ..panzoom import PanZoom #------------------------------------------------------------------------------ @@ -53,7 +54,9 @@ def set_data(self): def _create_visual(qtbot, canvas, interact, box_index): c = canvas + # Attach the interact *and* PanZoom. The order matters! interact.attach(c) + PanZoom().attach(c) visual = MyTestVisual() visual.attach(c) From 837a766f69d0ee48d58a94f7a898a25affbbc5da Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 12:57:46 +0100 Subject: [PATCH 0452/1059] Parametrize names of uniform variables in PanZoom --- phy/plot/panzoom.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index a06f485ac..9978ad97b 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -54,6 +54,8 @@ def __init__(self, zmin=1e-5, zmax=1e5, xmin=None, xmax=None, ymin=None, ymax=None, + pan_var_name='u_pan', + zoom_var_name='u_zoom', ): """ Initialize the transform. @@ -79,6 +81,9 @@ def __init__(self, """ super(PanZoom, self).__init__() + self.pan_var_name = pan_var_name + self.zoom_var_name = zoom_var_name + self._aspect = aspect self._zmin = zmin @@ -98,16 +103,17 @@ def __init__(self, self._canvas_aspect = np.ones(2) def get_shader_declarations(self): - return 'uniform vec2 u_pan;\nuniform vec2 u_zoom;\n', '' + return ('uniform vec2 {};\n'.format(self.pan_var_name) + + 'uniform vec2 {};\n'.format(self.zoom_var_name)), '' def get_transforms(self): - return [Translate(translate='u_pan'), - Scale(scale='u_zoom')] + return [Translate(translate=self.pan_var_name), + Scale(scale=self.zoom_var_name)] def update_program(self, program): zoom = self._zoom_aspect() - program['u_pan'] = self._pan - program['u_zoom'] = zoom + program[self.pan_var_name] = self._pan + program[self.zoom_var_name] = zoom # Various properties # ------------------------------------------------------------------------- From eff7667777b181a6f9123a9562a0696580840d5a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 13:13:42 +0100 Subject: [PATCH 0453/1059] Add Stacked interact --- phy/plot/interact.py | 71 ++++++++++++++++++++++++++------- phy/plot/panzoom.py | 30 ++++---------- phy/plot/tests/test_interact.py | 18 +++++++-- 3 files changed, 79 insertions(+), 40 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 8a2aba8be..7a1ac5e83 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -24,21 +24,32 @@ class Grid(BaseInteract): """Grid interact. - NOTE: to be used in a grid, a visual must define `a_box_index`. + NOTE: to be used in a grid, a visual must define `a_box_index` + (by default) or another GLSL variable specified in `box_var`. + + Parameters + ---------- + + n_rows : int + Number of rows in the grid. + n_cols : int + Number of cols in the grid. + box_var : str + Name of the GLSL variable with the box index. """ - def __init__(self, shape, box_var=None): + def __init__(self, n_rows, n_cols, box_var=None): super(Grid, self).__init__() self._zoom = 1. # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' - self.shape = shape - assert len(shape) == 2 - assert shape[0] >= 1 - assert shape[1] >= 1 + self.shape = (n_rows, n_cols) + assert len(self.shape) == 2 + assert self.shape[0] >= 1 + assert self.shape[1] >= 1 def get_shader_declarations(self): return ('attribute vec2 a_box_index;\n' @@ -98,7 +109,17 @@ def on_key_press(self, event): class Boxed(BaseInteract): """Boxed interact. - NOTE: to be used in a boxed, a visual must define `a_box_index`. + NOTE: to be used in a boxed, a visual must define `a_box_index` + (by default) or another GLSL variable specified in `box_var`. + + Parameters + ---------- + + box_bounds : array-like + A (n, 4) array where each row contains the `(xmin, ymin, xmax, ymax)` + bounds of every box, in normalized device coordinates. + box_var : str + Name of the GLSL variable with the box index. """ def __init__(self, box_bounds, box_var=None): @@ -138,16 +159,36 @@ def update_program(self, program): program['n_boxes'] = self.n_boxes -class Stacked(BaseInteract): +class Stacked(Boxed): """Stacked interact. - NOTE: to be used in a stacked, a visual must define `a_box_index`. + NOTE: to be used in a stacked, a visual must define `a_box_index` + (by default) or another GLSL variable specified in `box_var`. + + Parameters + ---------- + + n_boxes : int + Number of boxes to stack vertically. + margin : int (0 by default) + The margin between the stacked subplots. Can be negative. Must be + between -1 and 1. The unit is relative to each box's size. + box_var : str + Name of the GLSL variable with the box index. """ + def __init__(self, n_boxes, margin=0, box_var=None): + + # The margin must be in [-1, 1] + margin = np.clip(margin, -1, 1) + # Normalize the margin. + margin = 2. * margin / float(n_boxes) + + # Signal bounds. + b = np.zeros((n_boxes, 4)) + b[:, 0] = -1 + b[:, 1] = np.linspace(-1, 1 - 2. / n_boxes + margin, n_boxes) + b[:, 2] = 1 + b[:, 3] = np.linspace(-1 + 2. / n_boxes - margin, 1., n_boxes) - # # Signal bounds. - # b = np.zeros((n_signals, 4)) - # b[:, 0] = -1 - # b[:, 1] = np.linspace(-1, 1 - 2. / n_signals, n_signals) - # b[:, 2] = 1 - # b[:, 3] = np.linspace(-1 + 2. / n_signals, 1., n_signals) + super(Stacked, self).__init__(b, box_var=box_var) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 9978ad97b..e863947af 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -54,33 +54,19 @@ def __init__(self, zmin=1e-5, zmax=1e5, xmin=None, xmax=None, ymin=None, ymax=None, + constrain_bounds=None, pan_var_name='u_pan', zoom_var_name='u_zoom', ): - """ - Initialize the transform. - - Parameters - ---------- - - aspect : float (default is None) - Indicate what is the aspect ratio of the object displayed. This is - necessary to convert pixel drag move in object space coordinates. - - pan : float, float (default is 0, 0) - Initial translation - - zoom : float, float (default is 1) - Initial zoom level - - zmin : float (default is 0.01) - Minimum zoom level - - zmax : float (default is 1000) - Maximum zoom level - """ super(PanZoom, self).__init__() + if constrain_bounds: + assert xmin is None + assert ymin is None + assert xmax is None + assert ymax is None + xmin, ymin, xmax, ymax = constrain_bounds + self.pan_var_name = pan_var_name self.zoom_var_name = zoom_var_name diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 8a43dbc4b..bacb03c03 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -13,8 +13,9 @@ from vispy.util import keys from ..base import BaseVisual -from ..interact import Grid, Boxed +from ..interact import Grid, Boxed, Stacked from ..panzoom import PanZoom +from ..transform import NDC #------------------------------------------------------------------------------ @@ -56,7 +57,7 @@ def _create_visual(qtbot, canvas, interact, box_index): # Attach the interact *and* PanZoom. The order matters! interact.attach(c) - PanZoom().attach(c) + PanZoom(constrain_bounds=NDC).attach(c) visual = MyTestVisual() visual.attach(c) @@ -80,7 +81,7 @@ def test_grid_1(qtbot, canvas): box_index = [[i, j] for i, j in product(range(2), range(3))] box_index = np.repeat(box_index, n, axis=0) - grid = Grid(shape=(2, 3)) + grid = Grid(2, 3) _create_visual(qtbot, canvas, grid, box_index) # Zoom with the keyboard. @@ -121,3 +122,14 @@ def test_boxed_1(qtbot, canvas): _create_visual(qtbot, canvas, boxed, box_index) # qtbot.stop() + + +def test_stacked_1(qtbot, canvas): + + n = 1000 + box_index = np.repeat(np.arange(6), n, axis=0) + + stacked = Stacked(n_boxes=6, margin=-10) + _create_visual(qtbot, canvas, stacked, box_index) + + # qtbot.stop() From cfb8ebacc87be43d09bd3380c7a4533550a6e8c5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 13:24:50 +0100 Subject: [PATCH 0454/1059] Allow arbitrary x coordinates in PlotVisual --- phy/plot/tests/test_visuals.py | 19 ++++++++++--------- phy/plot/visuals.py | 32 +++++++++++++++++++++----------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index c77e8daf3..99a5586d4 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,8 @@ import numpy as np from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, - BoxVisual, AxesVisual,) + BoxVisual, AxesVisual, + ) #------------------------------------------------------------------------------ @@ -74,34 +75,34 @@ def test_scatter_custom(qtbot, canvas_pz): #------------------------------------------------------------------------------ def test_plot_empty(qtbot, canvas): - data = np.zeros((1, 0)) + y = np.zeros((1, 0)) _test_visual(qtbot, canvas, PlotVisual(), - data=data) + y=y) def test_plot_0(qtbot, canvas_pz): - data = np.zeros((1, 10)) + y = np.zeros((1, 10)) _test_visual(qtbot, canvas_pz, PlotVisual(), - data=data) + y=y) def test_plot_1(qtbot, canvas_pz): - data = .2 * np.random.randn(1, 10) + y = .2 * np.random.randn(1, 10) _test_visual(qtbot, canvas_pz, PlotVisual(), - data=data) + y=y) def test_plot_2(qtbot, canvas_pz): n_signals = 50 - data = 20 * np.random.randn(n_signals, 10) + y = 20 * np.random.randn(n_signals, 10) # Signal colors. c = np.random.uniform(.5, 1, size=(n_signals, 4)) c[:, 3] = .5 _test_visual(qtbot, canvas_pz, PlotVisual(), - data=data, data_bounds=[-50, 50], + y=y, data_bounds=[-50, 50], signal_colors=c) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index e0da2445d..b2cd5f4bf 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -115,6 +115,10 @@ def _get_color(color, default): return color +def _get_linear_x(n_signals, n_samples): + return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -203,27 +207,33 @@ def get_transforms(self): ] def set_data(self, - data=None, + x=None, + y=None, depth=None, data_bounds=None, signal_colors=None, ): - pos = _check_pos_2D(data) - n_signals, n_samples = data.shape - n = n_signals * n_samples - # Generate the x coordinates. - x = np.linspace(-1., 1., n_samples) - x = np.tile(x, n_signals) - assert x.shape == (n,) + # Default x coordinates. + if x is None: + assert y is not None + x = _get_linear_x(*y.shape) + + assert x is not None + assert y is not None + assert x.ndim == 2 + assert x.shape == y.shape + n_signals, n_samples = x.shape + n = n_signals * n_samples # Generate the (n, 2) pos array. pos = np.empty((n, 2), dtype=np.float32) - pos[:, 0] = x - pos[:, 1] = data.ravel() + pos[:, 0] = x.ravel() + pos[:, 1] = y.ravel() + pos = _check_pos_2D(pos) # Generate the complete data_bounds 4-tuple from the specified 2-tuple. - self.data_bounds = _get_data_bounds_1D(data_bounds, data) + self.data_bounds = _get_data_bounds_1D(data_bounds, y) # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) From 18327d24dd8733b4c5bf3c63e8067624e88c33e8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 14:43:24 +0100 Subject: [PATCH 0455/1059] Minor updates in visuals --- phy/plot/visuals.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index b2cd5f4bf..1d02b08b2 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -53,7 +53,7 @@ def _get_data_bounds_1D(data_bounds, data): def _check_pos_2D(pos): """Check position data before GPU uploading.""" assert pos is not None - pos = np.asarray(pos) + pos = np.asarray(pos, dtype=np.float32) assert pos.ndim == 2 return pos @@ -126,7 +126,10 @@ def _get_linear_x(n_signals, n_samples): class ScatterVisual(BaseVisual): shader_name = 'scatter' gl_primitive_type = 'points' + _default_marker_size = 10. + _default_marker_type = 'disc' + _default_color = (1, 1, 1, 1) _supported_marker_types = ( 'arrow', 'asterisk', @@ -156,7 +159,7 @@ def __init__(self, marker_type=None): self.n_points = None # Set the marker type. - self.marker_type = marker_type or 'disc' + self.marker_type = marker_type or self._default_marker_type assert self.marker_type in self._supported_marker_types # Enable transparency. @@ -189,7 +192,7 @@ def set_data(self, pos_tr = self.apply_cpu_transforms(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) self.program['a_size'] = _get_attr(size, self._default_marker_size, n) - self.program['a_color'] = _get_attr(colors, (1, 1, 1, 1), n) + self.program['a_color'] = _get_attr(colors, self._default_color, n) class PlotVisual(BaseVisual): From f6fd0ab923f814321ad1db2d12e4ddad81e129ff Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 15:00:47 +0100 Subject: [PATCH 0456/1059] BaseCanvas is interactive --- phy/plot/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/phy/plot/base.py b/phy/plot/base.py index 2f598d3a8..4b6c92a96 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -211,6 +211,7 @@ def update(self): class BaseCanvas(Canvas): """A blank VisPy canvas with a custom event system that keeps the order.""" def __init__(self, *args, **kwargs): + kwargs['keys'] = 'interactive' super(BaseCanvas, self).__init__(*args, **kwargs) self._events = EventEmitter() self.interacts = [] From 3b6c4f55be3172b22a290c616d6705ba1756a7b7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 15:02:21 +0100 Subject: [PATCH 0457/1059] WIP: start plot module --- phy/plot/plot.py | 185 ++++++++++++++++++++++++++++++++++++ phy/plot/tests/test_plot.py | 40 ++++++++ 2 files changed, 225 insertions(+) create mode 100644 phy/plot/plot.py create mode 100644 phy/plot/tests/test_plot.py diff --git a/phy/plot/plot.py b/phy/plot/plot.py new file mode 100644 index 000000000..71a92ac9c --- /dev/null +++ b/phy/plot/plot.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- + +"""Plotting interface.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from itertools import groupby +from collections import defaultdict + +import numpy as np + +from .base import BaseCanvas +from .interact import Grid, Boxed, Stacked +from .visuals import ScatterVisual, PlotVisual, HistogramVisual + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +class Accumulator(object): + def __init__(self): + self._size = defaultdict(int) + self._data = defaultdict(list) + + @property + def size(self): + return self._size[list(self._size.keys())[0]] + + @property + def data(self): + return {name: self[name] for name in self._data} + + def __setitem__(self, name, val): + self._size[name] += len(val) + self._data[name].append(val) + + def __getitem__(self, name): + size = self.size + assert all(s == size for s in self._size.values()) + return np.vstack(self._data[name]).astype(np.float32) + + +#------------------------------------------------------------------------------ +# Base plotting interface +#------------------------------------------------------------------------------ + +class SubView(object): + def __init__(self, idx): + self.spec = {'idx': idx} + + @property + def visual_class(self): + return self.spec.get('visual_class', None) + + def _set(self, visual_class, loc): + self.spec['visual_class'] = visual_class + self.spec.update(loc) + + def __getattr__(self, name): + return self.spec[name] + + def scatter(self, x, y, color=None, size=None, marker=None): + # Validate x and y. + assert x.ndim == y.ndim == 1 + assert x.shape == y.shape + n = x.shape[0] + # Default color. + if color is None: + color = np.ones((n, 4), dtype=np.float32) + # Default size. + if size is None: + ds = ScatterVisual._default_marker_size + size = ds * np.ones((n, 1), dtype=np.float32) + # Default marker. + if marker is None: + marker = ScatterVisual._default_marker_type + # Set the spec. + loc = dict(x=x, y=y, color=color, size=size, marker=marker) + return self._set(ScatterVisual, loc) + + def plot(self, x, y, color=None): + loc = locals() + return self._set(PlotVisual, loc) + + def hist(self, hist, color=None): + loc = locals() + return self._set(HistogramVisual, loc) + + def __repr__(self): + return str(self.spec) + + +class BaseView(BaseCanvas): + def __init__(self, interacts): + super(BaseView, self).__init__() + # Attach the passed interacts to the current canvas. + for interact in interacts: + interact.attach(self) + self.subviews = {} + + # To override + # ------------------------------------------------------------------------- + + def get_box_ndim(self): + raise NotImplementedError() + + def iter_index(self): + raise NotImplementedError() + + # Internal methods + # ------------------------------------------------------------------------- + + def iter_subviews(self): + for idx in self.iter_index(): + sv = self.subviews.get(idx, None) + if sv: + yield sv + + def __getitem__(self, idx): + sv = SubView(idx) + self.subviews[idx] = sv + return sv + + def _build_scatter(self, subviews, marker): + """Build all scatter subviews with the same marker type.""" + + ac = Accumulator() + for sv in subviews: + assert sv.marker == marker + n = len(sv.x) + ac['pos'] = np.c_[sv.x, sv.y] + ac['color'] = sv.color + ac['size'] = sv.size + ac['box_index'] = np.tile(sv.idx, (n, 1)) + + v = ScatterVisual() + v.attach(self) + v.set_data(pos=ac['pos'], colors=ac['color'], size=ac['size']) + v.program['a_box_index'] = ac['box_index'] + + def build(self): + """Build all visuals.""" + for visual_class, subviews in groupby(self.iter_subviews(), + lambda sv: sv.visual_class): + if visual_class == ScatterVisual: + for marker, subviews_scatter in groupby(subviews, + lambda sv: sv.marker): + self._build_scatter(subviews_scatter, marker) + elif visual_class == PlotVisual: + self._build_plot(subviews) + elif visual_class == HistogramVisual: + self._build_histogram(subviews) + + +#------------------------------------------------------------------------------ +# Plotting interface +#------------------------------------------------------------------------------ + +class GridView(BaseView): + def __init__(self, n_rows, n_cols): + self.n_rows, self.n_cols = n_rows, n_cols + interacts = [Grid(n_rows, n_cols)] + super(GridView, self).__init__(interacts) + + def get_box_ndim(self): + return 2 + + def iter_index(self): + for i in range(self.n_rows): + for j in range(self.n_cols): + yield (i, j) + + +class StackedView(BaseView): + def __init__(self, n_plots): + super(StackedView, self).__init__() + + +class BoxedView(BaseView): + def __init__(self, box_positions): + super(BoxedView, self).__init__() diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py new file mode 100644 index 000000000..065c4e673 --- /dev/null +++ b/phy/plot/tests/test_plot.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- + +"""Test plotting interface.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +from ..plot import GridView + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def _show(qtbot, view, stop=False): + view.build() + view.show() + qtbot.waitForWindowShown(view.native) + if stop: # pragma: no cover + qtbot.stop() + view.close() + + +#------------------------------------------------------------------------------ +# Test plotting interface +#------------------------------------------------------------------------------ + +def test_subplot_view(qtbot): + view = GridView(2, 3) + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + view[1, 1].scatter(x, y) + + _show(qtbot, view) From a4dc893a62cefacd53709713e24a30b03404e87e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 15:08:34 +0100 Subject: [PATCH 0458/1059] Default color in visuals --- phy/plot/visuals.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 1d02b08b2..943c0ad5c 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -119,6 +119,9 @@ def _get_linear_x(n_signals, n_samples): return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) +DEFAULT_COLOR = (0.03, 0.57, 0.98, .75) + + #------------------------------------------------------------------------------ # Visuals #------------------------------------------------------------------------------ @@ -129,7 +132,7 @@ class ScatterVisual(BaseVisual): _default_marker_size = 10. _default_marker_type = 'disc' - _default_color = (1, 1, 1, 1) + _default_color = DEFAULT_COLOR _supported_marker_types = ( 'arrow', 'asterisk', @@ -198,6 +201,7 @@ def set_data(self, class PlotVisual(BaseVisual): shader_name = 'plot' gl_primitive_type = 'line_strip' + _default_color = DEFAULT_COLOR def __init__(self): super(PlotVisual, self).__init__() @@ -246,7 +250,7 @@ def set_data(self, self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) # Signal colors. - signal_colors = _get_texture(signal_colors, (1,) * 4, + signal_colors = _get_texture(signal_colors, self._default_color, n_signals, [0, 1]) self.program['u_signal_colors'] = Texture2D(signal_colors) @@ -257,6 +261,7 @@ def set_data(self, class HistogramVisual(BaseVisual): shader_name = 'histogram' gl_primitive_type = 'triangles' + _default_color = DEFAULT_COLOR def __init__(self): super(HistogramVisual, self).__init__() @@ -297,7 +302,7 @@ def set_data(self, # Hist colors. self.program['u_hist_colors'] = _get_texture(hist_colors, - (1, 1, 1, 1), + self._default_color, n_hists, [0, 1]) # Hist bounds. From b37b91354e55d80063c445b6b80d63902f9c8af1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 15:18:59 +0100 Subject: [PATCH 0459/1059] Bug fixes --- phy/plot/glsl/scatter.frag | 4 ++-- phy/plot/plot.py | 23 +++++++++++++++++------ phy/plot/tests/test_plot.py | 9 ++++++++- phy/plot/tests/test_visuals.py | 2 +- phy/plot/visuals.py | 16 ++++++++-------- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag index fc092a7e9..9a6782b4c 100644 --- a/phy/plot/glsl/scatter.frag +++ b/phy/plot/glsl/scatter.frag @@ -1,5 +1,5 @@ #include "antialias/filled.glsl" -#include "markers/%MARKER_TYPE.glsl" +#include "markers/%MARKER.glsl" varying vec4 v_color; varying float v_size; @@ -8,6 +8,6 @@ void main() { vec2 P = gl_PointCoord.xy - vec2(0.5,0.5); float point_size = v_size + 2. * (1.0 + 1.5*1.0); - float distance = marker_%MARKER_TYPE(P*point_size, v_size); + float distance = marker_%MARKER(P*point_size, v_size); gl_FragColor = filled(distance, 1.0, 1.0, v_color); } diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 71a92ac9c..f0cd1eb5f 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -13,7 +13,7 @@ import numpy as np from .base import BaseCanvas -from .interact import Grid, Boxed, Stacked +from .interact import Grid # Boxed, Stacked from .visuals import ScatterVisual, PlotVisual, HistogramVisual @@ -70,14 +70,25 @@ def scatter(self, x, y, color=None, size=None, marker=None): n = x.shape[0] # Default color. if color is None: - color = np.ones((n, 4), dtype=np.float32) + color = ScatterVisual._default_color + color = np.asarray(color) + if color.ndim == 1: + color = np.tile(color, (n, 1)).astype(np.float32) + assert color.shape == (n, 4) # Default size. if size is None: - ds = ScatterVisual._default_marker_size - size = ds * np.ones((n, 1), dtype=np.float32) + size = ScatterVisual._default_marker_size + if not hasattr(size, '__len__'): + size = [size] + size = np.asarray(size) + if size.size == 1: + size = np.tile(size, (n, 1)).astype(np.float32) + if size.ndim == 1: + size = size[:, np.newaxis] + assert size.shape == (n, 1) # Default marker. if marker is None: - marker = ScatterVisual._default_marker_type + marker = ScatterVisual._default_marker # Set the spec. loc = dict(x=x, y=y, color=color, size=size, marker=marker) return self._set(ScatterVisual, loc) @@ -137,7 +148,7 @@ def _build_scatter(self, subviews, marker): ac['size'] = sv.size ac['box_index'] = np.tile(sv.idx, (n, 1)) - v = ScatterVisual() + v = ScatterVisual(marker=marker) v.attach(self) v.set_data(pos=ac['pos'], colors=ac['color'], size=ac['size']) v.program['a_box_index'] = ac['box_index'] diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 065c4e673..c4d913942 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -35,6 +35,13 @@ def test_subplot_view(qtbot): x = np.random.randn(n) y = np.random.randn(n) - view[1, 1].scatter(x, y) + + view[0, 1].scatter(x, y) + view[0, 2].scatter(x, y, color=np.random.uniform(.5, .8, size=(n, 4))) + + view[1, 0].scatter(x, y, size=np.random.uniform(5, 20, size=n)) + view[1, 1] + view[1, 2].scatter(x[::5], y[::5], marker='heart', + color=(1, 0, 0, .25), size=20) _show(qtbot, view) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 99a5586d4..1847b5213 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -42,7 +42,7 @@ def test_scatter_markers(qtbot, canvas_pz): n = 100 pos = .2 * np.random.randn(n, 2) - v = ScatterVisual(marker_type='vbar') + v = ScatterVisual(marker='vbar') v.attach(c) v.set_data(pos=pos) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 943c0ad5c..14417548d 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -119,7 +119,7 @@ def _get_linear_x(n_signals, n_samples): return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) -DEFAULT_COLOR = (0.03, 0.57, 0.98, .75) +DEFAULT_COLOR = (0.03, 0.57, 0.98, .75) #------------------------------------------------------------------------------ @@ -131,9 +131,9 @@ class ScatterVisual(BaseVisual): gl_primitive_type = 'points' _default_marker_size = 10. - _default_marker_type = 'disc' + _default_marker = 'disc' _default_color = DEFAULT_COLOR - _supported_marker_types = ( + _supported_markers = ( 'arrow', 'asterisk', 'chevron', @@ -155,15 +155,15 @@ class ScatterVisual(BaseVisual): 'vbar', ) - def __init__(self, marker_type=None): + def __init__(self, marker=None): super(ScatterVisual, self).__init__() # Default bounds. self.data_bounds = NDC self.n_points = None # Set the marker type. - self.marker_type = marker_type or self._default_marker_type - assert self.marker_type in self._supported_marker_types + self.marker = marker or self._default_marker + assert self.marker in self._supported_markers # Enable transparency. _enable_depth_mask() @@ -171,7 +171,7 @@ def __init__(self, marker_type=None): def get_shaders(self): v, f = super(ScatterVisual, self).get_shaders() # Replace the marker type in the shader. - f = f.replace('%MARKER_TYPE', self.marker_type) + f = f.replace('%MARKER', self.marker) return v, f def get_transforms(self): @@ -181,7 +181,7 @@ def set_data(self, pos=None, depth=None, colors=None, - marker_type=None, + marker=None, size=None, data_bounds=None, ): From cf15b0679be38ddfd79ab257f0f67d4370333de9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 15:58:08 +0100 Subject: [PATCH 0460/1059] WIP: refactor array utilities in plot --- phy/plot/plot.py | 29 +++++-------------- phy/plot/tests/test_visuals.py | 2 +- phy/plot/visuals.py | 52 ++++++++++++++-------------------- 3 files changed, 29 insertions(+), 54 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index f0cd1eb5f..be609362b 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -14,7 +14,7 @@ from .base import BaseCanvas from .interact import Grid # Boxed, Stacked -from .visuals import ScatterVisual, PlotVisual, HistogramVisual +from .visuals import _get_array, ScatterVisual, PlotVisual, HistogramVisual #------------------------------------------------------------------------------ @@ -22,6 +22,7 @@ #------------------------------------------------------------------------------ class Accumulator(object): + """Accumulate arrays for concatenation.""" def __init__(self): self._size = defaultdict(int) self._data = defaultdict(list) @@ -68,27 +69,11 @@ def scatter(self, x, y, color=None, size=None, marker=None): assert x.ndim == y.ndim == 1 assert x.shape == y.shape n = x.shape[0] - # Default color. - if color is None: - color = ScatterVisual._default_color - color = np.asarray(color) - if color.ndim == 1: - color = np.tile(color, (n, 1)).astype(np.float32) - assert color.shape == (n, 4) - # Default size. - if size is None: - size = ScatterVisual._default_marker_size - if not hasattr(size, '__len__'): - size = [size] - size = np.asarray(size) - if size.size == 1: - size = np.tile(size, (n, 1)).astype(np.float32) - if size.ndim == 1: - size = size[:, np.newaxis] - assert size.shape == (n, 1) + # Set the color and size. + color = _get_array(color, (n, 4), ScatterVisual._default_color) + size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) # Default marker. - if marker is None: - marker = ScatterVisual._default_marker + marker = marker or ScatterVisual._default_marker # Set the spec. loc = dict(x=x, y=y, color=color, size=size, marker=marker) return self._set(ScatterVisual, loc) @@ -150,7 +135,7 @@ def _build_scatter(self, subviews, marker): v = ScatterVisual(marker=marker) v.attach(self) - v.set_data(pos=ac['pos'], colors=ac['color'], size=ac['size']) + v.set_data(pos=ac['pos'], color=ac['color'], size=ac['size']) v.program['a_box_index'] = ac['box_index'] def build(self): diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 1847b5213..980457e12 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -67,7 +67,7 @@ def test_scatter_custom(qtbot, canvas_pz): s = 5 + 20 * np.random.rand(n) _test_visual(qtbot, canvas_pz, ScatterVisual(), - pos=pos, colors=c, size=s) + pos=pos, color=c, size=s) #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 14417548d..346d72507 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -50,6 +50,19 @@ def _get_data_bounds_1D(data_bounds, data): return data_bounds +def _get_array(val, shape, default=None): + """Ensure an object is an array with the specified shape.""" + assert val is not None or default is not None + out = np.zeros(shape, dtype=np.float32) + # This solves `ValueError: could not broadcast input array from shape (n) + # into shape (n, 1)`. + if val is not None and isinstance(val, np.ndarray) and out.ndim > val.ndim: + val = val[:, np.newaxis] + out[...] = val if val is not None else default + assert out.shape == shape + return out + + def _check_pos_2D(pos): """Check position data before GPU uploading.""" assert pos is not None @@ -61,21 +74,9 @@ def _check_pos_2D(pos): def _get_pos_depth(pos_tr, depth): """Prepare a (N, 3) position-depth array for GPU uploading.""" n = pos_tr.shape[0] - pos_tr = np.asarray(pos_tr, dtype=np.float32) - assert pos_tr.shape == (n, 2) - - # Set the depth. - if depth is None: - depth = np.zeros(n, dtype=np.float32) - depth = np.asarray(depth, dtype=np.float32) - assert depth.shape == (n,) - - # Set the a_position attribute. - pos_depth = np.empty((n, 3), dtype=np.float32) - pos_depth[:, :2] = pos_tr - pos_depth[:, 2] = depth - - return pos_depth + pos_tr = _get_array(pos_tr, (n, 2)) + depth = _get_array(depth, (n, 1), 0) + return np.c_[pos_tr, depth] def _get_hist_max(hist): @@ -86,19 +87,6 @@ def _get_hist_max(hist): return hist_max -def _get_attr(attr, default, n): - """Prepare an attribute for GPU uploading.""" - if not hasattr(default, '__len__'): - default = [default] - if attr is None: - attr = np.tile(default, (n, 1)) - attr = np.asarray(attr, dtype=np.float32) - if attr.ndim == 1: - attr = attr[:, np.newaxis] - assert attr.shape == (n, len(default)) - return attr - - def _get_index(n_items, item_size, n): """Prepare an index attribute for GPU uploading.""" index = np.arange(n_items) @@ -180,7 +168,7 @@ def get_transforms(self): def set_data(self, pos=None, depth=None, - colors=None, + color=None, marker=None, size=None, data_bounds=None, @@ -194,8 +182,10 @@ def set_data(self, pos_tr = self.apply_cpu_transforms(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) - self.program['a_size'] = _get_attr(size, self._default_marker_size, n) - self.program['a_color'] = _get_attr(colors, self._default_color, n) + self.program['a_size'] = _get_array(size, (n, 1), + self._default_marker_size) + self.program['a_color'] = _get_array(color, (n, 4), + self._default_color) class PlotVisual(BaseVisual): From 1ef043930eaab53bdd513d54dd38d78cf1f50412 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 16:05:43 +0100 Subject: [PATCH 0461/1059] Change grid interact keyboard shortcuts --- phy/plot/interact.py | 4 +--- phy/plot/tests/test_interact.py | 14 +++++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 7a1ac5e83..0b3b9c9f6 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -86,12 +86,10 @@ def zoom(self, value): def on_key_press(self, event): """Pan and zoom with the keyboard.""" super(Grid, self).on_key_press(event) - if event.modifiers: - return key = event.key # Zoom. - if key in ('-', '+'): + if key in ('-', '+') and event.modifiers == ('Control',): k = .05 if key == '+' else -.05 self.zoom *= math.exp(1.5 * k) self.update() diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index bacb03c03..f4f680bae 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -57,7 +57,7 @@ def _create_visual(qtbot, canvas, interact, box_index): # Attach the interact *and* PanZoom. The order matters! interact.attach(c) - PanZoom(constrain_bounds=NDC).attach(c) + PanZoom(aspect=None, constrain_bounds=NDC).attach(c) visual = MyTestVisual() visual.attach(c) @@ -84,22 +84,22 @@ def test_grid_1(qtbot, canvas): grid = Grid(2, 3) _create_visual(qtbot, canvas, grid, box_index) - # Zoom with the keyboard. + # No effect without modifiers. c.events.key_press(key=keys.Key('+')) + assert grid.zoom == 1. + + # Zoom with the keyboard. + c.events.key_press(key=keys.Key('+'), modifiers=(keys.CONTROL,)) assert grid.zoom > 1 # Unzoom with the keyboard. - c.events.key_press(key=keys.Key('-')) + c.events.key_press(key=keys.Key('-'), modifiers=(keys.CONTROL,)) assert grid.zoom == 1. # Set the zoom explicitly. grid.zoom = 2 assert grid.zoom == 2. - # No effect with modifiers. - c.events.key_press(key=keys.Key('r'), modifiers=(keys.CONTROL,)) - assert grid.zoom == 2. - # Press 'R'. c.events.key_press(key=keys.Key('r')) assert grid.zoom == 1. From 784d903b9125f1926efeae7871c80240f608de62 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 16:21:51 +0100 Subject: [PATCH 0462/1059] WIP: plot function --- phy/plot/glsl/plot.vert | 4 ++-- phy/plot/plot.py | 37 +++++++++++++++++++++++++++++----- phy/plot/tests/test_plot.py | 16 ++++++++++++++- phy/plot/tests/test_visuals.py | 2 +- phy/plot/visuals.py | 8 ++++---- 5 files changed, 54 insertions(+), 13 deletions(-) diff --git a/phy/plot/glsl/plot.vert b/phy/plot/glsl/plot.vert index 3992106d5..a7a0dbf00 100644 --- a/phy/plot/glsl/plot.vert +++ b/phy/plot/glsl/plot.vert @@ -3,7 +3,7 @@ attribute vec3 a_position; attribute float a_signal_index; // 0..n_signals-1 -uniform sampler2D u_signal_colors; +uniform sampler2D u_plot_colors; uniform float n_signals; varying vec4 v_color; @@ -14,6 +14,6 @@ void main() { gl_Position = transform(xy); gl_Position.z = a_position.z; - v_color = fetch_texture(a_signal_index, u_signal_colors, n_signals); + v_color = fetch_texture(a_signal_index, u_plot_colors, n_signals); v_signal_index = a_signal_index; } diff --git a/phy/plot/plot.py b/phy/plot/plot.py index be609362b..c6fd5031c 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -14,6 +14,8 @@ from .base import BaseCanvas from .interact import Grid # Boxed, Stacked +from .panzoom import PanZoom +from .transform import NDC from .visuals import _get_array, ScatterVisual, PlotVisual, HistogramVisual @@ -40,8 +42,8 @@ def __setitem__(self, name, val): self._data[name].append(val) def __getitem__(self, name): - size = self.size - assert all(s == size for s in self._size.values()) + # size = self.size + # assert all(s == size for s in self._size.values()) return np.vstack(self._data[name]).astype(np.float32) @@ -79,7 +81,15 @@ def scatter(self, x, y, color=None, size=None, marker=None): return self._set(ScatterVisual, loc) def plot(self, x, y, color=None): - loc = locals() + x = np.atleast_2d(x) + y = np.atleast_2d(y) + # Validate x and y. + assert x.ndim == y.ndim == 2 + assert x.shape == y.shape + n_plots, n_samples = x.shape + color = _get_array(color, (n_plots, 4), PlotVisual._default_color) + # Set the spec. + loc = dict(x=x, y=y, color=color) return self._set(PlotVisual, loc) def hist(self, hist, color=None): @@ -87,7 +97,7 @@ def hist(self, hist, color=None): return self._set(HistogramVisual, loc) def __repr__(self): - return str(self.spec) + return str(self.spec) # pragma: no cover class BaseView(BaseCanvas): @@ -138,6 +148,22 @@ def _build_scatter(self, subviews, marker): v.set_data(pos=ac['pos'], color=ac['color'], size=ac['size']) v.program['a_box_index'] = ac['box_index'] + def _build_plot(self, subviews): + """Build all plot subviews.""" + + ac = Accumulator() + for sv in subviews: + n = sv.x.size + ac['x'] = sv.x + ac['y'] = sv.y + ac['plot_colors'] = sv.color + ac['box_index'] = np.tile(sv.idx, (n, 1)) + + v = PlotVisual() + v.attach(self) + v.set_data(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors']) + v.program['a_box_index'] = ac['box_index'] + def build(self): """Build all visuals.""" for visual_class, subviews in groupby(self.iter_subviews(), @@ -159,7 +185,8 @@ def build(self): class GridView(BaseView): def __init__(self, n_rows, n_cols): self.n_rows, self.n_cols = n_rows, n_cols - interacts = [Grid(n_rows, n_cols)] + pz = PanZoom(aspect=None, constrain_bounds=NDC) + interacts = [Grid(n_rows, n_cols), pz] super(GridView, self).__init__(interacts) def get_box_ndim(self): diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index c4d913942..39853d5ac 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -10,6 +10,7 @@ import numpy as np from ..plot import GridView +from ..visuals import _get_linear_x #------------------------------------------------------------------------------ @@ -29,7 +30,7 @@ def _show(qtbot, view, stop=False): # Test plotting interface #------------------------------------------------------------------------------ -def test_subplot_view(qtbot): +def test_grid_scatter(qtbot): view = GridView(2, 3) n = 1000 @@ -45,3 +46,16 @@ def test_subplot_view(qtbot): color=(1, 0, 0, .25), size=20) _show(qtbot, view) + + +def test_grid_plot(qtbot): + view = GridView(1, 2) + n_plots, n_samples = 10, 50 + + x = _get_linear_x(n_plots, n_samples) + y = np.random.randn(n_plots, n_samples) + + view[0, 0].plot(x, y) + view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) + + _show(qtbot, view) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 980457e12..2e2456618 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -103,7 +103,7 @@ def test_plot_2(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, PlotVisual(), y=y, data_bounds=[-50, 50], - signal_colors=c) + plot_colors=c) #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 346d72507..f9aae2986 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -208,7 +208,7 @@ def set_data(self, y=None, depth=None, data_bounds=None, - signal_colors=None, + plot_colors=None, ): # Default x coordinates. @@ -240,9 +240,9 @@ def set_data(self, self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) # Signal colors. - signal_colors = _get_texture(signal_colors, self._default_color, - n_signals, [0, 1]) - self.program['u_signal_colors'] = Texture2D(signal_colors) + plot_colors = _get_texture(plot_colors, self._default_color, + n_signals, [0, 1]) + self.program['u_plot_colors'] = Texture2D(plot_colors) # Number of signals. self.program['n_signals'] = n_signals From da1b8d90be44e4eecfe23369b73ad9d956d9e0e4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 16:47:05 +0100 Subject: [PATCH 0463/1059] WIP: boxed and stacked views --- phy/plot/interact.py | 1 + phy/plot/plot.py | 74 +++++++++++++++++++++++++++++-------- phy/plot/tests/test_plot.py | 41 +++++++++++++++++++- phy/plot/visuals.py | 5 ++- 4 files changed, 103 insertions(+), 18 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 0b3b9c9f6..73d351504 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -188,5 +188,6 @@ def __init__(self, n_boxes, margin=0, box_var=None): b[:, 1] = np.linspace(-1, 1 - 2. / n_boxes + margin, n_boxes) b[:, 2] = 1 b[:, 3] = np.linspace(-1 + 2. / n_boxes - margin, 1., n_boxes) + b = b[::-1, :] super(Stacked, self).__init__(b, box_var=box_var) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index c6fd5031c..228df4e17 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -13,7 +13,7 @@ import numpy as np from .base import BaseCanvas -from .interact import Grid # Boxed, Stacked +from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom from .transform import NDC from .visuals import _get_array, ScatterVisual, PlotVisual, HistogramVisual @@ -59,9 +59,9 @@ def __init__(self, idx): def visual_class(self): return self.spec.get('visual_class', None) - def _set(self, visual_class, loc): + def _set(self, visual_class, spec): self.spec['visual_class'] = visual_class - self.spec.update(loc) + self.spec.update(spec) def __getattr__(self, name): return self.spec[name] @@ -77,8 +77,8 @@ def scatter(self, x, y, color=None, size=None, marker=None): # Default marker. marker = marker or ScatterVisual._default_marker # Set the spec. - loc = dict(x=x, y=y, color=color, size=size, marker=marker) - return self._set(ScatterVisual, loc) + spec = dict(x=x, y=y, color=color, size=size, marker=marker) + return self._set(ScatterVisual, spec) def plot(self, x, y, color=None): x = np.atleast_2d(x) @@ -87,14 +87,22 @@ def plot(self, x, y, color=None): assert x.ndim == y.ndim == 2 assert x.shape == y.shape n_plots, n_samples = x.shape + # Get the colors. color = _get_array(color, (n_plots, 4), PlotVisual._default_color) # Set the spec. - loc = dict(x=x, y=y, color=color) - return self._set(PlotVisual, loc) - - def hist(self, hist, color=None): - loc = locals() - return self._set(HistogramVisual, loc) + spec = dict(x=x, y=y, color=color) + return self._set(PlotVisual, spec) + + def hist(self, data, color=None): + # Validate data. + if data.ndim == 1: + data = data[np.newaxis, :] + assert data.ndim == 2 + n_hists, n_samples = data.shape + # Get the colors. + color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) + spec = dict(data=data, color=color) + return self._set(HistogramVisual, spec) def __repr__(self): return str(self.spec) # pragma: no cover @@ -164,6 +172,22 @@ def _build_plot(self, subviews): v.set_data(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors']) v.program['a_box_index'] = ac['box_index'] + def _build_histogram(self, subviews): + """Build all histogram subviews.""" + + ac = Accumulator() + for sv in subviews: + n = sv.data.size + ac['data'] = sv.data + ac['hist_colors'] = sv.color + # NOTE: the `6 * ` comes from the histogram tesselation. + ac['box_index'] = np.tile(sv.idx, (6 * n, 1)) + + v = HistogramVisual() + v.attach(self) + v.set_data(hist=ac['data'], hist_colors=ac['hist_colors']) + v.program['a_box_index'] = ac['box_index'] + def build(self): """Build all visuals.""" for visual_class, subviews in groupby(self.iter_subviews(), @@ -198,11 +222,31 @@ def iter_index(self): yield (i, j) +class BoxedView(BaseView): + def __init__(self, box_bounds): + self.n_plots = len(box_bounds) + pz = PanZoom(aspect=None, constrain_bounds=NDC) + interacts = [Boxed(box_bounds), pz] + super(BoxedView, self).__init__(interacts) + + def get_box_ndim(self): + return 1 + + def iter_index(self): + for i in range(self.n_plots): + yield i + + class StackedView(BaseView): def __init__(self, n_plots): - super(StackedView, self).__init__() + self.n_plots = n_plots + pz = PanZoom(aspect=None, constrain_bounds=NDC) + interacts = [Stacked(n_plots, margin=.1), pz] + super(StackedView, self).__init__(interacts) + def get_box_ndim(self): + return 1 -class BoxedView(BaseView): - def __init__(self, box_positions): - super(BoxedView, self).__init__() + def iter_index(self): + for i in range(self.n_plots): + yield i diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 39853d5ac..e4d696820 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -9,7 +9,7 @@ import numpy as np -from ..plot import GridView +from ..plot import GridView, BoxedView, StackedView from ..visuals import _get_linear_x @@ -59,3 +59,42 @@ def test_grid_plot(qtbot): view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) _show(qtbot, view) + + +def test_grid_hist(qtbot): + view = GridView(3, 3) + + hist = np.random.rand(3, 3, 20) + + for i in range(3): + for j in range(3): + view[i, j].hist(hist[i, j, :], + color=np.random.uniform(.5, .8, size=4)) + + _show(qtbot, view) + + +def test_grid_complete(qtbot): + view = GridView(2, 2) + t = _get_linear_x(1, 1000).ravel() + + view[0, 0].scatter(*np.random.randn(2, 100)) + view[0, 1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + + view[1, 1].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + + _show(qtbot, view) + + +def test_stacked_complete(qtbot): + view = StackedView(4) + t = _get_linear_x(1, 1000).ravel() + + view[0].scatter(*np.random.randn(2, 100)) + view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + + view[2].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + + _show(qtbot, view) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index f9aae2986..6dbe84a29 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -56,8 +56,9 @@ def _get_array(val, shape, default=None): out = np.zeros(shape, dtype=np.float32) # This solves `ValueError: could not broadcast input array from shape (n) # into shape (n, 1)`. - if val is not None and isinstance(val, np.ndarray) and out.ndim > val.ndim: - val = val[:, np.newaxis] + if val is not None and isinstance(val, np.ndarray): + if val.size == out.size: + val = val.reshape(out.shape) out[...] = val if val is not None else default assert out.shape == shape return out From 424619873d4d0bd4db434ca7d93531941c37bd9c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 25 Oct 2015 16:51:44 +0100 Subject: [PATCH 0464/1059] Increase coverage --- phy/plot/plot.py | 24 ------------------------ phy/plot/tests/test_plot.py | 17 ++++++++++++++++- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 228df4e17..8d1454716 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -26,24 +26,12 @@ class Accumulator(object): """Accumulate arrays for concatenation.""" def __init__(self): - self._size = defaultdict(int) self._data = defaultdict(list) - @property - def size(self): - return self._size[list(self._size.keys())[0]] - - @property - def data(self): - return {name: self[name] for name in self._data} - def __setitem__(self, name, val): - self._size[name] += len(val) self._data[name].append(val) def __getitem__(self, name): - # size = self.size - # assert all(s == size for s in self._size.values()) return np.vstack(self._data[name]).astype(np.float32) @@ -119,9 +107,6 @@ def __init__(self, interacts): # To override # ------------------------------------------------------------------------- - def get_box_ndim(self): - raise NotImplementedError() - def iter_index(self): raise NotImplementedError() @@ -213,9 +198,6 @@ def __init__(self, n_rows, n_cols): interacts = [Grid(n_rows, n_cols), pz] super(GridView, self).__init__(interacts) - def get_box_ndim(self): - return 2 - def iter_index(self): for i in range(self.n_rows): for j in range(self.n_cols): @@ -229,9 +211,6 @@ def __init__(self, box_bounds): interacts = [Boxed(box_bounds), pz] super(BoxedView, self).__init__(interacts) - def get_box_ndim(self): - return 1 - def iter_index(self): for i in range(self.n_plots): yield i @@ -244,9 +223,6 @@ def __init__(self, n_plots): interacts = [Stacked(n_plots, margin=.1), pz] super(StackedView, self).__init__(interacts) - def get_box_ndim(self): - return 1 - def iter_index(self): for i in range(self.n_plots): yield i diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index e4d696820..31266d1d0 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -89,11 +89,26 @@ def test_grid_complete(qtbot): def test_stacked_complete(qtbot): view = StackedView(4) - t = _get_linear_x(1, 1000).ravel() + t = _get_linear_x(1, 1000).ravel() view[0].scatter(*np.random.randn(2, 100)) view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + view[2].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + + _show(qtbot, view) + +def test_boxed_complete(qtbot): + n = 3 + b = np.zeros((n, 4)) + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 2. / 3., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 2. / 3., 1., n) + view = BoxedView(b) + + t = _get_linear_x(1, 1000).ravel() + view[0].scatter(*np.random.randn(2, 100)) + view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) view[2].hist(np.random.rand(5, 10), color=np.random.uniform(.4, .9, size=(5, 4))) From c1d1c6a69052429dfe70bfe868e14d61e8220668 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 09:42:14 +0100 Subject: [PATCH 0465/1059] Export some objects in plot --- phy/plot/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 10d217e71..10b1e489b 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -12,6 +12,9 @@ from vispy import config +from .plot import GridView, BoxedView, StackedView # noqa +from.visuals import _get_linear_x + #------------------------------------------------------------------------------ # Add the `glsl/ path` for shader include From 96563169231e1bb2ff9e0e8bb7e0343f3ee0686b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 09:42:25 +0100 Subject: [PATCH 0466/1059] WIP: start WaveformView --- phy/cluster/manual/tests/test_views.py | 57 +++++++++ phy/cluster/manual/views.py | 153 +++++++++++++++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 phy/cluster/manual/tests/test_views.py create mode 100644 phy/cluster/manual/views.py diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py new file mode 100644 index 000000000..4e8d1237f --- /dev/null +++ b/phy/cluster/manual/tests/test_views.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +"""Test views.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +from phy.io.mock import (artificial_waveforms, + artificial_spike_clusters, + artificial_masks, + ) +from phy.electrode.mea import staggered_positions +from ..views import WaveformView + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def _show(qtbot, view, stop=False): + view.show() + qtbot.waitForWindowShown(view.native) + if stop: # pragma: no cover + qtbot.stop() + view.close() + + +#------------------------------------------------------------------------------ +# Test views +#------------------------------------------------------------------------------ + +def test_waveform_view(qtbot): + n_spikes = 20 + n_samples = 30 + n_channels = 10 + n_clusters = 3 + + waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) + masks = artificial_masks(n_spikes, n_channels) + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + channel_positions = staggered_positions(n_channels) + + v = WaveformView(waveforms=waveforms, + masks=masks, + spike_clusters=spike_clusters, + channel_positions=channel_positions, + ) + + spike_ids = np.arange(10) + cluster_ids = np.unique(spike_clusters[spike_ids]) + + v.on_select(cluster_ids, spike_ids) + + _show(qtbot, v, True) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py new file mode 100644 index 000000000..21eb28bc0 --- /dev/null +++ b/phy/cluster/manual/views.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +"""Manual clustering views.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +import logging + +import numpy as np + +from phy.io.array import _index_of +from phy.electrode.mea import linear_positions +from phy.plot import BoxedView, _get_linear_x +from phy.plot.visuals import _get_data_bounds +from phy.plot.transform import Range + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Utils +# ----------------------------------------------------------------------------- + + +# Default color map for the selected clusters. +_COLORMAP = np.array([[8, 146, 252], + [255, 2, 2], + [240, 253, 2], + [228, 31, 228], + [2, 217, 2], + [255, 147, 2], + [212, 150, 70], + [205, 131, 201], + [201, 172, 36], + [150, 179, 62], + [95, 188, 122], + [129, 173, 190], + [231, 107, 119], + ]) + + +def _selected_clusters_colors(n_clusters=None): + if n_clusters is None: + n_clusters = _COLORMAP.shape[0] + if n_clusters > _COLORMAP.shape[0]: + colors = np.tile(_COLORMAP, (1 + n_clusters // _COLORMAP.shape[0], 1)) + else: + colors = _COLORMAP + return colors[:n_clusters, ...] / 255. + + +# ----------------------------------------------------------------------------- +# Views +# ----------------------------------------------------------------------------- + +class WaveformView(BoxedView): + def __init__(self, + waveforms=None, + masks=None, + spike_clusters=None, + channel_positions=None, + ): + """ + + The channel order in waveforms needs to correspond to the one + in channel_positions. + + """ + # Initialize the view. + + if channel_positions is None: + channel_positions = linear_positions(self.n_channels) + bounds = _get_data_bounds(None, channel_positions) + channel_positions = Range(from_bounds=bounds).apply(channel_positions) + channel_positions *= .75 + + self.box_size = (.1, .1) + bs = np.array([self.box_size]) + box_bounds = np.c_[channel_positions - bs / 2., + channel_positions + bs / 2., + ] + super(WaveformView, self).__init__(box_bounds) + + # Waveforms. + assert waveforms.ndim == 3 + self.n_spikes, self.n_samples, self.n_channels = waveforms.shape + self.waveforms = waveforms + + # Masks. + if masks is None: + masks = np.ones((self.n_spikes, self.n_channels), dtype=np.float32) + assert masks.ndim == 2 + assert masks.shape == (self.n_spikes, self.n_channels) + self.masks = masks + + # Spike clusters. + assert spike_clusters.shape == (self.n_spikes,) + self.spike_clusters = spike_clusters + + # Channel positions. + assert channel_positions.shape == (self.n_channels, 2) + self.channel_positions = channel_positions + + def on_select(self, cluster_ids, spike_ids): + n_clusters = len(cluster_ids) + n_spikes = len(spike_ids) + if n_spikes == 0: + return + + # Relative spike clusters. + # NOTE: the order of the clusters in cluster_ids matters. + # It will influence the relative index of the clusters, which + # in return influence the depth. + spike_clusters = self.spike_clusters[spike_ids] + assert np.all(np.in1d(spike_clusters, cluster_ids)) + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + + # Fetch the waveforms. + w = self.waveforms[spike_ids] + colors = _selected_clusters_colors(n_clusters) + t = _get_linear_x(n_spikes, self.n_samples) + + # Get the colors. + color = colors[spike_clusters_rel[spike_ids]] + # Alpha channel. + color = np.c_[color, np.ones((n_spikes, 1))] + # TODO: depth + + # Plot all waveforms. + for ch in range(self.n_channels): + self[ch].plot(x=t, y=w[:, :, ch], color=color) + + self.build() + self.update() + + def on_cluster(self, up): + pass + + def on_mouse_move(self, e): + pass + + def on_key_press(self, e): + pass + + def attach_to_gui(self, gui): + gui.add_view(self) + + # TODO: make sure the GUI emits these events + gui.connect(self.on_select) + gui.connect(self.on_cluster) From 13e8b07e22c6e2edf4a87225f41d9eeb6059d41d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 10:07:54 +0100 Subject: [PATCH 0467/1059] WIP --- phy/cluster/manual/views.py | 2 +- phy/plot/interact.py | 5 +++++ phy/plot/plot.py | 5 +++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 21eb28bc0..fa4004601 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -69,8 +69,8 @@ def __init__(self, in channel_positions. """ - # Initialize the view. + # Initialize the view. if channel_positions is None: channel_positions = linear_positions(self.n_channels) bounds = _get_data_bounds(None, channel_positions) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 73d351504..b27b3dfa7 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -116,6 +116,11 @@ class Boxed(BaseInteract): box_bounds : array-like A (n, 4) array where each row contains the `(xmin, ymin, xmax, ymax)` bounds of every box, in normalized device coordinates. + + NOTE: the box bounds need to be contained within [-1, 1] at all times, + otherwise an error will be raised. This is to prevent silent clipping + of the values when they are passed to a gloo Texture2D. + box_var : str Name of the GLSL variable with the box index. diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 8d1454716..7af2a67c9 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -207,8 +207,9 @@ def iter_index(self): class BoxedView(BaseView): def __init__(self, box_bounds): self.n_plots = len(box_bounds) - pz = PanZoom(aspect=None, constrain_bounds=NDC) - interacts = [Boxed(box_bounds), pz] + self._boxed = Boxed(box_bounds) + self._pz = PanZoom(aspect=None, constrain_bounds=NDC) + interacts = [self._boxed, self._pz] super(BoxedView, self).__init__(interacts) def iter_index(self): From 5255199effe94844dfc67554e5f5c08966aa0329 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 13:27:28 +0100 Subject: [PATCH 0468/1059] WIP: box placement --- phy/plot/tests/test_utils.py | 34 ++++++++++++++++ phy/plot/utils.py | 76 +++++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index f9851ae6a..632cb1e7d 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -14,9 +14,12 @@ from numpy.testing import assert_allclose as ac from vispy import config +from phy.electrode.mea import linear_positions from ..utils import (_load_shader, _tesselate_histogram, _enable_depth_mask, + _boxes_overlap, + _get_boxes, ) @@ -50,3 +53,34 @@ def on_draw(e): canvas.show() qtbot.waitForWindowShown(canvas.native) + + +def test_boxes_overlap(): + + def _get_args(boxes): + x0, y0, x1, y1 = np.array(boxes).T + x0 = x0[:, np.newaxis] + x1 = x1[:, np.newaxis] + y0 = y0[:, np.newaxis] + y1 = y1[:, np.newaxis] + return x0, y0, x1, y1 + + boxes = [[-1, -1, 0, 0], [0.01, 0.01, 1, 1]] + x0, y0, x1, y1 = _get_args(boxes) + assert not _boxes_overlap(x0, y0, x1, y1) + + boxes = [[-1, -1, 0.1, 0.1], [0, 0, 1, 1]] + x0, y0, x1, y1 = _get_args(boxes) + assert _boxes_overlap(x0, y0, x1, y1) + + +def test_get_boxes(): + positions = [[-1, -1], [1., 1.]] + x0, y0, x1, y1 = _get_boxes(positions) + assert np.all(x1 - x0 >= .4) + assert np.all(y1 - y0 >= .4) + assert not _boxes_overlap(x0, y0, x1, y1) + + positions = linear_positions(4) + x0, y0, x1, y1 = _get_boxes(positions) + assert not _boxes_overlap(x0, y0, x1, y1) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 1417501eb..ab837465f 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -11,9 +11,10 @@ import os.path as op import numpy as np - from vispy import gloo +from .transform import Range + logger = logging.getLogger(__name__) @@ -95,3 +96,76 @@ def _get_texture(arr, default, n_items, from_bounds): assert np.all(arr <= 255) arr = arr.astype(np.uint8) return arr + + +def _boxes_overlap(x0, y0, x1, y1): + n = len(x0) + overlap_matrix = ((x0 < x1.T) & (x1 > x0.T) & (y0 < y1.T) & (y1 > y0.T)) + overlap_matrix[np.arange(n), np.arange(n)] = False + return np.any(overlap_matrix.ravel()) + + +def _rescale_positions(pos, size): + """Rescale positions so that the boxes fit in NDC.""" + a, b = size + + # Get x, y. + pos = np.asarray(pos, dtype=np.float32) + x, y = pos.T + x = x[:, np.newaxis] + y = y[:, np.newaxis] + + xmin, xmax = x.min(), x.max() + ymin, ymax = y.min(), y.max() + + # Renormalize into [-1, 1]. + pos = Range(from_bounds=(xmin, ymin, xmax, ymax), + to_bounds=(-1, -1, 1, 1)).apply(pos) + + # Rescale the positions so that everything fits in the box. + alpha = 1. + if xmin != 0: + alpha = min(alpha, (-1 + a) / xmin) + if xmax != 0: + alpha = min(alpha, (+1 - a) / xmax) + + beta = 1. + if ymin != 0: + beta = min(beta, (-1 + b) / ymin) + if ymax != 0: + beta = min(beta, (+1 - b) / ymax) + + # Get xy01. + x0, y0 = alpha * x - a, beta * y - b + x1, y1 = alpha * x + a, beta * y + b + + return x0, y0, x1, y1 + + +def _get_boxes(pos): + """Generate non-overlapping boxes in NDC from a set of positions.""" + + # Find a box_size such that the boxes are non-overlapping. + def f(size): + a, b = size + x0, y0, x1, y1 = _rescale_positions(pos, size) + + if _boxes_overlap(x0, y0, x1, y1): + return 0. + + return -(2 * a + b) + + cons = [{'type': 'ineq', 'fun': lambda s: s[0]}, + {'type': 'ineq', 'fun': lambda s: s[1]}, + {'type': 'ineq', 'fun': lambda s: 1 - s[0]}, + {'type': 'ineq', 'fun': lambda s: 1 - s[1]}, + ] + + from scipy.optimize import minimize + res = minimize(f, (.05, .01), + constraints=cons, + ) + w, h = res.x + assert f((w, h)) < 0 + + return _rescale_positions(pos, (w, h)) From 7e29979304fb4f5030b670fa8dc8ff135590d281 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 18:47:41 +0100 Subject: [PATCH 0469/1059] Fix --- phy/cluster/manual/tests/test_views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 4e8d1237f..0d7744ea0 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -54,4 +54,4 @@ def test_waveform_view(qtbot): v.on_select(cluster_ids, spike_ids) - _show(qtbot, v, True) + _show(qtbot, v) From 28d7e9feebb4c562b3bfcbf96f811c13049996d6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 21:37:14 +0100 Subject: [PATCH 0470/1059] Improve _get_boxes() --- phy/plot/tests/test_utils.py | 34 ++++++++--- phy/plot/utils.py | 111 +++++++++++++++++------------------ 2 files changed, 81 insertions(+), 64 deletions(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 632cb1e7d..045b51640 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -11,14 +11,16 @@ import os.path as op import numpy as np +from numpy.testing import assert_array_equal as ae from numpy.testing import assert_allclose as ac from vispy import config -from phy.electrode.mea import linear_positions +from phy.electrode.mea import linear_positions, staggered_positions from ..utils import (_load_shader, _tesselate_histogram, _enable_depth_mask, _boxes_overlap, + _binary_search, _get_boxes, ) @@ -74,13 +76,29 @@ def _get_args(boxes): assert _boxes_overlap(x0, y0, x1, y1) +def test_binary_search(): + def f(x): + return x < .4 + ac(_binary_search(f, 0, 1), .4) + ac(_binary_search(f, 0, .3), .3) + ac(_binary_search(f, .5, 1), .5) + + def test_get_boxes(): - positions = [[-1, -1], [1., 1.]] - x0, y0, x1, y1 = _get_boxes(positions) - assert np.all(x1 - x0 >= .4) - assert np.all(y1 - y0 >= .4) - assert not _boxes_overlap(x0, y0, x1, y1) + positions = [[-1, 0], [1, 0]] + boxes = _get_boxes(positions) + ac(boxes, [[-1, -.25, 0, .25], + [0, -.25, 1, .25]], atol=1e-4) positions = linear_positions(4) - x0, y0, x1, y1 = _get_boxes(positions) - assert not _boxes_overlap(x0, y0, x1, y1) + boxes = _get_boxes(positions) + ac(boxes, [[-0.5, -1.0, +0.5, -0.5], + [-0.5, -0.5, +0.5, +0.0], + [-0.5, +0.0, +0.5, +0.5], + [-0.5, +0.5, +0.5, +1.0], + ], atol=1e-4) + + positions = staggered_positions(8) + boxes = _get_boxes(positions) + ae(boxes[:, 1], np.arange(.75, -1.1, -.25)) + ae(boxes[:, 3], np.arange(1, -.76, -.25)) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index ab837465f..59fa9b50c 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -105,9 +105,25 @@ def _boxes_overlap(x0, y0, x1, y1): return np.any(overlap_matrix.ravel()) -def _rescale_positions(pos, size): - """Rescale positions so that the boxes fit in NDC.""" - a, b = size +def _binary_search(f, xmin, xmax, eps=1e-9): + """Return the largest x such f(x) is True.""" + middle = (xmax + xmin) / 2. + while xmax - xmin > eps: + assert xmin < xmax + middle = (xmax + xmin) / 2. + if f(xmax): + return xmax + if not f(xmin): + return xmin + if f(middle): + xmin = middle + else: + xmax = middle + return middle + + +def _get_boxes(pos, margin=0): + """Generate non-overlapping boxes in NDC from a set of positions.""" # Get x, y. pos = np.asarray(pos, dtype=np.float32) @@ -115,57 +131,40 @@ def _rescale_positions(pos, size): x = x[:, np.newaxis] y = y[:, np.newaxis] + # Deal with degenerate x case. xmin, xmax = x.min(), x.max() - ymin, ymax = y.min(), y.max() - - # Renormalize into [-1, 1]. - pos = Range(from_bounds=(xmin, ymin, xmax, ymax), - to_bounds=(-1, -1, 1, 1)).apply(pos) - - # Rescale the positions so that everything fits in the box. - alpha = 1. - if xmin != 0: - alpha = min(alpha, (-1 + a) / xmin) - if xmax != 0: - alpha = min(alpha, (+1 - a) / xmax) - - beta = 1. - if ymin != 0: - beta = min(beta, (-1 + b) / ymin) - if ymax != 0: - beta = min(beta, (+1 - b) / ymax) - - # Get xy01. - x0, y0 = alpha * x - a, beta * y - b - x1, y1 = alpha * x + a, beta * y + b - - return x0, y0, x1, y1 - - -def _get_boxes(pos): - """Generate non-overlapping boxes in NDC from a set of positions.""" - - # Find a box_size such that the boxes are non-overlapping. - def f(size): - a, b = size - x0, y0, x1, y1 = _rescale_positions(pos, size) - - if _boxes_overlap(x0, y0, x1, y1): - return 0. - - return -(2 * a + b) - - cons = [{'type': 'ineq', 'fun': lambda s: s[0]}, - {'type': 'ineq', 'fun': lambda s: s[1]}, - {'type': 'ineq', 'fun': lambda s: 1 - s[0]}, - {'type': 'ineq', 'fun': lambda s: 1 - s[1]}, - ] - - from scipy.optimize import minimize - res = minimize(f, (.05, .01), - constraints=cons, - ) - w, h = res.x - assert f((w, h)) < 0 - - return _rescale_positions(pos, (w, h)) + if xmin == xmax: + wmax = 1. + else: + wmax = xmax - xmin + + ar = .5 + + def f1(w): + """Return true if the configuration with the current box size + is non-overlapping.""" + h = w * ar # fixed aspect ratio + return not _boxes_overlap(x - w, y - h, x + w, y + h) + + # Find the largest box size leading to non-overlapping boxes. + w = _binary_search(f1, 0, wmax) + w = w * (1 - margin) # margin + h = w * ar # aspect ratio + + x0, y0 = x - w, y - h + x1, y1 = x + w, y + h + + # Renormalize the whole thing by keeping the aspect ratio. + x0min, y0min, x1max, y1max = x0.min(), y0.min(), x1.max(), y1.max() + dx = x1max - x0min + dy = y1max - y0min + if dx > dy: + b = (x0min, (y1max + y0min) / 2. - dx / 2., + x1max, (y1max + y0min) / 2. + dx / 2.) + else: + b = ((x1max + x0min) / 2. - dy / 2., y0min, + (x1max + x0min) / 2. + dy / 2., y1max) + + r = Range(from_bounds=b, + to_bounds=(-1, -1, 1, 1)) + return np.c_[r.apply(np.c_[x0, y0]), r.apply(np.c_[x1, y1])] From 0dc22763267422ee63f72e78d98bd8c75c421fae Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 21:47:50 +0100 Subject: [PATCH 0471/1059] Default box positions in waveform view --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 13 ++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 0d7744ea0..b4060c9cf 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -35,7 +35,7 @@ def _show(qtbot, view, stop=False): def test_waveform_view(qtbot): n_spikes = 20 n_samples = 30 - n_channels = 10 + n_channels = 40 n_clusters = 3 waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index fa4004601..732e8b895 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -14,8 +14,7 @@ from phy.io.array import _index_of from phy.electrode.mea import linear_positions from phy.plot import BoxedView, _get_linear_x -from phy.plot.visuals import _get_data_bounds -from phy.plot.transform import Range +from phy.plot.utils import _get_boxes logger = logging.getLogger(__name__) @@ -73,15 +72,7 @@ def __init__(self, # Initialize the view. if channel_positions is None: channel_positions = linear_positions(self.n_channels) - bounds = _get_data_bounds(None, channel_positions) - channel_positions = Range(from_bounds=bounds).apply(channel_positions) - channel_positions *= .75 - - self.box_size = (.1, .1) - bs = np.array([self.box_size]) - box_bounds = np.c_[channel_positions - bs / 2., - channel_positions + bs / 2., - ] + box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(box_bounds) # Waveforms. From 33aaa342bf9d3264462ade36019ba9414e248678 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 26 Oct 2015 22:00:08 +0100 Subject: [PATCH 0472/1059] WIP --- phy/cluster/manual/views.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 732e8b895..baf56a351 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -80,6 +80,7 @@ def __init__(self, self.n_spikes, self.n_samples, self.n_channels = waveforms.shape self.waveforms = waveforms + # TODO: refactor with _get_array # Masks. if masks is None: masks = np.ones((self.n_spikes, self.n_channels), dtype=np.float32) @@ -124,6 +125,8 @@ def on_select(self, cluster_ids, spike_ids): for ch in range(self.n_channels): self[ch].plot(x=t, y=w[:, :, ch], color=color) + # TODO: build only once, then just set data (don't recreate visuals) + # TODO: more interactions in boxed interact self.build() self.update() @@ -142,3 +145,32 @@ def attach_to_gui(self, gui): # TODO: make sure the GUI emits these events gui.connect(self.on_select) gui.connect(self.on_cluster) + + +class TraceView(BoxedView): + def __init__(self, + traces=None, + spike_times=None, + spike_clusters=None,): + pass + + +class FeatureView(BoxedView): + def __init__(self, + features=None, + dimensions=None, + extra_features=None, + ): + pass + + +class CorrelogramView(BoxedView): + def __init__(self, + spike_samples=None, + spike_times=None, + bin_size=None, + window_size=None, + excerpt_size=None, + n_excerpts=None, + ): + pass From d95535c06ea84f17a3c3b00de7afeffe256f25a8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 09:27:16 +0100 Subject: [PATCH 0473/1059] WIP: refactor --- phy/cluster/manual/views.py | 9 ++------- phy/plot/utils.py | 14 ++++++++++++++ phy/plot/visuals.py | 18 +++--------------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index baf56a351..30832883a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -14,7 +14,7 @@ from phy.io.array import _index_of from phy.electrode.mea import linear_positions from phy.plot import BoxedView, _get_linear_x -from phy.plot.utils import _get_boxes +from phy.plot.utils import _get_boxes, _get_array logger = logging.getLogger(__name__) @@ -80,13 +80,8 @@ def __init__(self, self.n_spikes, self.n_samples, self.n_channels = waveforms.shape self.waveforms = waveforms - # TODO: refactor with _get_array # Masks. - if masks is None: - masks = np.ones((self.n_spikes, self.n_channels), dtype=np.float32) - assert masks.ndim == 2 - assert masks.shape == (self.n_spikes, self.n_channels) - self.masks = masks + self.masks = _get_array(masks, (self.n_spikes, self.n_channels), 1) # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 59fa9b50c..97850aca4 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -98,6 +98,20 @@ def _get_texture(arr, default, n_items, from_bounds): return arr +def _get_array(val, shape, default=None): + """Ensure an object is an array with the specified shape.""" + assert val is not None or default is not None + out = np.zeros(shape, dtype=np.float32) + # This solves `ValueError: could not broadcast input array from shape (n) + # into shape (n, 1)`. + if val is not None and isinstance(val, np.ndarray): + if val.size == out.size: + val = val.reshape(out.shape) + out[...] = val if val is not None else default + assert out.shape == shape + return out + + def _boxes_overlap(x0, y0, x1, y1): n = len(x0) overlap_matrix = ((x0 < x1.T) & (x1 > x0.T) & (y0 < y1.T) & (y1 > y0.T)) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 6dbe84a29..8cf0a436b 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -12,7 +12,9 @@ from .base import BaseVisual from .transform import Range, GPU, NDC -from .utils import _enable_depth_mask, _tesselate_histogram, _get_texture +from .utils import (_enable_depth_mask, _tesselate_histogram, + _get_texture, _get_array, + ) #------------------------------------------------------------------------------ @@ -50,20 +52,6 @@ def _get_data_bounds_1D(data_bounds, data): return data_bounds -def _get_array(val, shape, default=None): - """Ensure an object is an array with the specified shape.""" - assert val is not None or default is not None - out = np.zeros(shape, dtype=np.float32) - # This solves `ValueError: could not broadcast input array from shape (n) - # into shape (n, 1)`. - if val is not None and isinstance(val, np.ndarray): - if val.size == out.size: - val = val.reshape(out.shape) - out[...] = val if val is not None else default - assert out.shape == shape - return out - - def _check_pos_2D(pos): """Check position data before GPU uploading.""" assert pos is not None From 543ed5409fcaffb08a1b8fe3479d84e8fdcf16fd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 10:00:38 +0100 Subject: [PATCH 0474/1059] Avoid recompilation of visual after a new call to set_data() --- phy/plot/plot.py | 28 ++++++++++++++++++++++------ phy/plot/tests/test_plot.py | 14 +++++++++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 7af2a67c9..2618f770e 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -103,6 +103,7 @@ def __init__(self, interacts): for interact in interacts: interact.attach(self) self.subviews = {} + self._visuals = {} # To override # ------------------------------------------------------------------------- @@ -124,6 +125,21 @@ def __getitem__(self, idx): self.subviews[idx] = sv return sv + def _get_visual(self, key): + if key not in self._visuals: + # Create the visual. + if isinstance(key, tuple): + # Case of the scatter plot, where the visual depends on the + # marker. + v = key[0](key[1]) + else: + v = key() + # Attach the visual to the view. + v.attach(self) + # Store the visual for reuse. + self._visuals[key] = v + return self._visuals[key] + def _build_scatter(self, subviews, marker): """Build all scatter subviews with the same marker type.""" @@ -136,8 +152,7 @@ def _build_scatter(self, subviews, marker): ac['size'] = sv.size ac['box_index'] = np.tile(sv.idx, (n, 1)) - v = ScatterVisual(marker=marker) - v.attach(self) + v = self._get_visual((ScatterVisual, marker)) v.set_data(pos=ac['pos'], color=ac['color'], size=ac['size']) v.program['a_box_index'] = ac['box_index'] @@ -152,8 +167,7 @@ def _build_plot(self, subviews): ac['plot_colors'] = sv.color ac['box_index'] = np.tile(sv.idx, (n, 1)) - v = PlotVisual() - v.attach(self) + v = self._get_visual(PlotVisual) v.set_data(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors']) v.program['a_box_index'] = ac['box_index'] @@ -168,13 +182,15 @@ def _build_histogram(self, subviews): # NOTE: the `6 * ` comes from the histogram tesselation. ac['box_index'] = np.tile(sv.idx, (6 * n, 1)) - v = HistogramVisual() - v.attach(self) + v = self._get_visual(HistogramVisual) v.set_data(hist=ac['data'], hist_colors=ac['hist_colors']) v.program['a_box_index'] = ac['box_index'] def build(self): """Build all visuals.""" + # TODO: fix a bug where an old subplot is not deleted if it + # is changed to another type, and there is no longer any subplot + # of the old type. The old visual should be delete or hidden. for visual_class, subviews in groupby(self.iter_subviews(), lambda sv: sv.visual_class): if visual_class == ScatterVisual: diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 31266d1d0..4298df410 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -112,4 +112,16 @@ def test_boxed_complete(qtbot): view[2].hist(np.random.rand(5, 10), color=np.random.uniform(.4, .9, size=(5, 4))) - _show(qtbot, view) + # Build and show. + view.build() + view.show() + + # Change a subplot. + view[2].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + + # Rebuild and show. + view.build() + qtbot.waitForWindowShown(view.native) + + view.close() From 0bd9a57ffd07026d626869dcbe0d01bbd37591d4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 10:08:21 +0100 Subject: [PATCH 0475/1059] Fix bug with depth in PlotVisual --- phy/plot/tests/test_visuals.py | 6 +++++- phy/plot/visuals.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 2e2456618..a0b276900 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -101,8 +101,12 @@ def test_plot_2(qtbot, canvas_pz): c = np.random.uniform(.5, 1, size=(n_signals, 4)) c[:, 3] = .5 + # Depth. + depth = np.linspace(0., -1., n_signals) + _test_visual(qtbot, canvas_pz, PlotVisual(), - y=y, data_bounds=[-50, 50], + y=y, depth=depth, + data_bounds=[-50, 50], plot_colors=c) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 8cf0a436b..8f06b41ab 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -223,6 +223,10 @@ def set_data(self, # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) + + # Depth. + depth = _get_array(depth, (n_signals,), 0) + depth = np.repeat(depth, n_samples) self.program['a_position'] = _get_pos_depth(pos_tr, depth) # Generate the signal index. From 687bdff98bc746878481c6a6c9badf822aa3215b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 10:15:56 +0100 Subject: [PATCH 0476/1059] Done depth in waveform view --- phy/cluster/manual/views.py | 13 ++++++++++--- phy/plot/plot.py | 6 ++++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 30832883a..d22ece973 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -114,13 +114,20 @@ def on_select(self, cluster_ids, spike_ids): color = colors[spike_clusters_rel[spike_ids]] # Alpha channel. color = np.c_[color, np.ones((n_spikes, 1))] - # TODO: depth + + # Depth as a function of the cluster index and masks. + m = self.masks[spike_ids, :] + depth = -0.1 - (spike_clusters_rel[:, np.newaxis] + m) + assert depth.shape == (n_spikes, self.n_channels) + depth = depth / float(n_clusters + 10.) + depth[m <= 0.25] = 0 # Plot all waveforms. for ch in range(self.n_channels): - self[ch].plot(x=t, y=w[:, :, ch], color=color) + self[ch].plot(x=t, y=w[:, :, ch], + color=color, + depth=depth[:, ch]) - # TODO: build only once, then just set data (don't recreate visuals) # TODO: more interactions in boxed interact self.build() self.update() diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 2618f770e..4d706ffde 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -68,7 +68,7 @@ def scatter(self, x, y, color=None, size=None, marker=None): spec = dict(x=x, y=y, color=color, size=size, marker=marker) return self._set(ScatterVisual, spec) - def plot(self, x, y, color=None): + def plot(self, x, y, color=None, depth=None): x = np.atleast_2d(x) y = np.atleast_2d(y) # Validate x and y. @@ -77,8 +77,10 @@ def plot(self, x, y, color=None): n_plots, n_samples = x.shape # Get the colors. color = _get_array(color, (n_plots, 4), PlotVisual._default_color) + # Get the depth. + depth = _get_array(depth, (n_plots,), 0) # Set the spec. - spec = dict(x=x, y=y, color=color) + spec = dict(x=x, y=y, color=color, depth=depth) return self._set(PlotVisual, spec) def hist(self, data, color=None): From 488b98e58cdbcf9335fc2a8daf2cdc79f5d2d118 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 10:52:12 +0100 Subject: [PATCH 0477/1059] Add test for _get_box_pos_size() --- phy/plot/tests/test_utils.py | 11 ++++++++++- phy/plot/utils.py | 38 ++++++++++++++++++++++++++---------- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 045b51640..fb0498df5 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -22,6 +22,7 @@ _boxes_overlap, _binary_search, _get_boxes, + _get_box_pos_size, ) @@ -88,7 +89,7 @@ def test_get_boxes(): positions = [[-1, 0], [1, 0]] boxes = _get_boxes(positions) ac(boxes, [[-1, -.25, 0, .25], - [0, -.25, 1, .25]], atol=1e-4) + [+0, -.25, 1, .25]], atol=1e-4) positions = linear_positions(4) boxes = _get_boxes(positions) @@ -102,3 +103,11 @@ def test_get_boxes(): boxes = _get_boxes(positions) ae(boxes[:, 1], np.arange(.75, -1.1, -.25)) ae(boxes[:, 3], np.arange(1, -.76, -.25)) + + +def test_get_box_pos_size(): + bounds = [[-1, -.25, 0, .25], + [+0, -.25, 1, .25]] + pos, size = _get_box_pos_size(bounds) + ae(pos, [[-.5, 0], [.5, 0]]) + assert size == (.5, .25) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 97850aca4..9cda07276 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -136,14 +136,7 @@ def _binary_search(f, xmin, xmax, eps=1e-9): return middle -def _get_boxes(pos, margin=0): - """Generate non-overlapping boxes in NDC from a set of positions.""" - - # Get x, y. - pos = np.asarray(pos, dtype=np.float32) - x, y = pos.T - x = x[:, np.newaxis] - y = y[:, np.newaxis] +def _get_box_size(x, y, ar=.5, margin=0): # Deal with degenerate x case. xmin, xmax = x.min(), x.max() @@ -152,8 +145,6 @@ def _get_boxes(pos, margin=0): else: wmax = xmax - xmin - ar = .5 - def f1(w): """Return true if the configuration with the current box size is non-overlapping.""" @@ -165,6 +156,20 @@ def f1(w): w = w * (1 - margin) # margin h = w * ar # aspect ratio + return w, h + + +def _get_boxes(pos, size=None, margin=0): + """Generate non-overlapping boxes in NDC from a set of positions.""" + + # Get x, y. + pos = np.asarray(pos, dtype=np.float32) + x, y = pos.T + x = x[:, np.newaxis] + y = y[:, np.newaxis] + + w, h = size if size is not None else _get_box_size(x, y) + x0, y0 = x - w, y - h x1, y1 = x + w, y + h @@ -182,3 +187,16 @@ def f1(w): r = Range(from_bounds=b, to_bounds=(-1, -1, 1, 1)) return np.c_[r.apply(np.c_[x0, y0]), r.apply(np.c_[x1, y1])] + + +def _get_box_pos_size(box_bounds): + box_bounds = np.asarray(box_bounds) + x0, y0, x1, y1 = box_bounds.T + w = (x1 - x0) * .5 + h = (y1 - y0) * .5 + # All boxes must have the same size. + if not np.all(w == w[0]) or not np.all(h == h[0]): + raise ValueError("All boxes don't have the same size.") + x = (x0 + x1) * .5 + y = (y0 + y1) * .5 + return np.c_[x, y], (w[0], h[0]) From ea1eebc6b58cfaabc74d63c7e6c2ee606a7add1a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 11:01:47 +0100 Subject: [PATCH 0478/1059] Box bounds, pos, size can be changed in Boxed interact --- phy/plot/interact.py | 57 +++++++++++++++++++++++++++++---- phy/plot/tests/test_interact.py | 20 ++++++++++++ phy/plot/tests/test_utils.py | 6 ++++ phy/plot/utils.py | 2 +- 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index b27b3dfa7..883b5cf92 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -14,7 +14,7 @@ from .base import BaseInteract from .transform import Scale, Range, Subplot, Clip, NDC -from .utils import _get_texture +from .utils import _get_texture, _get_boxes, _get_box_pos_size #------------------------------------------------------------------------------ @@ -125,15 +125,27 @@ class Boxed(BaseInteract): Name of the GLSL variable with the box index. """ - def __init__(self, box_bounds, box_var=None): + def __init__(self, + box_bounds=None, + box_pos=None, + box_size=None, + box_var=None): super(Boxed, self).__init__() # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' - self.box_bounds = np.atleast_2d(box_bounds) - assert self.box_bounds.shape[1] == 4 - self.n_boxes = len(self.box_bounds) + # Find the box bounds if only the box positions are passed. + if box_bounds is None: + assert box_pos is not None + # This will find a good box size automatically if it is not + # specified. + box_bounds = _get_boxes(box_pos, size=box_size) + + self._box_bounds = np.atleast_2d(box_bounds) + assert self._box_bounds.shape[1] == 4 + + self.n_boxes = len(self._box_bounds) def get_shader_declarations(self): return ('#include "utils.glsl"\n\n' @@ -157,10 +169,43 @@ def get_transforms(self): def update_program(self, program): # Signal bounds (positions). - box_bounds = _get_texture(self.box_bounds, NDC, self.n_boxes, [-1, 1]) + box_bounds = _get_texture(self._box_bounds, NDC, self.n_boxes, [-1, 1]) program['u_box_bounds'] = Texture2D(box_bounds) program['n_boxes'] = self.n_boxes + # Change the box bounds, positions, or size + #-------------------------------------------------------------------------- + + @property + def box_bounds(self): + return self._box_bounds + + @box_bounds.setter + def box_bounds(self, val): + assert val.shape == (self.n_boxes, 4) + self._box_bounds = val + self.update() + + @property + def box_pos(self): + box_pos, _ = _get_box_pos_size(self._box_bounds) + return box_pos + + @box_pos.setter + def box_pos(self, val): + assert val.shape == (self.n_boxes, 2) + self.box_bounds = _get_boxes(val, size=self.box_size) + + @property + def box_size(self): + _, box_size = _get_box_pos_size(self._box_bounds) + return box_size + + @box_size.setter + def box_size(self, val): + assert len(val) == 2 + self.box_bounds = _get_boxes(self.box_pos, size=val) + class Stacked(Boxed): """Stacked interact. diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index f4f680bae..081880802 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -10,6 +10,7 @@ from itertools import product import numpy as np +from numpy.testing import assert_equal as ae from vispy.util import keys from ..base import BaseVisual @@ -121,6 +122,25 @@ def test_boxed_1(qtbot, canvas): boxed = Boxed(box_bounds=b) _create_visual(qtbot, canvas, boxed, box_index) + ae(boxed.box_bounds, b) + boxed.box_bounds = b + + # qtbot.stop() + + +def test_boxed_2(qtbot, canvas): + """Test setting the box position and size dynamically.""" + + n = 1000 + pos = np.c_[np.zeros(6), np.linspace(-1., 1., 6)] + box_index = np.repeat(np.arange(6), n, axis=0) + + boxed = Boxed(box_pos=pos) + _create_visual(qtbot, canvas, boxed, box_index) + + boxed.box_pos *= .25 + boxed.box_size = [1, .1] + # qtbot.stop() diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index fb0498df5..0a38e6872 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -13,6 +13,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae from numpy.testing import assert_allclose as ac +from pytest import raises from vispy import config from phy.electrode.mea import linear_positions, staggered_positions @@ -111,3 +112,8 @@ def test_get_box_pos_size(): pos, size = _get_box_pos_size(bounds) ae(pos, [[-.5, 0], [.5, 0]]) assert size == (.5, .25) + + with raises(ValueError): + bounds = [[-1, -.25, 0, .25], + [+0, -.25, 1, .5]] + _get_box_pos_size(bounds) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 9cda07276..8b8a81840 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -195,7 +195,7 @@ def _get_box_pos_size(box_bounds): w = (x1 - x0) * .5 h = (y1 - y0) * .5 # All boxes must have the same size. - if not np.all(w == w[0]) or not np.all(h == h[0]): + if not np.allclose(w, w[0]) or not np.allclose(h, h[0]): raise ValueError("All boxes don't have the same size.") x = (x0 + x1) * .5 y = (y0 + y1) * .5 From c218060bfbead42a14d8e09288e3af97c38f6bc2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 11:24:58 +0100 Subject: [PATCH 0479/1059] Key interactions in Boxed interact --- phy/plot/interact.py | 47 +++++++++++++++++++++++++++++++++ phy/plot/tests/test_interact.py | 33 +++++++++++++++++++++++ phy/plot/tests/test_utils.py | 6 ----- phy/plot/utils.py | 5 +--- 4 files changed, 81 insertions(+), 10 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 883b5cf92..e23bc1e67 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -131,6 +131,7 @@ def __init__(self, box_size=None, box_var=None): super(Boxed, self).__init__() + self._key_pressed = None # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' @@ -206,6 +207,52 @@ def box_size(self, val): assert len(val) == 2 self.box_bounds = _get_boxes(self.box_pos, size=val) + # Interaction event callbacks + #-------------------------------------------------------------------------- + + _arrows = ('Left', 'Right', 'Up', 'Down') + _pm = ('+', '-') + + def on_key_press(self, event): + """Handle key press events.""" + key = event.key + + self._key_pressed = key + + ctrl = 'Control' in event.modifiers + shift = 'Shift' in event.modifiers + + # Box scale. + if ctrl and key in self._arrows + self._pm: + coeff = 1.1 + box_size = np.array(self.box_size) + if key == 'Left': + box_size[0] /= coeff + elif key == 'Right': + box_size[0] *= coeff + elif key in ('Down', '-'): + box_size[1] /= coeff + elif key in ('Up', '+'): + box_size[1] *= coeff + self.box_size = box_size + + # Probe scale. + if shift and key in self._arrows: + coeff = 1.1 + box_pos = self.box_pos + if key == 'Left': + box_pos[:, 0] /= coeff + elif key == 'Right': + box_pos[:, 0] *= coeff + elif key == 'Down': + box_pos[:, 1] /= coeff + elif key == 'Up': + box_pos[:, 1] *= coeff + self.box_pos = box_pos + + def on_key_release(self, event): + self._key_pressed = None # pragma: no cover + class Stacked(Boxed): """Stacked interact. diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 081880802..e92f984cc 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -11,6 +11,7 @@ import numpy as np from numpy.testing import assert_equal as ae +from numpy.testing import assert_allclose as ac from vispy.util import keys from ..base import BaseVisual @@ -125,6 +126,38 @@ def test_boxed_1(qtbot, canvas): ae(boxed.box_bounds, b) boxed.box_bounds = b + # Change box vertical size. + bs = boxed.box_size + for k in (('+', '-'), ('Up', 'Down')): + canvas.events.key_press(key=keys.Key(k[0]), modifiers=(keys.CONTROL,)) + assert boxed.box_size[1] > bs[1] + canvas.events.key_press(key=keys.Key(k[1]), modifiers=(keys.CONTROL,)) + ac(boxed.box_size[1], bs[1], atol=1e-3) + + # Change box horizontal size. + bs = boxed.box_size + canvas.events.key_press(key=keys.Key('Left'), modifiers=(keys.CONTROL,)) + assert boxed.box_size[0] < bs[0] + canvas.events.key_press(key=keys.Key('Right'), modifiers=(keys.CONTROL,)) + ac(boxed.box_size[0], bs[0], atol=1e-3) + + # Change box vertical positions. + bp = boxed.box_pos + canvas.events.key_press(key=keys.Key('Up'), modifiers=(keys.SHIFT,)) + assert np.all(np.abs(boxed.box_pos[:, 1]) > np.abs(bp[:, 1])) + canvas.events.key_press(key=keys.Key('Down'), modifiers=(keys.SHIFT,)) + ac(boxed.box_pos, bp, atol=1e-3) + + # Change box horizontal positions. + bp = boxed.box_pos + canvas.events.key_press(key=keys.Key('Left'), modifiers=(keys.SHIFT,)) + assert np.all(np.abs(boxed.box_pos[:, 0]) < np.abs(bp[:, 0])) + canvas.events.key_press(key=keys.Key('Right'), modifiers=(keys.SHIFT,)) + ac(boxed.box_pos, bp, atol=1e-3) + + # Release a key. + canvas.events.key_release(key=keys.Key('Right'), modifiers=(keys.SHIFT,)) + # qtbot.stop() diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 0a38e6872..fb0498df5 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -13,7 +13,6 @@ import numpy as np from numpy.testing import assert_array_equal as ae from numpy.testing import assert_allclose as ac -from pytest import raises from vispy import config from phy.electrode.mea import linear_positions, staggered_positions @@ -112,8 +111,3 @@ def test_get_box_pos_size(): pos, size = _get_box_pos_size(bounds) ae(pos, [[-.5, 0], [.5, 0]]) assert size == (.5, .25) - - with raises(ValueError): - bounds = [[-1, -.25, 0, .25], - [+0, -.25, 1, .5]] - _get_box_pos_size(bounds) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 8b8a81840..77ddbaa48 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -194,9 +194,6 @@ def _get_box_pos_size(box_bounds): x0, y0, x1, y1 = box_bounds.T w = (x1 - x0) * .5 h = (y1 - y0) * .5 - # All boxes must have the same size. - if not np.allclose(w, w[0]) or not np.allclose(h, h[0]): - raise ValueError("All boxes don't have the same size.") x = (x0 + x1) * .5 y = (y0 + y1) * .5 - return np.c_[x, y], (w[0], h[0]) + return np.c_[x, y], (w.mean(), h.mean()) From 223c2a5163a1737dfc0d42ad34cf2fe52a3c0f57 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 11:38:26 +0100 Subject: [PATCH 0480/1059] Test multiple on_select() in waveform view --- phy/cluster/manual/tests/test_views.py | 15 +++++++++++++-- phy/cluster/manual/views.py | 3 +-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index b4060c9cf..ff172b96f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -43,15 +43,26 @@ def test_waveform_view(qtbot): spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) channel_positions = staggered_positions(n_channels) + # Create the view. v = WaveformView(waveforms=waveforms, masks=masks, spike_clusters=spike_clusters, channel_positions=channel_positions, ) - spike_ids = np.arange(10) + # Select some spikes. + spike_ids = np.arange(5) cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # Show the view. + v.show() + qtbot.waitForWindowShown(v.native) + # Select other spikes. + spike_ids = np.arange(2, 10) + cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) - _show(qtbot, v) + # qtbot.stop() + v.close() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d22ece973..4eeec5652 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -111,7 +111,7 @@ def on_select(self, cluster_ids, spike_ids): t = _get_linear_x(n_spikes, self.n_samples) # Get the colors. - color = colors[spike_clusters_rel[spike_ids]] + color = colors[spike_clusters_rel] # Alpha channel. color = np.c_[color, np.ones((n_spikes, 1))] @@ -128,7 +128,6 @@ def on_select(self, cluster_ids, spike_ids): color=color, depth=depth[:, ch]) - # TODO: more interactions in boxed interact self.build() self.update() From b0675933d471cfae4aec1b8556b770beffc0be51 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 13:41:16 +0100 Subject: [PATCH 0481/1059] WIP: trace view --- phy/cluster/manual/tests/test_views.py | 16 ++- phy/cluster/manual/views.py | 137 +++++++++++++++++++++++-- 2 files changed, 145 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index ff172b96f..df88ec744 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -11,9 +11,10 @@ from phy.io.mock import (artificial_waveforms, artificial_spike_clusters, artificial_masks, + artificial_traces, ) from phy.electrode.mea import staggered_positions -from ..views import WaveformView +from ..views import WaveformView, TraceView #------------------------------------------------------------------------------ @@ -49,7 +50,6 @@ def test_waveform_view(qtbot): spike_clusters=spike_clusters, channel_positions=channel_positions, ) - # Select some spikes. spike_ids = np.arange(5) cluster_ids = np.unique(spike_clusters[spike_ids]) @@ -66,3 +66,15 @@ def test_waveform_view(qtbot): # qtbot.stop() v.close() + + +def test_trace_view_no_spikes(qtbot): + n_samples = 5000 + n_channels = 12 + sample_rate = 10000. + + traces = artificial_traces(n_samples, n_channels) + + # Create the view. + v = TraceView(traces=traces, sample_rate=sample_rate) + _show(qtbot, v) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 4eeec5652..860eaf5a8 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -13,7 +13,7 @@ from phy.io.array import _index_of from phy.electrode.mea import linear_positions -from phy.plot import BoxedView, _get_linear_x +from phy.plot import BoxedView, StackedView, GridView, _get_linear_x from phy.plot.utils import _get_boxes, _get_array logger = logging.getLogger(__name__) @@ -148,15 +148,140 @@ def attach_to_gui(self, gui): gui.connect(self.on_cluster) -class TraceView(BoxedView): +class TraceView(StackedView): def __init__(self, traces=None, + sample_rate=None, spike_times=None, - spike_clusters=None,): - pass + spike_clusters=None, + masks=None, + n_samples_per_spike=None, + ): + + # Sample rate. + assert sample_rate > 0 + self.sample_rate = sample_rate + + # Traces. + assert traces.ndim == 2 + self.n_samples, self.n_channels = traces.shape + self.traces = traces + + # Number of samples per spike. + self.n_samples_per_spike = (n_samples_per_spike or + int(.002 * sample_rate)) + + # Spike times. + if spike_times is not None: + self.n_spikes = len(spike_times) + assert spike_times.shape == (self.n_spikes,) + self.spike_times = spike_times + + # Spike clusters. + if spike_clusters is None: + spike_clusters = np.zeros(self.n_spikes) + assert spike_clusters.shape == (self.n_spikes,) + self.spike_clusters = spike_clusters + + # Masks. + masks = _get_array(masks, (self.n_spikes, self.n_channels), 1) + assert masks.shape == (self.n_spikes, self.n_channels) + self.masks = masks + else: + self.spike_times = self.spike_clusters = self.masks = None + + # Initialize the view. + super(TraceView, self).__init__(self.n_channels) + + # TODO: choose the interval. + self.set_interval((0., .25)) + + def _load_traces(self, interval): + """Load traces in an interval (in seconds).""" + + start, end = interval + + i, j = int(self.sample_rate * start), int(self.sample_rate * end) + traces = self.traces[i:j, :] + # Detrend the traces. + m = np.mean(traces[::10, :], axis=0) + traces -= m + + # Create the plots. + return traces + + def _load_spikes(self, interval): + assert self.spike_times is not None + # Keep the spikes in the interval. + a, b = self.spike_times.searchsorted(interval) + return self.spike_times[a:b], self.spike_clusters[a:b], self.masks[a:b] + + def set_interval(self, interval): + + color = (.5, .5, .5, 1) + + # Load traces. + traces = self._load_traces(interval) + assert traces.shape[1] == self.n_channels + + # Generate the trace plots. + # TODO OPTIM: avoid the loop and generate all channel traces in + # one pass with NumPy (but need to set a_box_index manually too). + t = _get_linear_x(1, traces.shape[0]) + for ch in range(self.n_channels): + self[ch].plot(t, traces[:, ch], color=color) + + # if self.spike_times is not None: + # spike_times, spike_clusters, masks = self._load_spikes(interval) + + self.build() + self.update() -class FeatureView(BoxedView): + # # Keep the spikes in the interval. + # spikes = self.spike_ids + # spike_samples = self.model.spike_samples[spikes] + # a, b = spike_samples.searchsorted(interval) + # spikes = spikes[a:b] + # self.view.visual.n_spikes = len(spikes) + # self.view.visual.spike_ids = spikes + + + # def on_select(self, cluster_ids, spike_ids): + # n_clusters = len(cluster_ids) + # n_spikes = len(spike_ids) + # if n_spikes == 0: + # return + + # # Relative spike clusters. + # # NOTE: the order of the clusters in cluster_ids matters. + # # It will influence the relative index of the clusters, which + # # in return influence the depth. + # spike_clusters = self.spike_clusters[spike_ids] + # assert np.all(np.in1d(spike_clusters, cluster_ids)) + # spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + + # # Fetch the waveforms. + # w = self.waveforms[spike_ids] + # colors = _selected_clusters_colors(n_clusters) + # t = _get_linear_x(n_spikes, self.n_samples) + + # # Get the colors. + # color = colors[spike_clusters_rel] + # # Alpha channel. + # color = np.c_[color, np.ones((n_spikes, 1))] + + # # Plot all waveforms. + # for ch in range(self.n_channels): + # self[ch].plot(x=t, y=w[:, :, ch], + # color=color, + # depth=depth[:, ch]) + + # self.build() + # self.update() + + +class FeatureView(GridView): def __init__(self, features=None, dimensions=None, @@ -165,7 +290,7 @@ def __init__(self, pass -class CorrelogramView(BoxedView): +class CorrelogramView(GridView): def __init__(self, spike_samples=None, spike_times=None, From ca5ef791f5cc60d60ae45e841144ec44656130e3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 13:58:40 +0100 Subject: [PATCH 0482/1059] WIP: fix plot data bounds --- phy/plot/tests/test_visuals.py | 2 +- phy/plot/visuals.py | 24 ++++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index a0b276900..38bae2a40 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -106,7 +106,7 @@ def test_plot_2(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, PlotVisual(), y=y, depth=depth, - data_bounds=[-50, 50], + data_bounds=[-1, -50, 1, 50], plot_colors=c) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 8f06b41ab..ae9b1fa46 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -34,20 +34,13 @@ def _get_data_bounds(data_bounds, pos): if data_bounds is None: m, M = pos.min(axis=0), pos.max(axis=0) data_bounds = [m[0], m[1], M[0], M[1]] - _check_data_bounds(data_bounds) - return data_bounds - - -def _get_data_bounds_1D(data_bounds, data): - """Generate the complete data_bounds 4-tuple from the specified 2-tuple.""" - if data_bounds is None: - data_bounds = [data.min(), data.max()] if data.size else [-1, 1] - assert len(data_bounds) == 2 - # Ensure that the data bounds are not degenerate. - if data_bounds[0] == data_bounds[1]: - data_bounds = [data_bounds[0] - 1, data_bounds[0] + 1] - ymin, ymax = data_bounds - data_bounds = [-1, ymin, 1, ymax] + data_bounds = list(data_bounds) + if data_bounds[0] == data_bounds[2]: # pragma: no cover + data_bounds[0] -= 1 + data_bounds[2] += 1 + if data_bounds[1] == data_bounds[3]: + data_bounds[1] -= 1 + data_bounds[3] += 1 _check_data_bounds(data_bounds) return data_bounds @@ -218,8 +211,7 @@ def set_data(self, pos[:, 1] = y.ravel() pos = _check_pos_2D(pos) - # Generate the complete data_bounds 4-tuple from the specified 2-tuple. - self.data_bounds = _get_data_bounds_1D(data_bounds, y) + self.data_bounds = _get_data_bounds(data_bounds, pos) # Set the transformed position. pos_tr = self.apply_cpu_transforms(pos) From 34596c04b0329eebb142aa702c3cd49286433d9a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 15:54:44 +0100 Subject: [PATCH 0483/1059] Move _get_padded() --- phy/io/array.py | 20 ++++++++++++++++++++ phy/io/tests/test_array.py | 11 +++++++++++ phy/traces/tests/test_waveform.py | 16 ++++------------ phy/traces/waveform.py | 22 +--------------------- 4 files changed, 36 insertions(+), 33 deletions(-) diff --git a/phy/io/array.py b/phy/io/array.py index e404cfcc0..b1942e3eb 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -156,6 +156,26 @@ def _pad(arr, n, dir='right'): return out +def _get_padded(data, start, end): + """Return `data[start:end]` filling in with zeros outside array bounds + + Assumes that either `start<0` or `end>len(data)` but not both. + + """ + if start < 0 and end > data.shape[0]: + raise RuntimeError() + if start < 0: + start_zeros = np.zeros((-start, data.shape[1]), + dtype=data.dtype) + return np.vstack((start_zeros, data[:end])) + elif end > data.shape[0]: + end_zeros = np.zeros((end - data.shape[0], data.shape[1]), + dtype=data.dtype) + return np.vstack((data[start:], end_zeros)) + else: + return data[start:end] + + def _in_polygon(points, polygon): """Return the points that are inside a polygon.""" from matplotlib.path import Path diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 634c31720..8cd47d600 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -26,6 +26,7 @@ get_excerpts, _range_from_slice, _pad, + _get_padded, read_array, write_array, ) @@ -114,6 +115,16 @@ def test_pad(): _pad(arr, -1) +def test_get_padded(): + arr = np.array([1, 2, 3])[:, np.newaxis] + + with raises(RuntimeError): + ae(_get_padded(arr, -2, 5).ravel(), [1, 2, 3, 0, 0]) + ae(_get_padded(arr, 1, 2).ravel(), [2]) + ae(_get_padded(arr, 0, 5).ravel(), [1, 2, 3, 0, 0]) + ae(_get_padded(arr, -2, 3).ravel(), [0, 0, 1, 2, 3]) + + def test_unique(): """Test _unique() function""" _unique([]) diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index 2af203632..836aa80bd 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -12,8 +12,10 @@ from phy.io.mock import artificial_traces, artificial_spike_samples from phy.utils import Bunch -from ..waveform import (_slice, WaveformLoader, WaveformExtractor, - SpikeLoader, _get_padded, +from ..waveform import (_slice, + WaveformLoader, + WaveformExtractor, + SpikeLoader, ) from ..filter import bandpass_filter, apply_filter @@ -104,16 +106,6 @@ def test_extract_simple(): ae(masks_f_o, [0.5, 1., 0., 0.]) -def test_get_padded(): - arr = np.array([1, 2, 3])[:, np.newaxis] - - with raises(RuntimeError): - ae(_get_padded(arr, -2, 5).ravel(), [1, 2, 3, 0, 0]) - ae(_get_padded(arr, 1, 2).ravel(), [2]) - ae(_get_padded(arr, 0, 5).ravel(), [1, 2, 3, 0, 0]) - ae(_get_padded(arr, -2, 3).ravel(), [0, 0, 1, 2, 3]) - - #------------------------------------------------------------------------------ # Tests utility functions #------------------------------------------------------------------------------ diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 6125fb82c..1259c75e4 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -12,7 +12,7 @@ from scipy.interpolate import interp1d from ..utils._types import _as_array, Bunch -from phy.io.array import _pad +from phy.io.array import _pad, _get_padded logger = logging.getLogger(__name__) @@ -21,26 +21,6 @@ # Waveform extractor from a connected component #------------------------------------------------------------------------------ -def _get_padded(data, start, end): - """Return `data[start:end]` filling in with zeros outside array bounds - - Assumes that either `start<0` or `end>len(data)` but not both. - - """ - if start < 0 and end > data.shape[0]: - raise RuntimeError() - if start < 0: - start_zeros = np.zeros((-start, data.shape[1]), - dtype=data.dtype) - return np.vstack((start_zeros, data[:end])) - elif end > data.shape[0]: - end_zeros = np.zeros((end - data.shape[0], data.shape[1]), - dtype=data.dtype) - return np.vstack((data[start:], end_zeros)) - else: - return data[start:end] - - class WaveformExtractor(object): """Extract waveforms after data filtering and spike detection.""" def __init__(self, From 3130dcc75670fb2d59df0216f579ea8d3a833e74 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 15:57:51 +0100 Subject: [PATCH 0484/1059] Move utils --- phy/plot/utils.py | 70 ++++++++++++++++++++++++++++++++++++++- phy/plot/visuals.py | 81 ++++++--------------------------------------- 2 files changed, 80 insertions(+), 71 deletions(-) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 77ddbaa48..e4f133151 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -13,7 +13,7 @@ import numpy as np from vispy import gloo -from .transform import Range +from .transform import Range, NDC logger = logging.getLogger(__name__) @@ -197,3 +197,71 @@ def _get_box_pos_size(box_bounds): x = (x0 + x1) * .5 y = (y0 + y1) * .5 return np.c_[x, y], (w.mean(), h.mean()) + + +def _check_data_bounds(data_bounds): + assert len(data_bounds) == 4 + assert data_bounds[0] < data_bounds[2] + assert data_bounds[1] < data_bounds[3] + + +def _get_data_bounds(data_bounds, pos): + """"Prepare data bounds, possibly using min/max of the data.""" + if not len(pos): + return data_bounds or NDC + if data_bounds is None: + m, M = pos.min(axis=0), pos.max(axis=0) + data_bounds = [m[0], m[1], M[0], M[1]] + data_bounds = list(data_bounds) + if data_bounds[0] == data_bounds[2]: # pragma: no cover + data_bounds[0] -= 1 + data_bounds[2] += 1 + if data_bounds[1] == data_bounds[3]: + data_bounds[1] -= 1 + data_bounds[3] += 1 + _check_data_bounds(data_bounds) + return data_bounds + + +def _check_pos_2D(pos): + """Check position data before GPU uploading.""" + assert pos is not None + pos = np.asarray(pos, dtype=np.float32) + assert pos.ndim == 2 + return pos + + +def _get_pos_depth(pos_tr, depth): + """Prepare a (N, 3) position-depth array for GPU uploading.""" + n = pos_tr.shape[0] + pos_tr = _get_array(pos_tr, (n, 2)) + depth = _get_array(depth, (n, 1), 0) + return np.c_[pos_tr, depth] + + +def _get_hist_max(hist): + hist_max = hist.max() if hist.size else 1. + hist_max = float(hist_max) + hist_max = hist_max if hist_max > 0 else 1. + assert hist_max > 0 + return hist_max + + +def _get_index(n_items, item_size, n): + """Prepare an index attribute for GPU uploading.""" + index = np.arange(n_items) + index = np.repeat(index, item_size) + index = index.astype(np.float32) + assert index.shape == (n,) + return index + + +def _get_color(color, default): + if color is None: + color = default + assert len(color) == 4 + return color + + +def _get_linear_x(n_signals, n_samples): + return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index ae9b1fa46..cad0bef33 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -12,8 +12,17 @@ from .base import BaseVisual from .transform import Range, GPU, NDC -from .utils import (_enable_depth_mask, _tesselate_histogram, - _get_texture, _get_array, +from .utils import (_enable_depth_mask, + _tesselate_histogram, + _get_texture, + _get_array, + _get_data_bounds, + _get_pos_depth, + _check_pos_2D, + _get_index, + _get_linear_x, + _get_hist_max, + _get_color, ) @@ -21,74 +30,6 @@ # Utils #------------------------------------------------------------------------------ -def _check_data_bounds(data_bounds): - assert len(data_bounds) == 4 - assert data_bounds[0] < data_bounds[2] - assert data_bounds[1] < data_bounds[3] - - -def _get_data_bounds(data_bounds, pos): - """"Prepare data bounds, possibly using min/max of the data.""" - if not len(pos): - return data_bounds or NDC - if data_bounds is None: - m, M = pos.min(axis=0), pos.max(axis=0) - data_bounds = [m[0], m[1], M[0], M[1]] - data_bounds = list(data_bounds) - if data_bounds[0] == data_bounds[2]: # pragma: no cover - data_bounds[0] -= 1 - data_bounds[2] += 1 - if data_bounds[1] == data_bounds[3]: - data_bounds[1] -= 1 - data_bounds[3] += 1 - _check_data_bounds(data_bounds) - return data_bounds - - -def _check_pos_2D(pos): - """Check position data before GPU uploading.""" - assert pos is not None - pos = np.asarray(pos, dtype=np.float32) - assert pos.ndim == 2 - return pos - - -def _get_pos_depth(pos_tr, depth): - """Prepare a (N, 3) position-depth array for GPU uploading.""" - n = pos_tr.shape[0] - pos_tr = _get_array(pos_tr, (n, 2)) - depth = _get_array(depth, (n, 1), 0) - return np.c_[pos_tr, depth] - - -def _get_hist_max(hist): - hist_max = hist.max() if hist.size else 1. - hist_max = float(hist_max) - hist_max = hist_max if hist_max > 0 else 1. - assert hist_max > 0 - return hist_max - - -def _get_index(n_items, item_size, n): - """Prepare an index attribute for GPU uploading.""" - index = np.arange(n_items) - index = np.repeat(index, item_size) - index = index.astype(np.float32) - assert index.shape == (n,) - return index - - -def _get_color(color, default): - if color is None: - color = default - assert len(color) == 4 - return color - - -def _get_linear_x(n_signals, n_samples): - return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) - - DEFAULT_COLOR = (0.03, 0.57, 0.98, .75) From 55bf664d8dc7b3771597817a0224c3868a343a19 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 15:58:55 +0100 Subject: [PATCH 0485/1059] Move utils --- phy/plot/plot.py | 3 ++- phy/plot/tests/test_plot.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 4d706ffde..426cf526e 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -16,7 +16,8 @@ from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom from .transform import NDC -from .visuals import _get_array, ScatterVisual, PlotVisual, HistogramVisual +from .utils import _get_array +from .visuals import ScatterVisual, PlotVisual, HistogramVisual #------------------------------------------------------------------------------ diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 4298df410..8330ded4f 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -10,7 +10,7 @@ import numpy as np from ..plot import GridView, BoxedView, StackedView -from ..visuals import _get_linear_x +from ..utils import _get_linear_x #------------------------------------------------------------------------------ From be6b8e301aa6a26fce14dc9b8afd426852250c8d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 17:35:38 +0100 Subject: [PATCH 0486/1059] Refactor plot interface --- phy/plot/base.py | 2 +- phy/plot/plot.py | 338 +++++++++++++++++++++--------------- phy/plot/tests/test_plot.py | 15 +- 3 files changed, 211 insertions(+), 144 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 4b6c92a96..3291b2450 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -181,7 +181,7 @@ def attach(self, canvas): canvas.connect(self.on_key_press) def is_attached(self): - """Whether the transform is attached to a canvas.""" + """Whether the interact is attached to a canvas.""" return self._canvas is not None def on_resize(self, event): diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 426cf526e..dba405eff 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -12,6 +12,7 @@ import numpy as np +from phy.utils import Bunch, _is_array_like from .base import BaseCanvas from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom @@ -40,63 +41,111 @@ def __getitem__(self, name): # Base plotting interface #------------------------------------------------------------------------------ -class SubView(object): - def __init__(self, idx): - self.spec = {'idx': idx} - - @property - def visual_class(self): - return self.spec.get('visual_class', None) - - def _set(self, visual_class, spec): - self.spec['visual_class'] = visual_class - self.spec.update(spec) - - def __getattr__(self, name): - return self.spec[name] - - def scatter(self, x, y, color=None, size=None, marker=None): - # Validate x and y. - assert x.ndim == y.ndim == 1 - assert x.shape == y.shape - n = x.shape[0] - # Set the color and size. - color = _get_array(color, (n, 4), ScatterVisual._default_color) - size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) - # Default marker. - marker = marker or ScatterVisual._default_marker - # Set the spec. - spec = dict(x=x, y=y, color=color, size=size, marker=marker) - return self._set(ScatterVisual, spec) - - def plot(self, x, y, color=None, depth=None): - x = np.atleast_2d(x) - y = np.atleast_2d(y) - # Validate x and y. - assert x.ndim == y.ndim == 2 - assert x.shape == y.shape - n_plots, n_samples = x.shape - # Get the colors. - color = _get_array(color, (n_plots, 4), PlotVisual._default_color) - # Get the depth. - depth = _get_array(depth, (n_plots,), 0) - # Set the spec. - spec = dict(x=x, y=y, color=color, depth=depth) - return self._set(PlotVisual, spec) - - def hist(self, data, color=None): - # Validate data. - if data.ndim == 1: - data = data[np.newaxis, :] - assert data.ndim == 2 - n_hists, n_samples = data.shape - # Get the colors. - color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) - spec = dict(data=data, color=color) - return self._set(HistogramVisual, spec) - - def __repr__(self): - return str(self.spec) # pragma: no cover +def _prepare_scatter(x, y, color=None, size=None, marker=None): + # Validate x and y. + assert x.ndim == y.ndim == 1 + assert x.shape == y.shape + n = x.shape[0] + # Set the color and size. + color = _get_array(color, (n, 4), ScatterVisual._default_color) + size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) + # Default marker. + marker = marker or ScatterVisual._default_marker + return dict(x=x, y=y, color=color, size=size, marker=marker) + + +def _prepare_plot(x, y, color=None, depth=None): + x = np.atleast_2d(x) + y = np.atleast_2d(y) + # Validate x and y. + assert x.ndim == y.ndim == 2 + assert x.shape == y.shape + n_plots, n_samples = x.shape + # Get the colors. + color = _get_array(color, (n_plots, 4), PlotVisual._default_color) + # Get the depth. + depth = _get_array(depth, (n_plots,), 0) + return dict(x=x, y=y, color=color, depth=depth) + + +def _prepare_hist(data, color=None): + # Validate data. + if data.ndim == 1: + data = data[np.newaxis, :] + assert data.ndim == 2 + n_hists, n_samples = data.shape + # Get the colors. + color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) + return dict(data=data, color=color) + + +def _prepare_box_index(box_index, n): + if not _is_array_like(box_index): + box_index = np.tile(box_index, (n, 1)) + box_index = np.asarray(box_index, dtype=np.int32) + assert box_index.ndim == 2 + assert box_index.shape[0] == n + return box_index + + +def _build_scatter(items): + """Build scatter items and return parameters for `set_data()`.""" + + ac = Accumulator() + for item in items: + # The item data has already been prepared. + n = len(item.data.x) + ac['pos'] = np.c_[item.data.x, item.data.y] + ac['color'] = item.data.color + ac['size'] = item.data.size + ac['box_index'] = _prepare_box_index(item.box_index, n) + + return (dict(pos=ac['pos'], color=ac['color'], size=ac['size']), + ac['box_index']) + + +def _build_plot(items): + """Build all plot items and return parameters for `set_data()`.""" + + ac = Accumulator() + for item in items: + n = item.data.x.size + ac['x'] = item.data.x + ac['y'] = item.data.y + ac['plot_colors'] = item.data.color + ac['box_index'] = _prepare_box_index(item.box_index, n) + + return (dict(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors']), + ac['box_index']) + + +def _build_histogram(items): + """Build all histogram items and return parameters for `set_data()`.""" + + ac = Accumulator() + for item in items: + n = item.data.data.size + ac['data'] = item.data.data + ac['hist_colors'] = item.data.color + # NOTE: the `6 * ` comes from the histogram tesselation. + ac['box_index'] = _prepare_box_index(item.box_index, 6 * n) + + return (dict(hist=ac['data'], hist_colors=ac['hist_colors']), + ac['box_index']) + + +class ViewItem(Bunch): + def __init__(self, base, visual_class=None, data=None, box_index=None): + super(ViewItem, self).__init__(visual_class=visual_class, + data=Bunch(data), + box_index=box_index, + to_build=True, + ) + self._base = base + + def set_data(self, **kwargs): + self.data.update(kwargs) + self.to_build = True class BaseView(BaseCanvas): @@ -105,28 +154,39 @@ def __init__(self, interacts): # Attach the passed interacts to the current canvas. for interact in interacts: interact.attach(self) - self.subviews = {} + self._items = [] # List of view items instances. self._visuals = {} # To override # ------------------------------------------------------------------------- - def iter_index(self): - raise NotImplementedError() + def __getitem__(self, idx): + class _Proxy(object): + def scatter(s, *args, **kwargs): + kwargs['box_index'] = idx + return self.scatter(*args, **kwargs) - # Internal methods - # ------------------------------------------------------------------------- + def plot(s, *args, **kwargs): + kwargs['box_index'] = idx + return self.plot(*args, **kwargs) - def iter_subviews(self): - for idx in self.iter_index(): - sv = self.subviews.get(idx, None) - if sv: - yield sv + def hist(s, *args, **kwargs): + kwargs['box_index'] = idx + return self.hist(*args, **kwargs) - def __getitem__(self, idx): - sv = SubView(idx) - self.subviews[idx] = sv - return sv + return _Proxy() + + def _iter_items(self): + """Iterate over all items.""" + for item in self._items: + yield item + + def _visuals_to_build(self): + visual_classes = set() + for item in self._items: + if item.to_build: + visual_classes.add(item.visual_class) + return visual_classes def _get_visual(self, key): if key not in self._visuals: @@ -143,67 +203,76 @@ def _get_visual(self, key): self._visuals[key] = v return self._visuals[key] - def _build_scatter(self, subviews, marker): - """Build all scatter subviews with the same marker type.""" - - ac = Accumulator() - for sv in subviews: - assert sv.marker == marker - n = len(sv.x) - ac['pos'] = np.c_[sv.x, sv.y] - ac['color'] = sv.color - ac['size'] = sv.size - ac['box_index'] = np.tile(sv.idx, (n, 1)) - - v = self._get_visual((ScatterVisual, marker)) - v.set_data(pos=ac['pos'], color=ac['color'], size=ac['size']) - v.program['a_box_index'] = ac['box_index'] - - def _build_plot(self, subviews): - """Build all plot subviews.""" - - ac = Accumulator() - for sv in subviews: - n = sv.x.size - ac['x'] = sv.x - ac['y'] = sv.y - ac['plot_colors'] = sv.color - ac['box_index'] = np.tile(sv.idx, (n, 1)) - - v = self._get_visual(PlotVisual) - v.set_data(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors']) - v.program['a_box_index'] = ac['box_index'] - - def _build_histogram(self, subviews): - """Build all histogram subviews.""" - - ac = Accumulator() - for sv in subviews: - n = sv.data.size - ac['data'] = sv.data - ac['hist_colors'] = sv.color - # NOTE: the `6 * ` comes from the histogram tesselation. - ac['box_index'] = np.tile(sv.idx, (6 * n, 1)) - - v = self._get_visual(HistogramVisual) - v.set_data(hist=ac['data'], hist_colors=ac['hist_colors']) - v.program['a_box_index'] = ac['box_index'] + # Public methods + # ------------------------------------------------------------------------- + + def plot(self, *args, **kwargs): + box_index = kwargs.pop('box_index', None) + data = _prepare_plot(*args, **kwargs) + item = ViewItem(self, visual_class=PlotVisual, + data=data, box_index=box_index) + self._items.append(item) + return item + + def scatter(self, *args, **kwargs): + box_index = kwargs.pop('box_index', None) + data = _prepare_scatter(*args, **kwargs) + item = ViewItem(self, visual_class=ScatterVisual, + data=data, box_index=box_index) + self._items.append(item) + return item + + def hist(self, *args, **kwargs): + box_index = kwargs.pop('box_index', None) + data = _prepare_hist(*args, **kwargs) + item = ViewItem(self, visual_class=HistogramVisual, + data=data, box_index=box_index) + self._items.append(item) + return item def build(self): """Build all visuals.""" - # TODO: fix a bug where an old subplot is not deleted if it - # is changed to another type, and there is no longer any subplot - # of the old type. The old visual should be delete or hidden. - for visual_class, subviews in groupby(self.iter_subviews(), - lambda sv: sv.visual_class): + visuals_to_build = self._visuals_to_build() + + for visual_class, items in groupby(self._iter_items(), + lambda item: item.visual_class): + items = list(items) + + # Skip visuals that do not need to be built. + if visual_class not in visuals_to_build: + continue + + # Histogram. + if visual_class == HistogramVisual: + data, box_index = _build_histogram(items) + v = self._get_visual(HistogramVisual) + v.set_data(**data) + v.program['a_box_index'] = box_index + for item in items: + item.to_build = False + + # Scatter. if visual_class == ScatterVisual: - for marker, subviews_scatter in groupby(subviews, - lambda sv: sv.marker): - self._build_scatter(subviews_scatter, marker) - elif visual_class == PlotVisual: - self._build_plot(subviews) - elif visual_class == HistogramVisual: - self._build_histogram(subviews) + items_grouped = groupby(items, lambda item: item.data.marker) + for marker, items_scatter in items_grouped: + items_scatter = list(items_scatter) + data, box_index = _build_scatter(items_scatter) + v = self._get_visual((ScatterVisual, marker)) + v.set_data(**data) + v.program['a_box_index'] = box_index + for item in items_scatter: + item.to_build = False + + # Plot. + if visual_class == PlotVisual: + data, box_index = _build_plot(items) + v = self._get_visual(PlotVisual) + v.set_data(**data) + v.program['a_box_index'] = box_index + for item in items: + item.to_build = False + + self.update() #------------------------------------------------------------------------------ @@ -217,11 +286,6 @@ def __init__(self, n_rows, n_cols): interacts = [Grid(n_rows, n_cols), pz] super(GridView, self).__init__(interacts) - def iter_index(self): - for i in range(self.n_rows): - for j in range(self.n_cols): - yield (i, j) - class BoxedView(BaseView): def __init__(self, box_bounds): @@ -231,10 +295,6 @@ def __init__(self, box_bounds): interacts = [self._boxed, self._pz] super(BoxedView, self).__init__(interacts) - def iter_index(self): - for i in range(self.n_plots): - yield i - class StackedView(BaseView): def __init__(self, n_plots): @@ -242,7 +302,3 @@ def __init__(self, n_plots): pz = PanZoom(aspect=None, constrain_bounds=NDC) interacts = [Stacked(n_plots, margin=.1), pz] super(StackedView, self).__init__(interacts) - - def iter_index(self): - for i in range(self.n_plots): - yield i diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 8330ded4f..16df6a79a 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -42,8 +42,14 @@ def test_grid_scatter(qtbot): view[1, 0].scatter(x, y, size=np.random.uniform(5, 20, size=n)) view[1, 1] + + # Multiple scatters in the same subplot. + view[1, 2].scatter(x[2::6], y[2::6], marker='asterisk', + color=(0, 1, 0, .25), size=20) view[1, 2].scatter(x[::5], y[::5], marker='heart', - color=(1, 0, 0, .25), size=20) + color=(1, 0, 0, .35), size=50) + view[1, 2].scatter(x[1::3], y[1::3], marker='heart', + color=(1, 0, 1, .35), size=30) _show(qtbot, view) @@ -92,9 +98,14 @@ def test_stacked_complete(qtbot): t = _get_linear_x(1, 1000).ravel() view[0].scatter(*np.random.randn(2, 100)) - view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + + # Different types of visuals in the same subplot. view[2].hist(np.random.rand(5, 10), color=np.random.uniform(.4, .9, size=(5, 4))) + view[2].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + + v = view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + v.set_data(color=(0, 1, 0, 1)) _show(qtbot, view) From 913ca83f0a166673668467f4173076aab2ee2e95 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 17:43:02 +0100 Subject: [PATCH 0487/1059] Add comments --- phy/plot/plot.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index dba405eff..b074a3d5c 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -135,6 +135,8 @@ def _build_histogram(items): class ViewItem(Bunch): + """A visual item that will be rendered in batch with other view items + of the same type.""" def __init__(self, base, visual_class=None, data=None, box_index=None): super(ViewItem, self).__init__(visual_class=visual_class, data=Bunch(data), @@ -149,6 +151,8 @@ def set_data(self, **kwargs): class BaseView(BaseCanvas): + """High-level plotting canvas.""" + def __init__(self, interacts): super(BaseView, self).__init__() # Attach the passed interacts to the current canvas. @@ -177,11 +181,12 @@ def hist(s, *args, **kwargs): return _Proxy() def _iter_items(self): - """Iterate over all items.""" + """Iterate over all view items.""" for item in self._items: yield item def _visuals_to_build(self): + """Return the set of visual classes that need to be rebuilt.""" visual_classes = set() for item in self._items: if item.to_build: @@ -189,6 +194,7 @@ def _visuals_to_build(self): return visual_classes def _get_visual(self, key): + """Create or return a visual from its class or tuple (class, param).""" if key not in self._visuals: # Create the visual. if isinstance(key, tuple): @@ -207,6 +213,7 @@ def _get_visual(self, key): # ------------------------------------------------------------------------- def plot(self, *args, **kwargs): + """Add a line plot.""" box_index = kwargs.pop('box_index', None) data = _prepare_plot(*args, **kwargs) item = ViewItem(self, visual_class=PlotVisual, @@ -215,6 +222,7 @@ def plot(self, *args, **kwargs): return item def scatter(self, *args, **kwargs): + """Add a scatter plot.""" box_index = kwargs.pop('box_index', None) data = _prepare_scatter(*args, **kwargs) item = ViewItem(self, visual_class=ScatterVisual, @@ -223,6 +231,7 @@ def scatter(self, *args, **kwargs): return item def hist(self, *args, **kwargs): + """Add a histogram plot.""" box_index = kwargs.pop('box_index', None) data = _prepare_hist(*args, **kwargs) item = ViewItem(self, visual_class=HistogramVisual, @@ -243,6 +252,7 @@ def build(self): continue # Histogram. + # TODO: refactor this (DRY). if visual_class == HistogramVisual: data, box_index = _build_histogram(items) v = self._get_visual(HistogramVisual) @@ -280,6 +290,7 @@ def build(self): #------------------------------------------------------------------------------ class GridView(BaseView): + """A 2D grid with clipping.""" def __init__(self, n_rows, n_cols): self.n_rows, self.n_cols = n_rows, n_cols pz = PanZoom(aspect=None, constrain_bounds=NDC) @@ -288,6 +299,7 @@ def __init__(self, n_rows, n_cols): class BoxedView(BaseView): + """Subplots at arbitrary positions""" def __init__(self, box_bounds): self.n_plots = len(box_bounds) self._boxed = Boxed(box_bounds) @@ -297,6 +309,7 @@ def __init__(self, box_bounds): class StackedView(BaseView): + """Stacked subplots""" def __init__(self, n_plots): self.n_plots = n_plots pz = PanZoom(aspect=None, constrain_bounds=NDC) From b1576c0469bf065dea5de7dccd1b3f9d92978b83 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 18:02:05 +0100 Subject: [PATCH 0488/1059] Support plots with different numbers of samples --- phy/plot/plot.py | 20 ++++++++++++++------ phy/plot/tests/test_plot.py | 2 +- phy/plot/visuals.py | 5 ++++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index b074a3d5c..2b9889939 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -264,6 +264,7 @@ def build(self): # Scatter. if visual_class == ScatterVisual: items_grouped = groupby(items, lambda item: item.data.marker) + # One visual per marker type. for marker, items_scatter in items_grouped: items_scatter = list(items_scatter) data, box_index = _build_scatter(items_scatter) @@ -275,12 +276,19 @@ def build(self): # Plot. if visual_class == PlotVisual: - data, box_index = _build_plot(items) - v = self._get_visual(PlotVisual) - v.set_data(**data) - v.program['a_box_index'] = box_index - for item in items: - item.to_build = False + items_grouped = groupby(items, + lambda item: item.data.x.shape[1]) + # HACK: one visual per number of samples, because currently + # a PlotVisual only accepts a regular (n_plots, n_samples) + # array as input. + for n_samples, items_plot in items_grouped: + items_plot = list(items_plot) + data, box_index = _build_plot(items_plot) + v = self._get_visual((PlotVisual, n_samples)) + v.set_data(**data) + v.program['a_box_index'] = box_index + for item in items_plot: + item.to_build = False self.update() diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 16df6a79a..a04a77f3b 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -104,7 +104,7 @@ def test_stacked_complete(qtbot): color=np.random.uniform(.4, .9, size=(5, 4))) view[2].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) - v = view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + v = view[1].plot(t[::2], np.sin(20 * t[::2]), color=(1, 0, 0, 1)) v.set_data(color=(0, 1, 0, 1)) _show(qtbot, view) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index cad0bef33..4662f47ae 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -116,9 +116,10 @@ class PlotVisual(BaseVisual): gl_primitive_type = 'line_strip' _default_color = DEFAULT_COLOR - def __init__(self): + def __init__(self, n_samples=None): super(PlotVisual, self).__init__() self.data_bounds = NDC + self.n_samples = n_samples _enable_depth_mask() def get_transforms(self): @@ -144,6 +145,8 @@ def set_data(self, assert x.ndim == 2 assert x.shape == y.shape n_signals, n_samples = x.shape + if self.n_samples: + assert n_samples == self.n_samples n = n_signals * n_samples # Generate the (n, 2) pos array. From 1eeb787bf45bcb36bee3e4a2d318ed32a9ffe16b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 27 Oct 2015 18:19:26 +0100 Subject: [PATCH 0489/1059] WIP: trace view --- phy/cluster/manual/tests/test_views.py | 61 +++++++++++++- phy/cluster/manual/views.py | 109 +++++++++++++------------ phy/plot/plot.py | 9 +- 3 files changed, 123 insertions(+), 56 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index df88ec744..e7691becc 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -7,14 +7,17 @@ #------------------------------------------------------------------------------ import numpy as np +from numpy.testing import assert_equal as ae +from pytest import raises from phy.io.mock import (artificial_waveforms, artificial_spike_clusters, + artificial_spike_samples, artificial_masks, artificial_traces, ) from phy.electrode.mea import staggered_positions -from ..views import WaveformView, TraceView +from ..views import WaveformView, TraceView, _extract_wave #------------------------------------------------------------------------------ @@ -29,6 +32,34 @@ def _show(qtbot, view, stop=False): view.close() +#------------------------------------------------------------------------------ +# Test utils +#------------------------------------------------------------------------------ + +def test_extract_wave(): + traces = np.arange(30).reshape((6, 5)) + mask = np.array([0, 1, 1, .5, 0]) + wave_len = 4 + + with raises(ValueError): + _extract_wave(traces, -1, mask, wave_len) + + with raises(ValueError): + _extract_wave(traces, 20, mask, wave_len) + + ae(_extract_wave(traces, 0, mask, wave_len)[0], + [[0, 0, 0], [0, 0, 0], [1, 2, 3], [6, 7, 8]]) + + ae(_extract_wave(traces, 1, mask, wave_len)[0], + [[0, 0, 0], [1, 2, 3], [6, 7, 8], [11, 12, 13]]) + + ae(_extract_wave(traces, 2, mask, wave_len)[0], + [[1, 2, 3], [6, 7, 8], [11, 12, 13], [16, 17, 18]]) + + ae(_extract_wave(traces, 5, mask, wave_len)[0], + [[16, 17, 18], [21, 22, 23], [0, 0, 0], [0, 0, 0]]) + + #------------------------------------------------------------------------------ # Test views #------------------------------------------------------------------------------ @@ -69,12 +100,36 @@ def test_waveform_view(qtbot): def test_trace_view_no_spikes(qtbot): - n_samples = 5000 + n_samples = 1000 n_channels = 12 - sample_rate = 10000. + sample_rate = 2000. traces = artificial_traces(n_samples, n_channels) # Create the view. v = TraceView(traces=traces, sample_rate=sample_rate) _show(qtbot, v) + + +def test_trace_view_spikes(qtbot): + n_samples = 1000 + n_channels = 12 + sample_rate = 2000. + n_spikes = 20 + n_clusters = 3 + + traces = artificial_traces(n_samples, n_channels) + spike_times = artificial_spike_samples(n_spikes) / sample_rate + # spike_times = [.1, .2] + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + masks = artificial_masks(n_spikes, n_channels) + + # Create the view. + v = TraceView(traces=traces, + sample_rate=sample_rate, + spike_times=spike_times, + spike_clusters=spike_clusters, + masks=masks, + n_samples_per_spike=6, + ) + _show(qtbot, v) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 860eaf5a8..af2fb7fbb 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -11,9 +11,10 @@ import numpy as np -from phy.io.array import _index_of +from phy.io.array import _index_of, _get_padded from phy.electrode.mea import linear_positions -from phy.plot import BoxedView, StackedView, GridView, _get_linear_x +from phy.plot import (BoxedView, StackedView, GridView, + _get_linear_x) from phy.plot.utils import _get_boxes, _get_array logger = logging.getLogger(__name__) @@ -23,7 +24,6 @@ # Utils # ----------------------------------------------------------------------------- - # Default color map for the selected clusters. _COLORMAP = np.array([[8, 146, 252], [255, 2, 2], @@ -51,6 +51,24 @@ def _selected_clusters_colors(n_clusters=None): return colors[:n_clusters, ...] / 255. +def _extract_wave(traces, spk, mask, wave_len=None): + n_samples, n_channels = traces.shape + if not (0 <= spk < n_samples): + raise ValueError() + assert mask.shape == (n_channels,) + channels = np.nonzero(mask > .1)[0] + # There should be at least one non-masked channel. + if not len(channels): + return + i = spk - wave_len // 2 + j = spk + wave_len // 2 + a, b = max(0, i), min(j, n_samples - 1) + data = traces[a:b, channels] + data = _get_padded(data, i - a, i - a + wave_len) + assert data.shape == (wave_len, len(channels)) + return data, channels + + # ----------------------------------------------------------------------------- # Views # ----------------------------------------------------------------------------- @@ -123,6 +141,7 @@ def on_select(self, cluster_ids, spike_ids): depth[m <= 0.25] = 0 # Plot all waveforms. + # TODO: optim: avoid the loop. for ch in range(self.n_channels): self[ch].plot(x=t, y=w[:, :, ch], color=color, @@ -173,6 +192,7 @@ def __init__(self, # Spike times. if spike_times is not None: + spike_times = np.asarray(spike_times) self.n_spikes = len(spike_times) assert spike_times.shape == (self.n_spikes,) self.spike_times = spike_times @@ -219,67 +239,56 @@ def _load_spikes(self, interval): def set_interval(self, interval): + start, end = interval color = (.5, .5, .5, 1) + dt = 1. / self.sample_rate + # Load traces. traces = self._load_traces(interval) + n_samples = traces.shape[0] assert traces.shape[1] == self.n_channels + m, M = traces.min(), traces.max() + data_bounds = [start, m, end, M] + # Generate the trace plots. # TODO OPTIM: avoid the loop and generate all channel traces in # one pass with NumPy (but need to set a_box_index manually too). - t = _get_linear_x(1, traces.shape[0]) + # t = _get_linear_x(1, traces.shape[0]) + t = start + np.arange(n_samples) * dt for ch in range(self.n_channels): - self[ch].plot(t, traces[:, ch], color=color) - - # if self.spike_times is not None: - # spike_times, spike_clusters, masks = self._load_spikes(interval) + self[ch].plot(t, traces[:, ch], color=color, + data_bounds=data_bounds) + + # Display the spikes. + if self.spike_times is not None: + wave_len = self.n_samples_per_spike + spike_times, spike_clusters, masks = self._load_spikes(interval) + n_spikes = len(spike_times) + dt = 1. / float(self.sample_rate) + dur_spike = wave_len * dt + trace_start = int(self.sample_rate * start) + + # ac = Accumulator() + for i in range(n_spikes): + sample_rel = (int(spike_times[i] * self.sample_rate) - + trace_start) + mask = self.masks[i] + # clu = spike_clusters[i] + w, ch = _extract_wave(traces, sample_rel, mask, wave_len) + n_ch = len(ch) + t0 = spike_times[i] - dur_spike / 2. + color = (1, 0, 0, 1) + box_index = np.repeat(ch[:, np.newaxis], wave_len, axis=0) + t = t0 + dt * np.arange(wave_len) + t = np.tile(t, (n_ch, 1)) + self.plot(t, w.T, color=color, box_index=box_index, + data_bounds=data_bounds) self.build() self.update() - # # Keep the spikes in the interval. - # spikes = self.spike_ids - # spike_samples = self.model.spike_samples[spikes] - # a, b = spike_samples.searchsorted(interval) - # spikes = spikes[a:b] - # self.view.visual.n_spikes = len(spikes) - # self.view.visual.spike_ids = spikes - - - # def on_select(self, cluster_ids, spike_ids): - # n_clusters = len(cluster_ids) - # n_spikes = len(spike_ids) - # if n_spikes == 0: - # return - - # # Relative spike clusters. - # # NOTE: the order of the clusters in cluster_ids matters. - # # It will influence the relative index of the clusters, which - # # in return influence the depth. - # spike_clusters = self.spike_clusters[spike_ids] - # assert np.all(np.in1d(spike_clusters, cluster_ids)) - # spike_clusters_rel = _index_of(spike_clusters, cluster_ids) - - # # Fetch the waveforms. - # w = self.waveforms[spike_ids] - # colors = _selected_clusters_colors(n_clusters) - # t = _get_linear_x(n_spikes, self.n_samples) - - # # Get the colors. - # color = colors[spike_clusters_rel] - # # Alpha channel. - # color = np.c_[color, np.ones((n_spikes, 1))] - - # # Plot all waveforms. - # for ch in range(self.n_channels): - # self[ch].plot(x=t, y=w[:, :, ch], - # color=color, - # depth=depth[:, ch]) - - # self.build() - # self.update() - class FeatureView(GridView): def __init__(self, diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 2b9889939..d49486385 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -54,7 +54,7 @@ def _prepare_scatter(x, y, color=None, size=None, marker=None): return dict(x=x, y=y, color=color, size=size, marker=marker) -def _prepare_plot(x, y, color=None, depth=None): +def _prepare_plot(x, y, color=None, depth=None, data_bounds=None): x = np.atleast_2d(x) y = np.atleast_2d(y) # Validate x and y. @@ -65,7 +65,7 @@ def _prepare_plot(x, y, color=None, depth=None): color = _get_array(color, (n_plots, 4), PlotVisual._default_color) # Get the depth. depth = _get_array(depth, (n_plots,), 0) - return dict(x=x, y=y, color=color, depth=depth) + return dict(x=x, y=y, color=color, depth=depth, data_bounds=data_bounds) def _prepare_hist(data, color=None): @@ -115,7 +115,10 @@ def _build_plot(items): ac['plot_colors'] = item.data.color ac['box_index'] = _prepare_box_index(item.box_index, n) - return (dict(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors']), + return (dict(x=ac['x'], y=ac['y'], + plot_colors=ac['plot_colors'], + data_bounds=item.data.data_bounds, + ), ac['box_index']) From e353cf1af80aed9d52ec574506e39facec2a757d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 28 Oct 2015 09:00:30 +0100 Subject: [PATCH 0490/1059] Add get_post_transforms() --- phy/plot/base.py | 20 +++++++++++++++++--- phy/plot/tests/test_base.py | 5 +++++ phy/plot/transform.py | 6 +++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 3291b2450..7fdde335d 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -35,8 +35,8 @@ class BaseVisual(object): * `get_shaders()`: return the vertex and fragment shaders, or just `shader_name` for built-in shaders * `get_transforms()`: return a list of `Transform` instances, which - can act on the CPU or the GPU. The interact's transforms will be - appended to that list when the visual is attached to the canvas. + * `get_post_transforms()`: return a GLSL snippet to insert after + all transforms in the vertex shader. * `set_data()`: has access to `self.program`. Must be called after `attach()`. @@ -52,13 +52,24 @@ def __init__(self): # ------------------------------------------------------------------------- def get_shaders(self): + """Return the vertex and fragment shader code.""" assert self.shader_name return (_load_shader(self.shader_name + '.vert'), _load_shader(self.shader_name + '.frag')) def get_transforms(self): + """Return the list of transforms for the visual. + + There needs to be one and exactly one instance of `GPU()`. + + """ return [GPU()] + def get_post_transforms(self): + """Return a GLSL snippet to insert after all transforms in the + vertex shader.""" + return '' + def set_data(self): """Set data to the program. @@ -253,7 +264,10 @@ def build_program(visual, interacts=()): vertex, fragment = visual.get_shaders() # Get the GLSL snippet to insert before the transformations. pre = '\n'.join(interact.get_pre_transforms() for interact in interacts) - vertex, fragment = transform_chain.insert_glsl(vertex, fragment, pre) + # GLSL snippet to insert after all transformations. + post = visual.get_post_transforms() + vertex, fragment = transform_chain.insert_glsl(vertex, fragment, + pre, post) # Insert shader declarations using the interacts (if any). if interacts: diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index be4d51449..1f2f12b35 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -145,6 +145,11 @@ def get_transforms(self): ), ] + def get_post_transforms(self): + return """ + gl_Position.y += 1; + """ + def set_data(self): data = np.random.uniform(0, 20, (1000, 2)).astype(np.float32) self.program['a_position'] = self.apply_cpu_transforms(data) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index df015fe2c..6c01fcdc0 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -285,9 +285,12 @@ def apply(self, arr): arr = t.apply(arr) return arr - def insert_glsl(self, vertex, fragment, pre_transforms=''): + def insert_glsl(self, vertex, fragment, + pre_transforms='', post_transforms=''): """Generate the GLSL code of the transform chain.""" + # TODO: move this to base.py + # Find the place where to insert the GLSL snippet. # This is "gl_Position = transform(data_var_name);" where # data_var_name is typically an attribute. @@ -322,6 +325,7 @@ def insert_glsl(self, vertex, fragment, pre_transforms=''): continue vs_insert += t.glsl(temp_var) + '\n' vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(temp_var) + vs_insert += post_transforms + '\n' # Clipping. clip = self.get('Clip') From cd2ccf1c4a78d1c954f0c58e6239502c0ddbb576 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 28 Oct 2015 09:08:08 +0100 Subject: [PATCH 0491/1059] Move insert_glsl() function to plot.base module --- phy/plot/base.py | 81 ++++++++++++++++++++++++++++++-- phy/plot/tests/test_base.py | 37 ++++++++++++++- phy/plot/tests/test_transform.py | 21 --------- phy/plot/transform.py | 70 --------------------------- 4 files changed, 113 insertions(+), 96 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 7fdde335d..b28baec49 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -8,17 +8,26 @@ #------------------------------------------------------------------------------ import logging +import re from vispy import gloo from vispy.app import Canvas -from .transform import TransformChain, GPU +from .transform import TransformChain, GPU, Clip from .utils import _load_shader from phy.utils import EventEmitter logger = logging.getLogger(__name__) +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def indent(text): + return '\n'.join(' ' + l.strip() for l in text.splitlines()) + + #------------------------------------------------------------------------------ # Base spike visual #------------------------------------------------------------------------------ @@ -243,6 +252,72 @@ def on_draw(self, e): # Build program with interacts #------------------------------------------------------------------------------ +def insert_glsl(transform_chain, vertex, fragment, + pre_transforms='', post_transforms=''): + """Generate the GLSL code of the transform chain.""" + + # TODO: move this to base.py + + # Find the place where to insert the GLSL snippet. + # This is "gl_Position = transform(data_var_name);" where + # data_var_name is typically an attribute. + vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') + r = vs_regex.search(vertex) + if not r: + logger.debug("The vertex shader doesn't contain the transform " + "placeholder: skipping the transform chain " + "GLSL insertion.") + return vertex, fragment + assert r + logger.log(5, "Found transform placeholder in vertex code: `%s`", + r.group(0)) + + # Find the GLSL variable with the data (should be a `vec2`). + var = r.group(1) + transform_chain.transformed_var_name = var + assert var and var in vertex + + # Generate the snippet to insert in the shaders. + temp_var = 'temp_pos_tr' + # Name for the (eventual) varying. + fvar = 'v_{}'.format(temp_var) + vs_insert = '' + # Insert the pre-transforms. + vs_insert += pre_transforms + '\n' + vs_insert += "vec2 {} = {};\n".format(temp_var, var) + for t in transform_chain.gpu_transforms: + if isinstance(t, Clip): + # Set the varying value in the vertex shader. + vs_insert += '{} = {};\n'.format(fvar, temp_var) + continue + vs_insert += t.glsl(temp_var) + '\n' + vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(temp_var) + vs_insert += post_transforms + '\n' + + # Clipping. + clip = transform_chain.get('Clip') + if clip: + # Varying name. + glsl_clip = clip.glsl(fvar) + + # Prepare the fragment regex. + fs_regex = re.compile(r'(void main\(\)\s*\{)') + fs_insert = '\\1\n{}'.format(glsl_clip) + + # Add the varying declaration for clipping. + varying_decl = 'varying vec2 {};\n'.format(fvar) + vertex = varying_decl + vertex + fragment = varying_decl + fragment + + # Make the replacement in the fragment shader for clipping. + fragment = fs_regex.sub(indent(fs_insert), fragment) + + # Insert the GLSL snippet of the transform chain in the vertex shader. + vertex = vs_regex.sub(indent(vs_insert), vertex) + + return vertex, fragment + + def build_program(visual, interacts=()): """Create the gloo program of a visual using the interacts transforms. @@ -266,8 +341,8 @@ def build_program(visual, interacts=()): pre = '\n'.join(interact.get_pre_transforms() for interact in interacts) # GLSL snippet to insert after all transformations. post = visual.get_post_transforms() - vertex, fragment = transform_chain.insert_glsl(vertex, fragment, - pre, post) + vertex, fragment = insert_glsl(transform_chain, vertex, fragment, + pre, post) # Insert shader declarations using the interacts (if any). if interacts: diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 1f2f12b35..6b33b23da 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -7,11 +7,13 @@ # Imports #------------------------------------------------------------------------------ +from textwrap import dedent + import numpy as np -from ..base import BaseVisual, BaseInteract +from ..base import BaseVisual, BaseInteract, insert_glsl from ..transform import (subplot_bounds, Translate, Scale, Range, - Clip, Subplot, GPU) + Clip, Subplot, GPU, TransformChain) #------------------------------------------------------------------------------ @@ -172,3 +174,34 @@ def get_transforms(self): assert len(canvas.interacts) == 1 qtbot.waitForWindowShown(canvas.native) # qtbot.stop() + + +def test_transform_chain_complete(): + t = TransformChain([Scale(scale=.5), + Scale(scale=2.)]) + t.add([Range(from_bounds=[-3, -3, 1, 1]), + GPU(), + Clip(), + Subplot(shape='u_shape', index='a_box_index'), + ]) + + vs = dedent(""" + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + } + """).strip() + + fs = dedent(""" + void main() { + gl_FragColor = vec4(1., 1., 1., 1.); + } + """).strip() + vs, fs = insert_glsl(t, vs, fs) + assert 'a_box_index' in vs + assert 'v_' in vs + assert 'v_' in fs + assert 'discard' in fs + + # Increase coverage. + insert_glsl(t, vs.replace('transform', ''), fs) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index bbbc88b29..f26b719d6 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -221,24 +221,3 @@ def test_transform_chain_complete(array): assert len(t.gpu_transforms) == 2 ae(t.apply(array), [[0, .5], [1, 1.5]]) - - vs = dedent(""" - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - } - """).strip() - - fs = dedent(""" - void main() { - gl_FragColor = vec4(1., 1., 1., 1.); - } - """).strip() - vs, fs = t.insert_glsl(vs, fs) - assert 'a_box_index' in vs - assert 'v_' in vs - assert 'v_' in fs - assert 'discard' in fs - - # Increase coverage. - t.insert_glsl(vs.replace('transform', ''), fs) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 6c01fcdc0..d97048d06 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -8,7 +8,6 @@ #------------------------------------------------------------------------------ from textwrap import dedent -import re import numpy as np from six import string_types @@ -58,10 +57,6 @@ def wrapped(*args, **kwargs): return wrapped -def indent(text): - return '\n'.join(' ' + l.strip() for l in text.splitlines()) - - def _glslify(r): """Transform a string or a n-tuple to a valid GLSL expression.""" if isinstance(r, string_types): @@ -284,68 +279,3 @@ def apply(self, arr): for t in self.cpu_transforms: arr = t.apply(arr) return arr - - def insert_glsl(self, vertex, fragment, - pre_transforms='', post_transforms=''): - """Generate the GLSL code of the transform chain.""" - - # TODO: move this to base.py - - # Find the place where to insert the GLSL snippet. - # This is "gl_Position = transform(data_var_name);" where - # data_var_name is typically an attribute. - vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') - r = vs_regex.search(vertex) - if not r: - logger.debug("The vertex shader doesn't contain the transform " - "placeholder: skipping the transform chain " - "GLSL insertion.") - return vertex, fragment - assert r - logger.log(5, "Found transform placeholder in vertex code: `%s`", - r.group(0)) - - # Find the GLSL variable with the data (should be a `vec2`). - var = r.group(1) - self.transformed_var_name = var - assert var and var in vertex - - # Generate the snippet to insert in the shaders. - temp_var = 'temp_pos_tr' - # Name for the (eventual) varying. - fvar = 'v_{}'.format(temp_var) - vs_insert = '' - # Insert the pre-transforms. - vs_insert += pre_transforms + '\n' - vs_insert += "vec2 {} = {};\n".format(temp_var, var) - for t in self.gpu_transforms: - if isinstance(t, Clip): - # Set the varying value in the vertex shader. - vs_insert += '{} = {};\n'.format(fvar, temp_var) - continue - vs_insert += t.glsl(temp_var) + '\n' - vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(temp_var) - vs_insert += post_transforms + '\n' - - # Clipping. - clip = self.get('Clip') - if clip: - # Varying name. - glsl_clip = clip.glsl(fvar) - - # Prepare the fragment regex. - fs_regex = re.compile(r'(void main\(\)\s*\{)') - fs_insert = '\\1\n{}'.format(glsl_clip) - - # Add the varying declaration for clipping. - varying_decl = 'varying vec2 {};\n'.format(fvar) - vertex = varying_decl + vertex - fragment = varying_decl + fragment - - # Make the replacement in the fragment shader for clipping. - fragment = fs_regex.sub(indent(fs_insert), fragment) - - # Insert the GLSL snippet of the transform chain in the vertex shader. - vertex = vs_regex.sub(indent(vs_insert), vertex) - - return vertex, fragment From 1af0bba91b6d110a7aa5863265746a910d78c8db Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 28 Oct 2015 18:22:49 +0100 Subject: [PATCH 0492/1059] Export ClusterMeta --- phy/cluster/manual/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 46249250d..8cd29d948 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -3,6 +3,7 @@ """Manual clustering facilities.""" +from ._utils import ClusterMeta from .clustering import Clustering from .wizard import Wizard from .gui_component import ManualClustering From fc4c1a94ed048fdd0c6db130fa41b64f388b2abd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 28 Oct 2015 18:32:30 +0100 Subject: [PATCH 0493/1059] Bug fix --- phy/io/array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/io/array.py b/phy/io/array.py index b1942e3eb..e757f9004 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -330,6 +330,8 @@ def _spikes_in_clusters(spike_clusters, clusters): def _spikes_per_cluster(spike_clusters, spike_ids=None): """Return a dictionary {cluster: list_of_spikes}.""" + if not len(spike_clusters): + return {} if spike_ids is None: spike_ids = np.arange(len(spike_clusters)).astype(np.int64) rel_spikes = np.argsort(spike_clusters) From 01bad81cae4036a1ccc29aa7a764dfa9494e2389 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 10:38:56 +0100 Subject: [PATCH 0494/1059] Remove old commented code --- phy/utils/plugin.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index e56efcb44..b7702c368 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -65,10 +65,6 @@ def _iter_plugin_files(dirs): plugin_dir = op.realpath(op.expanduser(plugin_dir)) if not op.exists(plugin_dir): continue - # for filename in os.listdir(plugin_dir): - # path = op.join(plugin_dir, filename) - # if not op.isdir(path): - # yield path for subdir, dirs, files in os.walk(plugin_dir): # Skip test folders. base = op.basename(subdir) From 9ec5fc031d7c120a595f3461ba262e32a53c7b3d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 11:11:03 +0100 Subject: [PATCH 0495/1059] Log when creating an action --- phy/gui/__init__.py | 1 + phy/gui/actions.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index 4008a593a..e3a05c5f4 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -5,3 +5,4 @@ from .qt import require_qt, create_app, run_app from .gui import GUI +from .actions import Actions diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 3e21ba2b7..420011ab1 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -175,6 +175,9 @@ def add(self, callback=None, name=None, shortcut=None, alias=None): action = _create_qaction(self.gui, name, callback, shortcut) action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback) + if not name.startswith('_'): + logger.debug("Add action `%s` (%s).", name, + _get_shortcut_string(action.shortcut())) self.gui.addAction(action) self._actions_dict[name] = action_obj # Register the alias -> name mapping. From b9d07b8e180e0231a104f5777782a11593969e0b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 11:30:18 +0100 Subject: [PATCH 0496/1059] WIP --- phy/__init__.py | 3 ++- phy/utils/plugin.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index e29171ce9..73127b8aa 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -15,8 +15,9 @@ from six import StringIO from .io.datasets import download_file, download_sample_data +from .utils.config import load_master_config from .utils._misc import _git_version -from .utils.plugin import IPlugin +from .utils.plugin import IPlugin, get_plugin, get_all_plugins #------------------------------------------------------------------------------ diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index b7702c368..f931f2660 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -124,6 +124,7 @@ def get_all_plugins(config=None): # By default, builtin and default user plugin. dirs = [_builtin_plugins_dir(), _user_plugins_dir()] # Add Plugins.dirs from the optionally-passed config object. - if config: + if config and isinstance(config.Plugins.dirs, list): dirs += config.Plugins.dirs + logger.debug("Discovering plugins in: %s.", ', '.join(dirs)) return [plugin for (plugin,) in discover_plugins(dirs)] From ecb9d0532984244b7cd0ff8c05505696a1b96ef4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 11:57:12 +0100 Subject: [PATCH 0497/1059] Update logging in CLI --- phy/utils/cli.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 80e8b9910..688560e51 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -9,7 +9,10 @@ #------------------------------------------------------------------------------ import logging +import os +import os.path as op import sys +from traceback import format_exception import click @@ -20,26 +23,47 @@ #------------------------------------------------------------------------------ -# CLI tool +# Set up logging with the CLI tool #------------------------------------------------------------------------------ -add_default_handler('DEBUG' if DEBUG else 'INFO') +add_default_handler(level='DEBUG' if DEBUG else 'INFO') -# Only show traceback in debug mode (--debug). def exceptionHandler(exception_type, exception, traceback): # pragma: no cover - logger.error("%s: %s", exception_type.__name__, exception) + logger.error("An error has occurred (%s): %s", + exception_type.__name__, exception) + logger.debug('\n'.join(format_exception(exception_type, + exception, + traceback))) + +# Only show traceback in debug mode (--debug). +# if not DEBUG: +sys.excepthook = exceptionHandler + +# Create a `phy.log` log file with DEBUG level in the current directory. +def _add_log_file(filename): + handler = logging.FileHandler(filename) + handler.setLevel(logging.DEBUG) + formatter = phy._Formatter(fmt=phy._logger_fmt, + datefmt='%Y-%m-%d %H:%M:%S') + handler.setFormatter(formatter) + logger.addHandler(handler) -if not DEBUG: - sys.excepthook = exceptionHandler +_add_log_file(op.join(os.getcwd(), 'phy.log')) +#------------------------------------------------------------------------------ +# CLI tool +#------------------------------------------------------------------------------ + @click.group() @click.version_option(version=phy.__version_git__) @click.help_option('-h', '--help') @click.pass_context def phy(ctx): + """By default, the `phy` command does nothing. Add subcommands with plugins + using `attach_to_cli()` and the `click` library.""" pass @@ -63,4 +87,6 @@ def load_cli_plugins(cli): # NOTE: plugin is a class, so we need to instantiate it. plugin().attach_to_cli(cli) + +# Load all plugins when importing this module. load_cli_plugins(phy) From 4c2449329a6b3d9a4a7984ff1ff06bb7690b930b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 13:28:29 +0100 Subject: [PATCH 0498/1059] Add verbose option in Actions.add() --- phy/gui/actions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 420011ab1..bcb356e5b 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -153,7 +153,8 @@ def exit(): def backup(self): return list(self._actions_dict.values()) - def add(self, callback=None, name=None, shortcut=None, alias=None): + def add(self, callback=None, name=None, shortcut=None, alias=None, + verbose=True): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar if callback is None: @@ -175,7 +176,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None): action = _create_qaction(self.gui, name, callback, shortcut) action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback) - if not name.startswith('_'): + if verbose and not name.startswith('_'): logger.debug("Add action `%s` (%s).", name, _get_shortcut_string(action.shortcut())) self.gui.addAction(action) @@ -378,4 +379,5 @@ def mode_off(self): name=action_obj.name, shortcut=action_obj.shortcut, alias=action_obj.alias, + verbose=False, ) From b3307be568d615ded19d1805051d94b01fe0da04 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 14:09:49 +0100 Subject: [PATCH 0499/1059] Bug fixes --- phy/cluster/manual/__init__.py | 3 ++- phy/cluster/manual/gui_component.py | 20 ++++++++++++++++++-- phy/cluster/manual/views.py | 4 +++- phy/traces/waveform.py | 1 + 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 8cd29d948..6689a86d1 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,5 +5,6 @@ from ._utils import ClusterMeta from .clustering import Clustering -from .wizard import Wizard from .gui_component import ManualClustering +from .views import WaveformView, TraceView, FeatureView, CorrelogramView +from .wizard import Wizard diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index e3b988424..646e8946d 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -147,7 +147,13 @@ class ManualClustering(object): ------ select(cluster_ids, spike_ids) + when clusters are selected + on_cluster(up) + when a merge or split happens + wizard_start() + when the wizard (re)starts save_requested(spike_clusters, cluster_groups) + when a save is requested by the user """ @@ -175,6 +181,7 @@ def __init__(self, shortcuts=None, ): + self.gui = None self.n_spikes_max_per_cluster = n_spikes_max_per_cluster # Load default shortcuts, and override any user shortcuts. @@ -202,6 +209,9 @@ def on_cluster(up): # TODO: how many spikes? logger.info("Assigned spikes.") + if self.gui: + self.gui.emit('on_cluster', up) + @self.cluster_meta.connect # noqa def on_cluster(up): if up.history: @@ -211,6 +221,9 @@ def on_cluster(up): ', '.join(map(str, up.metadata_changed)), up.metadata_value) + if self.gui: + self.gui.emit('on_cluster', up) + @self.wizard.connect def on_select(cluster_ids): """When the wizard selects clusters, choose a spikes subset @@ -252,11 +265,14 @@ def on_select(cluster_ids): spike_ids = select_spikes(np.array(cluster_ids), self.n_spikes_max_per_cluster, self.clustering.spikes_per_cluster) - gui.emit('select', cluster_ids, spike_ids) + + if self.gui: + self.gui.emit('select', cluster_ids, spike_ids) @self.wizard.connect def on_start(): - gui.emit('wizard_start') + if self.gui: + gui.emit('wizard_start') # Create the actions. self._create_actions(gui) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index af2fb7fbb..d72831747 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -159,7 +159,9 @@ def on_mouse_move(self, e): def on_key_press(self, e): pass - def attach_to_gui(self, gui): + def attach(self, gui): + """Attach the view to the GUI.""" + gui.add_view(self) # TODO: make sure the GUI emits these events diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 1259c75e4..0ea3ad589 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -346,6 +346,7 @@ def __init__(self, waveforms, spike_samples): self.shape = (len(spike_samples), waveforms.n_samples_waveforms, waveforms.n_channels_waveforms) + self.ndim = len(self.shape) def __getitem__(self, item): times = self._spike_samples[item] From ba99bfa061b67779aebaeb77fe58b3a65f387df0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 14:30:13 +0100 Subject: [PATCH 0500/1059] Bug fixes --- phy/cluster/manual/views.py | 15 ++++++++------- phy/gui/actions.py | 3 +++ phy/io/array.py | 1 + phy/io/tests/test_array.py | 2 ++ phy/utils/cli.py | 6 +++--- 5 files changed, 17 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d72831747..4d5818b0d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -15,7 +15,7 @@ from phy.electrode.mea import linear_positions from phy.plot import (BoxedView, StackedView, GridView, _get_linear_x) -from phy.plot.utils import _get_boxes, _get_array +from phy.plot.utils import _get_boxes logger = logging.getLogger(__name__) @@ -99,7 +99,7 @@ def __init__(self, self.waveforms = waveforms # Masks. - self.masks = _get_array(masks, (self.n_spikes, self.n_channels), 1) + self.masks = masks # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) @@ -134,8 +134,11 @@ def on_select(self, cluster_ids, spike_ids): color = np.c_[color, np.ones((n_spikes, 1))] # Depth as a function of the cluster index and masks. - m = self.masks[spike_ids, :] + m = self.masks[spike_ids] + m = np.atleast_2d(m) + assert m.ndim == 2 depth = -0.1 - (spike_clusters_rel[:, np.newaxis] + m) + assert m.shape == (n_spikes, self.n_channels) assert depth.shape == (n_spikes, self.n_channels) depth = depth / float(n_clusters + 10.) depth[m <= 0.25] = 0 @@ -164,9 +167,8 @@ def attach(self, gui): gui.add_view(self) - # TODO: make sure the GUI emits these events - gui.connect(self.on_select) - gui.connect(self.on_cluster) + gui.connect_(self.on_select) + gui.connect_(self.on_cluster) class TraceView(StackedView): @@ -206,7 +208,6 @@ def __init__(self, self.spike_clusters = spike_clusters # Masks. - masks = _get_array(masks, (self.n_spikes, self.n_channels), 1) assert masks.shape == (self.n_spikes, self.n_channels) self.masks = masks else: diff --git a/phy/gui/actions.py b/phy/gui/actions.py index bcb356e5b..4061c8309 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -9,6 +9,8 @@ from functools import partial import logging +import sys +import traceback from six import string_types, PY3 @@ -355,6 +357,7 @@ def run(self, snippet): self.actions.run(name, *snippet_args[1:]) except Exception as e: logger.warn("Error when executing snippet: \"%s\".", str(e)) + logger.debug(''.join(traceback.format_exception(*sys.exc_info()))) def is_mode_on(self): return self.command.startswith(':') diff --git a/phy/io/array.py b/phy/io/array.py index e757f9004..0fdc2110c 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -108,6 +108,7 @@ def _index_of(arr, lookup): # TODO: assertions to disable in production for performance reasons. # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large # values + lookup = np.asarray(lookup, dtype=np.int32) m = (lookup.max() if len(lookup) else 0) + 1 tmp = np.zeros(m + 1, dtype=np.int) # Ensure that -1 values are kept. diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 8cd47d600..634bf3c2d 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -328,6 +328,8 @@ def test_spikes_per_cluster(): n_clusters = 3 spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + assert not _spikes_per_cluster([]) + spikes_per_cluster = _spikes_per_cluster(spike_clusters) assert list(spikes_per_cluster.keys()) == list(range(n_clusters)) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 688560e51..8147112ba 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -32,9 +32,9 @@ def exceptionHandler(exception_type, exception, traceback): # pragma: no cover logger.error("An error has occurred (%s): %s", exception_type.__name__, exception) - logger.debug('\n'.join(format_exception(exception_type, - exception, - traceback))) + logger.debug(''.join(format_exception(exception_type, + exception, + traceback))) # Only show traceback in debug mode (--debug). # if not DEBUG: From 8cfa406fa58e57a3e5e47a06ad745001bf35ee6c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 14:33:07 +0100 Subject: [PATCH 0501/1059] Only create when running the phy CLI --- .gitignore | 1 + phy/utils/cli.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 922883500..7f2da8ac9 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ wiki .ipynb_checkpoints .*fuse* *.orig +*.log .eggs .profile __pycache__ diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 8147112ba..b6a03e84c 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -17,7 +17,7 @@ import click import phy -from phy import add_default_handler, DEBUG +from phy import add_default_handler, DEBUG, _Formatter, _logger_fmt logger = logging.getLogger(__name__) @@ -41,17 +41,16 @@ def exceptionHandler(exception_type, exception, traceback): # pragma: no cover sys.excepthook = exceptionHandler -# Create a `phy.log` log file with DEBUG level in the current directory. def _add_log_file(filename): + """Create a `phy.log` log file with DEBUG level in the + current directory.""" handler = logging.FileHandler(filename) handler.setLevel(logging.DEBUG) - formatter = phy._Formatter(fmt=phy._logger_fmt, - datefmt='%Y-%m-%d %H:%M:%S') + formatter = _Formatter(fmt=_logger_fmt, + datefmt='%Y-%m-%d %H:%M:%S') handler.setFormatter(formatter) logger.addHandler(handler) -_add_log_file(op.join(os.getcwd(), 'phy.log')) - #------------------------------------------------------------------------------ # CLI tool @@ -64,7 +63,9 @@ def _add_log_file(filename): def phy(ctx): """By default, the `phy` command does nothing. Add subcommands with plugins using `attach_to_cli()` and the `click` library.""" - pass + + # Create a `phy.log` log file with DEBUG level in the current directory. + _add_log_file(op.join(os.getcwd(), 'phy.log')) #------------------------------------------------------------------------------ From 834a1c317af51dc36e089d549dcba7a81cbc8068 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 15:04:15 +0100 Subject: [PATCH 0502/1059] WIP: cluster manual views --- phy/cluster/manual/views.py | 15 +++++++++++---- phy/plot/base.py | 1 - phy/plot/plot.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 4d5818b0d..a0d80ed5a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -79,6 +79,7 @@ def __init__(self, masks=None, spike_clusters=None, channel_positions=None, + keys='interactive', ): """ @@ -91,7 +92,7 @@ def __init__(self, if channel_positions is None: channel_positions = linear_positions(self.n_channels) box_bounds = _get_boxes(channel_positions) - super(WaveformView, self).__init__(box_bounds) + super(WaveformView, self).__init__(box_bounds, keys=keys) # Waveforms. assert waveforms.ndim == 3 @@ -109,6 +110,12 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions + # Initialize the subplots. + self._plots = {ch: self[ch].plot(x=[], y=[]) + for ch in range(self.n_channels)} + self.build() + self.update() + def on_select(self, cluster_ids, spike_ids): n_clusters = len(cluster_ids) n_spikes = len(spike_ids) @@ -146,9 +153,9 @@ def on_select(self, cluster_ids, spike_ids): # Plot all waveforms. # TODO: optim: avoid the loop. for ch in range(self.n_channels): - self[ch].plot(x=t, y=w[:, :, ch], - color=color, - depth=depth[:, ch]) + self._plots[ch].set_data(x=t, y=w[:, :, ch], + color=color, + depth=depth[:, ch]) self.build() self.update() diff --git a/phy/plot/base.py b/phy/plot/base.py index b28baec49..8ab993ae7 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -231,7 +231,6 @@ def update(self): class BaseCanvas(Canvas): """A blank VisPy canvas with a custom event system that keeps the order.""" def __init__(self, *args, **kwargs): - kwargs['keys'] = 'interactive' super(BaseCanvas, self).__init__(*args, **kwargs) self._events = EventEmitter() self.interacts = [] diff --git a/phy/plot/plot.py b/phy/plot/plot.py index d49486385..e98b7d742 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -156,8 +156,8 @@ def set_data(self, **kwargs): class BaseView(BaseCanvas): """High-level plotting canvas.""" - def __init__(self, interacts): - super(BaseView, self).__init__() + def __init__(self, interacts, **kwargs): + super(BaseView, self).__init__(**kwargs) # Attach the passed interacts to the current canvas. for interact in interacts: interact.attach(self) @@ -302,27 +302,27 @@ def build(self): class GridView(BaseView): """A 2D grid with clipping.""" - def __init__(self, n_rows, n_cols): + def __init__(self, n_rows, n_cols, **kwargs): self.n_rows, self.n_cols = n_rows, n_cols pz = PanZoom(aspect=None, constrain_bounds=NDC) interacts = [Grid(n_rows, n_cols), pz] - super(GridView, self).__init__(interacts) + super(GridView, self).__init__(interacts, **kwargs) class BoxedView(BaseView): """Subplots at arbitrary positions""" - def __init__(self, box_bounds): + def __init__(self, box_bounds, **kwargs): self.n_plots = len(box_bounds) self._boxed = Boxed(box_bounds) self._pz = PanZoom(aspect=None, constrain_bounds=NDC) interacts = [self._boxed, self._pz] - super(BoxedView, self).__init__(interacts) + super(BoxedView, self).__init__(interacts, **kwargs) class StackedView(BaseView): """Stacked subplots""" - def __init__(self, n_plots): + def __init__(self, n_plots, **kwargs): self.n_plots = n_plots pz = PanZoom(aspect=None, constrain_bounds=NDC) interacts = [Stacked(n_plots, margin=.1), pz] - super(StackedView, self).__init__(interacts) + super(StackedView, self).__init__(interacts, **kwargs) From 7478a67f6c84af48d6b581f6dcf6c45e97be8922 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 15:24:47 +0100 Subject: [PATCH 0503/1059] Bug fixes --- phy/cluster/manual/gui_component.py | 2 ++ phy/gui/gui.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 646e8946d..768724e3a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -10,6 +10,7 @@ import logging import numpy as np +from six import string_types from ._history import GlobalHistory from ._utils import create_cluster_meta @@ -54,6 +55,7 @@ def _process_ups(ups): # pragma: no cover def _wizard_group(group): # The group should be None, 'mua', 'noise', or 'good'. + assert group is None or isinstance(group, string_types) group = group.lower() if group else group return _wizard_group_mapping.get(group, None) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index f25b9edfb..549d0af22 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -69,7 +69,7 @@ def __init__(self, raise RuntimeError("A Qt application must be created.") super(GUI, self).__init__() if title is None: - title = 'phy' + title = self.__class__.__name__ self.setWindowTitle(title) if position is not None: self.move(position[0], position[1]) @@ -119,7 +119,7 @@ def show(self): def add_view(self, view, - title='view', + title=None, position=None, closable=True, floatable=True, @@ -131,10 +131,13 @@ def add_view(self, try: from vispy.app import Canvas if isinstance(view, Canvas): + title = title or view.__class__.__name__ view = view.native except ImportError: # pragma: no cover pass + title = title or view.__class__.__name__ + # Create the gui widget. dockwidget = DockWidget(self) dockwidget.setObjectName(title) From 9309c3e2f30c4fc40b0f6609f34a0391ff28cb91 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 21:32:38 +0100 Subject: [PATCH 0504/1059] Bug fixes --- phy/cluster/manual/gui_component.py | 9 ++------- phy/io/array.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 768724e3a..6a7e44f36 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -226,13 +226,6 @@ def on_cluster(up): if self.gui: self.gui.emit('on_cluster', up) - @self.wizard.connect - def on_select(cluster_ids): - """When the wizard selects clusters, choose a spikes subset - and emit the `select` event on the GUI.""" - logger.debug("Select clusters %s.", - ', '.join(map(str, cluster_ids))) - _attach_wizard(self.wizard, self.clustering, self.cluster_meta) def _create_actions(self, gui): @@ -267,6 +260,8 @@ def on_select(cluster_ids): spike_ids = select_spikes(np.array(cluster_ids), self.n_spikes_max_per_cluster, self.clustering.spikes_per_cluster) + logger.debug("Select clusters: %s (%d spikes).", + ', '.join(map(str, cluster_ids)), len(spike_ids)) if self.gui: self.gui.emit('select', cluster_ids, spike_ids) diff --git a/phy/io/array.py b/phy/io/array.py index 0fdc2110c..a94f4fedf 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -391,7 +391,7 @@ def select_spikes(cluster_ids=None, # Decrease the number of spikes per cluster when there # are more clusters. n = max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1)) - n = int(np.clip(n, 1, n_clusters)) + n = int(max(1, n)) spikes = spikes_per_cluster[cluster] selection[cluster] = regular_subset(spikes, n_spikes_max=n) return _flatten_per_cluster(selection) From 03f7c5f8819f1805a079228f002222ec64b310f7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 22:04:20 +0100 Subject: [PATCH 0505/1059] Use float textures for better precision in views --- phy/plot/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index e4f133151..7cf1c5cdd 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -84,17 +84,17 @@ def _get_texture(arr, default, n_items, from_bounds): # Convert to 3D texture. arr = arr[np.newaxis, ...].astype(np.float32) assert arr.shape == (1, n_items, n_cols) - # NOTE: we need to cast the texture to [0, 255] (uint8). + # NOTE: we need to cast the texture to [0., 1.] (float texture). # This is easy as soon as we assume that the signal bounds are in # [-1, 1]. assert len(from_bounds) == 2 m, M = map(float, from_bounds) assert np.all(arr >= m) assert np.all(arr <= M) - arr = 255 * (arr - m) / (M - m) + arr = 1. * (arr - m) / (M - m) assert np.all(arr >= 0) - assert np.all(arr <= 255) - arr = arr.astype(np.uint8) + assert np.all(arr <= 1.) + arr = arr.astype(np.float32) return arr From b381a5d8e2574d543f76985271069bded81ba86d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 22:15:13 +0100 Subject: [PATCH 0506/1059] Fix bug with depth in plot --- phy/plot/plot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index e98b7d742..eb03ea527 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -112,11 +112,13 @@ def _build_plot(items): n = item.data.x.size ac['x'] = item.data.x ac['y'] = item.data.y + ac['depth'] = item.data.depth ac['plot_colors'] = item.data.color ac['box_index'] = _prepare_box_index(item.box_index, n) return (dict(x=ac['x'], y=ac['y'], plot_colors=ac['plot_colors'], + depth=ac['depth'], data_bounds=item.data.data_bounds, ), ac['box_index']) From 043491249e2f394987eb365e10be982451433436 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 29 Oct 2015 22:26:28 +0100 Subject: [PATCH 0507/1059] Add mask color in waveform view --- phy/cluster/manual/views.py | 22 ++++++++++++++-------- phy/plot/plot.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a0d80ed5a..0a05846d0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -10,6 +10,7 @@ import logging import numpy as np +from matplotlib.colors import hsv_to_rgb, rgb_to_hsv from phy.io.array import _index_of, _get_padded from phy.electrode.mea import linear_positions @@ -135,24 +136,29 @@ def on_select(self, cluster_ids, spike_ids): colors = _selected_clusters_colors(n_clusters) t = _get_linear_x(n_spikes, self.n_samples) - # Get the colors. - color = colors[spike_clusters_rel] - # Alpha channel. - color = np.c_[color, np.ones((n_spikes, 1))] - # Depth as a function of the cluster index and masks. m = self.masks[spike_ids] m = np.atleast_2d(m) assert m.ndim == 2 - depth = -0.1 - (spike_clusters_rel[:, np.newaxis] + m) + depth = (-0.1 - (spike_clusters_rel[:, np.newaxis] + m) / + float(n_clusters + 10.)) + depth[m <= 0.25] = 0 assert m.shape == (n_spikes, self.n_channels) assert depth.shape == (n_spikes, self.n_channels) - depth = depth / float(n_clusters + 10.) - depth[m <= 0.25] = 0 # Plot all waveforms. # TODO: optim: avoid the loop. for ch in range(self.n_channels): + + # Color as a function of the mask. + color = colors[spike_clusters_rel] + hsv = rgb_to_hsv(color[:, :3]) + # Change the saturation and value as a function of the mask. + hsv[:, 1] *= m[:, ch] + hsv[:, 2] *= .5 * (1. + m[:, ch]) + color = hsv_to_rgb(hsv) + color = np.c_[color, .5 * np.ones((n_spikes, 1))] + self._plots[ch].set_data(x=t, y=w[:, :, ch], color=color, depth=depth[:, ch]) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index eb03ea527..66e0f0d80 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -64,7 +64,7 @@ def _prepare_plot(x, y, color=None, depth=None, data_bounds=None): # Get the colors. color = _get_array(color, (n_plots, 4), PlotVisual._default_color) # Get the depth. - depth = _get_array(depth, (n_plots,), 0) + depth = _get_array(depth, (n_plots, 1), 0) return dict(x=x, y=y, color=color, depth=depth, data_bounds=data_bounds) From 870d903cab448424564e23595bde5a79e89b2bf3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 08:25:38 +0100 Subject: [PATCH 0508/1059] Bug fix --- phy/io/array.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/io/array.py b/phy/io/array.py index a94f4fedf..42c2628af 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -331,7 +331,7 @@ def _spikes_in_clusters(spike_clusters, clusters): def _spikes_per_cluster(spike_clusters, spike_ids=None): """Return a dictionary {cluster: list_of_spikes}.""" - if not len(spike_clusters): + if spike_clusters is None or not len(spike_clusters): return {} if spike_ids is None: spike_ids = np.arange(len(spike_clusters)).astype(np.int64) @@ -390,8 +390,8 @@ def select_spikes(cluster_ids=None, for cluster in cluster_ids: # Decrease the number of spikes per cluster when there # are more clusters. - n = max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1)) - n = int(max(1, n)) + n = int(max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1))) + n = max(1, n) spikes = spikes_per_cluster[cluster] selection[cluster] = regular_subset(spikes, n_spikes_max=n) return _flatten_per_cluster(selection) From c32ee76df7a278ad485693b658934077383368d3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 08:26:45 +0100 Subject: [PATCH 0509/1059] Bug fixes in wizard --- phy/cluster/manual/gui_component.py | 4 ++-- phy/cluster/manual/tests/test_wizard.py | 2 ++ phy/cluster/manual/wizard.py | 16 +++++++++++----- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 6a7e44f36..3adcb8771 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -162,7 +162,7 @@ class ManualClustering(object): default_shortcuts = { 'save': 'Save', # Wizard actions. - 'next': 'space', + 'next_by_quality': 'space', 'previous': 'shift+space', 'reset_wizard': 'ctrl+alt+space', 'first': 'MoveToStartOfLine', @@ -237,9 +237,9 @@ def _create_actions(self, gui): # Wizard. self.actions.add(self.wizard.restart, name='reset_wizard') self.actions.add(self.wizard.previous) - self.actions.add(self.wizard.next) self.actions.add(self.wizard.next_by_quality) self.actions.add(self.wizard.next_by_similarity) + self.actions.add(self.wizard.next) # no shortcut self.actions.add(self.wizard.pin) self.actions.add(self.wizard.unpin) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py index 00b2767f1..01a24825c 100644 --- a/phy/cluster/manual/tests/test_wizard.py +++ b/phy/cluster/manual/tests/test_wizard.py @@ -203,6 +203,8 @@ def test_wizard_next_1(wizard, status): w = wizard assert w.next_selection([30]) == [20] + + w.reset() assert w.next_selection([30], ignore_group=True) == [20] # After the last good, the best ignored. diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 19674b25f..2ff6f03e2 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -174,6 +174,7 @@ def __init__(self): self.reset() def reset(self): + self._selection = [] self._history = History([]) # Quality and status functions @@ -318,6 +319,7 @@ def _check_functions(self): def next_selection(self, cluster_ids=None, strategy=None, ignore_group=False): + """Make a new cluster selection according to a given strategy.""" self._check_functions() cluster_ids = cluster_ids or self._selection strategy = strategy or _best_quality_strategy @@ -329,11 +331,15 @@ def status(cluster): return self._cluster_status(cluster) else: status = self._cluster_status - self.select(strategy(cluster_ids, - cluster_ids=self._get_cluster_ids(), - quality=self._quality, - status=status, - similarity=self._similarity)) + new_selection = strategy(cluster_ids, + cluster_ids=self._get_cluster_ids(), + quality=self._quality, + status=status, + similarity=self._similarity) + # Skip new selection if it is the same. + if new_selection == self._selection: + return + self.select(new_selection) return self._selection def next_by_quality(self): From a1e2ff557b83608daef6f9b27439d949baf02e0d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 09:15:26 +0100 Subject: [PATCH 0510/1059] Change new_cluster_id behavior In the case of A1-undo-A2, any new cluster ids generated during A1 are discarded in A2, and brand new clusters ids are generated for A2. This permits memoization of functions that take cluster ids as arguments, and ensures that the cached functions are effectively invalidated in this case. --- phy/cluster/manual/clustering.py | 29 ++++++++++++++++----- phy/cluster/manual/tests/test_clustering.py | 12 ++++++--- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 883f69f3e..267bcc4a9 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -47,7 +47,11 @@ def _concatenate_spike_clusters(*pairs): return concat[:, 0].astype(np.int64), concat[:, 1].astype(np.int64) -def _extend_assignment(spike_ids, old_spike_clusters, spike_clusters_rel): +def _extend_assignment(spike_ids, + old_spike_clusters, + spike_clusters_rel, + new_cluster_id, + ): # 1. Add spikes that belong to modified clusters. # 2. Find new cluster ids for all changed clusters. @@ -59,7 +63,6 @@ def _extend_assignment(spike_ids, old_spike_clusters, spike_clusters_rel): assert spike_clusters_rel.min() >= 0 # We renumber the new cluster indices. - new_cluster_id = old_spike_clusters.max() + 1 new_spike_clusters = (spike_clusters_rel + (new_cluster_id - spike_clusters_rel.min())) @@ -200,10 +203,13 @@ def cluster_counts(self): def new_cluster_id(self): """Generate a brand new cluster id. - This is `maximum cluster id + 1`. + NOTE: This new id strictly increases after an undo + new action, + meaning that old cluster ids are *not* reused. This ensures that + any cluster_id-based cache will always be valid even after undo + operations (i.e. no need for explicit cache invalidation in this case). """ - return int(np.max(self.cluster_ids)) + 1 + return self._new_cluster_id @property def n_clusters(self): @@ -228,6 +234,9 @@ def spikes_in_clusters(self, clusters): #-------------------------------------------------------------------------- def _update_all_spikes_per_cluster(self): + # Reset the new cluster id. + self._new_cluster_id = self._spike_clusters.max() + 1 + # Update the spikes_per_cluster dict. self._spikes_per_cluster = _spikes_per_cluster(self._spike_clusters, self._spike_ids) @@ -270,6 +279,9 @@ def _do_assign(self, spike_ids, new_spike_clusters): old_spike_clusters, old_spikes_per_cluster, new_spike_clusters, new_spikes_per_cluster) + # We update the new cluster id (strictly increasing during a session). + self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1) + # We make the assignments. self._spike_clusters[spike_ids] = new_spike_clusters return up @@ -294,6 +306,9 @@ def _do_merge(self, spike_ids, cluster_ids, to): for cluster in cluster_ids: del self._spikes_per_cluster[cluster] + # We update the new cluster id (strictly increasing during a session). + self._new_cluster_id = max(max(up.added) + 1, self._new_cluster_id) + # Assign the clusters. self.spike_clusters[spike_ids] = to return up @@ -408,12 +423,14 @@ def assign(self, spike_ids, spike_clusters_rel=0): # Normalize the spike-cluster assignment such that # there are only new or dead clusters, not modified clusters. - # This implies that spikes not explicitely selected, but that + # This implies that spikes not explicitly selected, but that # belong to clusters affected by the operation, will be assigned # to brand new clusters. spike_ids, cluster_ids = _extend_assignment(spike_ids, self._spike_clusters, - spike_clusters_rel) + spike_clusters_rel, + self.new_cluster_id(), + ) up = self._do_assign(spike_ids, cluster_ids) undo_state = self.emit('request_undo_state', up) diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 091a3b8b7..c81c88607 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -84,7 +84,9 @@ def test_extend_assignment(): clusters_rel = [123] * len(spike_ids) new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, spike_clusters, - clusters_rel) + clusters_rel, + 10, + ) ae(new_spike_ids, [0, 2, 6]) ae(new_cluster_ids, [10, 10, 11]) @@ -92,7 +94,9 @@ def test_extend_assignment(): clusters_rel = [0, 1] new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, spike_clusters, - clusters_rel) + clusters_rel, + 10, + ) ae(new_spike_ids, [0, 2, 6]) ae(new_cluster_ids, [10, 11, 12]) @@ -330,7 +334,9 @@ def on_request_undo_state(up): _assert_is_checkpoint(3) # We merge again. - assert clustering.new_cluster_id() == 14 + # NOTE: 14 has been wasted, move to 15: necessary to avoid explicit cache + # invalidation when caching clusterid-based functions. + assert clustering.new_cluster_id() == 15 assert any(clustering.spike_clusters == 13) assert all(clustering.spike_clusters != 14) info = clustering.merge([8, 7], 15) From 9a414f9fee50b2060c3cf570a50f578ae15fa9e0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 12:55:42 +0100 Subject: [PATCH 0511/1059] WIP: fix bug in Context --- phy/cluster/manual/wizard.py | 2 -- phy/io/context.py | 7 +++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 2ff6f03e2..78eea7b37 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -161,8 +161,6 @@ class Wizard(EventEmitter): * The `next_*()` functions propose a new selection as a function of the current selection. - TODO: cache expensive functions. - """ def __init__(self): super(Wizard, self).__init__() diff --git a/phy/io/context.py b/phy/io/context.py index e0403baa7..c743a8d9f 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -143,11 +143,12 @@ class Context(object): def __init__(self, cache_dir, ipy_view=None): # Make sure the cache directory exists. - self.cache_dir = op.realpath(cache_dir) + self.cache_dir = op.realpath(op.expanduser(cache_dir)) if not op.exists(self.cache_dir): + logger.debug("Create cache directory `%s`.", self.cache_dir) os.makedirs(self.cache_dir) - self._set_memory(cache_dir) + self._set_memory(self.cache_dir) self.ipy_view = ipy_view if ipy_view else None def _set_memory(self, cache_dir): @@ -156,6 +157,8 @@ def _set_memory(self, cache_dir): from joblib import Memory joblib_cachedir = self._path('joblib') self._memory = Memory(cachedir=joblib_cachedir) + logger.debug("Initialize joblib cache dir at `%s`.", + joblib_cachedir) except ImportError: # pragma: no cover logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") From 2d63696f6b93096cec647560e96ffa3e9ee6f093 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 14:07:27 +0100 Subject: [PATCH 0512/1059] Waveform normalization --- phy/cluster/manual/views.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 0a05846d0..c6d95b276 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -100,6 +100,12 @@ def __init__(self, self.n_spikes, self.n_samples, self.n_channels = waveforms.shape self.waveforms = waveforms + # Waveform normalization. + n = waveforms.shape[0] + k = max(1, n // 1000) + m = np.abs(waveforms[::k]).max() + self.data_bounds = [-1, -m, +1, +m] + # Masks. self.masks = masks @@ -161,7 +167,9 @@ def on_select(self, cluster_ids, spike_ids): self._plots[ch].set_data(x=t, y=w[:, :, ch], color=color, - depth=depth[:, ch]) + depth=depth[:, ch], + data_bounds=self.data_bounds, + ) self.build() self.update() From 28c3ffdb454a30dcbbf83b7591030e0a1edb00ae Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 14:14:21 +0100 Subject: [PATCH 0513/1059] Update joblib location --- phy/io/context.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index c743a8d9f..f41832327 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -155,10 +155,9 @@ def _set_memory(self, cache_dir): # Try importing joblib. try: from joblib import Memory - joblib_cachedir = self._path('joblib') - self._memory = Memory(cachedir=joblib_cachedir) + self._memory = Memory(cachedir=self.cache_dir) logger.debug("Initialize joblib cache dir at `%s`.", - joblib_cachedir) + self.cache_dir) except ImportError: # pragma: no cover logger.warn("Joblib is not installed. " "Install it with `conda install joblib`.") @@ -176,7 +175,7 @@ def ipy_view(self, value): # Dill is necessary because we need to serialize closures. value.use_dill() - def _path(self, rel_path, *args, **kwargs): + def _path(self, rel_path='', *args, **kwargs): """Get the full path to a local cache resource.""" return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) From 5541b19c34507fef38e13bee92b5a35acd07bd7d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 14:59:12 +0100 Subject: [PATCH 0514/1059] Percentile waveform normalization --- phy/cluster/manual/gui_component.py | 2 +- phy/cluster/manual/views.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 3adcb8771..f1599056a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -168,7 +168,7 @@ class ManualClustering(object): 'first': 'MoveToStartOfLine', 'last': 'MoveToEndOfLine', 'pin': 'return', - 'unpin': 'Back', + 'unpin': 'Backspace', # Clustering actions. 'merge': 'g', 'split': 'k', diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index c6d95b276..5591743c7 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -75,6 +75,9 @@ def _extract_wave(traces, spk, mask, wave_len=None): # ----------------------------------------------------------------------------- class WaveformView(BoxedView): + normalization_percentile = .95 + normalization_n_spikes = 1000 + def __init__(self, waveforms=None, masks=None, @@ -102,8 +105,12 @@ def __init__(self, # Waveform normalization. n = waveforms.shape[0] - k = max(1, n // 1000) - m = np.abs(waveforms[::k]).max() + k = max(1, n // self.normalization_n_spikes) + w = np.abs(waveforms[::k]) + n = w.shape[0] + w = w.reshape((n, -1)) + w = w.max(axis=1) + m = np.percentile(w, self.normalization_percentile) self.data_bounds = [-1, -m, +1, +m] # Masks. From 629128574005b2caad5ea473933f61ab58ada88f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 30 Oct 2015 15:29:26 +0100 Subject: [PATCH 0515/1059] WIP: wizard navigation --- phy/cluster/manual/gui_component.py | 2 +- phy/cluster/manual/wizard.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index f1599056a..91ae9ade7 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -168,7 +168,7 @@ class ManualClustering(object): 'first': 'MoveToStartOfLine', 'last': 'MoveToEndOfLine', 'pin': 'return', - 'unpin': 'Backspace', + 'unpin': 'backspace', # Clustering actions. 'merge': 'g', 'split': 'k', diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py index 78eea7b37..d4f6c2bb0 100644 --- a/phy/cluster/manual/wizard.py +++ b/phy/cluster/manual/wizard.py @@ -270,10 +270,16 @@ def pin(self): if not candidates: # pragma: no cover return self.select([self.best, candidates[0]]) + # Clear the navigation history when pinning, such that `previous` + # keeps the pinned cluster selected. + self._history.clear() def unpin(self): if len(self._selection) == 2: self.select([self.selection[0]]) + # Clear the navigation history when unpinning, such that `previous` + # keeps the pinned cluster selected. + self._history.clear() # Navigation #-------------------------------------------------------------------------- From d2d68262a085c6e4aa0a81dd3318f84dc3fee285 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 31 Oct 2015 16:19:29 +0100 Subject: [PATCH 0516/1059] Remove static files and utils --- phy/cluster/manual/static/__init__.py | 0 phy/cluster/manual/static/styles.css | 56 --------------------------- phy/cluster/manual/static/wizard.html | 14 ------- phy/gui/_utils.py | 23 ----------- phy/gui/static/__init__.py | 0 phy/gui/static/styles.css | 0 phy/gui/static/wrap_qt.html | 23 ----------- phy/gui/tests/test_utils.py | 17 -------- 8 files changed, 133 deletions(-) delete mode 100644 phy/cluster/manual/static/__init__.py delete mode 100644 phy/cluster/manual/static/styles.css delete mode 100644 phy/cluster/manual/static/wizard.html delete mode 100644 phy/gui/_utils.py delete mode 100644 phy/gui/static/__init__.py delete mode 100644 phy/gui/static/styles.css delete mode 100644 phy/gui/static/wrap_qt.html delete mode 100644 phy/gui/tests/test_utils.py diff --git a/phy/cluster/manual/static/__init__.py b/phy/cluster/manual/static/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/cluster/manual/static/styles.css b/phy/cluster/manual/static/styles.css deleted file mode 100644 index 66ae9081d..000000000 --- a/phy/cluster/manual/static/styles.css +++ /dev/null @@ -1,56 +0,0 @@ -/* Wizard panel */ - -.control-panel { - margin: 0; - font-weight: bold; - font-size: 24pt; - padding: 10px; - text-align: center -} - -.control-panel > div { - display: inline-block; - margin: 0 auto; -} - -.control-panel .best { - margin-right: 20px; -} - -.control-panel .match { -} - -.control-panel > div .id { - margin: 10px 0 20px 0; - height: 40px; - text-align: center; - vertical-align: middle; - padding: 5px 0 10px 0; -} - -/* Progress bar */ -.control-panel progress[value] { - width: 200px; -} - -/* Cluster group */ -.control-panel .unsorted { -} - -.control-panel .good { - background-color: #243318; -} - -.control-panel .ignored { - background-color: #331d12; -} - -/* Stats panel */ - -.stats td, .stats th { - padding: 0 20px 0 0; -} - -.stats td { - text-align: right; -} diff --git a/phy/cluster/manual/static/wizard.html b/phy/cluster/manual/static/wizard.html deleted file mode 100644 index 6a38170d7..000000000 --- a/phy/cluster/manual/static/wizard.html +++ /dev/null @@ -1,14 +0,0 @@ -
-
-
{best}
-
- -
-
-
-
{match}
-
- -
-
-
diff --git a/phy/gui/_utils.py b/phy/gui/_utils.py deleted file mode 100644 index ca7ae0c64..000000000 --- a/phy/gui/_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -"""HTML/CSS utilities.""" - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -import os.path as op - - -# ----------------------------------------------------------------------------- -# Utilities -# ----------------------------------------------------------------------------- - -def _read(fn, static_path=None): - """Read a file in a static directory. - - By default, this is `./static/`.""" - if static_path is None: - static_path = op.join(op.dirname(op.realpath(__file__)), 'static') - with open(op.join(static_path, fn), 'r') as f: - return f.read() diff --git a/phy/gui/static/__init__.py b/phy/gui/static/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/gui/static/styles.css b/phy/gui/static/styles.css deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/gui/static/wrap_qt.html b/phy/gui/static/wrap_qt.html deleted file mode 100644 index fcad69ee7..000000000 --- a/phy/gui/static/wrap_qt.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - - -%HTML% - - - diff --git a/phy/gui/tests/test_utils.py b/phy/gui/tests/test_utils.py deleted file mode 100644 index 63db87902..000000000 --- a/phy/gui/tests/test_utils.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test HTML/CSS utilities.""" - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -from .._utils import _read - - -# ----------------------------------------------------------------------------- -# Utilities -# ----------------------------------------------------------------------------- - -def test_read(): - assert _read('wrap_qt.html') From 0cfcf8feeb8af9b04afdb3c3363cbdadd7f09588 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 31 Oct 2015 16:40:45 +0100 Subject: [PATCH 0517/1059] WIP: add widgets module --- phy/gui/tests/test_widgets.py | 20 +++++++++ phy/gui/widgets.py | 84 +++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 phy/gui/tests/test_widgets.py create mode 100644 phy/gui/widgets.py diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py new file mode 100644 index 000000000..9db76f817 --- /dev/null +++ b/phy/gui/tests/test_widgets.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +"""Test widgets.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from ..widgets import HTMLWidget + + +#------------------------------------------------------------------------------ +# Test actions +#------------------------------------------------------------------------------ + +def test_widget(qtbot): + widget = HTMLWidget() + widget.show() + qtbot.waitForWindowShown(widget) + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py new file mode 100644 index 000000000..5b042ffed --- /dev/null +++ b/phy/gui/widgets.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +"""HTML widgets for GUIs.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +import logging + +from .qt import QWebView + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Table +# ----------------------------------------------------------------------------- + +_DEFAULT_STYLES = """ + html, body, table { + background-color: black; + color: white; + font-family: sans-serif; + font-size: 18pt; + margin: 5px 10px; + } +""" + + +_PAGE_TEMPLATE = """ + + + {title:s} + + + + + +{body:s} + + + +""" + + +class HTMLWidget(QWebView): + title = 'Widget' + body = '' + + def __init__(self): + super(HTMLWidget, self).__init__() + self._styles = [_DEFAULT_STYLES] + self._scripts = [] + + def html(self): + return self.page().mainFrame().toHtml() + + def add_styles(self, s): + self._styles.append(s) + + def add_scripts(self, s): + self._scripts.append(s) + + def build(self): + styles = '\n\n'.join(self._styles) + scripts = '\n\n'.join(self._scripts) + html = _PAGE_TEMPLATE.format(title=self.title, + styles=styles, + scripts=scripts, + body=self.body, + ) + self.setHtml(html) + + def show(self): + # Build if no HTML has been set. + if self.html() == '': + self.build() + return super(HTMLWidget, self).show() From 1d543d3cef5f06a4384ec5915fb19691f7f828ab Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 31 Oct 2015 16:45:53 +0100 Subject: [PATCH 0518/1059] WIP: widgets --- phy/gui/tests/test_widgets.py | 12 +++++++++++- phy/gui/widgets.py | 19 ++++++++++--------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 9db76f817..f8feda05c 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -13,8 +13,18 @@ # Test actions #------------------------------------------------------------------------------ -def test_widget(qtbot): +def test_widget_empty(qtbot): widget = HTMLWidget() widget.show() qtbot.waitForWindowShown(widget) # qtbot.stop() + + +def test_widget_html(qtbot): + widget = HTMLWidget() + widget.add_styles('html, body, p {background-color: purple;}') + widget.add_header('') + widget.set_body('Hello world!') + widget.show() + qtbot.waitForWindowShown(widget) + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 5b042ffed..3b00b2427 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -36,9 +36,7 @@ - + {header:s} @@ -56,7 +54,8 @@ class HTMLWidget(QWebView): def __init__(self): super(HTMLWidget, self).__init__() self._styles = [_DEFAULT_STYLES] - self._scripts = [] + self._header = '' + self._body = '' def html(self): return self.page().mainFrame().toHtml() @@ -64,16 +63,18 @@ def html(self): def add_styles(self, s): self._styles.append(s) - def add_scripts(self, s): - self._scripts.append(s) + def add_header(self, h): + self._header += (h + '\n') + + def set_body(self, s): + self._body = s def build(self): styles = '\n\n'.join(self._styles) - scripts = '\n\n'.join(self._scripts) html = _PAGE_TEMPLATE.format(title=self.title, styles=styles, - scripts=scripts, - body=self.body, + header=self._header, + body=self._body, ) self.setHtml(html) From 2bd71db071f67a472a7b6d56f0d0dbea92373eeb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 31 Oct 2015 17:22:20 +0100 Subject: [PATCH 0519/1059] WIP: sort table --- phy/gui/qt.py | 4 +-- phy/gui/static/tablesort.min.js | 5 +++ phy/gui/tests/test_widgets.py | 30 ++++++++++++++++- phy/gui/widgets.py | 60 +++++++++++++++++++++++++++++++-- 4 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 phy/gui/static/tablesort.min.js diff --git a/phy/gui/qt.py b/phy/gui/qt.py index e7d0cbc9c..734c6679e 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -18,12 +18,12 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - pyqtSignal, QSize) + pyqtSignal, QSize, QUrl) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, QMessageBox, QApplication, ) -from PyQt4.QtWebKit import QWebView # noqa +from PyQt4.QtWebKit import QWebView, QWebSettings # noqa # ----------------------------------------------------------------------------- diff --git a/phy/gui/static/tablesort.min.js b/phy/gui/static/tablesort.min.js new file mode 100644 index 000000000..33ded4fbb --- /dev/null +++ b/phy/gui/static/tablesort.min.js @@ -0,0 +1,5 @@ +/*! + * tablesort v3.1.0 (2015-07-03) + * http://tristen.ca/tablesort/demo/ + * Copyright (c) 2015 ; Licensed MIT +*/!function(){function a(b,c){if(!(this instanceof a))return new a(b,c);if(!b||"TABLE"!==b.tagName)throw new Error("Element must be a table");this.init(b,c||{})}var b=[],c=function(a){var b;return window.CustomEvent&&"function"==typeof window.CustomEvent?b=new CustomEvent(a):(b=document.createEvent("CustomEvent"),b.initCustomEvent(a,!1,!1,void 0)),b},d=function(a){return a.getAttribute("data-sort")||a.textContent||a.innerText||""},e=function(a,b){return a=a.toLowerCase(),b=b.toLowerCase(),a===b?0:b>a?1:-1},f=function(a,b){return function(c,d){var e=a(c.td,d.td);return 0===e?b?d.index-c.index:c.index-d.index:e}};a.extend=function(a,c,d){if("function"!=typeof c||"function"!=typeof d)throw new Error("Pattern and sort must be a function");b.push({name:a,pattern:c,sort:d})},a.prototype={init:function(a,b){var c,d,e,f,g=this;if(g.table=a,g.thead=!1,g.options=b,a.rows&&a.rows.length>0&&(a.tHead&&a.tHead.rows.length>0?(c=a.tHead.rows[a.tHead.rows.length-1],g.thead=!0):c=a.rows[0]),c){var h=function(){g.current&&g.current!==this&&(g.current.classList.remove("sort-up"),g.current.classList.remove("sort-down")),g.current=this,g.sortTable(this)};for(e=0;e0&&m.push(l),n++;if(!m)return}for(n=0;nn;n++)r[n]?(l=r[n],t++):l=q[n-t].tr,i.table.tBodies[0].appendChild(l);i.table.dispatchEvent(c("afterSort"))}},refresh:function(){void 0!==this.current&&this.sortTable(this.current,!0)}},"undefined"!=typeof module&&module.exports?module.exports=a:window.Tablesort=a}(); \ No newline at end of file diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index f8feda05c..7f5f70e13 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from ..widgets import HTMLWidget +from ..widgets import HTMLWidget, Table #------------------------------------------------------------------------------ @@ -25,6 +25,34 @@ def test_widget_html(qtbot): widget.add_styles('html, body, p {background-color: purple;}') widget.add_header('') widget.set_body('Hello world!') + widget.show() qtbot.waitForWindowShown(widget) # qtbot.stop() + + +def test_table(qtbot): + table = Table() + + table.set_body(""" + + + + + + + + + + +
idcount
120
210
330
+ + + + """) + + table.show() + qtbot.waitForWindowShown(table) + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 3b00b2427..6c95c6af7 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -8,14 +8,15 @@ # ----------------------------------------------------------------------------- import logging +import os.path as op -from .qt import QWebView +from .qt import QWebView, QUrl, QWebSettings logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- -# Table +# HTML widget # ----------------------------------------------------------------------------- _DEFAULT_STYLES = """ @@ -53,6 +54,8 @@ class HTMLWidget(QWebView): def __init__(self): super(HTMLWidget, self).__init__() + self.settings().setAttribute( + QWebSettings.LocalContentCanAccessRemoteUrls, True) self._styles = [_DEFAULT_STYLES] self._header = '' self._body = '' @@ -63,6 +66,9 @@ def html(self): def add_styles(self, s): self._styles.append(s) + def add_script_src(self, filename): + self.add_header(''.format(filename)) + def add_header(self, h): self._header += (h + '\n') @@ -76,10 +82,58 @@ def build(self): header=self._header, body=self._body, ) - self.setHtml(html) + logger.log(5, "Set HTML: %s", html) + static_dir = op.join(op.realpath(op.dirname(__file__)), 'static/') + base_url = QUrl().fromLocalFile(static_dir) + self.setHtml(html, base_url) def show(self): # Build if no HTML has been set. if self.html() == '': self.build() return super(HTMLWidget, self).show() + + +# ----------------------------------------------------------------------------- +# HTML table +# ----------------------------------------------------------------------------- + +_TABLE_STYLE = r""" + +th.sort-header::-moz-selection { background:transparent; } +th.sort-header::selection { background:transparent; } +th.sort-header { cursor:pointer; } + +table th.sort-header:after { + content: "\25B2"; + float: right; + margin-left: 5px; + margin-right: 5px; + visibility: hidden; +} + +table th.sort-header:hover:after { + visibility: visible; +} + +table th.sort-up:after { + content: "\25BC"; +} +table th.sort-down:after { + content: "\25B2"; +} + +table th.sort-up:after, +table th.sort-down:after, +table th.sort-down:hover:after { + visibility: visible; +} + +""" + + +class Table(HTMLWidget): + def __init__(self): + super(Table, self).__init__() + self.add_styles(_TABLE_STYLE) + self.add_script_src('tablesort.min.js') From 0a6da0a2432cd6026b601c8dba68789cbf014a8d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 1 Nov 2015 05:33:18 +0100 Subject: [PATCH 0520/1059] Update table styles --- phy/gui/widgets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 6c95c6af7..87a125b32 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -101,19 +101,19 @@ def show(self): _TABLE_STYLE = r""" th.sort-header::-moz-selection { background:transparent; } -th.sort-header::selection { background:transparent; } +th.sort-header::selection { background:transparent; } th.sort-header { cursor:pointer; } table th.sort-header:after { - content: "\25B2"; - float: right; - margin-left: 5px; - margin-right: 5px; - visibility: hidden; + content: "\25B2"; + float: right; + margin-left: 3px; + margin-right: 15px; + visibility: hidden; } table th.sort-header:hover:after { - visibility: visible; + visibility: visible; } table th.sort-up:after { @@ -126,7 +126,7 @@ def show(self): table th.sort-up:after, table th.sort-down:after, table th.sort-down:hover:after { - visibility: visible; + visibility: visible; } """ From 9e17fa2f7a9b400a8843dd510b87555d15a85fc8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 1 Nov 2015 06:09:35 +0100 Subject: [PATCH 0521/1059] Add Python-Javascript communication in HTML widget --- phy/gui/qt.py | 2 +- phy/gui/tests/test_widgets.py | 2 +- phy/gui/widgets.py | 21 ++++++++++++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 734c6679e..c83f30ef5 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -18,7 +18,7 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - pyqtSignal, QSize, QUrl) + pyqtSignal, pyqtSlot, QSize, QUrl) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, QMessageBox, QApplication, diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 7f5f70e13..0e0068cdf 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -25,7 +25,7 @@ def test_widget_html(qtbot): widget.add_styles('html, body, p {background-color: purple;}') widget.add_header('') widget.set_body('Hello world!') - + widget.eval_js('widget.set_body("Hello from Javascript!");') widget.show() qtbot.waitForWindowShown(widget) # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 87a125b32..e6608c59b 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -10,7 +10,7 @@ import logging import os.path as op -from .qt import QWebView, QUrl, QWebSettings +from .qt import QWebView, QUrl, QWebSettings, pyqtSlot logger = logging.getLogger(__name__) @@ -49,6 +49,13 @@ class HTMLWidget(QWebView): + """An HTML widget that is displayed with Qt. + + Python methods can be called from Javascript with `widget.the_method()`. + They must be decorated with `pyqtSlot(str)` or similar, depending on + the parameters. + + """ title = 'Widget' body = '' @@ -59,6 +66,17 @@ def __init__(self): self._styles = [_DEFAULT_STYLES] self._header = '' self._body = '' + self.add_to_js('widget', self) + + def add_to_js(self, name, var): + """Add an object to Javascript.""" + frame = self.page().mainFrame() + frame.addToJavaScriptWindowObject(name, var) + + def eval_js(self, expr): + """Evaluate a Javascript expression.""" + frame = self.page().mainFrame() + frame.evaluateJavaScript(expr) def html(self): return self.page().mainFrame().toHtml() @@ -72,6 +90,7 @@ def add_script_src(self, filename): def add_header(self, h): self._header += (h + '\n') + @pyqtSlot(str) def set_body(self, s): self._body = s From d348b5818ba8d4ead31b6787a041b08a0cb073a6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 12:26:56 +0100 Subject: [PATCH 0522/1059] WIP: table widget --- MANIFEST.in | 2 +- phy/gui/qt.py | 2 +- phy/gui/static/table.css | 59 +++++++++ phy/gui/static/table.js | 222 ++++++++++++++++++++++++++++++++++ phy/gui/tests/test_widgets.py | 32 +++-- phy/gui/widgets.py | 148 ++++++++++++++++------- requirements-dev.txt | 4 +- setup.py | 3 +- 8 files changed, 404 insertions(+), 68 deletions(-) create mode 100644 phy/gui/static/table.css create mode 100644 phy/gui/static/table.js diff --git a/MANIFEST.in b/MANIFEST.in index 2531a5bd8..f42e27c56 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,6 +5,6 @@ recursive-include tests * recursive-include phy/electrode/probes *.prb recursive-include phy/plot/glsl *.vert *.frag *.glsl recursive-include phy/cluster/manual/static *.html *.css -recursive-include phy/gui/static *.html *.css +recursive-include phy/gui/static *.html *.css *.js recursive-exclude * __pycache__ recursive-exclude * *.py[co] diff --git a/phy/gui/qt.py b/phy/gui/qt.py index c83f30ef5..40c58ef26 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -23,7 +23,7 @@ QMainWindow, QDockWidget, QWidget, QMessageBox, QApplication, ) -from PyQt4.QtWebKit import QWebView, QWebSettings # noqa +from PyQt4.QtWebKit import QWebView, QWebPage, QWebSettings # noqa # ----------------------------------------------------------------------------- diff --git a/phy/gui/static/table.css b/phy/gui/static/table.css new file mode 100644 index 000000000..001a20e88 --- /dev/null +++ b/phy/gui/static/table.css @@ -0,0 +1,59 @@ + +th.sort-header::-moz-selection { background:transparent; } +th.sort-header::selection { background:transparent; } +th.sort-header { cursor:pointer; } + +table { + border-collapse: collapse; +} + +table td { + padding: 5px 10px; + margin: 0; +} + +table th.sort-header:after { + content: "\25B2"; + margin-left: 5px; + margin-right: 15px; + visibility: hidden; +} + +table th.sort-header:hover:after { + visibility: visible; +} + +table th.sort-up:after { + content: "\25BC"; +} +table th.sort-down:after { + content: "\25B2"; +} + +table th.sort-up:after, +table th.sort-down:after, +table th.sort-down:hover:after { + visibility: visible; +} + +table tr { cursor:pointer; } + +table tr:hover { + background-color: #222; +} + +table tr:hover th { + background-color: #000; +} + +table tr.selected td { + background-color: #444; +} + +table tr.pinned { + background-color: #888; +} + +table tr[data-skip='true'] { + color: #888; +} diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js new file mode 100644 index 000000000..db5a53208 --- /dev/null +++ b/phy/gui/static/table.js @@ -0,0 +1,222 @@ + +var Table = function (el) { + this.el = el; + this.state = { + sortCol: null, + sortDir: null, + selected: [], + pinned: [], + } + this.headers = {}; // {name: th} mapping + this.rows = {}; // {id: tr} mapping + this.tablesort = null; +}; + +Table.prototype.setData = function(data) { + if (data.items.length == 0) return; + var that = this; + var keys = data.cols; + + var thead = document.createElement("thead"); + var tbody = document.createElement("tbody"); + + // Header. + var tr = document.createElement("tr"); + for (var j = 0; j < keys.length; j++) { + var key = keys[j]; + var th = document.createElement("th"); + th.appendChild(document.createTextNode(key)); + tr.appendChild(th); + this.headers[key] = th; + } + thead.appendChild(tr); + + // Data rows. + for (var i = 0; i < data.items.length; i++) { + tr = document.createElement("tr"); + var row = data.items[i]; + for (var j = 0; j < keys.length; j++) { + var key = keys[j]; + var value = row[key]; + var td = document.createElement("td"); + td.appendChild(document.createTextNode(value)); + tr.appendChild(td); + } + + // Set the data values on the row. + for (var key in row) { + tr.dataset[key] = row[key]; + } + + tr.onclick = function(e) { + var selected = [parseInt(this.dataset.id)]; + + var evt = e ? e:window.event; + if (evt.ctrlKey || evt.metaKey) { + selected = that.state.selected.concat(selected); + } + that.select(selected); + } + + tbody.appendChild(tr); + this.rows[data.items[i].id] = tr; + } + + this.el.appendChild(thead); + this.el.appendChild(tbody); + + // Enable the tablesort plugin. + this.tablesort = new Tablesort(this.el); + + // Synchronize the state. + var that = this; + this.el.addEventListener('afterSort', function() { + for (var header in that.headers) { + if (that.headers[header].classList.contains('sort-up')) { + that.state.sortCol = header; + that.state.sortDir = 'desc'; + break; + } + if (that.headers[header].classList.contains('sort-down')) { + that.state.sortCol = header; + that.state.sortDir = 'asc'; + break; + } + } + }); +}; + +Table.prototype.setState = function(state) { + + // Make sure both sortCol and sortDir are specified. + if (!('sortCol' in state) && ('sortDir' in state)) { + state.sortCol = this.state.sortCol; + } + if (!('sortDir' in state) && ('sortCol' in state)) { + state.sortDir = this.state.sortDir; + } + + if ('sortCol' in state) { + + // Update the state. + this.state.sortCol = state.sortCol; + this.state.sortDir = state.sortDir; + + // Remove all sorts. + for (var h in this.headers) { + this.headers[h].classList.remove('sort-up'); + this.headers[h].classList.remove('sort-down'); + } + + // Set the sort direction in the header class. + var header = this.headers[state.sortCol]; + header.classList.add(state.sortDir === 'desc' ? + 'sort-down' : 'sort-up'); + + // Sort the table. + this.tablesort.sortTable(header); + } + if ('selected' in state) { + this.setRowClass('selected', state.selected); + this.state.selected = state.selected.map(parseInt); + } + if ('pinned' in state) { + this.setRowClass('pinned', state.pinned); + this.state.pinned = state.pinned.map(parseInt); + } +}; + +Table.prototype.setRowClass = function(name, ids) { + // Remove the class on all rows. + for (var i = 0; i < this.state[name].length; i++) { + var id = this.state[name][i]; + var row = this.rows[id]; + row.classList.remove(name); + } + + // Add the class. + for (var i = 0; i < ids.length; i++) { + var id = ids[i]; + this.rows[id].classList.add(name); + } +}; + +Table.prototype.stateUpdated = function() { + // TODO: call widget.setState(this.state); +}; + +Table.prototype.getState = function() { + return this.state; +} + +Table.prototype.clear = function() { + this.setState({ + selected: [], + pinned: [], + }); +}; + +Table.prototype.select = function(items) { + this.setState({ + selected: items, + }); +}; + +Table.prototype.pin = function(items) { + this.setState({ + pinned: items, + }); +}; + +Table.prototype.unpin = function() { + this.setState({ + pinned: [], + }); +}; + +Table.prototype.next = function() { + if (this.state.selected.length != 1) return; + var id = this.state.selected[0]; + var row = this.rows[id]; + var i0 = row.rowIndex + 1; + var items = []; + + for (var i = i0; i < this.el.rows.length; i++) { + row = this.el.rows[i]; + if (!(row.dataset.skip)) { + items.push(row.dataset.id); + break; + } + } + + if (!(items.length)) return; + + // TODO: keep the pinned + this.setState({ + selected: items, + }); +}; + +Table.prototype.previous = function() { + if (this.state.selected.length != 1) return; + var id = this.state.selected[0]; + var row = this.rows[id]; + var i0 = row.rowIndex - 1; + var items = []; + + // NOTE: i >= 1 because we skip the header column. + for (var i = i0; i >= 1; i--) { + row = this.el.rows[i]; + if (!(row.dataset.skip)) { + items.push(row.dataset.id); + break; + } + } + + if (!(items.length)) return; + + // TODO: keep the pinned + this.setState({ + selected: items, + }); +}; diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 0e0068cdf..0a9f9c875 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -34,25 +34,21 @@ def test_widget_html(qtbot): def test_table(qtbot): table = Table() - table.set_body(""" + table.show() + qtbot.waitForWindowShown(table) - - - - - - - - - -
idcount
120
210
330
+ items = [{'id': i, 'count': 10 * i} for i in range(10)] + items[4]['skip'] = True - + table.set_data(cols=['id', 'count'], + items=items, + ) + table.select([4]) - """) + assert table.pinned == [] - table.show() - qtbot.waitForWindowShown(table) - # qtbot.stop() + table.next() + assert table.selected == [5] + + table.previous() + assert table.selected == [3] diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index e6608c59b..eec2dfb48 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -7,10 +7,11 @@ # Imports # ----------------------------------------------------------------------------- +import json import logging import os.path as op -from .qt import QWebView, QUrl, QWebSettings, pyqtSlot +from .qt import QWebView, QWebPage, QUrl, QWebSettings, pyqtSlot logger = logging.getLogger(__name__) @@ -48,6 +49,11 @@ """ +class WebPage(QWebPage): + def javaScriptConsoleMessage(self, msg, line, source): + logger.debug(msg) # pragma: no cover + + class HTMLWidget(QWebView): """An HTML widget that is displayed with Qt. @@ -63,37 +69,44 @@ def __init__(self): super(HTMLWidget, self).__init__() self.settings().setAttribute( QWebSettings.LocalContentCanAccessRemoteUrls, True) + self.settings().setAttribute( + QWebSettings.DeveloperExtrasEnabled, True) + self.setPage(WebPage()) + self._obj = None self._styles = [_DEFAULT_STYLES] self._header = '' self._body = '' self.add_to_js('widget', self) - def add_to_js(self, name, var): - """Add an object to Javascript.""" - frame = self.page().mainFrame() - frame.addToJavaScriptWindowObject(name, var) - - def eval_js(self, expr): - """Evaluate a Javascript expression.""" - frame = self.page().mainFrame() - frame.evaluateJavaScript(expr) - - def html(self): - return self.page().mainFrame().toHtml() + # Headers + # ------------------------------------------------------------------------- def add_styles(self, s): self._styles.append(s) + def add_style_src(self, filename): + self.add_header(('').format(filename)) + def add_script_src(self, filename): self.add_header(''.format(filename)) def add_header(self, h): self._header += (h + '\n') + # HTML methods + # ------------------------------------------------------------------------- + @pyqtSlot(str) def set_body(self, s): self._body = s + def add_body(self, s): + self._body += '\n' + s + '\n' + + def html(self): + return self.page().mainFrame().toHtml() + def build(self): styles = '\n\n'.join(self._styles) html = _PAGE_TEMPLATE.format(title=self.title, @@ -106,6 +119,29 @@ def build(self): base_url = QUrl().fromLocalFile(static_dir) self.setHtml(html, base_url) + # Javascript methods + # ------------------------------------------------------------------------- + + def add_to_js(self, name, var): + """Add an object to Javascript.""" + frame = self.page().mainFrame() + frame.addToJavaScriptWindowObject(name, var) + + def eval_js(self, expr): + """Evaluate a Javascript expression.""" + frame = self.page().mainFrame() + frame.evaluateJavaScript(expr) + + @pyqtSlot(str) + def _set_from_js(self, obj): + self._obj = json.loads(obj) + + def get_js(self, expr): + self.eval_js('widget._set_from_js(JSON.stringify({}));'.format(expr)) + obj = self._obj + self._obj = None + return obj + def show(self): # Build if no HTML has been set. if self.html() == '': @@ -117,42 +153,62 @@ def show(self): # HTML table # ----------------------------------------------------------------------------- -_TABLE_STYLE = r""" - -th.sort-header::-moz-selection { background:transparent; } -th.sort-header::selection { background:transparent; } -th.sort-header { cursor:pointer; } - -table th.sort-header:after { - content: "\25B2"; - float: right; - margin-left: 3px; - margin-right: 15px; - visibility: hidden; -} - -table th.sort-header:hover:after { - visibility: visible; -} - -table th.sort-up:after { - content: "\25BC"; -} -table th.sort-down:after { - content: "\25B2"; -} - -table th.sort-up:after, -table th.sort-down:after, -table th.sort-down:hover:after { - visibility: visible; -} - -""" +def _create_json_dict(**kwargs): + d = {} + for k, v in kwargs.items(): + if v is not None: + d[k] = v + return json.dumps(d) class Table(HTMLWidget): + + _table_id = 'the-table' + def __init__(self): super(Table, self).__init__() - self.add_styles(_TABLE_STYLE) + self.add_style_src('table.css') self.add_script_src('tablesort.min.js') + self.add_script_src('table.js') + self.set_body('
'.format( + self._table_id)) + self.add_body(''''''.format(self._table_id)) + self.build() + + def set_data(self, items, cols): + data = _create_json_dict(items=items, + cols=cols, + ) + self.eval_js('table.setData({});'.format(data)) + + def set_state(self, selected=None, pinned=None, + sort_col=None, sort_dir=None): + state = _create_json_dict(selected=selected, + pinned=pinned, + sortCol=sort_col, + sortDir=sort_dir, + ) + self.eval_js('table.setState({});'.format(state)) + + def next(self): + self.eval_js('table.next();') + + def previous(self): + self.eval_js('table.previous();') + + def select(self, ids): + self.set_state(selected=ids) + + @property + def state(self): + return self.get_js("table.getState()") + + @property + def selected(self): + return self.state.get('selected', []) + + @property + def pinned(self): + return self.state.get('pinned', []) diff --git a/requirements-dev.txt b/requirements-dev.txt index 06845e29a..d4e0c5384 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,10 +2,12 @@ # https://github.com/pytest-dev/pytest/issues/744 git+https://github.com/pytest-dev/pytest.git +# Need development version of pytest-qt for qtbot.wait() method +git+https://github.com/pytest-dev/pytest-qt.git + flake8 coverage==3.7.1 coveralls responses pytest-cov -pytest-qt nose diff --git a/setup.py b/setup.py index 0dd232d4f..be69360bb 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,8 @@ def _package_tree(pkgroot): packages=_package_tree('phy'), package_dir={'phy': 'phy'}, package_data={ - 'phy': ['*.vert', '*.frag', '*.glsl', '*.html', '*.css', '*.prb'], + 'phy': ['*.vert', '*.frag', '*.glsl', + '*.html', '*.css', '*.js', '*.prb'], }, entry_points={ 'console_scripts': [ From 632aefc34d73c8c8e9c56e94065fd2d905e3fb8b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 12:37:14 +0100 Subject: [PATCH 0523/1059] Add comments --- phy/gui/tests/test_widgets.py | 9 +++++++++ phy/gui/widgets.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 0a9f9c875..8ebf7dda5 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -31,6 +31,15 @@ def test_widget_html(qtbot): # qtbot.stop() +def test_widget_javascript(qtbot): + widget = HTMLWidget() + widget.show() + qtbot.waitForWindowShown(widget) + widget.eval_js('number = 1;') + assert widget.get_js('number') == 1 + # qtbot.stop() + + def test_table(qtbot): table = Table() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index eec2dfb48..148bf1a2e 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -82,16 +82,20 @@ def __init__(self): # ------------------------------------------------------------------------- def add_styles(self, s): + """Add CSS styles.""" self._styles.append(s) def add_style_src(self, filename): + """Link a CSS file.""" self.add_header(('').format(filename)) def add_script_src(self, filename): + """Link a JS script.""" self.add_header(''.format(filename)) def add_header(self, h): + """Add HTML code to the header.""" self._header += (h + '\n') # HTML methods @@ -99,15 +103,19 @@ def add_header(self, h): @pyqtSlot(str) def set_body(self, s): + """Set the HTML body.""" self._body = s def add_body(self, s): + """Add HTML code to the body.""" self._body += '\n' + s + '\n' def html(self): + """Return the full HTML source of the widget.""" return self.page().mainFrame().toHtml() def build(self): + """Build the full HTML source.""" styles = '\n\n'.join(self._styles) html = _PAGE_TEMPLATE.format(title=self.title, styles=styles, @@ -134,15 +142,26 @@ def eval_js(self, expr): @pyqtSlot(str) def _set_from_js(self, obj): + """Called from Javascript to pass any object to Python through JSON.""" self._obj = json.loads(obj) def get_js(self, expr): + """Evaluate a Javascript expression and get a Python object. + + This uses JSON serialization/deserialization under the hood. + + """ self.eval_js('widget._set_from_js(JSON.stringify({}));'.format(expr)) obj = self._obj self._obj = None return obj def show(self): + """Show the widget. + + A build is triggered if necessary. + + """ # Build if no HTML has been set. if self.html() == '': self.build() @@ -162,6 +181,7 @@ def _create_json_dict(**kwargs): class Table(HTMLWidget): + """A sortable table with support for selection and pinning.""" _table_id = 'the-table' @@ -178,6 +198,11 @@ def __init__(self): self.build() def set_data(self, items, cols): + """Set the rows and cols of the table. + + TODO: rename items to rows. + + """ data = _create_json_dict(items=items, cols=cols, ) @@ -185,6 +210,7 @@ def set_data(self, items, cols): def set_state(self, selected=None, pinned=None, sort_col=None, sort_dir=None): + """Set the state of the widget.""" state = _create_json_dict(selected=selected, pinned=pinned, sortCol=sort_col, @@ -193,22 +219,28 @@ def set_state(self, selected=None, pinned=None, self.eval_js('table.setState({});'.format(state)) def next(self): + """Select the next non-skip row.""" self.eval_js('table.next();') def previous(self): + """Select the previous non-skip row.""" self.eval_js('table.previous();') def select(self, ids): + """Select some rows.""" self.set_state(selected=ids) @property def state(self): + """Current state.""" return self.get_js("table.getState()") @property def selected(self): + """Currently selected rows.""" return self.state.get('selected', []) @property def pinned(self): + """Currently pinned rows.""" return self.state.get('pinned', []) From a7d73d112018fb84beec8b5ec256f18438543e4b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 13:01:16 +0100 Subject: [PATCH 0524/1059] Python 2 fix --- phy/gui/widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 148bf1a2e..4d675655f 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -143,7 +143,7 @@ def eval_js(self, expr): @pyqtSlot(str) def _set_from_js(self, obj): """Called from Javascript to pass any object to Python through JSON.""" - self._obj = json.loads(obj) + self._obj = json.loads(str(obj)) def get_js(self, expr): """Evaluate a Javascript expression and get a Python object. From 0d81d96e99e39679713fe81b3e0d35cf8e703f6a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 14:01:14 +0100 Subject: [PATCH 0525/1059] Bug fix in table --- phy/gui/static/table.js | 24 ++---------------------- phy/gui/tests/test_widgets.py | 4 ++-- phy/gui/widgets.py | 21 +++++---------------- 3 files changed, 9 insertions(+), 40 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index db5a53208..edec844fb 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -5,7 +5,6 @@ var Table = function (el) { sortCol: null, sortDir: null, selected: [], - pinned: [], } this.headers = {}; // {name: th} mapping this.rows = {}; // {id: tr} mapping @@ -49,7 +48,7 @@ Table.prototype.setData = function(data) { } tr.onclick = function(e) { - var selected = [parseInt(this.dataset.id)]; + var selected = [this.dataset.id]; var evt = e ? e:window.event; if (evt.ctrlKey || evt.metaKey) { @@ -118,11 +117,7 @@ Table.prototype.setState = function(state) { } if ('selected' in state) { this.setRowClass('selected', state.selected); - this.state.selected = state.selected.map(parseInt); - } - if ('pinned' in state) { - this.setRowClass('pinned', state.pinned); - this.state.pinned = state.pinned.map(parseInt); + this.state.selected = state.selected; } }; @@ -152,7 +147,6 @@ Table.prototype.getState = function() { Table.prototype.clear = function() { this.setState({ selected: [], - pinned: [], }); }; @@ -162,18 +156,6 @@ Table.prototype.select = function(items) { }); }; -Table.prototype.pin = function(items) { - this.setState({ - pinned: items, - }); -}; - -Table.prototype.unpin = function() { - this.setState({ - pinned: [], - }); -}; - Table.prototype.next = function() { if (this.state.selected.length != 1) return; var id = this.state.selected[0]; @@ -191,7 +173,6 @@ Table.prototype.next = function() { if (!(items.length)) return; - // TODO: keep the pinned this.setState({ selected: items, }); @@ -215,7 +196,6 @@ Table.prototype.previous = function() { if (!(items.length)) return; - // TODO: keep the pinned this.setState({ selected: items, }); diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 8ebf7dda5..81d76156a 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -54,10 +54,10 @@ def test_table(qtbot): ) table.select([4]) - assert table.pinned == [] - table.next() assert table.selected == [5] table.previous() assert table.selected == [3] + + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 4d675655f..700957aec 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -181,7 +181,7 @@ def _create_json_dict(**kwargs): class Table(HTMLWidget): - """A sortable table with support for selection and pinning.""" + """A sortable table with support for selection.""" _table_id = 'the-table' @@ -198,21 +198,15 @@ def __init__(self): self.build() def set_data(self, items, cols): - """Set the rows and cols of the table. - - TODO: rename items to rows. - - """ + """Set the rows and cols of the table.""" data = _create_json_dict(items=items, cols=cols, ) self.eval_js('table.setData({});'.format(data)) - def set_state(self, selected=None, pinned=None, - sort_col=None, sort_dir=None): + def set_state(self, selected=None, sort_col=None, sort_dir=None): """Set the state of the widget.""" - state = _create_json_dict(selected=selected, - pinned=pinned, + state = _create_json_dict(selected=[int(_) for _ in selected], sortCol=sort_col, sortDir=sort_dir, ) @@ -238,9 +232,4 @@ def state(self): @property def selected(self): """Currently selected rows.""" - return self.state.get('selected', []) - - @property - def pinned(self): - """Currently pinned rows.""" - return self.state.get('pinned', []) + return [int(_) for _ in self.state.get('selected', [])] From 09543fb6f51d1a81e14c150990a45d54efebd676 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 14:41:12 +0100 Subject: [PATCH 0526/1059] Remove table state --- phy/gui/static/table.js | 106 ++++++---------------------------------- phy/gui/widgets.py | 17 +------ 2 files changed, 18 insertions(+), 105 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index edec844fb..5207d69d8 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -1,11 +1,7 @@ var Table = function (el) { this.el = el; - this.state = { - sortCol: null, - sortDir: null, - selected: [], - } + this.selected = []; this.headers = {}; // {name: th} mapping this.rows = {}; // {id: tr} mapping this.tablesort = null; @@ -52,7 +48,7 @@ Table.prototype.setData = function(data) { var evt = e ? e:window.event; if (evt.ctrlKey || evt.metaKey) { - selected = that.state.selected.concat(selected); + selected = that.selected.concat(selected); } that.select(selected); } @@ -66,99 +62,33 @@ Table.prototype.setData = function(data) { // Enable the tablesort plugin. this.tablesort = new Tablesort(this.el); - - // Synchronize the state. - var that = this; - this.el.addEventListener('afterSort', function() { - for (var header in that.headers) { - if (that.headers[header].classList.contains('sort-up')) { - that.state.sortCol = header; - that.state.sortDir = 'desc'; - break; - } - if (that.headers[header].classList.contains('sort-down')) { - that.state.sortCol = header; - that.state.sortDir = 'asc'; - break; - } - } - }); }; -Table.prototype.setState = function(state) { - - // Make sure both sortCol and sortDir are specified. - if (!('sortCol' in state) && ('sortDir' in state)) { - state.sortCol = this.state.sortCol; - } - if (!('sortDir' in state) && ('sortCol' in state)) { - state.sortDir = this.state.sortDir; - } - - if ('sortCol' in state) { - - // Update the state. - this.state.sortCol = state.sortCol; - this.state.sortDir = state.sortDir; +Table.prototype.select = function(ids) { - // Remove all sorts. - for (var h in this.headers) { - this.headers[h].classList.remove('sort-up'); - this.headers[h].classList.remove('sort-down'); - } - - // Set the sort direction in the header class. - var header = this.headers[state.sortCol]; - header.classList.add(state.sortDir === 'desc' ? - 'sort-down' : 'sort-up'); - - // Sort the table. - this.tablesort.sortTable(header); - } - if ('selected' in state) { - this.setRowClass('selected', state.selected); - this.state.selected = state.selected; - } -}; - -Table.prototype.setRowClass = function(name, ids) { // Remove the class on all rows. - for (var i = 0; i < this.state[name].length; i++) { - var id = this.state[name][i]; + for (var i = 0; i < this.selected.length; i++) { + var id = this.selected[i]; var row = this.rows[id]; - row.classList.remove(name); + row.classList.remove('selected'); } // Add the class. for (var i = 0; i < ids.length; i++) { - var id = ids[i]; - this.rows[id].classList.add(name); + var id = parseInt(String(ids[i])); + this.rows[id].classList.add('selected'); } -}; -Table.prototype.stateUpdated = function() { - // TODO: call widget.setState(this.state); + this.selected = ids; }; -Table.prototype.getState = function() { - return this.state; -} - Table.prototype.clear = function() { - this.setState({ - selected: [], - }); -}; - -Table.prototype.select = function(items) { - this.setState({ - selected: items, - }); + this.selected = []; }; Table.prototype.next = function() { - if (this.state.selected.length != 1) return; - var id = this.state.selected[0]; + if (this.selected.length != 1) return; + var id = this.selected[0]; var row = this.rows[id]; var i0 = row.rowIndex + 1; var items = []; @@ -173,14 +103,12 @@ Table.prototype.next = function() { if (!(items.length)) return; - this.setState({ - selected: items, - }); + this.select(items); }; Table.prototype.previous = function() { - if (this.state.selected.length != 1) return; - var id = this.state.selected[0]; + if (this.selected.length != 1) return; + var id = this.selected[0]; var row = this.rows[id]; var i0 = row.rowIndex - 1; var items = []; @@ -196,7 +124,5 @@ Table.prototype.previous = function() { if (!(items.length)) return; - this.setState({ - selected: items, - }); + this.select(items); }; diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 700957aec..0edbbe2ed 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -204,14 +204,6 @@ def set_data(self, items, cols): ) self.eval_js('table.setData({});'.format(data)) - def set_state(self, selected=None, sort_col=None, sort_dir=None): - """Set the state of the widget.""" - state = _create_json_dict(selected=[int(_) for _ in selected], - sortCol=sort_col, - sortDir=sort_dir, - ) - self.eval_js('table.setState({});'.format(state)) - def next(self): """Select the next non-skip row.""" self.eval_js('table.next();') @@ -222,14 +214,9 @@ def previous(self): def select(self, ids): """Select some rows.""" - self.set_state(selected=ids) - - @property - def state(self): - """Current state.""" - return self.get_js("table.getState()") + self.eval_js('table.select({});'.format(json.dumps(ids))) @property def selected(self): """Currently selected rows.""" - return [int(_) for _ in self.state.get('selected', [])] + return [int(_) for _ in self.get_js('table.selected')] From 2e5ddfc60346e8325a63b99136bc7585447f2b8f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 15:04:16 +0100 Subject: [PATCH 0527/1059] Widgets can raise Python events from Javascript --- phy/gui/static/table.js | 7 ++++++- phy/gui/tests/test_widgets.py | 23 +++++++++++++++++++++++ phy/gui/widgets.py | 25 ++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 5207d69d8..ec39c1cb7 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -64,7 +64,8 @@ Table.prototype.setData = function(data) { this.tablesort = new Tablesort(this.el); }; -Table.prototype.select = function(ids) { +Table.prototype.select = function(ids, raise_event) { + raise_event = typeof raise_event !== 'undefined' ? false : true; // Remove the class on all rows. for (var i = 0; i < this.selected.length; i++) { @@ -80,6 +81,10 @@ Table.prototype.select = function(ids) { } this.selected = ids; + + if (raise_event) { + emit("select", ids); + } }; Table.prototype.clear = function() { diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 81d76156a..00a902d88 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -37,6 +37,18 @@ def test_widget_javascript(qtbot): qtbot.waitForWindowShown(widget) widget.eval_js('number = 1;') assert widget.get_js('number') == 1 + + _out = [] + + @widget.connect_ + def on_test(arg): + _out.append(arg) + + widget.eval_js('emit("test", [1, 2]);') + assert _out == [[1, 2]] + + widget.unconnect_(on_test) + # qtbot.stop() @@ -60,4 +72,15 @@ def test_table(qtbot): table.previous() assert table.selected == [3] + _sel = [] + + @table.connect_ + def on_select(items): + _sel.append(items) + + table.eval_js('table.select([1]);') + assert _sel == [[1]] + + assert table.selected == [1] + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 0edbbe2ed..c4ef4e155 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -12,6 +12,7 @@ import os.path as op from .qt import QWebView, QWebPage, QUrl, QWebSettings, pyqtSlot +from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -77,6 +78,24 @@ def __init__(self): self._header = '' self._body = '' self.add_to_js('widget', self) + self._event = EventEmitter() + self.add_header('''''') + + # Events + # ------------------------------------------------------------------------- + + def emit(self, *args, **kwargs): + return self._event.emit(*args, **kwargs) + + def connect_(self, *args, **kwargs): + self._event.connect(*args, **kwargs) + + def unconnect_(self, *args, **kwargs): + self._event.unconnect(*args, **kwargs) # Headers # ------------------------------------------------------------------------- @@ -145,6 +164,10 @@ def _set_from_js(self, obj): """Called from Javascript to pass any object to Python through JSON.""" self._obj = json.loads(str(obj)) + @pyqtSlot(str, str) + def _emit_from_js(self, name, arg_json): + self.emit(name, json.loads(str(arg_json))) + def get_js(self, expr): """Evaluate a Javascript expression and get a Python object. @@ -214,7 +237,7 @@ def previous(self): def select(self, ids): """Select some rows.""" - self.eval_js('table.select({});'.format(json.dumps(ids))) + self.eval_js('table.select({}, false);'.format(json.dumps(ids))) @property def selected(self): From 7c9cfa9ed7628900d62a6fdbfbb8a3d6f121346d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 17:15:33 +0100 Subject: [PATCH 0528/1059] Fix Python 2 issue with QString --- phy/gui/widgets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index c4ef4e155..aef67a076 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -11,6 +11,8 @@ import logging import os.path as op +from six import text_type + from .qt import QWebView, QWebPage, QUrl, QWebSettings, pyqtSlot from phy.utils import EventEmitter @@ -162,11 +164,11 @@ def eval_js(self, expr): @pyqtSlot(str) def _set_from_js(self, obj): """Called from Javascript to pass any object to Python through JSON.""" - self._obj = json.loads(str(obj)) + self._obj = json.loads(text_type(obj)) @pyqtSlot(str, str) def _emit_from_js(self, name, arg_json): - self.emit(name, json.loads(str(arg_json))) + self.emit(text_type(name), json.loads(text_type(arg_json))) def get_js(self, expr): """Evaluate a Javascript expression and get a Python object. From 12d4eba03699a38eeec686a20129f960713e7828 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 2 Nov 2015 20:23:16 +0100 Subject: [PATCH 0529/1059] Fix --- phy/gui/static/table.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index ec39c1cb7..dd2853c2e 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -76,8 +76,8 @@ Table.prototype.select = function(ids, raise_event) { // Add the class. for (var i = 0; i < ids.length; i++) { - var id = parseInt(String(ids[i])); - this.rows[id].classList.add('selected'); + ids[i] = parseInt(String(ids[i])); + this.rows[ids[i]].classList.add('selected'); } this.selected = ids; From 5652b61e56553d3dcb05102f120418556110c8eb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 10:28:36 +0100 Subject: [PATCH 0530/1059] Add mock test for Task --- phy/io/context.py | 4 ---- phy/io/tests/test_context.py | 8 +++++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index f41832327..3ad507b04 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -175,10 +175,6 @@ def ipy_view(self, value): # Dill is necessary because we need to serialize closures. value.use_dill() - def _path(self, rel_path='', *args, **kwargs): - """Get the full path to a local cache resource.""" - return op.join(self.cache_dir, rel_path.format(*args, **kwargs)) - def cache(self, f): """Cache a function using the context's cache directory.""" if self._memory is None: # pragma: no cover diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index b4f326a1c..497f130e7 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -14,7 +14,9 @@ from pytest import yield_fixture, mark, raises from six.moves import cPickle -from ..context import Context, _iter_chunks_dask, write_array, read_array +from ..context import (Context, Task, + _iter_chunks_dask, write_array, read_array, + ) #------------------------------------------------------------------------------ @@ -130,6 +132,10 @@ def square(x): assert parallel_context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] +def test_task(): + task = Task(ctx=None) + + #------------------------------------------------------------------------------ # Test context dask #------------------------------------------------------------------------------ From 379468778da6e98df41afdda3867b19fdd6af3b5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 11:08:04 +0100 Subject: [PATCH 0531/1059] Delay eval_js() calls after the page has loaded --- phy/gui/tests/test_widgets.py | 3 ++- phy/gui/widgets.py | 24 +++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 00a902d88..21689faab 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -28,7 +28,8 @@ def test_widget_html(qtbot): widget.eval_js('widget.set_body("Hello from Javascript!");') widget.show() qtbot.waitForWindowShown(widget) - # qtbot.stop() + widget.build() + assert 'Javascript' in widget.html() def test_widget_javascript(qtbot): diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index aef67a076..0fc1f9aba 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -86,6 +86,10 @@ def __init__(self): widget._emit_from_js(name, JSON.stringify(arg)); }; ''') + # Pending eval_js to call *after* the page has been built and loaded. + # Use for calls to `eval_js()` before the page is loaded. + self._pending_eval_js = [] + self.loadFinished.connect(self._load_finished) # Events # ------------------------------------------------------------------------- @@ -148,9 +152,18 @@ def build(self): base_url = QUrl().fromLocalFile(static_dir) self.setHtml(html, base_url) + def is_built(self): + return self.html() != '' + # Javascript methods # ------------------------------------------------------------------------- + def _load_finished(self): + assert self.is_built() + for expr in self._pending_eval_js: + self.eval_js(expr) + self._pending_eval_js = [] + def add_to_js(self, name, var): """Add an object to Javascript.""" frame = self.page().mainFrame() @@ -158,8 +171,13 @@ def add_to_js(self, name, var): def eval_js(self, expr): """Evaluate a Javascript expression.""" - frame = self.page().mainFrame() - frame.evaluateJavaScript(expr) + if not self.is_built(): + # If the page is not built yet, postpone the evaluation of the JS + # to after the page is loaded. + self._pending_eval_js.append(expr) + return + logger.debug("Evaluate Javascript: `%s`.", expr) + self.page().mainFrame().evaluateJavaScript(expr) @pyqtSlot(str) def _set_from_js(self, obj): @@ -188,7 +206,7 @@ def show(self): """ # Build if no HTML has been set. - if self.html() == '': + if not self.is_built(): self.build() return super(HTMLWidget, self).show() From 2522eed30bf9535bc4989d4e28177a75a4cc193b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 11:26:43 +0100 Subject: [PATCH 0532/1059] WIP: remove wizard and add cluster view --- phy/cluster/manual/__init__.py | 1 - phy/cluster/manual/gui_component.py | 198 +++++----- phy/cluster/manual/tests/conftest.py | 20 - .../manual/tests/test_gui_component.py | 220 ++--------- phy/cluster/manual/tests/test_wizard.py | 344 ----------------- phy/cluster/manual/wizard.py | 353 ------------------ phy/gui/widgets.py | 4 +- phy/io/tests/test_context.py | 1 + 8 files changed, 116 insertions(+), 1025 deletions(-) delete mode 100644 phy/cluster/manual/tests/test_wizard.py delete mode 100644 phy/cluster/manual/wizard.py diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 6689a86d1..d6b871021 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -7,4 +7,3 @@ from .clustering import Clustering from .gui_component import ManualClustering from .views import WaveformView, TraceView, FeatureView, CorrelogramView -from .wizard import Wizard diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 91ae9ade7..8ef5ac8b7 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -10,13 +10,16 @@ import logging import numpy as np -from six import string_types from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering -from .wizard import Wizard +from phy.stats.clusters import (mean, + max_waveform_amplitude, + mean_masked_features_distance, + ) from phy.gui.actions import Actions +from phy.gui.widgets import Table from phy.io.array import select_spikes logger = logging.getLogger(__name__) @@ -42,83 +45,61 @@ def _process_ups(ups): # pragma: no cover raise NotImplementedError() -# ----------------------------------------------------------------------------- -# Attach wizard to effectors (clustering and cluster_meta) -# ----------------------------------------------------------------------------- - -_wizard_group_mapping = { - 'noise': 'ignored', - 'mua': 'ignored', - 'good': 'good', -} - - -def _wizard_group(group): - # The group should be None, 'mua', 'noise', or 'good'. - assert group is None or isinstance(group, string_types) - group = group.lower() if group else group - return _wizard_group_mapping.get(group, None) - - -def _attach_wizard_to_effector(wizard, effector): - - # Save the current selection when an action occurs. - @effector.connect - def on_request_undo_state(up): - return {'selection': wizard._selection} - - @effector.connect - def on_cluster(up): - if not up.history: - # Reset the history after every change. - # That's because the history contains references to dead clusters. - wizard.reset() - if up.history == 'undo': - # Revert to the given selection after an undo. - wizard.select(up.undo_state[0]['selection'], add_to_history=False) - - -def _attach_wizard_to_clustering(wizard, clustering): - _attach_wizard_to_effector(wizard, clustering) - - @wizard.set_cluster_ids_function - def get_cluster_ids(): - return clustering.cluster_ids - - @clustering.connect - def on_cluster(up): - if up.added and up.history != 'undo': - wizard.select([up.added[0]]) - # NOTE: after a merge, select the merged one AND the most similar. - # There is an ambiguity after a merge: does the merge occurs during - # a wizard session, in which case we want to pin the merged - # cluster? If it is just a "cold" merge, then we might not want - # to pin the merged cluster. But cold merges are supposed to be - # less frequent than wizard merges. - wizard.pin() - - -def _attach_wizard_to_cluster_meta(wizard, cluster_meta): - _attach_wizard_to_effector(wizard, cluster_meta) - - @wizard.set_status_function - def status(cluster): - group = cluster_meta.get('group', cluster) - return _wizard_group(group) - - @cluster_meta.connect - def on_cluster(up): - if up.description == 'metadata_group' and up.history != 'undo': - cluster = up.metadata_changed[0] - wizard.next_selection([cluster], ignore_group=True) - # TODO: pin after a move? Yes if the previous selection >= 2, no - # otherwise. See similar note above. - # wizard.pin() - - -def _attach_wizard(wizard, clustering, cluster_meta): - _attach_wizard_to_clustering(wizard, clustering) - _attach_wizard_to_cluster_meta(wizard, cluster_meta) +def _make_wizard_default_functions(waveforms=None, + features=None, + masks=None, + n_features_per_channel=None, + spikes_per_cluster=None, + ): + spc = spikes_per_cluster + nfc = n_features_per_channel + + def max_waveform_amplitude_quality(cluster): + spike_ids = select_spikes(cluster_ids=[cluster], + max_n_spikes_per_cluster=100, + spikes_per_cluster=spc, + ) + m = np.atleast_2d(masks[spike_ids]) + w = np.atleast_3d(waveforms[spike_ids]) + mean_masks = mean(m) + mean_waveforms = mean(w) + q = max_waveform_amplitude(mean_masks, mean_waveforms) + logger.debug("Computed cluster quality for %d: %.3f.", + cluster, q) + return q + + def mean_masked_features_similarity(c0, c1): + s0 = select_spikes(cluster_ids=[c0], + max_n_spikes_per_cluster=100, + spikes_per_cluster=spc, + ) + s1 = select_spikes(cluster_ids=[c1], + max_n_spikes_per_cluster=100, + spikes_per_cluster=spc, + ) + + f0 = features[s0] + m0 = np.atleast_2d(masks[s0]) + + f1 = features[s1] + m1 = np.atleast_2d(masks[s1]) + + mf0 = mean(f0) + mm0 = mean(m0) + + mf1 = mean(f1) + mm1 = mean(m1) + + d = mean_masked_features_distance(mf0, mf1, mm0, mm1, + n_features_per_channel=nfc, + ) + + logger.debug("Computed cluster similarity for (%d, %d): %.3f.", + c0, c1, d) + return -d # NOTE: convert distance to score + + return (max_waveform_amplitude_quality, + mean_masked_features_similarity) # ----------------------------------------------------------------------------- @@ -130,7 +111,6 @@ class ManualClustering(object): * Clustering instance: merge, split, undo, redo * ClusterMeta instance: change cluster metadata (e.g. group) - * Wizard * Selection * Many manual clustering-related actions, snippets, shortcuts, etc. @@ -162,13 +142,9 @@ class ManualClustering(object): default_shortcuts = { 'save': 'Save', # Wizard actions. - 'next_by_quality': 'space', + 'next': 'space', 'previous': 'shift+space', 'reset_wizard': 'ctrl+alt+space', - 'first': 'MoveToStartOfLine', - 'last': 'MoveToEndOfLine', - 'pin': 'return', - 'unpin': 'backspace', # Clustering actions. 'merge': 'g', 'split': 'k', @@ -181,6 +157,8 @@ def __init__(self, cluster_groups=None, n_spikes_max_per_cluster=100, shortcuts=None, + quality_func=None, + similarity_func=None, ): self.gui = None @@ -195,8 +173,9 @@ def __init__(self, self.cluster_meta = create_cluster_meta(cluster_groups) self._global_history = GlobalHistory(process_ups=_process_ups) - # Create the wizard and attach it to Clustering/ClusterMeta. - self.wizard = Wizard() + # Wizard functions. + self.quality_func = quality_func or (lambda c: 0) + self.similarity_func = similarity_func or (lambda c, d: 0) # Log the actions. @self.clustering.connect @@ -226,7 +205,7 @@ def on_cluster(up): if self.gui: self.gui.emit('on_cluster', up) - _attach_wizard(self.wizard, self.clustering, self.cluster_meta) + # _attach_wizard(self.wizard, self.clustering, self.cluster_meta) def _create_actions(self, gui): self.actions = Actions(gui, default_shortcuts=self.shortcuts) @@ -235,13 +214,13 @@ def _create_actions(self, gui): self.actions.add(self.select, alias='c') # Wizard. - self.actions.add(self.wizard.restart, name='reset_wizard') - self.actions.add(self.wizard.previous) - self.actions.add(self.wizard.next_by_quality) - self.actions.add(self.wizard.next_by_similarity) - self.actions.add(self.wizard.next) # no shortcut - self.actions.add(self.wizard.pin) - self.actions.add(self.wizard.unpin) + # self.actions.add(self.wizard.restart, name='reset_wizard') + # self.actions.add(self.wizard.previous) + # self.actions.add(self.wizard.next_by_quality) + # self.actions.add(self.wizard.next_by_similarity) + # self.actions.add(self.wizard.next) # no shortcut + # self.actions.add(self.wizard.pin) + # self.actions.add(self.wizard.unpin) # Clustering. self.actions.add(self.merge) @@ -250,10 +229,22 @@ def _create_actions(self, gui): self.actions.add(self.undo) self.actions.add(self.redo) + def _create_cluster_view(self): + table = Table() + cols = ['id', 'quality'] + items = [{'id': int(clu), 'quality': self.quality_func(clu)} + for clu in self.clustering.cluster_ids] + table.set_data(items, cols) + table.build() + return table + def attach(self, gui): self.gui = gui - @self.wizard.connect + self.cluster_view = self._create_cluster_view() + gui.add_view(self.cluster_view, title='ClusterView') + + @self.cluster_view.connect_ def on_select(cluster_ids): """When the wizard selects clusters, choose a spikes subset and emit the `select` event on the GUI.""" @@ -266,17 +257,12 @@ def on_select(cluster_ids): if self.gui: self.gui.emit('select', cluster_ids, spike_ids) - @self.wizard.connect - def on_start(): - if self.gui: - gui.emit('wizard_start') - # Create the actions. self._create_actions(gui) return self - # Wizard-related actions + # Selection actions # ------------------------------------------------------------------------- def select(self, *cluster_ids): @@ -285,15 +271,15 @@ def select(self, *cluster_ids): # the snippet: ":c 1 2 3". if cluster_ids and isinstance(cluster_ids[0], (tuple, list)): cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:]) - self.wizard.select(cluster_ids) + # self.wizard.select(cluster_ids) # Clustering actions # ------------------------------------------------------------------------- def merge(self, cluster_ids=None): - if cluster_ids is None: - cluster_ids = self.wizard.selection - if len(cluster_ids) <= 1: + # if cluster_ids is None: + # cluster_ids = self.wizard.selection + if len(cluster_ids or []) <= 1: return self.clustering.merge(cluster_ids) self._global_history.action(self.clustering) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index e08d344fa..e5b84d81b 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,9 +8,6 @@ from pytest import yield_fixture -from ..wizard import Wizard -from ..gui_component import _wizard_group - #------------------------------------------------------------------------------ # Fixtures @@ -32,11 +29,6 @@ def cluster_groups(): yield {0: 'noise', 1: 'good', 10: 'mua', 11: 'good'} -@yield_fixture -def status(cluster_groups): - yield lambda c: _wizard_group(cluster_groups.get(c, None)) - - @yield_fixture def quality(): yield lambda c: c @@ -45,15 +37,3 @@ def quality(): @yield_fixture def similarity(): yield lambda c, d: c * 1.01 + d - - -@yield_fixture -def wizard(get_cluster_ids, status, quality, similarity): - wizard = Wizard() - - wizard.set_cluster_ids_function(get_cluster_ids) - wizard.set_status_function(status) - wizard.set_quality_function(quality) - wizard.set_similarity_function(similarity) - - yield wizard diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 8c1e27a33..8d1a44d4e 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -10,15 +10,8 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from ..clustering import Clustering -from .._utils import create_cluster_meta -from ..gui_component import (_wizard_group, - _attach_wizard, - _attach_wizard_to_clustering, - _attach_wizard_to_cluster_meta, - ManualClustering, - ) -from phy.gui.tests.conftest import gui # noqa +from ..gui_component import ManualClustering +from phy.gui import GUI #------------------------------------------------------------------------------ @@ -46,201 +39,21 @@ def assert_selection(*cluster_ids): # pragma: no cover if not _s: return assert _s[-1][0] == list(cluster_ids) - if len(cluster_ids) >= 1: - assert mc.wizard.best == cluster_ids[0] - elif len(cluster_ids) >= 2: - assert mc.wizard.match == cluster_ids[2] yield mc, assert_selection -#------------------------------------------------------------------------------ -# Test wizard attach -#------------------------------------------------------------------------------ - -def test_wizard_group(): - assert _wizard_group('noise') == 'ignored' - assert _wizard_group('mua') == 'ignored' - assert _wizard_group('good') == 'good' - assert _wizard_group('unknown') is None - assert _wizard_group(None) is None - - -def test_attach_wizard_to_clustering_merge(wizard, cluster_ids): - clustering = Clustering(np.array(cluster_ids)) - _attach_wizard_to_clustering(wizard, clustering) - - assert wizard.selection == [] - - wizard.select([30, 20, 10]) - assert wizard.selection == [30, 20, 10] - - clustering.merge([30, 20]) - # Select the merged cluster along with its most similar one (=pin merged). - assert wizard.selection == [31, 2] - - # Undo: the previous selection reappears. - clustering.undo() - assert wizard.selection == [30, 20, 10] - - # Redo. - clustering.redo() - assert wizard.selection == [31, 2] - - -def test_attach_wizard_to_clustering_split(wizard, cluster_ids): - clustering = Clustering(np.array(cluster_ids)) - _attach_wizard_to_clustering(wizard, clustering) - - wizard.select([30, 20, 10]) - assert wizard.selection == [30, 20, 10] - - clustering.split([5, 3]) - assert wizard.selection == [31, 30] - - # Undo: the previous selection reappears. - clustering.undo() - assert wizard.selection == [30, 20, 10] - - # Redo. - clustering.redo() - assert wizard.selection == [31, 30] - - -def test_attach_wizard_to_cluster_meta(wizard, cluster_groups): - cluster_meta = create_cluster_meta(cluster_groups) - _attach_wizard_to_cluster_meta(wizard, cluster_meta) - - wizard.select([30]) - - wizard.select([20]) - assert wizard.selection == [20] - - cluster_meta.set('group', [20], 'noise') - assert cluster_meta.get('group', 20) == 'noise' - assert wizard.selection == [2] - - cluster_meta.set('group', [2], 'good') - assert wizard.selection == [11] - - # Restart. - wizard.restart() - assert wizard.selection == [30] - - # 30, 20, 11, 10, 2, 1, 0 - # N, i, g, i, g, g, i - assert wizard.next_by_quality() == [11] - assert wizard.next_by_quality() == [2] - assert wizard.next_by_quality() == [1] - assert wizard.next_by_quality() == [20] - assert wizard.next_by_quality() == [10] - assert wizard.next_by_quality() == [0] - - -def test_attach_wizard_to_cluster_meta_undo(wizard, cluster_groups): - cluster_meta = create_cluster_meta(cluster_groups) - _attach_wizard_to_cluster_meta(wizard, cluster_meta) - - wizard.select([20]) - - cluster_meta.set('group', [20], 'noise') - assert wizard.selection == [2] - - wizard.next_by_quality() - assert wizard.selection == [11] - - cluster_meta.undo() - assert wizard.selection == [20] - - cluster_meta.redo() - assert wizard.selection == [2] - - -def test_attach_wizard_1(wizard, cluster_ids, cluster_groups): - clustering = Clustering(np.array(cluster_ids)) - cluster_meta = create_cluster_meta(cluster_groups) - _attach_wizard(wizard, clustering, cluster_meta) - - wizard.restart() - assert wizard.selection == [30] - - wizard.pin() - assert wizard.selection == [30, 20] - - clustering.merge(wizard.selection) - assert wizard.selection == [31, 2] - assert cluster_meta.get('group', 31) is None - - wizard.next_by_quality() - assert wizard.selection == [31, 11] - - clustering.undo() - assert wizard.selection == [30, 20] - - -def test_attach_wizard_2(wizard, cluster_ids, cluster_groups): - clustering = Clustering(np.array(cluster_ids)) - cluster_meta = create_cluster_meta(cluster_groups) - _attach_wizard(wizard, clustering, cluster_meta) - - wizard.select([30, 20]) - assert wizard.selection == [30, 20] - - clustering.split([1]) - assert wizard.selection == [31, 30] - assert cluster_meta.get('group', 31) is None - - wizard.next_by_quality() - assert wizard.selection == [31, 20] - - clustering.undo() - assert wizard.selection == [30, 20] - - -def test_attach_wizard_3(wizard, cluster_ids, cluster_groups): - clustering = Clustering(np.array(cluster_ids)) - cluster_meta = create_cluster_meta(cluster_groups) - _attach_wizard(wizard, clustering, cluster_meta) - - wizard.select([30, 20]) - assert wizard.selection == [30, 20] - - cluster_meta.set('group', 30, 'noise') - assert wizard.selection == [20] +@yield_fixture +def gui(qapp): + gui = GUI(position=(200, 100), size=(500, 500)) + yield gui + gui.close() #------------------------------------------------------------------------------ # Test GUI components #------------------------------------------------------------------------------ -def test_wizard_start_1(manual_clustering): - mc, assert_selection = manual_clustering - - # Check that the wizard_start event is fired. - _check = [] - - @mc.gui.connect_ - def on_wizard_start(): - _check.append('wizard') - - mc.wizard.restart() - assert _check == ['wizard'] - - -def test_wizard_start_2(manual_clustering): - mc, assert_selection = manual_clustering - - # Check that the wizard_start event is fired. - _check = [] - - @mc.gui.connect_ - def on_wizard_start(): - _check.append('wizard') - - mc.select([1]) - assert _check == ['wizard'] - - def test_manual_clustering_edge_cases(manual_clustering): mc, assert_selection = manual_clustering @@ -309,7 +122,16 @@ def test_manual_clustering_split_2(gui): # noqa mc.attach(gui) mc.actions.split([0]) - assert mc.wizard.selection == [2, 1] + # assert mc.wizard.selection == [2, 1] + + +def test_manual_clustering_show(qtbot, gui): # noqa + spike_clusters = np.array([0, 0, 1, 2, 0, 1]) + + mc = ManualClustering(spike_clusters) + mc.attach(gui) + gui.show() + # qtbot.stop() def test_manual_clustering_move(manual_clustering, quality, similarity): @@ -318,11 +140,11 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.actions.select([30]) assert_selection(30) - mc.wizard.set_quality_function(quality) - mc.wizard.set_similarity_function(similarity) + # mc.wizard.set_quality_function(quality) + # mc.wizard.set_similarity_function(similarity) - mc.actions.next_by_quality() - assert_selection(20) + # mc.actions.next_by_quality() + # assert_selection(20) mc.actions.move([20], 'noise') assert_selection(2) diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py deleted file mode 100644 index 01a24825c..000000000 --- a/phy/cluster/manual/tests/test_wizard.py +++ /dev/null @@ -1,344 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test wizard.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import raises - -from ..wizard import (_argsort, - _sort_by_status, - _next_in_list, - _best_clusters, - _most_similar_clusters, - _best_quality_strategy, - _best_similarity_strategy, - Wizard, - ) - - -#------------------------------------------------------------------------------ -# Test utility functions -#------------------------------------------------------------------------------ - -def test_argsort(): - l = [(1, .1), (2, .2), (3, .3), (4, .4)] - assert _argsort(l) == [4, 3, 2, 1] - - assert _argsort(l, n_max=0) == [4, 3, 2, 1] - assert _argsort(l, n_max=1) == [4] - assert _argsort(l, n_max=2) == [4, 3] - assert _argsort(l, n_max=10) == [4, 3, 2, 1] - - assert _argsort(l, reverse=False) == [1, 2, 3, 4] - - -def test_sort_by_status(status): - cluster_ids = [10, 0, 1, 30, 2, 20] - assert _sort_by_status(cluster_ids, status=status) == \ - [30, 2, 20, 1, 10, 0] - assert _sort_by_status(cluster_ids, status=status, - remove_ignored=True) == [30, 2, 20, 1] - - -def test_next_in_list(): - l = [1, 2, 3] - assert _next_in_list(l, 0) == 0 - assert _next_in_list(l, 1) == 2 - assert _next_in_list(l, 2) == 3 - assert _next_in_list(l, 3) == 3 - assert _next_in_list(l, 4) == 4 - - -def test_best_clusters(quality): - l = list(range(1, 5)) - assert _best_clusters(l, quality) == [4, 3, 2, 1] - assert _best_clusters(l, quality, n_max=0) == [4, 3, 2, 1] - assert _best_clusters(l, quality, n_max=1) == [4] - assert _best_clusters(l, quality, n_max=2) == [4, 3] - assert _best_clusters(l, quality, n_max=10) == [4, 3, 2, 1] - - -def test_most_similar_clusters(cluster_ids, similarity, status): - - def _similar(cluster): - return _most_similar_clusters(cluster, - cluster_ids=cluster_ids, - similarity=similarity, - status=status) - - assert not _similar(None) - assert not _similar(100) - - assert _similar(0) == [30, 20, 2, 11, 1] - assert _similar(1) == [30, 20, 2, 11] - assert _similar(2) == [30, 20, 11, 1] - - assert _similar(10) == [30, 20, 2, 11, 1] - assert _similar(11) == [30, 20, 2, 1] - assert _similar(20) == [30, 2, 11, 1] - assert _similar(30) == [20, 2, 11, 1] - - -#------------------------------------------------------------------------------ -# Test strategy functions -#------------------------------------------------------------------------------ - -def test_best_quality_strategy_1(cluster_ids, quality, status, similarity): - - def _next(selection): - return _best_quality_strategy(selection, - cluster_ids=cluster_ids, - quality=quality, - status=status, - similarity=similarity) - - assert not _next(None) - assert _next([]) == [30] - assert _next([30]) == [20] - assert _next([20]) == [2] - assert _next([2]) == [11] - - assert _next([30, 20]) == [30, 2] - assert _next([10, 2]) == [10, 11] - assert _next([10, 11]) == [10, 1] - assert _next([10, 1]) == [10, 1] # 0 is ignored, so it does not appear. - - -def test_best_quality_strategy_2(quality, similarity): - - def status(cluster): - return {0: 'ignored', 1: None, 2: 'good', 3: None}[cluster] - - def _next(selection): - return _best_quality_strategy(selection, - cluster_ids=list(range(4)), - quality=quality, - status=status, - similarity=similarity) - - assert _next([3, 1]) == [3, 2] - - -def test_best_similarity_strategy(cluster_ids, quality, status, similarity): - - def _next(selection): - return _best_similarity_strategy(selection, - cluster_ids=cluster_ids, - quality=quality, - status=status, - similarity=similarity) - - assert not _next(None) - assert _next([]) == [30, 20] - assert _next([30, 20]) == [30, 11] - assert _next([30, 11]) == [30, 2] - assert _next([20, 10]) == [20, 2] - assert _next([10, 2]) == [2, 1] - assert _next([2, 1]) == [2, 1] # 0 is ignored, so it does not appear. - - -#------------------------------------------------------------------------------ -# Test wizard -#------------------------------------------------------------------------------ - -def test_wizard_empty(): - wizard = Wizard() - with raises(RuntimeError): - wizard.restart() - - wizard = Wizard() - wizard.set_cluster_ids_function(lambda: []) - wizard.restart() - - -def test_wizard_nav(wizard): - w = wizard - assert w.cluster_ids == [0, 1, 2, 10, 11, 20, 30] - assert w.n_clusters == 7 - - assert w.selection == [] - - ### - w.select([]) - assert w.selection == [] - - assert w.best is None - assert w.match is None - - ### - w.select([1]) - assert w.selection == [1] - - assert w.best == 1 - assert w.match is None - - ### - w.select([1, 2, 4]) - assert w.selection == [1, 2] - - assert w.best == 1 - assert w.match == 2 - - ### - w.previous() - assert w.selection == [1] - - for _ in range(2): - w.previous() - assert w.selection == [1] - - ### - w.next() - assert w.selection == [1, 2] - - for _ in range(2): - w.next() - assert w.selection == [1, 2] - - -def test_wizard_next_1(wizard, status): - w = wizard - - assert w.next_selection([30]) == [20] - - w.reset() - assert w.next_selection([30], ignore_group=True) == [20] - - # After the last good, the best ignored. - assert w.next_selection([1]) == [10] - # After the last unsorted (1's group is ignored), the best good. - assert w.next_selection([1], ignore_group=True) == [11] - - @w.set_status_function - def status_bis(cluster): - if cluster == 30: - return 'ignored' - return status(cluster) - - assert w.next_selection([30]) == [10] - assert w.next_selection([30], ignore_group=True) == [20] - - -def test_wizard_next_2(wizard): - w = wizard - - # 30, 20, 11, 10, 2, 1, 0 - # N, i, g, g, N, g, i - - @w.set_status_function - def status_bis(cluster): - return {0: 'ignored', - 1: 'good', - 2: None, - 10: 'good', - 11: 'good', - 20: 'ignored', - 30: None, - }[cluster] - - wizard.select([30]) - assert wizard.next_by_quality() == [2] - assert wizard.next_by_quality() == [11] - - -def test_wizard_next_3(wizard): - w = wizard - - @w.set_cluster_ids_function - def cluster_ids(): - return [0, 1, 2, 3] - - @w.set_status_function - def status_bis(cluster): - return {0: 'ignored', 1: None, 2: 'good', 3: None}[cluster] - - wizard.select([3, 1]) - assert wizard.next_by_quality() == [3, 2] - - -def test_wizard_pin_by_quality(wizard): - w = wizard - - w.pin() - assert w.selection == [] - - w.unpin() - assert w.selection == [] - - w.next_by_quality() - assert w.selection == [30] - - w.next_by_quality() - assert w.selection == [20] - - # Pin. - w.pin() - assert w.selection == [20, 30] - - w.next_by_quality() - assert w.selection == [20, 2] - - # Unpin. - w.unpin() - assert w.selection == [20] - - w.next_by_quality() - assert w.selection == [2] - - # Pin. - w.pin() - assert w.selection == [2, 30] - - w.next_by_quality() - assert w.selection == [2, 20] - - # Candidate is best among good. - w.next_by_quality() - assert w.selection == [2, 11] - - # Candidate is last among good, ignored are completely ignored. - w.next_by_quality() - assert w.selection == [2, 1] - - w.next_by_quality() - assert w.selection == [2, 1] - - -def test_wizard_pin_by_similarity(wizard): - w = wizard - - w.pin() - assert w.selection == [] - - w.unpin() - assert w.selection == [] - - w.next_by_similarity() - assert w.selection == [30, 20] - - w.next_by_similarity() - assert w.selection == [30, 11] - - w.pin() - assert w.selection == [30, 20] - - w.next_by_similarity() - assert w.selection == [30, 11] - - w.unpin() - assert w.selection == [30] - - w.select([20, 10]) - assert w.selection == [20, 10] - - w.next_by_similarity() - assert w.selection == [20, 2] - - w.next_by_similarity() - assert w.selection == [20, 1] - - w.next_by_similarity() - assert w.selection == [11, 2] diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py deleted file mode 100644 index d4f6c2bb0..000000000 --- a/phy/cluster/manual/wizard.py +++ /dev/null @@ -1,353 +0,0 @@ -# -*- coding: utf-8 -*- -"""Wizard.""" - -#------------------------------------------------------------------------------ -# Imports - -#------------------------------------------------------------------------------ - -from itertools import product -import logging -from operator import itemgetter - -from ._history import History -from phy.utils import EventEmitter - -logger = logging.getLogger(__name__) - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _argsort(seq, reverse=True, n_max=None): - """Return the list of clusters in decreasing order of value from - a list of tuples (cluster, value).""" - out = [cl for (cl, v) in sorted(seq, - key=itemgetter(1), - reverse=reverse)] - if n_max in (None, 0): - return out - else: - return out[:n_max] - - -def _next_in_list(l, item): - if l and item in l and l.index(item) < len(l) - 1: - return l[l.index(item) + 1] - return item - - -def _sort_by_status(clusters, status=None, remove_ignored=False): - """Sort clusters according to their status.""" - assert status - _sort_map = {None: 0, 'good': 1, 'ignored': 2} - if remove_ignored: - clusters = [c for c in clusters if status(c) != 'ignored'] - # NOTE: sorted is "stable": it doesn't change the order of elements - # that compare equal, which ensures that the order of clusters is kept - # among any given status. - key = lambda cluster: _sort_map[status(cluster)] - return sorted(clusters, key=key) - - -def _best_clusters(clusters, quality, n_max=None): - return _argsort([(cluster, quality(cluster)) - for cluster in clusters], n_max=n_max) - - -def _most_similar_clusters(cluster, cluster_ids=None, n_max=None, - similarity=None, status=None): - """Return the `n_max` most similar clusters to a given cluster.""" - if cluster not in cluster_ids: - return [] - s = [(other, similarity(cluster, other)) - for other in cluster_ids - if other != cluster and status(other) != 'ignored'] - clusters = _argsort(s, n_max=n_max) - out = _sort_by_status(clusters, status=status) - return out - - -#------------------------------------------------------------------------------ -# Strategy functions -#------------------------------------------------------------------------------ - -def _best_quality_strategy(selection, - cluster_ids=None, - quality=None, - status=None, - similarity=None): - """Two cases depending on the number of selected clusters: - - * 1: move to the next best cluster - * 2: move to the next most similar pair - * 3+: do nothing - - """ - if selection is None: - return selection - n = len(selection) - if n <= 1: - best_clusters = _best_clusters(cluster_ids, quality) - # Sort the best clusters according to their status. - best_clusters = _sort_by_status(best_clusters, status=status) - if selection: - return [_next_in_list(best_clusters, selection[0])] - elif best_clusters: - return [best_clusters[0]] - else: # pragma: no cover - return selection - elif n == 2: - best, match = selection - candidates = _most_similar_clusters(best, - cluster_ids=cluster_ids, - similarity=similarity, - status=status, - ) - if not candidates: # pragma: no cover - return selection - candidate = _next_in_list(candidates, match) - return [best, candidate] - - -def _best_similarity_strategy(selection, - cluster_ids=None, - quality=None, - status=None, - similarity=None): - if selection is None: - return selection - n = len(selection) - if n >= 2: - best, match = selection - value = similarity(best, match) - else: - best, match = None, None - value = None - # We remove the current pair, the (x, x) pairs, and we ensure that - # (d, c) doesn't appear if (c, d) does. We choose the pair where - # the first cluster of the pair has the highest quality. - # Finally we remove the ignored clusters. - s = [((c, d), similarity(c, d)) - for c, d in product(cluster_ids, repeat=2) - if c != d and (c, d) != (best, match) - and quality(c) >= quality(d) - and status(c) != 'ignored' - and status(d) != 'ignored' - ] - - if value is not None: - s = [((c, d), v) for ((c, d), v) in s if v <= value] - pairs = _argsort(s) - if pairs: - return list(pairs[0]) - else: - return selection - - -#------------------------------------------------------------------------------ -# Wizard -#------------------------------------------------------------------------------ - -class Wizard(EventEmitter): - """Propose a selection of high-quality clusters and merge candidates. - - * The wizard is responsible for the selected clusters. - * The wizard keeps no state about the clusters: the state is entirely - provided by functions: cluster_ids, status (group), similarity, quality. - * The wizard keeps track of the history of the selected clusters, but this - history is cleared after every action that changes the state. - * The `next_*()` functions propose a new selection as a function of the - current selection. - - """ - def __init__(self): - super(Wizard, self).__init__() - self._similarity = None - self._quality = None - self._get_cluster_ids = None - self._cluster_status = None - self._selection = [] - self.reset() - - def reset(self): - self._selection = [] - self._history = History([]) - - # Quality and status functions - #-------------------------------------------------------------------------- - - def set_cluster_ids_function(self, func): - """Register a function giving the list of cluster ids.""" - self._get_cluster_ids = func - - def set_status_function(self, func): - """Register a function returning the status of a cluster: None, - 'ignored', or 'good'. - - Can be used as a decorator. - - """ - self._cluster_status = func - return func - - def set_similarity_function(self, func): - """Register a function returning the similarity between two clusters. - - Can be used as a decorator. - - """ - self._similarity = func - return func - - def set_quality_function(self, func): - """Register a function returning the quality of a cluster. - - Can be used as a decorator. - - """ - self._quality = func - return func - - # Properties - #-------------------------------------------------------------------------- - - @property - def cluster_ids(self): - """Array of cluster ids in the current clustering.""" - if not self._get_cluster_ids: - return [] - return sorted(self._get_cluster_ids()) - - @property - def n_clusters(self): - """Total number of clusters.""" - return len(self.cluster_ids) - - # Selection methods - #-------------------------------------------------------------------------- - - def select(self, cluster_ids, add_to_history=True): - if cluster_ids is None: # pragma: no cover - return - clusters = self.cluster_ids - cluster_ids = [cluster for cluster in cluster_ids - if cluster in clusters] - if not self._selection and cluster_ids: - self.emit('start') - self._selection = cluster_ids - if add_to_history: - self._history.add(self._selection) - self.emit('select', self._selection) - - @property - def selection(self): - """Return the current cluster selection.""" - return self._selection - - @property - def best(self): - """Currently-selected best cluster.""" - return self._selection[0] if self._selection else None - - @property - def match(self): - """Currently-selected closest match.""" - return self._selection[1] if len(self._selection) >= 2 else None - - def pin(self): - """Select the cluster the most similar cluster to the current best.""" - best = self.best - if best is None: - return - self._check_functions() - candidates = _most_similar_clusters(best, - cluster_ids=self.cluster_ids, - similarity=self._similarity, - status=self._cluster_status) - assert best not in candidates - if not candidates: # pragma: no cover - return - self.select([self.best, candidates[0]]) - # Clear the navigation history when pinning, such that `previous` - # keeps the pinned cluster selected. - self._history.clear() - - def unpin(self): - if len(self._selection) == 2: - self.select([self.selection[0]]) - # Clear the navigation history when unpinning, such that `previous` - # keeps the pinned cluster selected. - self._history.clear() - - # Navigation - #-------------------------------------------------------------------------- - - def _set_selection_from_history(self): - cluster_ids = self._history.current_item - if not cluster_ids: # pragma: no cover - return - self.select(cluster_ids, add_to_history=False) - - def previous(self): - if self._history.current_position <= 2: - return self._selection - self._history.back() - self._set_selection_from_history() - return self._selection - - def next(self): - if not self._history.is_last(): - # Go forward after a previous. - self._history.forward() - self._set_selection_from_history() - - def restart(self): - self.select([]) - self.next_by_quality() - - def _check_functions(self): - if not self._get_cluster_ids: - raise RuntimeError("The cluster_ids function must be set.") - if not self._cluster_status: - logger.warn("A cluster status function has not been set.") - self._cluster_status = lambda c: None - if not self._quality: - logger.warn("A cluster quality function has not been set.") - self._quality = lambda c: 0 - if not self._similarity: - logger.warn("A cluster similarity function has not been set.") - self._similarity = lambda c, d: 0 - - def next_selection(self, cluster_ids=None, - strategy=None, - ignore_group=False): - """Make a new cluster selection according to a given strategy.""" - self._check_functions() - cluster_ids = cluster_ids or self._selection - strategy = strategy or _best_quality_strategy - if ignore_group: - # Ignore the status of the selected clusters. - def status(cluster): - if cluster in cluster_ids: - return None - return self._cluster_status(cluster) - else: - status = self._cluster_status - new_selection = strategy(cluster_ids, - cluster_ids=self._get_cluster_ids(), - quality=self._quality, - status=status, - similarity=self._similarity) - # Skip new selection if it is the same. - if new_selection == self._selection: - return - self.select(new_selection) - return self._selection - - def next_by_quality(self): - return self.next_selection(strategy=_best_quality_strategy) - - def next_by_similarity(self): - return self.next_selection(strategy=_best_similarity_strategy) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 0fc1f9aba..c12c8215e 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -174,9 +174,10 @@ def eval_js(self, expr): if not self.is_built(): # If the page is not built yet, postpone the evaluation of the JS # to after the page is loaded. + logger.log(5, "Postpone evaluation of `%s`.", expr) self._pending_eval_js.append(expr) return - logger.debug("Evaluate Javascript: `%s`.", expr) + logger.log(5, "Evaluate Javascript: `%s`.", expr) self.page().mainFrame().evaluateJavaScript(expr) @pyqtSlot(str) @@ -238,7 +239,6 @@ def __init__(self): self.add_body(''''''.format(self._table_id)) - self.build() def set_data(self, items, cols): """Set the rows and cols of the table.""" diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 497f130e7..d1caf4b42 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -134,6 +134,7 @@ def square(x): def test_task(): task = Task(ctx=None) + assert task #------------------------------------------------------------------------------ From 5f48119a237601b0588bedf66cc8e96c896a841a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 11:46:32 +0100 Subject: [PATCH 0533/1059] WIP: add sort_by() method in table --- phy/gui/static/table.js | 9 +++++++++ phy/gui/widgets.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index dd2853c2e..01d817b97 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -12,6 +12,11 @@ Table.prototype.setData = function(data) { var that = this; var keys = data.cols; + // Clear the table. + while (this.el.firstChild) { + this.el.removeChild(this.el.firstChild); + } + var thead = document.createElement("thead"); var tbody = document.createElement("tbody"); @@ -64,6 +69,10 @@ Table.prototype.setData = function(data) { this.tablesort = new Tablesort(this.el); }; +Table.prototype.sortBy = function(header) { + this.tablesort.sortTable(this.headers[header]); +}; + Table.prototype.select = function(ids, raise_event) { raise_event = typeof raise_event !== 'undefined' ? false : true; diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index c12c8215e..229059a7a 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -247,6 +247,10 @@ def set_data(self, items, cols): ) self.eval_js('table.setData({});'.format(data)) + def sort_by(self, header): + """Sort by a given variable.""" + self.eval_js('table.sortBy("{}");'.format(header)) + def next(self): """Select the next non-skip row.""" self.eval_js('table.next();') From 3f7aa89715b0d82ad0287e13312aeccfaa64d433 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 11:46:50 +0100 Subject: [PATCH 0534/1059] WIP: selection in manual clustering component --- phy/cluster/manual/gui_component.py | 40 +++++++++++++++++-- .../manual/tests/test_gui_component.py | 25 +++++++----- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 8ef5ac8b7..52026864e 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -232,22 +232,42 @@ def _create_actions(self, gui): def _create_cluster_view(self): table = Table() cols = ['id', 'quality'] + # TODO: skip items = [{'id': int(clu), 'quality': self.quality_func(clu)} for clu in self.clustering.cluster_ids] table.set_data(items, cols) table.build() return table + def _create_similarity_view(self): + table = Table() + table.build() + return table + def attach(self, gui): self.gui = gui + # Cluster view. self.cluster_view = self._create_cluster_view() gui.add_view(self.cluster_view, title='ClusterView') - @self.cluster_view.connect_ - def on_select(cluster_ids): - """When the wizard selects clusters, choose a spikes subset - and emit the `select` event on the GUI.""" + # Similarity view. + self.similarity_view = self._create_similarity_view() + gui.add_view(self.similarity_view, title='SimilarityView') + + def _update_similarity_view(cluster_ids): + if len(cluster_ids) == 1: + sel = int(cluster_ids[0]) + cols = ['id', 'similarity'] + # TODO: skip + items = [{'id': int(clu), + 'similarity': self.similarity_func(sel, clu)} + for clu in self.clustering.cluster_ids] + self.similarity_view.set_data(items, cols) + self.similarity_view.sort_by('similarity') + self.similarity_view.sort_by('similarity') + + def _select(cluster_ids): spike_ids = select_spikes(np.array(cluster_ids), self.n_spikes_max_per_cluster, self.clustering.spikes_per_cluster) @@ -257,6 +277,18 @@ def on_select(cluster_ids): if self.gui: self.gui.emit('select', cluster_ids, spike_ids) + def on_select1(cluster_ids): + # Update the similarity view when the selection changes in + # the cluster view. + _update_similarity_view(cluster_ids) + _select(cluster_ids) + self.cluster_view.connect_(on_select1, event='select') + + def on_select2(cluster_ids): + # TODO: prepend the clusters selected in the cluster view + _select(cluster_ids) + self.similarity_view.connect_(on_select2, event='select') # noqa + # Create the actions. self._create_actions(gui) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 8d1a44d4e..4ffe23cda 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -18,7 +18,7 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture # noqa +@yield_fixture def manual_clustering(gui, cluster_ids, cluster_groups): spike_clusters = np.array(cluster_ids) @@ -115,7 +115,7 @@ def test_manual_clustering_split(manual_clustering): assert_selection(31, 20) -def test_manual_clustering_split_2(gui): # noqa +def test_manual_clustering_split_2(gui): spike_clusters = np.array([0, 0, 1]) mc = ManualClustering(spike_clusters) @@ -125,15 +125,6 @@ def test_manual_clustering_split_2(gui): # noqa # assert mc.wizard.selection == [2, 1] -def test_manual_clustering_show(qtbot, gui): # noqa - spike_clusters = np.array([0, 0, 1, 2, 0, 1]) - - mc = ManualClustering(spike_clusters) - mc.attach(gui) - gui.show() - # qtbot.stop() - - def test_manual_clustering_move(manual_clustering, quality, similarity): mc, assert_selection = manual_clustering @@ -154,3 +145,15 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.actions.redo() assert_selection(2) + + +def test_manual_clustering_show(qtbot, gui): + spike_clusters = np.array([0, 0, 1, 2, 0, 1]) + + def sf(c, d): + return float(c + d) + + mc = ManualClustering(spike_clusters, similarity_func=sf) + mc.attach(gui) + gui.show() + qtbot.stop() From 63c099eb716a42c32d1422a9a3c1e3b5824c2064 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 15:33:55 +0100 Subject: [PATCH 0535/1059] Fix bug with sorting numbers in HTML table --- phy/gui/static/tablesort.number.js | 26 ++++++++++++++++++++++++++ phy/gui/tests/test_widgets.py | 10 +++++++++- phy/gui/widgets.py | 1 + 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 phy/gui/static/tablesort.number.js diff --git a/phy/gui/static/tablesort.number.js b/phy/gui/static/tablesort.number.js new file mode 100644 index 000000000..2e11b462a --- /dev/null +++ b/phy/gui/static/tablesort.number.js @@ -0,0 +1,26 @@ +(function(){ + var cleanNumber = function(i) { + return i.replace(/[^\-?0-9.]/g, ''); + }, + + compareNumber = function(a, b) { + a = parseFloat(a); + b = parseFloat(b); + + a = isNaN(a) ? 0 : a; + b = isNaN(b) ? 0 : b; + + return a - b; + }; + + Tablesort.extend('number', function(item) { + return item.match(/^-?[£\x24Û¢´€]?\d+\s*([,\.]\d{0,2})/) || // Prefixed currency + item.match(/^-?\d+\s*([,\.]\d{0,2})?[£\x24Û¢´€]/) || // Suffixed currency + item.match(/^-?(\d)*-?([,\.]){0,1}-?(\d)+([E,e][\-+][\d]+)?%?$/); // Number + }, function(a, b) { + a = cleanNumber(a); + b = cleanNumber(b); + + return compareNumber(b, a); + }); +}()); diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 21689faab..f4d8f553e 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -59,7 +59,7 @@ def test_table(qtbot): table.show() qtbot.waitForWindowShown(table) - items = [{'id': i, 'count': 10 * i} for i in range(10)] + items = [{'id': i, 'count': 100 - 10 * i} for i in range(10)] items[4]['skip'] = True table.set_data(cols=['id', 'count'], @@ -84,4 +84,12 @@ def on_select(items): assert table.selected == [1] + # Sort by count decreasing, and check that 0 (count 100) comes before + # 1 (count 90). This checks that sorting works with number (need to + # import tablesort.number.js). + table.sort_by('count') + table.sort_by('count') + table.previous() + assert table.selected == [0] + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 229059a7a..2a71e52f1 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -233,6 +233,7 @@ def __init__(self): super(Table, self).__init__() self.add_style_src('table.css') self.add_script_src('tablesort.min.js') + self.add_script_src('tablesort.number.js') self.add_script_src('table.js') self.set_body('
'.format( self._table_id)) From ea6f20d44f185d31b855c196f3030c8e20864ca9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 16:05:03 +0100 Subject: [PATCH 0536/1059] WIP: update manual clustering component with cluster views --- phy/cluster/manual/gui_component.py | 150 ++++++++++++++-------------- 1 file changed, 76 insertions(+), 74 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 52026864e..ffef1efbf 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -114,26 +114,26 @@ class ManualClustering(object): * Selection * Many manual clustering-related actions, snippets, shortcuts, etc. - Bring the `select` event to the GUI. This is raised when clusters are - selected by the user or by the wizard. - Parameters ---------- - gui : GUI spike_clusters : ndarray cluster_groups : dictionary n_spikes_max_per_cluster : int + shortcuts : dict + quality_func : function + similarity_func : function + + GUI events + ---------- - Events - ------ + When this component is attached to a GUI, the GUI emits the following + events: select(cluster_ids, spike_ids) when clusters are selected - on_cluster(up) + cluster(up) when a merge or split happens - wizard_start() - when the wizard (re)starts save_requested(spike_clusters, cluster_groups) when a save is requested by the user @@ -191,7 +191,7 @@ def on_cluster(up): logger.info("Assigned spikes.") if self.gui: - self.gui.emit('on_cluster', up) + self.gui.emit('cluster', up) @self.cluster_meta.connect # noqa def on_cluster(up): @@ -203,9 +203,10 @@ def on_cluster(up): up.metadata_value) if self.gui: - self.gui.emit('on_cluster', up) + self.gui.emit('cluster', up) - # _attach_wizard(self.wizard, self.clustering, self.cluster_meta) + # Internal methods + # ------------------------------------------------------------------------- def _create_actions(self, gui): self.actions = Actions(gui, default_shortcuts=self.shortcuts) @@ -213,15 +214,6 @@ def _create_actions(self, gui): # Selection. self.actions.add(self.select, alias='c') - # Wizard. - # self.actions.add(self.wizard.restart, name='reset_wizard') - # self.actions.add(self.wizard.previous) - # self.actions.add(self.wizard.next_by_quality) - # self.actions.add(self.wizard.next_by_similarity) - # self.actions.add(self.wizard.next) # no shortcut - # self.actions.add(self.wizard.pin) - # self.actions.add(self.wizard.unpin) - # Clustering. self.actions.add(self.merge) self.actions.add(self.split) @@ -229,65 +221,48 @@ def _create_actions(self, gui): self.actions.add(self.undo) self.actions.add(self.redo) - def _create_cluster_view(self): - table = Table() + def _create_cluster_views(self, gui): + # Create the cluster view. + self.cluster_view = cluster_view = Table() cols = ['id', 'quality'] # TODO: skip items = [{'id': int(clu), 'quality': self.quality_func(clu)} for clu in self.clustering.cluster_ids] - table.set_data(items, cols) - table.build() - return table + # TODO: custom measures + cluster_view.set_data(items, cols) + cluster_view.build() + gui.add_view(cluster_view, title='ClusterView') + + # Create the similarity view. + self.similarity_view = similarity_view = Table() + similarity_view.build() + gui.add_view(similarity_view, title='SimilarityView') + + @self.cluster_view.connect_ + def on_select(cluster_ids): + self.select(cluster_ids) + self.pin(cluster_ids) + + @self.similarity_view.connect_ # noqa + def on_select(cluster_ids): + # Select the clusters from both views. + cluster_ids = cluster_view.selected + cluster_ids + self.select(cluster_ids) + + # Public methods + # ------------------------------------------------------------------------- - def _create_similarity_view(self): - table = Table() - table.build() - return table + def set_quality_func(self, f): + self.quality_func = f + + def set_similarity_func(self, f): + self.similarity_func = f def attach(self, gui): self.gui = gui - # Cluster view. - self.cluster_view = self._create_cluster_view() - gui.add_view(self.cluster_view, title='ClusterView') - - # Similarity view. - self.similarity_view = self._create_similarity_view() - gui.add_view(self.similarity_view, title='SimilarityView') - - def _update_similarity_view(cluster_ids): - if len(cluster_ids) == 1: - sel = int(cluster_ids[0]) - cols = ['id', 'similarity'] - # TODO: skip - items = [{'id': int(clu), - 'similarity': self.similarity_func(sel, clu)} - for clu in self.clustering.cluster_ids] - self.similarity_view.set_data(items, cols) - self.similarity_view.sort_by('similarity') - self.similarity_view.sort_by('similarity') - - def _select(cluster_ids): - spike_ids = select_spikes(np.array(cluster_ids), - self.n_spikes_max_per_cluster, - self.clustering.spikes_per_cluster) - logger.debug("Select clusters: %s (%d spikes).", - ', '.join(map(str, cluster_ids)), len(spike_ids)) - - if self.gui: - self.gui.emit('select', cluster_ids, spike_ids) - - def on_select1(cluster_ids): - # Update the similarity view when the selection changes in - # the cluster view. - _update_similarity_view(cluster_ids) - _select(cluster_ids) - self.cluster_view.connect_(on_select1, event='select') - - def on_select2(cluster_ids): - # TODO: prepend the clusters selected in the cluster view - _select(cluster_ids) - self.similarity_view.connect_(on_select2, event='select') # noqa + # Create the cluster views. + self._create_cluster_views(gui) # Create the actions. self._create_actions(gui) @@ -298,12 +273,39 @@ def on_select2(cluster_ids): # ------------------------------------------------------------------------- def select(self, *cluster_ids): - # HACK: allow for select(1, 2, 3) in addition to select([1, 2, 3]) + """Choose spikes from the specified clusters and emit the + `select` event on the GUI.""" + + # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])` # This makes it more convenient to select multiple clusters with - # the snippet: ":c 1 2 3". + # the snippet: `:c 1 2 3` instead of `:c 1,2,3`. if cluster_ids and isinstance(cluster_ids[0], (tuple, list)): cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:]) - # self.wizard.select(cluster_ids) + + # Choose a spike subset. + spike_ids = select_spikes(np.array(cluster_ids), + self.n_spikes_max_per_cluster, + self.clustering.spikes_per_cluster) + logger.debug("Select clusters: %s (%d spikes).", + ', '.join(map(str, cluster_ids)), len(spike_ids)) + if self.gui: + self.gui.emit('select', cluster_ids, spike_ids) + + def pin(self, cluster_ids): + """Update the similarity view with matches for the specified + clusters.""" + # TODO: similarity wrt several clusters + sel = int(cluster_ids[0]) + cols = ['id', 'similarity'] + # TODO: skip + items = [{'id': int(clu), + 'similarity': self.similarity_func(sel, clu)} + for clu in self.clustering.cluster_ids] + self.similarity_view.set_data(items, cols) + + # NOTE: sort twice to get decreasing order. + self.similarity_view.sort_by('similarity') + self.similarity_view.sort_by('similarity') # Clustering actions # ------------------------------------------------------------------------- From 78c6bedf0079926ab203215c1c282f061f11361e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 16:23:46 +0100 Subject: [PATCH 0537/1059] Add table.current_sort property --- phy/gui/static/table.js | 12 +++++++++ phy/gui/tests/test_widgets.py | 47 ++++++++++++++++++++++++++--------- phy/gui/widgets.py | 4 +++ 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 01d817b97..3eb1ddbf3 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -73,6 +73,18 @@ Table.prototype.sortBy = function(header) { this.tablesort.sortTable(this.headers[header]); }; +Table.prototype.currentSort = function() { + for (var header in this.headers) { + if (this.headers[header].classList.contains('sort-up')) { + return [header, 'desc']; + } + if (this.headers[header].classList.contains('sort-down')) { + return [header, 'asc']; + } + } + return [null, null]; +} + Table.prototype.select = function(ids, raise_event) { raise_event = typeof raise_event !== 'undefined' ? false : true; diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index f4d8f553e..df3eca421 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -6,9 +6,34 @@ # Imports #------------------------------------------------------------------------------ +from pytest import yield_fixture + from ..widgets import HTMLWidget, Table +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def table(qtbot): + table = Table() + + table.show() + qtbot.waitForWindowShown(table) + + items = [{'id': i, 'count': 100 - 10 * i} for i in range(10)] + items[4]['skip'] = True + + table.set_data(cols=['id', 'count'], + items=items, + ) + + yield table + + table.close() + + #------------------------------------------------------------------------------ # Test actions #------------------------------------------------------------------------------ @@ -53,18 +78,7 @@ def on_test(arg): # qtbot.stop() -def test_table(qtbot): - table = Table() - - table.show() - qtbot.waitForWindowShown(table) - - items = [{'id': i, 'count': 100 - 10 * i} for i in range(10)] - items[4]['skip'] = True - - table.set_data(cols=['id', 'count'], - items=items, - ) +def test_table_nav(qtbot, table): table.select([4]) table.next() @@ -84,12 +98,21 @@ def on_select(items): assert table.selected == [1] + # qtbot.stop() + + +def test_table_sort(qtbot, table): + table.select([1]) + # Sort by count decreasing, and check that 0 (count 100) comes before # 1 (count 90). This checks that sorting works with number (need to # import tablesort.number.js). table.sort_by('count') table.sort_by('count') + table.previous() assert table.selected == [0] + assert table.current_sort == ('count', 'desc') + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 2a71e52f1..6a0a0f3ee 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -268,3 +268,7 @@ def select(self, ids): def selected(self): """Currently selected rows.""" return [int(_) for _ in self.get_js('table.selected')] + + @property + def current_sort(self): + return tuple(self.get_js('table.currentSort()')) From dfb0946e0bc45d3009f72ea1da3a1dc7df6ca456 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 16:33:39 +0100 Subject: [PATCH 0538/1059] WIP: select after actions --- phy/cluster/manual/gui_component.py | 38 +++++++++++++++---- .../manual/tests/test_gui_component.py | 2 +- phy/gui/static/table.js | 4 +- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index ffef1efbf..8943cbf16 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -224,12 +224,9 @@ def _create_actions(self, gui): def _create_cluster_views(self, gui): # Create the cluster view. self.cluster_view = cluster_view = Table() - cols = ['id', 'quality'] - # TODO: skip - items = [{'id': int(clu), 'quality': self.quality_func(clu)} - for clu in self.clustering.cluster_ids] - # TODO: custom measures - cluster_view.set_data(items, cols) + # NOTE: table.setData() must be called *before* build() so that + # the JS call is deferred to after the HTML widget is fully loaded. + self._update_cluster_view(cluster_view) cluster_view.build() gui.add_view(cluster_view, title='ClusterView') @@ -238,17 +235,42 @@ def _create_cluster_views(self, gui): similarity_view.build() gui.add_view(similarity_view, title='SimilarityView') - @self.cluster_view.connect_ + # Selection in the cluster view. + @cluster_view.connect_ def on_select(cluster_ids): + # Emit GUI.select when the selection changes in the cluster view. self.select(cluster_ids) + # Pin the clusters and update the similarity view. self.pin(cluster_ids) - @self.similarity_view.connect_ # noqa + # Selection in the similarity view. + @similarity_view.connect_ # noqa def on_select(cluster_ids): # Select the clusters from both views. cluster_ids = cluster_view.selected + cluster_ids self.select(cluster_ids) + # Update the cluster views and selection when a cluster event occurs. + @self.gui.connect_ + def on_cluster(up): + # Get the current sort of the cluster view. + sort = cluster_view.current_sort + # Reinitialize the cluster view. + self._update_cluster_view(cluster_view) + # Select all new clusters in view 1. + if up.added: + cluster_view.select(up.added) + else: + cluster_view.next() + + def _update_cluster_view(self, cluster_view): + cols = ['id', 'quality'] + # TODO: skip + items = [{'id': int(clu), 'quality': self.quality_func(clu)} + for clu in self.clustering.cluster_ids] + # TODO: custom measures + cluster_view.set_data(items, cols) + # Public methods # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 4ffe23cda..a741e156b 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -156,4 +156,4 @@ def sf(c, d): mc = ManualClustering(spike_clusters, similarity_func=sf) mc.attach(gui) gui.show() - qtbot.stop() + # qtbot.stop() diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 3eb1ddbf3..bd34653a3 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -113,7 +113,7 @@ Table.prototype.clear = function() { }; Table.prototype.next = function() { - if (this.selected.length != 1) return; + // TODO: what to do when doing next() while several items are selected. var id = this.selected[0]; var row = this.rows[id]; var i0 = row.rowIndex + 1; @@ -133,7 +133,7 @@ Table.prototype.next = function() { }; Table.prototype.previous = function() { - if (this.selected.length != 1) return; + // TODO: what to do when doing next() while several items are selected. var id = this.selected[0]; var row = this.rows[id]; var i0 = row.rowIndex - 1; From 1621fca7ab46db9bca0b394a6978e6606c41ac60 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 16:55:25 +0100 Subject: [PATCH 0539/1059] Tests pass --- phy/cluster/manual/gui_component.py | 48 +++++++++++-------- .../manual/tests/test_gui_component.py | 8 ++-- phy/gui/widgets.py | 12 ++++- 3 files changed, 43 insertions(+), 25 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 8943cbf16..b78884b9f 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -239,7 +239,7 @@ def _create_cluster_views(self, gui): @cluster_view.connect_ def on_select(cluster_ids): # Emit GUI.select when the selection changes in the cluster view. - self.select(cluster_ids) + self._emit_select(cluster_ids) # Pin the clusters and update the similarity view. self.pin(cluster_ids) @@ -248,7 +248,7 @@ def on_select(cluster_ids): def on_select(cluster_ids): # Select the clusters from both views. cluster_ids = cluster_view.selected + cluster_ids - self.select(cluster_ids) + self._emit_select(cluster_ids) # Update the cluster views and selection when a cluster event occurs. @self.gui.connect_ @@ -257,20 +257,36 @@ def on_cluster(up): sort = cluster_view.current_sort # Reinitialize the cluster view. self._update_cluster_view(cluster_view) + # Reset the previous sort options. + if sort[0]: + self.cluster_view.sort_by(sort[0]) + # TODO: second time for desc # Select all new clusters in view 1. if up.added: - cluster_view.select(up.added) + self.select(up.added) else: cluster_view.next() def _update_cluster_view(self, cluster_view): cols = ['id', 'quality'] # TODO: skip - items = [{'id': int(clu), 'quality': self.quality_func(clu)} + items = [{'id': clu, 'quality': self.quality_func(clu)} for clu in self.clustering.cluster_ids] # TODO: custom measures cluster_view.set_data(items, cols) + def _emit_select(self, cluster_ids): + """Choose spikes from the specified clusters and emit the + `select` event on the GUI.""" + # Choose a spike subset. + spike_ids = select_spikes(np.array(cluster_ids), + self.n_spikes_max_per_cluster, + self.clustering.spikes_per_cluster) + logger.debug("Select clusters: %s (%d spikes).", + ', '.join(map(str, cluster_ids)), len(spike_ids)) + if self.gui: + self.gui.emit('select', cluster_ids, spike_ids) + # Public methods # ------------------------------------------------------------------------- @@ -295,32 +311,23 @@ def attach(self, gui): # ------------------------------------------------------------------------- def select(self, *cluster_ids): - """Choose spikes from the specified clusters and emit the - `select` event on the GUI.""" - + """Select action: select clusters in the cluster view.""" # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])` # This makes it more convenient to select multiple clusters with # the snippet: `:c 1 2 3` instead of `:c 1,2,3`. if cluster_ids and isinstance(cluster_ids[0], (tuple, list)): cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:]) - - # Choose a spike subset. - spike_ids = select_spikes(np.array(cluster_ids), - self.n_spikes_max_per_cluster, - self.clustering.spikes_per_cluster) - logger.debug("Select clusters: %s (%d spikes).", - ', '.join(map(str, cluster_ids)), len(spike_ids)) - if self.gui: - self.gui.emit('select', cluster_ids, spike_ids) + # Update the cluster view selection. + self.cluster_view.select(cluster_ids) def pin(self, cluster_ids): """Update the similarity view with matches for the specified clusters.""" # TODO: similarity wrt several clusters - sel = int(cluster_ids[0]) + sel = cluster_ids[0] cols = ['id', 'similarity'] # TODO: skip - items = [{'id': int(clu), + items = [{'id': clu, 'similarity': self.similarity_func(sel, clu)} for clu in self.clustering.cluster_ids] self.similarity_view.set_data(items, cols) @@ -333,8 +340,9 @@ def pin(self, cluster_ids): # ------------------------------------------------------------------------- def merge(self, cluster_ids=None): - # if cluster_ids is None: - # cluster_ids = self.wizard.selection + if cluster_ids is None: + cluster_ids = (self.cluster_view.selected + + self.similarity_view.selected) if len(cluster_ids or []) <= 1: return self.clustering.merge(cluster_ids) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index a741e156b..0d82f8242 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -19,12 +19,14 @@ #------------------------------------------------------------------------------ @yield_fixture -def manual_clustering(gui, cluster_ids, cluster_groups): +def manual_clustering(gui, cluster_ids, cluster_groups, quality, similarity): spike_clusters = np.array(cluster_ids) mc = ManualClustering(spike_clusters, cluster_groups=cluster_groups, shortcuts={'undo': 'ctrl+z'}, + quality_func=quality, + similarity_func=similarity, ) _s = [] @@ -33,12 +35,12 @@ def manual_clustering(gui, cluster_ids, cluster_groups): # Connect to the `select` event. @mc.gui.connect_ def on_select(cluster_ids, spike_ids): - _s.append((cluster_ids, spike_ids)) + _s.append(cluster_ids) def assert_selection(*cluster_ids): # pragma: no cover if not _s: return - assert _s[-1][0] == list(cluster_ids) + assert _s[-1] == list(cluster_ids) yield mc, assert_selection diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 6a0a0f3ee..e9b82f574 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -15,6 +15,7 @@ from .qt import QWebView, QWebPage, QUrl, QWebSettings, pyqtSlot from phy.utils import EventEmitter +from phy.utils._misc import _CustomEncoder logger = logging.getLogger(__name__) @@ -216,12 +217,19 @@ def show(self): # HTML table # ----------------------------------------------------------------------------- +def dumps(o): + return json.dumps(o, cls=_CustomEncoder) + + def _create_json_dict(**kwargs): d = {} + # Remove None elements. for k, v in kwargs.items(): if v is not None: d[k] = v - return json.dumps(d) + # The custom encoder serves for NumPy scalars that are non + # JSON-serializable (!!). + return dumps(d) class Table(HTMLWidget): @@ -262,7 +270,7 @@ def previous(self): def select(self, ids): """Select some rows.""" - self.eval_js('table.select({}, false);'.format(json.dumps(ids))) + self.eval_js('table.select({}, false);'.format(dumps(ids))) @property def selected(self): From f4c793e8f6821583bfeae5709ce51936216975b0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 16:56:25 +0100 Subject: [PATCH 0540/1059] Update tests --- phy/cluster/manual/tests/test_gui_component.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 0d82f8242..3e769244f 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -133,11 +133,8 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.actions.select([30]) assert_selection(30) - # mc.wizard.set_quality_function(quality) - # mc.wizard.set_similarity_function(similarity) - - # mc.actions.next_by_quality() - # assert_selection(20) + mc.cluster_view.next() + assert_selection(20) mc.actions.move([20], 'noise') assert_selection(2) From d996844780f15f790b0b91a99d6a986aa42bd788 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 17:17:14 +0100 Subject: [PATCH 0541/1059] Add raise_event parameter to Table.select() --- phy/gui/static/table.js | 3 ++- phy/gui/widgets.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index bd34653a3..05951f907 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -83,9 +83,10 @@ Table.prototype.currentSort = function() { } } return [null, null]; -} +}; Table.prototype.select = function(ids, raise_event) { + // The default is true. raise_event = typeof raise_event !== 'undefined' ? false : true; // Remove the class on all rows. diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index e9b82f574..9abc103f1 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -268,9 +268,11 @@ def previous(self): """Select the previous non-skip row.""" self.eval_js('table.previous();') - def select(self, ids): + def select(self, ids, raise_event=False): """Select some rows.""" - self.eval_js('table.select({}, false);'.format(dumps(ids))) + raise_event = text_type(raise_event).lower() + self.eval_js('table.select({}, {});'.format(dumps(ids), + raise_event)) @property def selected(self): From ca1343d3cd12810c10026713ade7ce7eb47aed9a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 4 Nov 2015 18:48:44 +0100 Subject: [PATCH 0542/1059] WIP --- phy/cluster/manual/tests/conftest.py | 5 --- .../manual/tests/test_gui_component.py | 40 +++++++------------ phy/gui/static/table.js | 8 +--- phy/gui/widgets.py | 6 +-- 4 files changed, 19 insertions(+), 40 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index e5b84d81b..7a48df7a6 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -19,11 +19,6 @@ def cluster_ids(): # i, g, N, i, g, N, N -@yield_fixture -def get_cluster_ids(cluster_ids): - yield lambda: cluster_ids - - @yield_fixture def cluster_groups(): yield {0: 'noise', 1: 'good', 10: 'mua', 11: 'good'} diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 3e769244f..82165bc63 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -38,8 +38,6 @@ def on_select(cluster_ids, spike_ids): _s.append(cluster_ids) def assert_selection(*cluster_ids): # pragma: no cover - if not _s: - return assert _s[-1] == list(cluster_ids) yield mc, assert_selection @@ -92,28 +90,28 @@ def test_manual_clustering_edge_cases(manual_clustering): def test_manual_clustering_merge(manual_clustering): mc, assert_selection = manual_clustering - mc.actions.select(30, 20) # NOTE: we pass multiple ints instead of a list - mc.actions.merge() + mc.select(30, 20) # NOTE: we pass multiple ints instead of a list + mc.merge() assert_selection(31, 2) - mc.actions.undo() + mc.undo() assert_selection(30, 20) - mc.actions.redo() + mc.redo() assert_selection(31, 2) def test_manual_clustering_split(manual_clustering): mc, assert_selection = manual_clustering - mc.actions.select([1, 2]) - mc.actions.split([1, 2]) + mc.select([1, 2]) + mc.split([1, 2]) assert_selection(31, 20) - mc.actions.undo() + mc.undo() assert_selection(1, 2) - mc.actions.redo() + mc.redo() assert_selection(31, 20) @@ -123,36 +121,28 @@ def test_manual_clustering_split_2(gui): mc = ManualClustering(spike_clusters) mc.attach(gui) - mc.actions.split([0]) + mc.split([0]) # assert mc.wizard.selection == [2, 1] def test_manual_clustering_move(manual_clustering, quality, similarity): mc, assert_selection = manual_clustering - mc.actions.select([30]) + mc.select([30]) assert_selection(30) mc.cluster_view.next() assert_selection(20) - mc.actions.move([20], 'noise') + mc.move([20], 'noise') assert_selection(2) - mc.actions.undo() + mc.undo() assert_selection(20) - mc.actions.redo() + mc.redo() assert_selection(2) -def test_manual_clustering_show(qtbot, gui): - spike_clusters = np.array([0, 0, 1, 2, 0, 1]) - - def sf(c, d): - return float(c + d) - - mc = ManualClustering(spike_clusters, similarity_func=sf) - mc.attach(gui) - gui.show() - # qtbot.stop() +# def test_manual_clustering_show(qtbot, gui): +# mc, assert_selection = manual_clustering diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 05951f907..58f031239 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -85,9 +85,7 @@ Table.prototype.currentSort = function() { return [null, null]; }; -Table.prototype.select = function(ids, raise_event) { - // The default is true. - raise_event = typeof raise_event !== 'undefined' ? false : true; +Table.prototype.select = function(ids) { // Remove the class on all rows. for (var i = 0; i < this.selected.length; i++) { @@ -104,9 +102,7 @@ Table.prototype.select = function(ids, raise_event) { this.selected = ids; - if (raise_event) { - emit("select", ids); - } + emit("select", ids); }; Table.prototype.clear = function() { diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 9abc103f1..0727f0824 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -268,11 +268,9 @@ def previous(self): """Select the previous non-skip row.""" self.eval_js('table.previous();') - def select(self, ids, raise_event=False): + def select(self, ids): """Select some rows.""" - raise_event = text_type(raise_event).lower() - self.eval_js('table.select({}, {});'.format(dumps(ids), - raise_event)) + self.eval_js('table.select({});'.format(dumps(ids))) @property def selected(self): From 50b935c865c05aa97887af78298480986e7510d6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 9 Nov 2015 21:57:46 +0100 Subject: [PATCH 0543/1059] WIP: update Python-JS communication in Qt widgets --- phy/gui/tests/test_widgets.py | 16 ++++++++++++---- phy/gui/widgets.py | 32 ++++---------------------------- 2 files changed, 16 insertions(+), 32 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index df3eca421..aac84ed61 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -40,6 +40,7 @@ def table(qtbot): def test_widget_empty(qtbot): widget = HTMLWidget() + widget.build() widget.show() qtbot.waitForWindowShown(widget) # qtbot.stop() @@ -50,20 +51,27 @@ def test_widget_html(qtbot): widget.add_styles('html, body, p {background-color: purple;}') widget.add_header('') widget.set_body('Hello world!') - widget.eval_js('widget.set_body("Hello from Javascript!");') + widget.build() widget.show() qtbot.waitForWindowShown(widget) - widget.build() - assert 'Javascript' in widget.html() + assert 'Hello world!' in widget.html() -def test_widget_javascript(qtbot): +def test_widget_javascript_1(qtbot): widget = HTMLWidget() + widget.build() widget.show() qtbot.waitForWindowShown(widget) + widget.eval_js('number = 1;') assert widget.get_js('number') == 1 + +def test_widget_javascript(qtbot): + widget = HTMLWidget() + widget.build() + widget.show() + qtbot.waitForWindowShown(widget) _out = [] @widget.connect_ diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 0727f0824..c2532376f 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -87,10 +87,7 @@ def __init__(self): widget._emit_from_js(name, JSON.stringify(arg)); }; ''') - # Pending eval_js to call *after* the page has been built and loaded. - # Use for calls to `eval_js()` before the page is loaded. - self._pending_eval_js = [] - self.loadFinished.connect(self._load_finished) + self.loadFinished.connect(lambda: self.emit('load')) # Events # ------------------------------------------------------------------------- @@ -127,7 +124,6 @@ def add_header(self, h): # HTML methods # ------------------------------------------------------------------------- - @pyqtSlot(str) def set_body(self, s): """Set the HTML body.""" self._body = s @@ -159,12 +155,6 @@ def is_built(self): # Javascript methods # ------------------------------------------------------------------------- - def _load_finished(self): - assert self.is_built() - for expr in self._pending_eval_js: - self.eval_js(expr) - self._pending_eval_js = [] - def add_to_js(self, name, var): """Add an object to Javascript.""" frame = self.page().mainFrame() @@ -172,12 +162,8 @@ def add_to_js(self, name, var): def eval_js(self, expr): """Evaluate a Javascript expression.""" - if not self.is_built(): - # If the page is not built yet, postpone the evaluation of the JS - # to after the page is loaded. - logger.log(5, "Postpone evaluation of `%s`.", expr) - self._pending_eval_js.append(expr) - return + if not self.is_built(): # pragma: no cover + raise RuntimeError("The page isn't built.") logger.log(5, "Evaluate Javascript: `%s`.", expr) self.page().mainFrame().evaluateJavaScript(expr) @@ -201,17 +187,6 @@ def get_js(self, expr): self._obj = None return obj - def show(self): - """Show the widget. - - A build is triggered if necessary. - - """ - # Build if no HTML has been set. - if not self.is_built(): - self.build() - return super(HTMLWidget, self).show() - # ----------------------------------------------------------------------------- # HTML table @@ -248,6 +223,7 @@ def __init__(self): self.add_body(''''''.format(self._table_id)) + self.build() def set_data(self, items, cols): """Set the rows and cols of the table.""" From 22b3564b78ec7adf736fd01b7e932e83195c321d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 9 Nov 2015 22:30:31 +0100 Subject: [PATCH 0544/1059] WIP: fix manual clustering component with new cluster view wizard --- phy/cluster/manual/gui_component.py | 19 ++++-- .../manual/tests/test_gui_component.py | 58 ++++++++----------- phy/gui/tests/test_widgets.py | 2 +- 3 files changed, 38 insertions(+), 41 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index b78884b9f..4617813a3 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -224,16 +224,18 @@ def _create_actions(self, gui): def _create_cluster_views(self, gui): # Create the cluster view. self.cluster_view = cluster_view = Table() - # NOTE: table.setData() must be called *before* build() so that - # the JS call is deferred to after the HTML widget is fully loaded. - self._update_cluster_view(cluster_view) - cluster_view.build() + + @cluster_view.connect_ + def on_load(): + self._update_cluster_view(cluster_view) + gui.add_view(cluster_view, title='ClusterView') + cluster_view.show() # Create the similarity view. self.similarity_view = similarity_view = Table() - similarity_view.build() gui.add_view(similarity_view, title='SimilarityView') + similarity_view.show() # Selection in the cluster view. @cluster_view.connect_ @@ -310,6 +312,10 @@ def attach(self, gui): # Selection actions # ------------------------------------------------------------------------- + @property + def selected(self): + return self.cluster_view.selected + self.similarity_view.selected + def select(self, *cluster_ids): """Select action: select clusters in the cluster view.""" # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])` @@ -329,7 +335,8 @@ def pin(self, cluster_ids): # TODO: skip items = [{'id': clu, 'similarity': self.similarity_func(sel, clu)} - for clu in self.clustering.cluster_ids] + for clu in self.clustering.cluster_ids + if clu not in cluster_ids] self.similarity_view.set_data(items, cols) # NOTE: sort twice to get decreasing order. diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 82165bc63..758be25a8 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -19,7 +19,8 @@ #------------------------------------------------------------------------------ @yield_fixture -def manual_clustering(gui, cluster_ids, cluster_groups, quality, similarity): +def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, + quality, similarity): spike_clusters = np.array(cluster_ids) mc = ManualClustering(spike_clusters, @@ -28,19 +29,9 @@ def manual_clustering(gui, cluster_ids, cluster_groups, quality, similarity): quality_func=quality, similarity_func=similarity, ) - _s = [] - mc.attach(gui) - - # Connect to the `select` event. - @mc.gui.connect_ - def on_select(cluster_ids, spike_ids): - _s.append(cluster_ids) - - def assert_selection(*cluster_ids): # pragma: no cover - assert _s[-1] == list(cluster_ids) - - yield mc, assert_selection + qtbot.waitForWindowShown(mc.cluster_view) + yield mc @yield_fixture @@ -55,31 +46,30 @@ def gui(qapp): #------------------------------------------------------------------------------ def test_manual_clustering_edge_cases(manual_clustering): - mc, assert_selection = manual_clustering + mc = manual_clustering # Empty selection at first. - assert_selection() ae(mc.clustering.cluster_ids, [0, 1, 2, 10, 11, 20, 30]) mc.select([0]) - assert_selection(0) + assert mc.selected == [0] mc.undo() mc.redo() # Merge. mc.merge() - assert_selection(0) + assert mc.selected == [0] mc.merge([]) - assert_selection(0) + assert mc.selected == [0] mc.merge([10]) - assert_selection(0) + assert mc.selected == [0] # Split. mc.split([]) - assert_selection(0) + assert mc.selected == [0] # Move. mc.move([], 'ignored') @@ -88,31 +78,31 @@ def test_manual_clustering_edge_cases(manual_clustering): def test_manual_clustering_merge(manual_clustering): - mc, assert_selection = manual_clustering + mc = manual_clustering mc.select(30, 20) # NOTE: we pass multiple ints instead of a list mc.merge() - assert_selection(31, 2) + assert mc.selected == [31, 2] mc.undo() - assert_selection(30, 20) + assert mc.selected == [30, 20] mc.redo() - assert_selection(31, 2) + assert mc.selected == [31, 2] def test_manual_clustering_split(manual_clustering): - mc, assert_selection = manual_clustering + mc = manual_clustering mc.select([1, 2]) mc.split([1, 2]) - assert_selection(31, 20) + assert mc.selected == [31, 20] mc.undo() - assert_selection(1, 2) + assert mc.selected == [1, 2] mc.redo() - assert_selection(31, 20) + assert mc.selected == [31, 20] def test_manual_clustering_split_2(gui): @@ -126,22 +116,22 @@ def test_manual_clustering_split_2(gui): def test_manual_clustering_move(manual_clustering, quality, similarity): - mc, assert_selection = manual_clustering + mc = manual_clustering mc.select([30]) - assert_selection(30) + assert mc.selected == [30] mc.cluster_view.next() - assert_selection(20) + assert mc.selected == [20] mc.move([20], 'noise') - assert_selection(2) + assert mc.selected == [2] mc.undo() - assert_selection(20) + assert mc.selected == [20] mc.redo() - assert_selection(2) + assert mc.selected == [2] # def test_manual_clustering_show(qtbot, gui): diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index aac84ed61..4d9bdfded 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -67,7 +67,7 @@ def test_widget_javascript_1(qtbot): assert widget.get_js('number') == 1 -def test_widget_javascript(qtbot): +def test_widget_javascript_2(qtbot): widget = HTMLWidget() widget.build() widget.show() From 32988b9580ce6f0acc0d29c84ab9f1b4fc342fdf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 9 Nov 2015 23:37:46 +0100 Subject: [PATCH 0545/1059] Make table.next() work when there is no current selection --- phy/gui/static/table.js | 39 ++++++++++++++++++----------------- phy/gui/tests/test_widgets.py | 5 +++++ 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 58f031239..c3cf1cd7e 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -112,40 +112,41 @@ Table.prototype.clear = function() { Table.prototype.next = function() { // TODO: what to do when doing next() while several items are selected. var id = this.selected[0]; - var row = this.rows[id]; - var i0 = row.rowIndex + 1; - var items = []; - + if (id === undefined) { + var row = null; + var i0 = 1; // 1, not 0, because we skip the header. + } + else { + var row = this.rows[id]; + var i0 = row.rowIndex + 1; + } for (var i = i0; i < this.el.rows.length; i++) { row = this.el.rows[i]; if (!(row.dataset.skip)) { - items.push(row.dataset.id); - break; + this.select([row.dataset.id]); + return; } } - - if (!(items.length)) return; - - this.select(items); }; Table.prototype.previous = function() { // TODO: what to do when doing next() while several items are selected. var id = this.selected[0]; - var row = this.rows[id]; - var i0 = row.rowIndex - 1; - var items = []; + if (id === undefined) { + var row = null; + var i0 = this.rows.length - 1; + } + else { + var row = this.rows[id]; + var i0 = row.rowIndex - 1; + } // NOTE: i >= 1 because we skip the header column. for (var i = i0; i >= 1; i--) { row = this.el.rows[i]; if (!(row.dataset.skip)) { - items.push(row.dataset.id); - break; + this.select([row.dataset.id]); + return; } } - - if (!(items.length)) return; - - this.select(items); }; diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 4d9bdfded..3d45789fa 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -86,6 +86,11 @@ def on_test(arg): # qtbot.stop() +def test_table_nav_first(qtbot, table): + table.next() + assert table.selected == [0] + + def test_table_nav(qtbot, table): table.select([4]) From 9ae21c9b121406ce84e66c6a37b6cdf3377270e8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 9 Nov 2015 23:40:24 +0100 Subject: [PATCH 0546/1059] WIP: making tests pass in cluster/manual --- phy/cluster/manual/gui_component.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4617813a3..09ef782ba 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -265,8 +265,11 @@ def on_cluster(up): # TODO: second time for desc # Select all new clusters in view 1. if up.added: + # TODO: self.select(sel1, sel2) for both views. self.select(up.added) + self.pin(up.added) else: + # TODO: move in the sim view if the moved cluster were there cluster_view.next() def _update_cluster_view(self, cluster_view): From d1386e2dfd56ae4e61d2e553a6db8c9d0249e9a9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 00:26:30 +0100 Subject: [PATCH 0547/1059] WIP: update gui_component --- phy/cluster/manual/gui_component.py | 28 +++++++++++++---- .../manual/tests/test_gui_component.py | 30 +++++++++---------- phy/gui/widgets.py | 1 + 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 09ef782ba..c25b30b24 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -252,6 +252,14 @@ def on_select(cluster_ids): cluster_ids = cluster_view.selected + cluster_ids self._emit_select(cluster_ids) + # Save the current selection when an action occurs. + def on_request_undo_state(up): + return {'selection': (self.cluster_view.selected, + self.similarity_view.selected)} + + self.clustering.connect(on_request_undo_state) + self.cluster_meta.connect(on_request_undo_state) + # Update the cluster views and selection when a cluster event occurs. @self.gui.connect_ def on_cluster(up): @@ -264,10 +272,18 @@ def on_cluster(up): self.cluster_view.sort_by(sort[0]) # TODO: second time for desc # Select all new clusters in view 1. - if up.added: + if up.history == 'undo': + # Select the clusters that were selected before the undone + # action. + clusters_0, clusters_1 = up.undo_state[0]['selection'] + self.cluster_view.select(clusters_0) + self.similarity_view.select(clusters_1) + elif up.added: # TODO: self.select(sel1, sel2) for both views. self.select(up.added) self.pin(up.added) + # TODO: only if similarity selection non empty + similarity_view.next() else: # TODO: move in the sim view if the moved cluster were there cluster_view.next() @@ -315,10 +331,6 @@ def attach(self, gui): # Selection actions # ------------------------------------------------------------------------- - @property - def selected(self): - return self.cluster_view.selected + self.similarity_view.selected - def select(self, *cluster_ids): """Select action: select clusters in the cluster view.""" # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])` @@ -332,6 +344,8 @@ def select(self, *cluster_ids): def pin(self, cluster_ids): """Update the similarity view with matches for the specified clusters.""" + if not len(cluster_ids): + return # TODO: similarity wrt several clusters sel = cluster_ids[0] cols = ['id', 'similarity'] @@ -346,6 +360,10 @@ def pin(self, cluster_ids): self.similarity_view.sort_by('similarity') self.similarity_view.sort_by('similarity') + @property + def selected(self): + return self.cluster_view.selected + self.similarity_view.selected + # Clustering actions # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 758be25a8..04836df43 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -35,8 +35,10 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, @yield_fixture -def gui(qapp): +def gui(qtbot): gui = GUI(position=(200, 100), size=(500, 500)) + gui.show() + qtbot.waitForWindowShown(gui) yield gui gui.close() @@ -82,13 +84,13 @@ def test_manual_clustering_merge(manual_clustering): mc.select(30, 20) # NOTE: we pass multiple ints instead of a list mc.merge() - assert mc.selected == [31, 2] + assert mc.selected == [31, 11] mc.undo() assert mc.selected == [30, 20] mc.redo() - assert mc.selected == [31, 2] + assert mc.selected == [31, 11] def test_manual_clustering_split(manual_clustering): @@ -96,13 +98,13 @@ def test_manual_clustering_split(manual_clustering): mc.select([1, 2]) mc.split([1, 2]) - assert mc.selected == [31, 20] + assert mc.selected == [31, 30] mc.undo() assert mc.selected == [1, 2] mc.redo() - assert mc.selected == [31, 20] + assert mc.selected == [31, 30] def test_manual_clustering_split_2(gui): @@ -112,27 +114,23 @@ def test_manual_clustering_split_2(gui): mc.attach(gui) mc.split([0]) - # assert mc.wizard.selection == [2, 1] + assert mc.selected == [2, 3, 1] def test_manual_clustering_move(manual_clustering, quality, similarity): mc = manual_clustering + mc.cluster_view.sort_by('quality') + # TODO: desc + # mc.cluster_view.sort_by('quality') - mc.select([30]) - assert mc.selected == [30] - - mc.cluster_view.next() + mc.select([20]) assert mc.selected == [20] mc.move([20], 'noise') - assert mc.selected == [2] + assert mc.selected == [30] mc.undo() assert mc.selected == [20] mc.redo() - assert mc.selected == [2] - - -# def test_manual_clustering_show(qtbot, gui): -# mc, assert_selection = manual_clustering + assert mc.selected == [30] diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index c2532376f..1c803a60f 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -234,6 +234,7 @@ def set_data(self, items, cols): def sort_by(self, header): """Sort by a given variable.""" + # TODO: asc or desc self.eval_js('table.sortBy("{}");'.format(header)) def next(self): From 2cc4fac19f9fce9143487c2dba6c622e1516bba6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 00:29:01 +0100 Subject: [PATCH 0548/1059] Fix widget.eval_js() --- phy/gui/tests/test_widgets.py | 2 +- phy/gui/widgets.py | 22 +++------------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 3d45789fa..0cd053d13 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -64,7 +64,7 @@ def test_widget_javascript_1(qtbot): qtbot.waitForWindowShown(widget) widget.eval_js('number = 1;') - assert widget.get_js('number') == 1 + assert widget.eval_js('number') == 1 def test_widget_javascript_2(qtbot): diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 1c803a60f..ffbbc5a2c 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -165,28 +165,12 @@ def eval_js(self, expr): if not self.is_built(): # pragma: no cover raise RuntimeError("The page isn't built.") logger.log(5, "Evaluate Javascript: `%s`.", expr) - self.page().mainFrame().evaluateJavaScript(expr) - - @pyqtSlot(str) - def _set_from_js(self, obj): - """Called from Javascript to pass any object to Python through JSON.""" - self._obj = json.loads(text_type(obj)) + return self.page().mainFrame().evaluateJavaScript(expr) @pyqtSlot(str, str) def _emit_from_js(self, name, arg_json): self.emit(text_type(name), json.loads(text_type(arg_json))) - def get_js(self, expr): - """Evaluate a Javascript expression and get a Python object. - - This uses JSON serialization/deserialization under the hood. - - """ - self.eval_js('widget._set_from_js(JSON.stringify({}));'.format(expr)) - obj = self._obj - self._obj = None - return obj - # ----------------------------------------------------------------------------- # HTML table @@ -252,8 +236,8 @@ def select(self, ids): @property def selected(self): """Currently selected rows.""" - return [int(_) for _ in self.get_js('table.selected')] + return [int(_) for _ in self.eval_js('table.selected')] @property def current_sort(self): - return tuple(self.get_js('table.currentSort()')) + return tuple(self.eval_js('table.currentSort()')) From 4e083c8416a4da8a7a08855fefa470e3be92464f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 01:07:16 +0100 Subject: [PATCH 0549/1059] WIP --- phy/cluster/manual/tests/test_gui_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 04836df43..2ab60f5e6 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -30,7 +30,7 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, similarity_func=similarity, ) mc.attach(gui) - qtbot.waitForWindowShown(mc.cluster_view) + qtbot.waitForWindowShown(gui) yield mc From 1125ffbb8206ed387054a31dc3d106187dba17db Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 11:29:46 +0100 Subject: [PATCH 0550/1059] Fix table build bug --- phy/cluster/manual/gui_component.py | 2 ++ phy/cluster/manual/tests/test_gui_component.py | 1 - phy/gui/actions.py | 4 ++-- phy/gui/tests/test_widgets.py | 2 +- phy/gui/widgets.py | 10 ++++++++-- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index c25b30b24..b391a98a4 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -230,11 +230,13 @@ def on_load(): self._update_cluster_view(cluster_view) gui.add_view(cluster_view, title='ClusterView') + cluster_view.build() cluster_view.show() # Create the similarity view. self.similarity_view = similarity_view = Table() gui.add_view(similarity_view, title='SimilarityView') + similarity_view.build() similarity_view.show() # Selection in the cluster view. diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 2ab60f5e6..9b4f99533 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -30,7 +30,6 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, similarity_func=similarity, ) mc.attach(gui) - qtbot.waitForWindowShown(gui) yield mc diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 4061c8309..43479d5b6 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -179,8 +179,8 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback) if verbose and not name.startswith('_'): - logger.debug("Add action `%s` (%s).", name, - _get_shortcut_string(action.shortcut())) + logger.log(5, "Add action `%s` (%s).", name, + _get_shortcut_string(action.shortcut())) self.gui.addAction(action) self._actions_dict[name] = action_obj # Register the alias -> name mapping. diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 0cd053d13..2f86a2dac 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -18,7 +18,7 @@ @yield_fixture def table(qtbot): table = Table() - + table.build() table.show() qtbot.waitForWindowShown(table) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index ffbbc5a2c..a9454c2e0 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -87,11 +87,13 @@ def __init__(self): widget._emit_from_js(name, JSON.stringify(arg)); }; ''') - self.loadFinished.connect(lambda: self.emit('load')) # Events # ------------------------------------------------------------------------- + def _load_finished(self, boo): + self.emit('load') + def emit(self, *args, **kwargs): return self._event.emit(*args, **kwargs) @@ -147,6 +149,7 @@ def build(self): logger.log(5, "Set HTML: %s", html) static_dir = op.join(op.realpath(op.dirname(__file__)), 'static/') base_url = QUrl().fromLocalFile(static_dir) + self.loadFinished.connect(self._load_finished) self.setHtml(html, base_url) def is_built(self): @@ -207,7 +210,10 @@ def __init__(self): self.add_body(''''''.format(self._table_id)) - self.build() + # NOTE: the table should *not* be built at initialization, because + # we may need to connect the load event before the table is built. + # This is why this line is commented. + # self.build() def set_data(self, items, cols): """Set the rows and cols of the table.""" From 47f9d8851f1ee91bc474b18e638d79ff6df999d8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 11:54:11 +0100 Subject: [PATCH 0551/1059] Fix Python 2 error with QVariant --- phy/gui/qt.py | 1 + phy/gui/widgets.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 40c58ef26..07c9a1d31 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -18,6 +18,7 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa + QVariant, pyqtSignal, pyqtSlot, QSize, QUrl) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index a9454c2e0..82772c0e2 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -13,7 +13,7 @@ from six import text_type -from .qt import QWebView, QWebPage, QUrl, QWebSettings, pyqtSlot +from .qt import QWebView, QWebPage, QUrl, QWebSettings, QVariant, pyqtSlot from phy.utils import EventEmitter from phy.utils._misc import _CustomEncoder @@ -168,7 +168,8 @@ def eval_js(self, expr): if not self.is_built(): # pragma: no cover raise RuntimeError("The page isn't built.") logger.log(5, "Evaluate Javascript: `%s`.", expr) - return self.page().mainFrame().evaluateJavaScript(expr) + out = self.page().mainFrame().evaluateJavaScript(expr) + return out.toPyObject() if isinstance(out, QVariant) else out @pyqtSlot(str, str) def _emit_from_js(self, name, arg_json): From 02d4969f0dcdd8a5faa88cca2f8b67347057de4f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 21:58:24 +0100 Subject: [PATCH 0552/1059] Rename default wizard function --- phy/cluster/manual/__init__.py | 2 +- phy/cluster/manual/gui_component.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index d6b871021..b9a3246ba 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,5 +5,5 @@ from ._utils import ClusterMeta from .clustering import Clustering -from .gui_component import ManualClustering +from .gui_component import ManualClustering, default_wizard_functions from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index b391a98a4..6e3b31fcb 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -45,12 +45,12 @@ def _process_ups(ups): # pragma: no cover raise NotImplementedError() -def _make_wizard_default_functions(waveforms=None, - features=None, - masks=None, - n_features_per_channel=None, - spikes_per_cluster=None, - ): +def default_wizard_functions(waveforms=None, + features=None, + masks=None, + n_features_per_channel=None, + spikes_per_cluster=None, + ): spc = spikes_per_cluster nfc = n_features_per_channel From 1a59ee6a3b02325eb7b274d3c76b9ff27bc5445e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 22:41:49 +0100 Subject: [PATCH 0553/1059] WIP: fix selection in table widget --- phy/gui/static/table.js | 30 ++++++++++++++++++++++++++---- phy/gui/tests/test_widgets.py | 6 ++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index c3cf1cd7e..efe506644 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -1,4 +1,12 @@ +function uniq(a) { + var seen = {}; + return a.filter(function(item) { + return seen.hasOwnProperty(item) ? false : (seen[item] = true); + }); +} + + var Table = function (el) { this.el = el; this.selected = []; @@ -49,13 +57,26 @@ Table.prototype.setData = function(data) { } tr.onclick = function(e) { - var selected = [this.dataset.id]; - + var id = parseInt(String(this.dataset.id)); var evt = e ? e:window.event; + // Control pressed: toggle selected. if (evt.ctrlKey || evt.metaKey) { - selected = that.selected.concat(selected); + var index = that.selected.indexOf(id); + // If this item is already selected, deselect it. + if (index != -1) { + var selected = that.selected.slice(); + selected.splice(index, 1); + that.select(selected); + } + // Otherwise, select it. + else { + that.select(that.selected.concat([id])); + } + } + // Otherwise, select just that item. + else { + that.select([id]); } - that.select(selected); } tbody.appendChild(tr); @@ -86,6 +107,7 @@ Table.prototype.currentSort = function() { }; Table.prototype.select = function(ids) { + ids = uniq(ids); // Remove the class on all rows. for (var i = 0; i < this.selected.length; i++) { diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 2f86a2dac..8a25edbd8 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -86,6 +86,12 @@ def on_test(arg): # qtbot.stop() +def test_table_duplicates(qtbot, table): + table.select([1, 1]) + assert table.selected == [1] + # qtbot.stop() + + def test_table_nav_first(qtbot, table): table.next() assert table.selected == [0] From 5a3c8d659d52434ab0c614c65c3f76a4c13e6623 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 22:57:44 +0100 Subject: [PATCH 0554/1059] Format numbers in table --- phy/gui/static/table.js | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index efe506644..274ebabe0 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -1,3 +1,4 @@ +// Utils. function uniq(a) { var seen = {}; @@ -6,6 +7,12 @@ function uniq(a) { }); } +function isFloat(n) { + return n === Number(n) && n % 1 !== 0; +} + + +// Table class. var Table = function (el) { this.el = el; @@ -46,6 +53,9 @@ Table.prototype.setData = function(data) { for (var j = 0; j < keys.length; j++) { var key = keys[j]; var value = row[key]; + // Format numbers. + if (isFloat(value)) + value = value.toPrecision(3); var td = document.createElement("td"); td.appendChild(document.createTextNode(value)); tr.appendChild(td); From 3470b77723487914d375fc613f4533585c162aa1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 11 Nov 2015 23:05:29 +0100 Subject: [PATCH 0555/1059] Fix bug with number sort in table --- phy/gui/static/tablesort.number.js | 2 +- phy/gui/tests/test_widgets.py | 8 ++++++-- phy/gui/widgets.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/phy/gui/static/tablesort.number.js b/phy/gui/static/tablesort.number.js index 2e11b462a..d43405fb8 100644 --- a/phy/gui/static/tablesort.number.js +++ b/phy/gui/static/tablesort.number.js @@ -1,6 +1,6 @@ (function(){ var cleanNumber = function(i) { - return i.replace(/[^\-?0-9.]/g, ''); + return i.replace(/[^\-\+eE\,?0-9\.]/g, ''); }, compareNumber = function(a, b) { diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 8a25edbd8..60b289f2b 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -22,7 +22,7 @@ def table(qtbot): table.show() qtbot.waitForWindowShown(table) - items = [{'id': i, 'count': 100 - 10 * i} for i in range(10)] + items = [{'id': i, 'count': 10000.5 - 10 * i} for i in range(10)] items[4]['skip'] = True table.set_data(cols=['id', 'count'], @@ -35,7 +35,7 @@ def table(qtbot): #------------------------------------------------------------------------------ -# Test actions +# Test widgets #------------------------------------------------------------------------------ def test_widget_empty(qtbot): @@ -86,6 +86,10 @@ def on_test(arg): # qtbot.stop() +#------------------------------------------------------------------------------ +# Test table +#------------------------------------------------------------------------------ + def test_table_duplicates(qtbot, table): table.select([1, 1]) assert table.selected == [1] diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 82772c0e2..8d27d1090 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -29,7 +29,7 @@ background-color: black; color: white; font-family: sans-serif; - font-size: 18pt; + font-size: 14pt; margin: 5px 10px; } """ From aea3a63e11ebb8cf6599bfedbdc49002b4803a5a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 14:46:24 +0100 Subject: [PATCH 0556/1059] Add _wait_signal() function --- phy/gui/qt.py | 12 +++++++++++- phy/gui/tests/test_qt.py | 21 ++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 07c9a1d31..0cd4b519d 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -18,7 +18,7 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - QVariant, + QVariant, QEventLoop, QTimer, pyqtSignal, pyqtSlot, QSize, QUrl) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, @@ -59,6 +59,16 @@ def _show_box(box): # pragma: no cover return _button_name_from_enum(box.exec_()) +def _wait_signal(signal, timeout=None): + """Block loop until signal emitted, or timeout (ms) elapses.""" + # http://jdreaver.com/posts/2014-07-03-waiting-for-signals-pyside-pyqt.html + loop = QEventLoop() + signal.connect(loop.quit) + if timeout is not None: + QTimer.singleShot(timeout, loop.quit) + loop.exec_() + + # ----------------------------------------------------------------------------- # Qt app # ----------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index 5ebef620f..fdb1628e8 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -8,10 +8,11 @@ from pytest import raises -from ..qt import (QMessageBox, Qt, QWebView, +from ..qt import (QMessageBox, Qt, QWebView, QTimer, _button_name_from_enum, _button_enum_from_name, _prompt, + _wait_signal, require_qt, create_app, QApplication, @@ -50,6 +51,24 @@ def test_qt_app(qtbot): view.close() +def test_wait_signal(qtbot): + x = [] + + def f(): + x.append(0) + + timer = QTimer() + timer.setInterval(100) + timer.setSingleShot(True) + timer.timeout.connect(f) + timer.start() + + assert x == [] + + _wait_signal(timer.timeout) + assert x == [0] + + def test_web_view(qtbot): view = QWebView() From fc370b4e0bbd49de64316902a0c2f4db60157d8f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 14:54:06 +0100 Subject: [PATCH 0557/1059] Use context manager in _wait_signal() --- phy/gui/qt.py | 5 +++++ phy/gui/tests/test_qt.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 0cd4b519d..f8d4d6a75 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -6,6 +6,7 @@ # Imports # ----------------------------------------------------------------------------- +from contextlib import contextmanager from functools import wraps import logging import sys @@ -59,11 +60,15 @@ def _show_box(box): # pragma: no cover return _button_name_from_enum(box.exec_()) +@contextmanager def _wait_signal(signal, timeout=None): """Block loop until signal emitted, or timeout (ms) elapses.""" # http://jdreaver.com/posts/2014-07-03-waiting-for-signals-pyside-pyqt.html loop = QEventLoop() signal.connect(loop.quit) + + yield + if timeout is not None: QTimer.singleShot(timeout, loop.quit) loop.exec_() diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index fdb1628e8..3e12053ad 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -65,7 +65,8 @@ def f(): assert x == [] - _wait_signal(timer.timeout) + with _wait_signal(timer.timeout): + pass assert x == [0] From 5c18948ecc441710dd6efb4db9359b56918f83c3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 15:07:55 +0100 Subject: [PATCH 0558/1059] Fix async widgets --- phy/gui/tests/test_widgets.py | 7 +------ phy/gui/widgets.py | 31 +++++++++++++++++++------------ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 60b289f2b..a5b765f67 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -18,7 +18,6 @@ @yield_fixture def table(qtbot): table = Table() - table.build() table.show() qtbot.waitForWindowShown(table) @@ -40,7 +39,6 @@ def table(qtbot): def test_widget_empty(qtbot): widget = HTMLWidget() - widget.build() widget.show() qtbot.waitForWindowShown(widget) # qtbot.stop() @@ -51,7 +49,6 @@ def test_widget_html(qtbot): widget.add_styles('html, body, p {background-color: purple;}') widget.add_header('') widget.set_body('Hello world!') - widget.build() widget.show() qtbot.waitForWindowShown(widget) assert 'Hello world!' in widget.html() @@ -59,17 +56,15 @@ def test_widget_html(qtbot): def test_widget_javascript_1(qtbot): widget = HTMLWidget() - widget.build() + widget.eval_js('number = 1;') widget.show() qtbot.waitForWindowShown(widget) - widget.eval_js('number = 1;') assert widget.eval_js('number') == 1 def test_widget_javascript_2(qtbot): widget = HTMLWidget() - widget.build() widget.show() qtbot.waitForWindowShown(widget) _out = [] diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 8d27d1090..84bc291ae 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -13,7 +13,10 @@ from six import text_type -from .qt import QWebView, QWebPage, QUrl, QWebSettings, QVariant, pyqtSlot +from .qt import (QWebView, QWebPage, QUrl, QWebSettings, QVariant, + pyqtSlot, + _wait_signal, + ) from phy.utils import EventEmitter from phy.utils._misc import _CustomEncoder @@ -87,13 +90,11 @@ def __init__(self): widget._emit_from_js(name, JSON.stringify(arg)); }; ''') + self._pending_js_eval = [] # Events # ------------------------------------------------------------------------- - def _load_finished(self, boo): - self.emit('load') - def emit(self, *args, **kwargs): return self._event.emit(*args, **kwargs) @@ -138,7 +139,7 @@ def html(self): """Return the full HTML source of the widget.""" return self.page().mainFrame().toHtml() - def build(self): + def _build(self): """Build the full HTML source.""" styles = '\n\n'.join(self._styles) html = _PAGE_TEMPLATE.format(title=self.title, @@ -149,7 +150,6 @@ def build(self): logger.log(5, "Set HTML: %s", html) static_dir = op.join(op.realpath(op.dirname(__file__)), 'static/') base_url = QUrl().fromLocalFile(static_dir) - self.loadFinished.connect(self._load_finished) self.setHtml(html, base_url) def is_built(self): @@ -165,8 +165,9 @@ def add_to_js(self, name, var): def eval_js(self, expr): """Evaluate a Javascript expression.""" - if not self.is_built(): # pragma: no cover - raise RuntimeError("The page isn't built.") + if not self.is_built(): + self._pending_js_eval.append(expr) + return logger.log(5, "Evaluate Javascript: `%s`.", expr) out = self.page().mainFrame().evaluateJavaScript(expr) return out.toPyObject() if isinstance(out, QVariant) else out @@ -175,6 +176,16 @@ def eval_js(self, expr): def _emit_from_js(self, name, arg_json): self.emit(text_type(name), json.loads(text_type(arg_json))) + def show(self): + with _wait_signal(self.loadFinished, 100): + self._build() + super(HTMLWidget, self).show() + # Call the pending JS eval calls after the page has been built. + assert self.is_built() + for expr in self._pending_js_eval: + self.eval_js(expr) + self._pending_js_eval = [] + # ----------------------------------------------------------------------------- # HTML table @@ -211,10 +222,6 @@ def __init__(self): self.add_body(''''''.format(self._table_id)) - # NOTE: the table should *not* be built at initialization, because - # we may need to connect the load event before the table is built. - # This is why this line is commented. - # self.build() def set_data(self, items, cols): """Set the rows and cols of the table.""" From e50a08c5500e65a2f25a7750c56b3d3272349705 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 15:09:50 +0100 Subject: [PATCH 0559/1059] Fix manual clustering component --- phy/cluster/manual/gui_component.py | 99 +++++++++---------- .../manual/tests/test_gui_component.py | 61 +++++++++++- 2 files changed, 104 insertions(+), 56 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 6e3b31fcb..70221d193 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -93,10 +93,10 @@ def mean_masked_features_similarity(c0, c1): d = mean_masked_features_distance(mf0, mf1, mm0, mm1, n_features_per_channel=nfc, ) - + d = 1. / max(1e-10, d) # From distance to similarity. logger.debug("Computed cluster similarity for (%d, %d): %.3f.", c0, c1, d) - return -d # NOTE: convert distance to score + return d return (max_waveform_amplitude_quality, mean_masked_features_similarity) @@ -157,8 +157,6 @@ def __init__(self, cluster_groups=None, n_spikes_max_per_cluster=100, shortcuts=None, - quality_func=None, - similarity_func=None, ): self.gui = None @@ -173,10 +171,6 @@ def __init__(self, self.cluster_meta = create_cluster_meta(cluster_groups) self._global_history = GlobalHistory(process_ups=_process_ups) - # Wizard functions. - self.quality_func = quality_func or (lambda c: 0) - self.similarity_func = similarity_func or (lambda c, d: 0) - # Log the actions. @self.clustering.connect def on_cluster(up): @@ -205,6 +199,9 @@ def on_cluster(up): if self.gui: self.gui.emit('cluster', up) + # Create the cluster views. + self._create_cluster_views() + # Internal methods # ------------------------------------------------------------------------- @@ -221,22 +218,13 @@ def _create_actions(self, gui): self.actions.add(self.undo) self.actions.add(self.redo) - def _create_cluster_views(self, gui): + def _create_cluster_views(self): # Create the cluster view. self.cluster_view = cluster_view = Table() - - @cluster_view.connect_ - def on_load(): - self._update_cluster_view(cluster_view) - - gui.add_view(cluster_view, title='ClusterView') - cluster_view.build() cluster_view.show() # Create the similarity view. self.similarity_view = similarity_view = Table() - gui.add_view(similarity_view, title='SimilarityView') - similarity_view.build() similarity_view.show() # Selection in the cluster view. @@ -262,35 +250,8 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) - # Update the cluster views and selection when a cluster event occurs. - @self.gui.connect_ - def on_cluster(up): - # Get the current sort of the cluster view. - sort = cluster_view.current_sort - # Reinitialize the cluster view. - self._update_cluster_view(cluster_view) - # Reset the previous sort options. - if sort[0]: - self.cluster_view.sort_by(sort[0]) - # TODO: second time for desc - # Select all new clusters in view 1. - if up.history == 'undo': - # Select the clusters that were selected before the undone - # action. - clusters_0, clusters_1 = up.undo_state[0]['selection'] - self.cluster_view.select(clusters_0) - self.similarity_view.select(clusters_1) - elif up.added: - # TODO: self.select(sel1, sel2) for both views. - self.select(up.added) - self.pin(up.added) - # TODO: only if similarity selection non empty - similarity_view.next() - else: - # TODO: move in the sim view if the moved cluster were there - cluster_view.next() - def _update_cluster_view(self, cluster_view): + assert self.quality_func cols = ['id', 'quality'] # TODO: skip items = [{'id': clu, 'quality': self.quality_func(clu)} @@ -316,18 +277,51 @@ def _emit_select(self, cluster_ids): def set_quality_func(self, f): self.quality_func = f + self._update_cluster_view(self.cluster_view) + self.cluster_view.sort_by('quality') + self.cluster_view.sort_by('quality') + def set_similarity_func(self, f): self.similarity_func = f def attach(self, gui): self.gui = gui - # Create the cluster views. - self._create_cluster_views(gui) - # Create the actions. self._create_actions(gui) + # Add the cluster views. + gui.add_view(self.cluster_view, title='ClusterView') + gui.add_view(self.similarity_view, title='SimilarityView') + + # Update the cluster views and selection when a cluster event occurs. + @self.gui.connect_ + def on_cluster(up): + # Get the current sort of the cluster view. + sort = self.cluster_view.current_sort + # Reinitialize the cluster view. + self._update_cluster_view(self.cluster_view) + # Reset the previous sort options. + if sort[0]: + self.cluster_view.sort_by(sort[0]) + # TODO: second time for desc + # Select all new clusters in view 1. + if up.history == 'undo': + # Select the clusters that were selected before the undone + # action. + clusters_0, clusters_1 = up.undo_state[0]['selection'] + self.cluster_view.select(clusters_0) + self.similarity_view.select(clusters_1) + elif up.added: + # TODO: self.select(sel1, sel2) for both views. + self.select(up.added) + self.pin(up.added) + # TODO: only if similarity selection non empty + self.similarity_view.next() + else: + # TODO: move in the sim view if the moved cluster were there + self.cluster_view.next() + return self # Selection actions @@ -343,11 +337,12 @@ def select(self, *cluster_ids): # Update the cluster view selection. self.cluster_view.select(cluster_ids) - def pin(self, cluster_ids): + def pin(self, cluster_ids=None): """Update the similarity view with matches for the specified clusters.""" - if not len(cluster_ids): - return + assert self.similarity_func + if cluster_ids is None or not len(cluster_ids): + cluster_ids = self.cluster_view.selected # TODO: similarity wrt several clusters sel = cluster_ids[0] cols = ['id', 'similarity'] diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 9b4f99533..a3739083c 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -10,8 +10,14 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from ..gui_component import ManualClustering +from ..gui_component import ManualClustering, default_wizard_functions from phy.gui import GUI +from phy.io.array import _spikes_per_cluster +from phy.io.mock import (artificial_waveforms, + artificial_masks, + artificial_features, + artificial_spike_clusters, + ) #------------------------------------------------------------------------------ @@ -26,10 +32,11 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, mc = ManualClustering(spike_clusters, cluster_groups=cluster_groups, shortcuts={'undo': 'ctrl+z'}, - quality_func=quality, - similarity_func=similarity, ) mc.attach(gui) + mc.set_quality_func(quality) + mc.set_similarity_func(similarity) + yield mc @@ -58,6 +65,9 @@ def test_manual_clustering_edge_cases(manual_clustering): mc.undo() mc.redo() + # Pin. + mc.pin([]) + # Merge. mc.merge() assert mc.selected == [0] @@ -78,6 +88,46 @@ def test_manual_clustering_edge_cases(manual_clustering): mc.save() +def test_manual_clustering_1(qtbot, gui): + + n_spikes = 10 + n_samples = 4 + n_channels = 7 + n_clusters = 3 + npc = 2 + + sc = artificial_spike_clusters(n_spikes, n_clusters) + spc = _spikes_per_cluster(sc) + + waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) + features = artificial_features(n_spikes, n_channels, npc) + masks = artificial_masks(n_spikes, n_channels) + + mc = ManualClustering(sc) + + q, s = default_wizard_functions(waveforms=waveforms, + features=features, + masks=masks, + n_features_per_channel=npc, + spikes_per_cluster=spc, + ) + mc.set_quality_func(q) + mc.set_similarity_func(s) + + mc.attach(gui) + gui.show() + qtbot.waitForWindowShown(gui) + + mc.cluster_view.next() + assert mc.cluster_view.selected == [1] + + mc.pin() + mc.similarity_view.next() + + assert mc.similarity_view.selected == [2] + assert mc.selected == [1, 2] + + def test_manual_clustering_merge(manual_clustering): mc = manual_clustering @@ -106,12 +156,15 @@ def test_manual_clustering_split(manual_clustering): assert mc.selected == [31, 30] -def test_manual_clustering_split_2(gui): +def test_manual_clustering_split_2(gui, quality, similarity): spike_clusters = np.array([0, 0, 1]) mc = ManualClustering(spike_clusters) mc.attach(gui) + mc.set_quality_func(quality) + mc.set_similarity_func(similarity) + mc.split([0]) assert mc.selected == [2, 3, 1] From 677bc22e4d9f4844f396adf38ee6e2fa964fbf98 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 15:16:36 +0100 Subject: [PATCH 0560/1059] Fix random test --- phy/cluster/manual/tests/test_gui_component.py | 14 +++++++++++--- phy/gui/widgets.py | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index a3739083c..c5ebdd605 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +from operator import itemgetter + from pytest import yield_fixture import numpy as np from numpy.testing import assert_array_equal as ae @@ -114,18 +116,24 @@ def test_manual_clustering_1(qtbot, gui): mc.set_quality_func(q) mc.set_similarity_func(s) + quality = [(c, q(c)) for c in spc] + best = sorted(quality, key=itemgetter(1))[-1][0] + + similarity = [(d, s(best, d)) for d in spc if d != best] + match = sorted(similarity, key=itemgetter(1))[-1][0] + mc.attach(gui) gui.show() qtbot.waitForWindowShown(gui) mc.cluster_view.next() - assert mc.cluster_view.selected == [1] + assert mc.cluster_view.selected == [best] mc.pin() mc.similarity_view.next() - assert mc.similarity_view.selected == [2] - assert mc.selected == [1, 2] + assert mc.similarity_view.selected == [match] + assert mc.selected == [best, match] def test_manual_clustering_merge(manual_clustering): diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 84bc291ae..754aea878 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -177,7 +177,7 @@ def _emit_from_js(self, name, arg_json): self.emit(text_type(name), json.loads(text_type(arg_json))) def show(self): - with _wait_signal(self.loadFinished, 100): + with _wait_signal(self.loadFinished, 50): self._build() super(HTMLWidget, self).show() # Call the pending JS eval calls after the page has been built. From 7e13291572e11790ea89b1de7823c0d30fb470a7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 15:44:38 +0100 Subject: [PATCH 0561/1059] Silent joblib cache --- phy/io/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/io/context.py b/phy/io/context.py index 3ad507b04..5c6eb1fa5 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -155,7 +155,7 @@ def _set_memory(self, cache_dir): # Try importing joblib. try: from joblib import Memory - self._memory = Memory(cachedir=self.cache_dir) + self._memory = Memory(cachedir=self.cache_dir, verbose=0) logger.debug("Initialize joblib cache dir at `%s`.", self.cache_dir) except ImportError: # pragma: no cover From fb8abc2f0d804bb13f5b4eb1a065b64820a58c31 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 21:11:24 +0100 Subject: [PATCH 0562/1059] Add dir argument in table.sort_by() --- phy/cluster/manual/gui_component.py | 9 ++------- phy/cluster/manual/tests/test_gui_component.py | 8 +++----- phy/gui/static/table.js | 5 ++++- phy/gui/tests/test_widgets.py | 3 +-- phy/gui/widgets.py | 5 ++--- 5 files changed, 12 insertions(+), 18 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 70221d193..174faa1e6 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -303,8 +303,7 @@ def on_cluster(up): self._update_cluster_view(self.cluster_view) # Reset the previous sort options. if sort[0]: - self.cluster_view.sort_by(sort[0]) - # TODO: second time for desc + self.cluster_view.sort_by(*sort) # Select all new clusters in view 1. if up.history == 'undo': # Select the clusters that were selected before the undone @@ -313,7 +312,6 @@ def on_cluster(up): self.cluster_view.select(clusters_0) self.similarity_view.select(clusters_1) elif up.added: - # TODO: self.select(sel1, sel2) for both views. self.select(up.added) self.pin(up.added) # TODO: only if similarity selection non empty @@ -352,10 +350,7 @@ def pin(self, cluster_ids=None): for clu in self.clustering.cluster_ids if clu not in cluster_ids] self.similarity_view.set_data(items, cols) - - # NOTE: sort twice to get decreasing order. - self.similarity_view.sort_by('similarity') - self.similarity_view.sort_by('similarity') + self.similarity_view.sort_by('similarity', 'desc') @property def selected(self): diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index c5ebdd605..5505d6252 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -179,18 +179,16 @@ def test_manual_clustering_split_2(gui, quality, similarity): def test_manual_clustering_move(manual_clustering, quality, similarity): mc = manual_clustering - mc.cluster_view.sort_by('quality') - # TODO: desc - # mc.cluster_view.sort_by('quality') + mc.cluster_view.sort_by('quality', 'desc') mc.select([20]) assert mc.selected == [20] mc.move([20], 'noise') - assert mc.selected == [30] + assert mc.selected == [11] mc.undo() assert mc.selected == [20] mc.redo() - assert mc.selected == [30] + assert mc.selected == [11] diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 274ebabe0..2b2e5160f 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -100,8 +100,11 @@ Table.prototype.setData = function(data) { this.tablesort = new Tablesort(this.el); }; -Table.prototype.sortBy = function(header) { +Table.prototype.sortBy = function(header, dir) { + dir = typeof dir !== 'undefined' ? dir : 'asc'; this.tablesort.sortTable(this.headers[header]); + if (dir == 'desc') + this.tablesort.sortTable(this.headers[header]); }; Table.prototype.currentSort = function() { diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index a5b765f67..4abe63438 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -125,8 +125,7 @@ def test_table_sort(qtbot, table): # Sort by count decreasing, and check that 0 (count 100) comes before # 1 (count 90). This checks that sorting works with number (need to # import tablesort.number.js). - table.sort_by('count') - table.sort_by('count') + table.sort_by('count', 'desc') table.previous() assert table.selected == [0] diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 754aea878..933dfa84c 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -230,10 +230,9 @@ def set_data(self, items, cols): ) self.eval_js('table.setData({});'.format(data)) - def sort_by(self, header): + def sort_by(self, header, dir='asc'): """Sort by a given variable.""" - # TODO: asc or desc - self.eval_js('table.sortBy("{}");'.format(header)) + self.eval_js('table.sortBy("{}", "{}");'.format(header, dir)) def next(self): """Select the next non-skip row.""" From 4437b24998aca0473daea26de8dbb497851c31ed Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 12 Nov 2015 23:14:50 +0100 Subject: [PATCH 0563/1059] Implement skipping in wizard --- phy/cluster/manual/gui_component.py | 50 +++++++++++-------- .../manual/tests/test_gui_component.py | 26 +++++++--- phy/gui/static/table.js | 10 +++- 3 files changed, 58 insertions(+), 28 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 174faa1e6..374b95311 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -181,8 +181,7 @@ def on_cluster(up): ', '.join(map(str, up.deleted)), up.added[0]) else: - # TODO: how many spikes? - logger.info("Assigned spikes.") + logger.info("Assigned %s spikes.", len(up.spike_ids)) if self.gui: self.gui.emit('cluster', up) @@ -250,14 +249,17 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) - def _update_cluster_view(self, cluster_view): + def _update_cluster_view(self): assert self.quality_func cols = ['id', 'quality'] - # TODO: skip - items = [{'id': clu, 'quality': self.quality_func(clu)} + items = [{'id': clu, + 'quality': self.quality_func(clu), + 'skip': self.cluster_meta.get('group', clu) in + ('noise', 'mua'), + } for clu in self.clustering.cluster_ids] # TODO: custom measures - cluster_view.set_data(items, cols) + self.cluster_view.set_data(items, cols) def _emit_select(self, cluster_ids): """Choose spikes from the specified clusters and emit the @@ -277,9 +279,8 @@ def _emit_select(self, cluster_ids): def set_quality_func(self, f): self.quality_func = f - self._update_cluster_view(self.cluster_view) - self.cluster_view.sort_by('quality') - self.cluster_view.sort_by('quality') + self._update_cluster_view() + self.cluster_view.sort_by('quality', 'desc') def set_similarity_func(self, f): self.similarity_func = f @@ -299,11 +300,15 @@ def attach(self, gui): def on_cluster(up): # Get the current sort of the cluster view. sort = self.cluster_view.current_sort - # Reinitialize the cluster view. - self._update_cluster_view(self.cluster_view) - # Reset the previous sort options. - if sort[0]: - self.cluster_view.sort_by(*sort) + sel_1 = self.similarity_view.selected + + if up.added: + # Reinitialize the cluster view. + self._update_cluster_view() + # Reset the previous sort options. + if sort[0]: + self.cluster_view.sort_by(*sort) + # Select all new clusters in view 1. if up.history == 'undo': # Select the clusters that were selected before the undone @@ -314,12 +319,17 @@ def on_cluster(up): elif up.added: self.select(up.added) self.pin(up.added) - # TODO: only if similarity selection non empty - self.similarity_view.next() + if sel_1: + self.similarity_view.next() + elif up.metadata_changed: + # Select next in similarity view if all moved are in that view. + if set(up.metadata_changed) <= set(sel_1): + self.similarity_view.next() + # Otherwise, select next in cluster view. + else: + self.cluster_view.next() else: - # TODO: move in the sim view if the moved cluster were there self.cluster_view.next() - return self # Selection actions @@ -338,6 +348,7 @@ def select(self, *cluster_ids): def pin(self, cluster_ids=None): """Update the similarity view with matches for the specified clusters.""" + # TODO: rename into _update_similarity_view() assert self.similarity_func if cluster_ids is None or not len(cluster_ids): cluster_ids = self.cluster_view.selected @@ -361,8 +372,7 @@ def selected(self): def merge(self, cluster_ids=None): if cluster_ids is None: - cluster_ids = (self.cluster_view.selected + - self.similarity_view.selected) + cluster_ids = self.selected if len(cluster_ids or []) <= 1: return self.clustering.merge(cluster_ids) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 5505d6252..7fab4bc34 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -90,7 +90,7 @@ def test_manual_clustering_edge_cases(manual_clustering): mc.save() -def test_manual_clustering_1(qtbot, gui): +def test_manual_clustering_default_metrics(qtbot, gui): n_spikes = 10 n_samples = 4 @@ -136,10 +136,25 @@ def test_manual_clustering_1(qtbot, gui): assert mc.selected == [best, match] +def test_manual_clustering_skip(qtbot, gui, manual_clustering): + mc = manual_clustering + + # yield [0, 1, 2, 10, 11, 20, 30] + # # i, g, N, i, g, N, N + expected = [30, 20, 11, 2, 1] + + for clu in expected: + mc.cluster_view.next() + assert mc.selected == [clu] + + def test_manual_clustering_merge(manual_clustering): mc = manual_clustering - mc.select(30, 20) # NOTE: we pass multiple ints instead of a list + mc.cluster_view.select([30]) + mc.similarity_view.select([20]) + assert mc.selected == [30, 20] + mc.merge() assert mc.selected == [31, 11] @@ -155,13 +170,13 @@ def test_manual_clustering_split(manual_clustering): mc.select([1, 2]) mc.split([1, 2]) - assert mc.selected == [31, 30] + assert mc.selected == [31] mc.undo() assert mc.selected == [1, 2] mc.redo() - assert mc.selected == [31, 30] + assert mc.selected == [31] def test_manual_clustering_split_2(gui, quality, similarity): @@ -174,12 +189,11 @@ def test_manual_clustering_split_2(gui, quality, similarity): mc.set_similarity_func(similarity) mc.split([0]) - assert mc.selected == [2, 3, 1] + assert mc.selected == [2, 3] def test_manual_clustering_move(manual_clustering, quality, similarity): mc = manual_clustering - mc.cluster_view.sort_by('quality', 'desc') mc.select([20]) assert mc.selected == [20] diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 2b2e5160f..2a4065f67 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -24,6 +24,11 @@ var Table = function (el) { Table.prototype.setData = function(data) { if (data.items.length == 0) return; + + // Reinitialize the state. + this.selected = []; + this.rows = {}; + var that = this; var keys = data.cols; @@ -103,8 +108,9 @@ Table.prototype.setData = function(data) { Table.prototype.sortBy = function(header, dir) { dir = typeof dir !== 'undefined' ? dir : 'asc'; this.tablesort.sortTable(this.headers[header]); - if (dir == 'desc') + if (dir == 'desc') { this.tablesort.sortTable(this.headers[header]); + } }; Table.prototype.currentSort = function() { @@ -157,7 +163,7 @@ Table.prototype.next = function() { } for (var i = i0; i < this.el.rows.length; i++) { row = this.el.rows[i]; - if (!(row.dataset.skip)) { + if (row.dataset.skip != 'true') { this.select([row.dataset.id]); return; } From 80a2dbb0cf7bd7ab09b112fc8852914babc01988 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 13 Nov 2015 21:34:07 +0100 Subject: [PATCH 0564/1059] Rename pin() --- phy/cluster/manual/gui_component.py | 8 +++----- phy/cluster/manual/tests/test_gui_component.py | 4 ---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 374b95311..ab9277238 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -232,7 +232,7 @@ def on_select(cluster_ids): # Emit GUI.select when the selection changes in the cluster view. self._emit_select(cluster_ids) # Pin the clusters and update the similarity view. - self.pin(cluster_ids) + self._update_similarity_view(cluster_ids) # Selection in the similarity view. @similarity_view.connect_ # noqa @@ -318,7 +318,6 @@ def on_cluster(up): self.similarity_view.select(clusters_1) elif up.added: self.select(up.added) - self.pin(up.added) if sel_1: self.similarity_view.next() elif up.metadata_changed: @@ -345,14 +344,13 @@ def select(self, *cluster_ids): # Update the cluster view selection. self.cluster_view.select(cluster_ids) - def pin(self, cluster_ids=None): + def _update_similarity_view(self, cluster_ids=None): """Update the similarity view with matches for the specified clusters.""" - # TODO: rename into _update_similarity_view() assert self.similarity_func if cluster_ids is None or not len(cluster_ids): cluster_ids = self.cluster_view.selected - # TODO: similarity wrt several clusters + # NOTE: we can also implement similarity wrt several clusters sel = cluster_ids[0] cols = ['id', 'similarity'] # TODO: skip diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 7fab4bc34..6e6b99af2 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -67,9 +67,6 @@ def test_manual_clustering_edge_cases(manual_clustering): mc.undo() mc.redo() - # Pin. - mc.pin([]) - # Merge. mc.merge() assert mc.selected == [0] @@ -129,7 +126,6 @@ def test_manual_clustering_default_metrics(qtbot, gui): mc.cluster_view.next() assert mc.cluster_view.selected == [best] - mc.pin() mc.similarity_view.next() assert mc.similarity_view.selected == [match] From 47f5567e2910bf716af114e3e47f1bf5d262e717 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 13 Nov 2015 22:30:17 +0100 Subject: [PATCH 0565/1059] Add custom columns in cluster view --- phy/cluster/manual/gui_component.py | 83 ++++++++++++++++++----------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index ab9277238..edb980c04 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -121,8 +121,6 @@ class ManualClustering(object): cluster_groups : dictionary n_spikes_max_per_cluster : int shortcuts : dict - quality_func : function - similarity_func : function GUI events ---------- @@ -200,6 +198,9 @@ def on_cluster(up): # Create the cluster views. self._create_cluster_views() + self._default_sort = None + self._columns = [] + self.add_column('skip', self._do_skip) # Internal methods # ------------------------------------------------------------------------- @@ -249,17 +250,47 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) + @property + def _column_names(self): + """Name of the columns.""" + return [name for (name, func) in self._columns] + + def _do_skip(self, cluster_id): + """Whether to skip that cluster.""" + return self.cluster_meta.get('group', cluster_id) in ('noise', 'mua') + + def _get_cluster_info(self, cluster_id, extra_columns=None): + """Return the data dictionary for a cluster row.""" + extra_columns = extra_columns or [] + info = {'id': cluster_id} + info.update({name: func(cluster_id) + for (name, func) in (self._columns + extra_columns)}) + return info + def _update_cluster_view(self): - assert self.quality_func - cols = ['id', 'quality'] - items = [{'id': clu, - 'quality': self.quality_func(clu), - 'skip': self.cluster_meta.get('group', clu) in - ('noise', 'mua'), - } - for clu in self.clustering.cluster_ids] - # TODO: custom measures - self.cluster_view.set_data(items, cols) + """Initialize the cluster view with cluster data.""" + items = [self._get_cluster_info(cluster_id) + for cluster_id in self.clustering.cluster_ids] + self.cluster_view.set_data(items, self._column_names) + if self._default_sort: + self.cluster_view.sort_by(self._default_sort, 'desc') + + def _update_similarity_view(self, cluster_ids=None): + """Update the similarity view with matches for the specified + clusters.""" + assert self.similarity_func + if cluster_ids is None or not len(cluster_ids): + cluster_ids = self.cluster_view.selected + cluster_id = cluster_ids[0] + # Similarity wrt the first cluster. + sim = lambda c: self.similarity_func(cluster_id, c) + items = [self._get_cluster_info(clu, [('similarity', sim)]) + for clu in self.clustering.cluster_ids + if clu not in cluster_ids + ] + cols = self._column_names + ['similarity'] + self.similarity_view.set_data(items, cols) + self.similarity_view.sort_by('similarity', 'desc') def _emit_select(self, cluster_ids): """Choose spikes from the specified clusters and emit the @@ -277,14 +308,19 @@ def _emit_select(self, cluster_ids): # ------------------------------------------------------------------------- def set_quality_func(self, f): - self.quality_func = f - + self.add_column('quality', f, True) self._update_cluster_view() - self.cluster_view.sort_by('quality', 'desc') def set_similarity_func(self, f): + """Set the similarity function.""" self.similarity_func = f + def add_column(self, name, func, is_default_sort=False): + """Add a new column in the cluster views.""" + self._columns.append((name, func)) + if is_default_sort: + self._default_sort = name + def attach(self, gui): self.gui = gui @@ -344,23 +380,6 @@ def select(self, *cluster_ids): # Update the cluster view selection. self.cluster_view.select(cluster_ids) - def _update_similarity_view(self, cluster_ids=None): - """Update the similarity view with matches for the specified - clusters.""" - assert self.similarity_func - if cluster_ids is None or not len(cluster_ids): - cluster_ids = self.cluster_view.selected - # NOTE: we can also implement similarity wrt several clusters - sel = cluster_ids[0] - cols = ['id', 'similarity'] - # TODO: skip - items = [{'id': clu, - 'similarity': self.similarity_func(sel, clu)} - for clu in self.clustering.cluster_ids - if clu not in cluster_ids] - self.similarity_view.set_data(items, cols) - self.similarity_view.sort_by('similarity', 'desc') - @property def selected(self): return self.cluster_view.selected + self.similarity_view.selected From fca866838f074b1cc86827c745a453db70e8d2c7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 13 Nov 2015 22:39:37 +0100 Subject: [PATCH 0566/1059] Fix --- phy/cluster/manual/gui_component.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index edb980c04..d1dc2d35e 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -200,7 +200,9 @@ def on_cluster(up): self._create_cluster_views() self._default_sort = None self._columns = [] - self.add_column('skip', self._do_skip) + # Default columns. + self.add_column('id', lambda clu: clu) + self.add_column('skip', self._do_skip, show=False) # Internal methods # ------------------------------------------------------------------------- @@ -253,7 +255,7 @@ def on_request_undo_state(up): @property def _column_names(self): """Name of the columns.""" - return [name for (name, func) in self._columns] + return [name for (name, func, show) in self._columns if show] def _do_skip(self, cluster_id): """Whether to skip that cluster.""" @@ -262,10 +264,8 @@ def _do_skip(self, cluster_id): def _get_cluster_info(self, cluster_id, extra_columns=None): """Return the data dictionary for a cluster row.""" extra_columns = extra_columns or [] - info = {'id': cluster_id} - info.update({name: func(cluster_id) - for (name, func) in (self._columns + extra_columns)}) - return info + return {name: func(cluster_id) + for (name, func, show) in (self._columns + extra_columns)} def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" @@ -284,7 +284,7 @@ def _update_similarity_view(self, cluster_ids=None): cluster_id = cluster_ids[0] # Similarity wrt the first cluster. sim = lambda c: self.similarity_func(cluster_id, c) - items = [self._get_cluster_info(clu, [('similarity', sim)]) + items = [self._get_cluster_info(clu, [('similarity', sim, True)]) for clu in self.clustering.cluster_ids if clu not in cluster_ids ] @@ -315,9 +315,9 @@ def set_similarity_func(self, f): """Set the similarity function.""" self.similarity_func = f - def add_column(self, name, func, is_default_sort=False): + def add_column(self, name, func, is_default_sort=False, show=True): """Add a new column in the cluster views.""" - self._columns.append((name, func)) + self._columns.append((name, func, show)) if is_default_sort: self._default_sort = name From c0762a32d897abff6c4229677d2f8fdf9a68d562 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 12:50:40 +0100 Subject: [PATCH 0567/1059] WIP --- phy/cluster/manual/gui_component.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index d1dc2d35e..d3243aa54 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -235,7 +235,7 @@ def on_select(cluster_ids): # Emit GUI.select when the selection changes in the cluster view. self._emit_select(cluster_ids) # Pin the clusters and update the similarity view. - self._update_similarity_view(cluster_ids) + self._update_similarity_view() # Selection in the similarity view. @similarity_view.connect_ # noqa @@ -275,18 +275,17 @@ def _update_cluster_view(self): if self._default_sort: self.cluster_view.sort_by(self._default_sort, 'desc') - def _update_similarity_view(self, cluster_ids=None): + def _update_similarity_view(self): """Update the similarity view with matches for the specified clusters.""" assert self.similarity_func - if cluster_ids is None or not len(cluster_ids): - cluster_ids = self.cluster_view.selected - cluster_id = cluster_ids[0] + selection = self.cluster_view.selected + cluster_id = self.cluster_view.selected[0] # Similarity wrt the first cluster. sim = lambda c: self.similarity_func(cluster_id, c) items = [self._get_cluster_info(clu, [('similarity', sim, True)]) for clu in self.clustering.cluster_ids - if clu not in cluster_ids + if clu not in selection ] cols = self._column_names + ['similarity'] self.similarity_view.set_data(items, cols) From fa557c709ac563abe08b4726cd29145497033b55 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 12:56:46 +0100 Subject: [PATCH 0568/1059] Increase coverage --- phy/cluster/manual/gui_component.py | 2 -- .../manual/tests/test_gui_component.py | 20 ++++++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index d3243aa54..84f05ce09 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -362,8 +362,6 @@ def on_cluster(up): # Otherwise, select next in cluster view. else: self.cluster_view.next() - else: - self.cluster_view.next() return self # Selection actions diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 6e6b99af2..89251e99a 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -188,7 +188,7 @@ def test_manual_clustering_split_2(gui, quality, similarity): assert mc.selected == [2, 3] -def test_manual_clustering_move(manual_clustering, quality, similarity): +def test_manual_clustering_move_1(manual_clustering, quality, similarity): mc = manual_clustering mc.select([20]) @@ -202,3 +202,21 @@ def test_manual_clustering_move(manual_clustering, quality, similarity): mc.redo() assert mc.selected == [11] + + +def test_manual_clustering_move_2(manual_clustering, quality, similarity): + mc = manual_clustering + + mc.select([20]) + mc.similarity_view.select([10]) + + assert mc.selected == [20, 10] + + mc.move([10], 'noise') + assert mc.selected == [20, 2] + + mc.undo() + assert mc.selected == [20, 10] + + mc.redo() + assert mc.selected == [20, 2] From 731506cc8e80293caba0ce94cd54c11a0a9207fe Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 17:18:45 +0100 Subject: [PATCH 0569/1059] WIP: wizard actions --- phy/cluster/manual/gui_component.py | 82 +++++++++++++++++-- .../manual/tests/test_gui_component.py | 45 ++++++++-- 2 files changed, 114 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 84f05ce09..a9a438601 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -7,6 +7,7 @@ # Imports # ----------------------------------------------------------------------------- +from functools import partial import logging import numpy as np @@ -138,14 +139,30 @@ class ManualClustering(object): """ default_shortcuts = { - 'save': 'Save', - # Wizard actions. - 'next': 'space', - 'previous': 'shift+space', - 'reset_wizard': 'ctrl+alt+space', - # Clustering actions. + # Clustering. 'merge': 'g', 'split': 'k', + + # Move. + 'move_best_to_noise': 'alt+n', + 'move_best_to_mua': 'alt+m', + 'move_best_to_good': 'alt+g', + + 'move_similar_to_noise': 'ctrl+n', + 'move_similar_to_mua': 'ctrl+m', + 'move_similar_to_good': 'ctrl+g', + + 'move_all_to_noise': 'ctrl+alt+n', + 'move_all_to_mua': 'ctrl+alt+m', + 'move_all_to_good': 'ctrl+alt+g', + + # Wizard. + 'reset': 'ctrl+alt+space', + 'next': 'space', + 'previous': 'shift+space', + + # Misc. + 'save': 'Save', 'undo': 'Undo', 'redo': 'Redo', } @@ -214,11 +231,29 @@ def _create_actions(self, gui): self.actions.add(self.select, alias='c') # Clustering. - self.actions.add(self.merge) - self.actions.add(self.split) + self.actions.add(self.merge, alias='g') + self.actions.add(self.split, alias='k') + + # Move. self.actions.add(self.move) + + for group in ('noise', 'mua', 'good'): + self.actions.add(partial(self.move_best, group), + name='move_best_to_' + group) + self.actions.add(partial(self.move_similar, group), + name='move_similar_to_' + group) + self.actions.add(partial(self.move_all, group), + name='move_all_to_' + group) + + # Wizard. + self.actions.add(self.reset) + self.actions.add(self.next) + self.actions.add(self.previous) + + # Others. self.actions.add(self.undo) self.actions.add(self.redo) + self.actions.add(self.save) def _create_cluster_views(self): # Create the cluster view. @@ -399,12 +434,43 @@ def split(self, spike_ids): self.clustering.split(spike_ids) self._global_history.action(self.clustering) + # Move actions + # ------------------------------------------------------------------------- + def move(self, cluster_ids, group): if len(cluster_ids) == 0: return self.cluster_meta.set('group', cluster_ids, group) self._global_history.action(self.cluster_meta) + def move_best(self, group): + self.move(self.cluster_view.selected, group) + + def move_similar(self, group): + self.move(self.similarity_view.selected, group) + + def move_all(self, group): + self.move(self.selected, group) + + # Wizard actions + # ------------------------------------------------------------------------- + + def reset(self): + self._update_cluster_view() + self.cluster_view.next() + + def next(self): + if not self.selected: + self.cluster_view.next() + else: + self.similarity_view.next() + + def previous(self): + self.similarity_view.previous() + + # Other actions + # ------------------------------------------------------------------------- + def undo(self): self._global_history.undo() diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 89251e99a..89fe2221d 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -40,6 +40,7 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, mc.set_similarity_func(similarity) yield mc + del mc @yield_fixture @@ -49,10 +50,12 @@ def gui(qtbot): qtbot.waitForWindowShown(gui) yield gui gui.close() + del gui + qtbot.wait(10) #------------------------------------------------------------------------------ -# Test GUI components +# Test GUI component #------------------------------------------------------------------------------ def test_manual_clustering_edge_cases(manual_clustering): @@ -120,8 +123,6 @@ def test_manual_clustering_default_metrics(qtbot, gui): match = sorted(similarity, key=itemgetter(1))[-1][0] mc.attach(gui) - gui.show() - qtbot.waitForWindowShown(gui) mc.cluster_view.next() assert mc.cluster_view.selected == [best] @@ -188,7 +189,7 @@ def test_manual_clustering_split_2(gui, quality, similarity): assert mc.selected == [2, 3] -def test_manual_clustering_move_1(manual_clustering, quality, similarity): +def test_manual_clustering_move_1(manual_clustering): mc = manual_clustering mc.select([20]) @@ -204,7 +205,7 @@ def test_manual_clustering_move_1(manual_clustering, quality, similarity): assert mc.selected == [11] -def test_manual_clustering_move_2(manual_clustering, quality, similarity): +def test_manual_clustering_move_2(manual_clustering): mc = manual_clustering mc.select([20]) @@ -220,3 +221,37 @@ def test_manual_clustering_move_2(manual_clustering, quality, similarity): mc.redo() assert mc.selected == [20, 2] + + +#------------------------------------------------------------------------------ +# Test shortcuts +#------------------------------------------------------------------------------ + +def test_manual_clustering_action_reset(manual_clustering): + mc = manual_clustering + + mc.actions.select([10, 11]) + + mc.actions.reset() + assert mc.selected == [30] + + +def test_manual_clustering_action_move(manual_clustering): + mc = manual_clustering + + mc.actions.next() + + assert mc.selected == [30] + mc.actions.move_best_to_noise() + + assert mc.selected == [20] + mc.actions.move_best_to_mua() + + assert mc.selected == [11] + mc.actions.move_best_to_good() + + assert mc.selected == [2] + + mc.cluster_meta.get('group', 30) == 'noise' + mc.cluster_meta.get('group', 20) == 'mua' + mc.cluster_meta.get('group', 12) == 'good' From 9ad76a07c2392ba74c87bdc8250e372118c59a31 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 18:07:14 +0100 Subject: [PATCH 0570/1059] Fix bug in table.previous() --- phy/gui/static/table.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 2a4065f67..ef20ab7e0 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -185,7 +185,7 @@ Table.prototype.previous = function() { // NOTE: i >= 1 because we skip the header column. for (var i = i0; i >= 1; i--) { row = this.el.rows[i]; - if (!(row.dataset.skip)) { + if (row.dataset.skip != 'true') { this.select([row.dataset.id]); return; } From ae742a8fc6157835659d46b474f3af6b42160244 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 18:09:58 +0100 Subject: [PATCH 0571/1059] WIP: more manual clustering tests --- phy/cluster/manual/gui_component.py | 67 ++++++++++--------- .../manual/tests/test_gui_component.py | 64 +++++++++++++++++- 2 files changed, 97 insertions(+), 34 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index a9a438601..0dbe8d894 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -326,6 +326,41 @@ def _update_similarity_view(self): self.similarity_view.set_data(items, cols) self.similarity_view.sort_by('similarity', 'desc') + def on_cluster(self, up): + """Update the cluster views after clustering actions.""" + + # Get the current sort of the cluster view. + sort = self.cluster_view.current_sort + sel_1 = self.similarity_view.selected + + if up.added: + # Reinitialize the cluster view. + self._update_cluster_view() + # Reset the previous sort options. + if sort[0]: + self.cluster_view.sort_by(*sort) + + # Select all new clusters in view 1. + if up.history == 'undo': + # Select the clusters that were selected before the undone + # action. + clusters_0, clusters_1 = up.undo_state[0]['selection'] + self.cluster_view.select(clusters_0) + self.similarity_view.select(clusters_1) + elif up.added: + self.select(up.added) + if sel_1: + self.similarity_view.next() + elif up.metadata_changed: + # Select next in similarity view if all moved are in that view. + if set(up.metadata_changed) <= set(sel_1): + self.similarity_view.next() + # Otherwise, select next in cluster view. + else: + self.cluster_view.next() + if sel_1: + self.similarity_view.next() + def _emit_select(self, cluster_ids): """Choose spikes from the specified clusters and emit the `select` event on the GUI.""" @@ -366,37 +401,7 @@ def attach(self, gui): gui.add_view(self.similarity_view, title='SimilarityView') # Update the cluster views and selection when a cluster event occurs. - @self.gui.connect_ - def on_cluster(up): - # Get the current sort of the cluster view. - sort = self.cluster_view.current_sort - sel_1 = self.similarity_view.selected - - if up.added: - # Reinitialize the cluster view. - self._update_cluster_view() - # Reset the previous sort options. - if sort[0]: - self.cluster_view.sort_by(*sort) - - # Select all new clusters in view 1. - if up.history == 'undo': - # Select the clusters that were selected before the undone - # action. - clusters_0, clusters_1 = up.undo_state[0]['selection'] - self.cluster_view.select(clusters_0) - self.similarity_view.select(clusters_1) - elif up.added: - self.select(up.added) - if sel_1: - self.similarity_view.next() - elif up.metadata_changed: - # Select next in similarity view if all moved are in that view. - if set(up.metadata_changed) <= set(sel_1): - self.similarity_view.next() - # Otherwise, select next in cluster view. - else: - self.cluster_view.next() + self.gui.connect_(self.on_cluster) return self # Selection actions diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 89fe2221d..4e106a24a 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -227,7 +227,7 @@ def test_manual_clustering_move_2(manual_clustering): # Test shortcuts #------------------------------------------------------------------------------ -def test_manual_clustering_action_reset(manual_clustering): +def test_manual_clustering_action_reset(qtbot, manual_clustering): mc = manual_clustering mc.actions.select([10, 11]) @@ -235,8 +235,17 @@ def test_manual_clustering_action_reset(manual_clustering): mc.actions.reset() assert mc.selected == [30] + mc.actions.next() + assert mc.selected == [30, 20] + + mc.actions.next() + assert mc.selected == [30, 11] -def test_manual_clustering_action_move(manual_clustering): + mc.actions.previous() + assert mc.selected == [30, 20] + + +def test_manual_clustering_action_move_1(manual_clustering): mc = manual_clustering mc.actions.next() @@ -254,4 +263,53 @@ def test_manual_clustering_action_move(manual_clustering): mc.cluster_meta.get('group', 30) == 'noise' mc.cluster_meta.get('group', 20) == 'mua' - mc.cluster_meta.get('group', 12) == 'good' + mc.cluster_meta.get('group', 11) == 'good' + + +def test_manual_clustering_action_move_2(manual_clustering): + mc = manual_clustering + + mc.select([30]) + mc.similarity_view.select([20]) + + assert mc.selected == [30, 20] + mc.actions.move_similar_to_noise() + + assert mc.selected == [30, 11] + mc.actions.move_similar_to_mua() + + assert mc.selected == [30, 2] + mc.actions.move_similar_to_good() + + assert mc.selected == [30, 1] + + mc.cluster_meta.get('group', 20) == 'noise' + mc.cluster_meta.get('group', 11) == 'mua' + mc.cluster_meta.get('group', 2) == 'good' + + +def _test_manual_clustering_action_move_3(manual_clustering): + mc = manual_clustering + + mc.select([30]) + mc.similarity_view.select([20]) + + assert mc.selected == [30, 20] + mc.actions.move_all_to_noise() + + assert mc.selected == [11, 10] + mc.actions.move_all_to_mua() + + assert mc.selected == [2, 1] + mc.actions.move_all_to_good() + + assert mc.selected == [30, 1] + + mc.cluster_meta.get('group', 30) == 'noise' + mc.cluster_meta.get('group', 20) == 'noise' + + mc.cluster_meta.get('group', 11) == 'mua' + mc.cluster_meta.get('group', 10) == 'mua' + + mc.cluster_meta.get('group', 2) == 'good' + mc.cluster_meta.get('group', 1) == 'good' From 794638dfba3e032ba1b3467ec45b334e28b3d094 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 18:37:42 +0100 Subject: [PATCH 0572/1059] Manual clustering tests pass --- phy/cluster/manual/gui_component.py | 34 ++++++++++++------- .../manual/tests/test_gui_component.py | 8 ++--- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 0dbe8d894..c99799750 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -304,11 +304,19 @@ def _get_cluster_info(self, cluster_id, extra_columns=None): def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" + + # Get the current sort of the cluster view. + sort_col, sort_dir = self.cluster_view.current_sort + + # Update the cluster view rows. items = [self._get_cluster_info(cluster_id) for cluster_id in self.clustering.cluster_ids] self.cluster_view.set_data(items, self._column_names) - if self._default_sort: - self.cluster_view.sort_by(self._default_sort, 'desc') + + # Sort with the previous sort or the default one. + sort_col = sort_col or self._default_sort + sort_dir = sort_dir or 'desc' + self.cluster_view.sort_by(sort_col, sort_dir) def _update_similarity_view(self): """Update the similarity view with matches for the specified @@ -329,16 +337,11 @@ def _update_similarity_view(self): def on_cluster(self, up): """Update the cluster views after clustering actions.""" - # Get the current sort of the cluster view. - sort = self.cluster_view.current_sort - sel_1 = self.similarity_view.selected + similar = self.similarity_view.selected + # Reinitialize the cluster view if clusters have changed. if up.added: - # Reinitialize the cluster view. self._update_cluster_view() - # Reset the previous sort options. - if sort[0]: - self.cluster_view.sort_by(*sort) # Select all new clusters in view 1. if up.history == 'undo': @@ -349,16 +352,23 @@ def on_cluster(self, up): self.similarity_view.select(clusters_1) elif up.added: self.select(up.added) - if sel_1: + if similar: self.similarity_view.next() elif up.metadata_changed: # Select next in similarity view if all moved are in that view. - if set(up.metadata_changed) <= set(sel_1): + if set(up.metadata_changed) <= set(similar): self.similarity_view.next() # Otherwise, select next in cluster view. else: + # Update the cluster view, and select the clusters that + # were selected before the action. + selected = self.cluster_view.selected + self._update_cluster_view() + self.select(selected) + + # Select the next cluster in the view. self.cluster_view.next() - if sel_1: + if similar: self.similarity_view.next() def _emit_select(self, cluster_ids): diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 4e106a24a..4c0c777ee 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -288,7 +288,7 @@ def test_manual_clustering_action_move_2(manual_clustering): mc.cluster_meta.get('group', 2) == 'good' -def _test_manual_clustering_action_move_3(manual_clustering): +def test_manual_clustering_action_move_3(manual_clustering): mc = manual_clustering mc.select([30]) @@ -297,13 +297,13 @@ def _test_manual_clustering_action_move_3(manual_clustering): assert mc.selected == [30, 20] mc.actions.move_all_to_noise() - assert mc.selected == [11, 10] + assert mc.selected == [11, 2] mc.actions.move_all_to_mua() - assert mc.selected == [2, 1] + assert mc.selected == [1] mc.actions.move_all_to_good() - assert mc.selected == [30, 1] + assert mc.selected == [1] mc.cluster_meta.get('group', 30) == 'noise' mc.cluster_meta.get('group', 20) == 'noise' From 8bfbd1ce3510c4a147f05c8ef7e00ba40eea8fd5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 19:50:57 +0100 Subject: [PATCH 0573/1059] Use float textures for box bounds in Boxed interact --- phy/plot/interact.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index e23bc1e67..7171d7d71 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -171,7 +171,8 @@ def get_transforms(self): def update_program(self, program): # Signal bounds (positions). box_bounds = _get_texture(self._box_bounds, NDC, self.n_boxes, [-1, 1]) - program['u_box_bounds'] = Texture2D(box_bounds) + program['u_box_bounds'] = Texture2D(box_bounds, + internalformat='rgba32f') program['n_boxes'] = self.n_boxes # Change the box bounds, positions, or size From 4d330ddba5b459c5affe0fffae361bd3b25d247c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 14 Nov 2015 19:53:19 +0100 Subject: [PATCH 0574/1059] Update similarity view after move --- phy/cluster/manual/gui_component.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index c99799750..67175d8fb 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -357,6 +357,13 @@ def on_cluster(self, up): elif up.metadata_changed: # Select next in similarity view if all moved are in that view. if set(up.metadata_changed) <= set(similar): + + # Update the cluster view, and select the clusters that + # were selected before the action. + selected = self.similarity_view.selected + self._update_similarity_view() + self.similarity_view.select(selected) + self.similarity_view.next() # Otherwise, select next in cluster view. else: From af27b6f76c7993fceae1a1c2b061e593ce3202da Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 11:42:22 +0100 Subject: [PATCH 0575/1059] Show good clusters in green --- phy/cluster/manual/gui_component.py | 33 ++++++++++++++----- .../manual/tests/test_gui_component.py | 4 ++- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 67175d8fb..2ef99350a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -107,6 +107,16 @@ def mean_masked_features_similarity(c0, c1): # Clustering GUI component # ----------------------------------------------------------------------------- +class ClusterView(Table): + def __init__(self): + super(ClusterView, self).__init__() + self.add_styles(''' + table tr[data-good='true'] { + color: #B4DEA6; + } + ''') + + class ManualClustering(object): """Component that brings manual clustering facilities to a GUI: @@ -219,7 +229,8 @@ def on_cluster(up): self._columns = [] # Default columns. self.add_column('id', lambda clu: clu) - self.add_column('skip', self._do_skip, show=False) + self.add_column('skip', self._skip_col, show=False) + self.add_column('good', self._good_col, show=False) # Internal methods # ------------------------------------------------------------------------- @@ -257,15 +268,15 @@ def _create_actions(self, gui): def _create_cluster_views(self): # Create the cluster view. - self.cluster_view = cluster_view = Table() - cluster_view.show() + self.cluster_view = ClusterView() + self.cluster_view.show() # Create the similarity view. - self.similarity_view = similarity_view = Table() - similarity_view.show() + self.similarity_view = ClusterView() + self.similarity_view.show() # Selection in the cluster view. - @cluster_view.connect_ + @self.cluster_view.connect_ def on_select(cluster_ids): # Emit GUI.select when the selection changes in the cluster view. self._emit_select(cluster_ids) @@ -273,10 +284,10 @@ def on_select(cluster_ids): self._update_similarity_view() # Selection in the similarity view. - @similarity_view.connect_ # noqa + @self.similarity_view.connect_ # noqa def on_select(cluster_ids): # Select the clusters from both views. - cluster_ids = cluster_view.selected + cluster_ids + cluster_ids = self.cluster_view.selected + cluster_ids self._emit_select(cluster_ids) # Save the current selection when an action occurs. @@ -292,10 +303,14 @@ def _column_names(self): """Name of the columns.""" return [name for (name, func, show) in self._columns if show] - def _do_skip(self, cluster_id): + def _skip_col(self, cluster_id): """Whether to skip that cluster.""" return self.cluster_meta.get('group', cluster_id) in ('noise', 'mua') + def _good_col(self, cluster_id): + """Good column for color.""" + return self.cluster_meta.get('group', cluster_id) == 'good' + def _get_cluster_info(self, cluster_id, extra_columns=None): """Return the data dictionary for a cluster row.""" extra_columns = extra_columns or [] diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 4c0c777ee..5d275ecef 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -245,7 +245,7 @@ def test_manual_clustering_action_reset(qtbot, manual_clustering): assert mc.selected == [30, 20] -def test_manual_clustering_action_move_1(manual_clustering): +def test_manual_clustering_action_move_1(qtbot, manual_clustering): mc = manual_clustering mc.actions.next() @@ -265,6 +265,8 @@ def test_manual_clustering_action_move_1(manual_clustering): mc.cluster_meta.get('group', 20) == 'mua' mc.cluster_meta.get('group', 11) == 'good' + # qtbot.stop() + def test_manual_clustering_action_move_2(manual_clustering): mc = manual_clustering From b2d6411929ea219aad515bc64299b0068c5c0995 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 14:31:57 +0100 Subject: [PATCH 0576/1059] Scroll to selected element in table --- phy/gui/static/table.js | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index ef20ab7e0..e9d7cef09 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -165,6 +165,7 @@ Table.prototype.next = function() { row = this.el.rows[i]; if (row.dataset.skip != 'true') { this.select([row.dataset.id]); + row.scrollIntoView(false); return; } } @@ -187,6 +188,7 @@ Table.prototype.previous = function() { row = this.el.rows[i]; if (row.dataset.skip != 'true') { this.select([row.dataset.id]); + // row.scrollIntoView(false); return; } } From 282e371a422557446e5e2417e9291e5c48eb5564 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 14:42:42 +0100 Subject: [PATCH 0577/1059] Add parameter to disable emit in table.select() --- phy/cluster/manual/gui_component.py | 7 ++----- phy/gui/static/table.js | 7 +++++-- phy/gui/widgets.py | 5 +++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 2ef99350a..7bec1ccbe 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -377,8 +377,7 @@ def on_cluster(self, up): # were selected before the action. selected = self.similarity_view.selected self._update_similarity_view() - self.similarity_view.select(selected) - + self.similarity_view.select(selected, do_emit=False) self.similarity_view.next() # Otherwise, select next in cluster view. else: @@ -386,9 +385,7 @@ def on_cluster(self, up): # were selected before the action. selected = self.cluster_view.selected self._update_cluster_view() - self.select(selected) - - # Select the next cluster in the view. + self.cluster_view.select(selected, do_emit=False) self.cluster_view.next() if similar: self.similarity_view.next() diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index e9d7cef09..668aa5ddb 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -125,7 +125,9 @@ Table.prototype.currentSort = function() { return [null, null]; }; -Table.prototype.select = function(ids) { +Table.prototype.select = function(ids, do_emit) { + do_emit = typeof do_emit !== 'undefined' ? do_emit : true; + ids = uniq(ids); // Remove the class on all rows. @@ -143,7 +145,8 @@ Table.prototype.select = function(ids) { this.selected = ids; - emit("select", ids); + if (do_emit) + emit("select", ids); }; Table.prototype.clear = function() { diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 933dfa84c..3f569873a 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -242,9 +242,10 @@ def previous(self): """Select the previous non-skip row.""" self.eval_js('table.previous();') - def select(self, ids): + def select(self, ids, do_emit=True): """Select some rows.""" - self.eval_js('table.select({});'.format(dumps(ids))) + do_emit = str(do_emit).lower() + self.eval_js('table.select({}, {});'.format(dumps(ids), do_emit)) @property def selected(self): From e566ab37e905c4b07d92e7134ac26db02fd68f5a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 19:53:31 +0100 Subject: [PATCH 0578/1059] WIP: refactor table --- phy/gui/static/table.js | 4 ++++ phy/gui/tests/test_widgets.py | 13 ++++++++----- phy/gui/widgets.py | 25 ++++++++++++++++++++++--- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 668aa5ddb..d62430eec 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -23,6 +23,10 @@ var Table = function (el) { }; Table.prototype.setData = function(data) { + /* + data.cols: list of column names + data.items: list of rows (each row is an object {col: value}) + */ if (data.items.length == 0) return; // Reinitialize the state. diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 4abe63438..63576df4e 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -21,12 +21,15 @@ def table(qtbot): table.show() qtbot.waitForWindowShown(table) - items = [{'id': i, 'count': 10000.5 - 10 * i} for i in range(10)] - items[4]['skip'] = True + @table.add_column + def count(id): + return 10000.5 - 10 * id - table.set_data(cols=['id', 'count'], - items=items, - ) + @table.add_column + def skip(id): + return id == 4 + + table.set_rows(range(10)) yield table diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 3f569873a..bed59c6ec 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -222,11 +222,30 @@ def __init__(self): self.add_body(''''''.format(self._table_id)) + self._columns = [('id', (lambda _: _), {})] - def set_data(self, items, cols): - """Set the rows and cols of the table.""" + def add_column(self, func=None, options=None): + """Add a column function, takes an id as argument, return a value.""" + if func is None: + return lambda f: self.add_column_func(f, options=options) + + options = options or {} + self._columns.append((func.__name__, func, options)) + + @property + def column_names(self): + return [name for (name, func, options) in self._columns + if options.get('show', True)] + + def _get_row(self, id): + """Create a row dictionary for a given object id.""" + return {name: func(id) for (name, func, options) in self._columns} + + def set_rows(self, ids): + """Set the rows of the table.""" + items = [self._get_row(id) for id in ids] data = _create_json_dict(items=items, - cols=cols, + cols=self.column_names, ) self.eval_js('table.setData({});'.format(data)) From fde8242509d29cabaf304bb07d29323e42caaed0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 19:55:21 +0100 Subject: [PATCH 0579/1059] Increase coverage --- phy/gui/tests/test_widgets.py | 2 +- phy/gui/widgets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 63576df4e..7b2ea554b 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -21,7 +21,7 @@ def table(qtbot): table.show() qtbot.waitForWindowShown(table) - @table.add_column + @table.add_column(options={'show': True}) def count(id): return 10000.5 - 10 * id diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index bed59c6ec..ef8a7b9f5 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -227,7 +227,7 @@ def __init__(self): def add_column(self, func=None, options=None): """Add a column function, takes an id as argument, return a value.""" if func is None: - return lambda f: self.add_column_func(f, options=options) + return lambda f: self.add_column(f, options=options) options = options or {} self._columns.append((func.__name__, func, options)) From 84806c133211f1d67b1cc799f602ed6fe929e34e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 22:56:02 +0100 Subject: [PATCH 0580/1059] WIP: refactor cluster views --- phy/cluster/manual/gui_component.py | 76 +++++++++---------- .../manual/tests/test_gui_component.py | 2 +- phy/gui/static/table.js | 2 + phy/gui/tests/test_widgets.py | 4 +- phy/gui/widgets.py | 20 +++-- 5 files changed, 52 insertions(+), 52 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 7bec1ccbe..3eb28b85f 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -7,7 +7,7 @@ # Imports # ----------------------------------------------------------------------------- -from functools import partial +from functools import partial, wraps import logging import numpy as np @@ -224,13 +224,26 @@ def on_cluster(up): self.gui.emit('cluster', up) # Create the cluster views. - self._create_cluster_views() self._default_sort = None - self._columns = [] + self._create_cluster_views() + # Default columns. - self.add_column('id', lambda clu: clu) - self.add_column('skip', self._skip_col, show=False) - self.add_column('good', self._good_col, show=False) + def skip(cluster_id): + """Whether to skip that cluster.""" + return (self.cluster_meta.get('group', cluster_id) + in ('noise', 'mua')) + self.add_column(skip, options={'show': False}) + + def good(cluster_id): + """Good column for color.""" + return self.cluster_meta.get('group', cluster_id) == 'good' + self.add_column(good, options={'show': False}) + + self._best = None + + def similarity(cluster_id): + return self.similarity_func(cluster_id, self._best) + self.similarity_view.add_column(similarity) # Internal methods # ------------------------------------------------------------------------- @@ -298,24 +311,14 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) - @property - def _column_names(self): - """Name of the columns.""" - return [name for (name, func, show) in self._columns if show] - - def _skip_col(self, cluster_id): - """Whether to skip that cluster.""" - return self.cluster_meta.get('group', cluster_id) in ('noise', 'mua') - - def _good_col(self, cluster_id): - """Good column for color.""" - return self.cluster_meta.get('group', cluster_id) == 'good' - - def _get_cluster_info(self, cluster_id, extra_columns=None): - """Return the data dictionary for a cluster row.""" - extra_columns = extra_columns or [] - return {name: func(cluster_id) - for (name, func, show) in (self._columns + extra_columns)} + def add_column(self, func=None, name=None, options=None): + options = options or {} + name = name or func.__name__ + assert name + if options.get('is_default_sort', False): + self._default_sort = name + self.cluster_view.add_column(func=func, name=name, options=options) + self.similarity_view.add_column(func=func, name=name, options=options) def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" @@ -324,12 +327,11 @@ def _update_cluster_view(self): sort_col, sort_dir = self.cluster_view.current_sort # Update the cluster view rows. - items = [self._get_cluster_info(cluster_id) - for cluster_id in self.clustering.cluster_ids] - self.cluster_view.set_data(items, self._column_names) + self.cluster_view.set_rows(self.clustering.cluster_ids) # Sort with the previous sort or the default one. sort_col = sort_col or self._default_sort + assert sort_col sort_dir = sort_dir or 'desc' self.cluster_view.sort_by(sort_col, sort_dir) @@ -339,14 +341,9 @@ def _update_similarity_view(self): assert self.similarity_func selection = self.cluster_view.selected cluster_id = self.cluster_view.selected[0] - # Similarity wrt the first cluster. - sim = lambda c: self.similarity_func(cluster_id, c) - items = [self._get_cluster_info(clu, [('similarity', sim, True)]) - for clu in self.clustering.cluster_ids - if clu not in selection - ] - cols = self._column_names + ['similarity'] - self.similarity_view.set_data(items, cols) + self._best = cluster_id + self.similarity_view.set_rows([c for c in self.clustering.cluster_ids + if c not in selection]) self.similarity_view.sort_by('similarity', 'desc') def on_cluster(self, up): @@ -406,19 +403,14 @@ def _emit_select(self, cluster_ids): # ------------------------------------------------------------------------- def set_quality_func(self, f): - self.add_column('quality', f, True) + self.add_column(func=f, name='quality', + options={'show': True, 'is_default_sort': True}) self._update_cluster_view() def set_similarity_func(self, f): """Set the similarity function.""" self.similarity_func = f - def add_column(self, name, func, is_default_sort=False, show=True): - """Add a new column in the cluster views.""" - self._columns.append((name, func, show)) - if is_default_sort: - self._default_sort = name - def attach(self, gui): self.gui = gui diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 5d275ecef..d8c2a413c 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -51,7 +51,7 @@ def gui(qtbot): yield gui gui.close() del gui - qtbot.wait(10) + qtbot.wait(5) #------------------------------------------------------------------------------ diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index d62430eec..6f0b9e6c3 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -111,6 +111,8 @@ Table.prototype.setData = function(data) { Table.prototype.sortBy = function(header, dir) { dir = typeof dir !== 'undefined' ? dir : 'asc'; + if (this.headers[header] == undefined) + throw "The column `" + header + "` doesn't exist." this.tablesort.sortTable(this.headers[header]); if (dir == 'desc') { this.tablesort.sortTable(this.headers[header]); diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 7b2ea554b..5293092e6 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -21,13 +21,13 @@ def table(qtbot): table.show() qtbot.waitForWindowShown(table) - @table.add_column(options={'show': True}) def count(id): return 10000.5 - 10 * id + table.add_column(count, options={'show': True}) - @table.add_column def skip(id): return id == 4 + table.add_column(skip) table.set_rows(range(10)) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index ef8a7b9f5..10a71934f 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -177,7 +177,7 @@ def _emit_from_js(self, name, arg_json): self.emit(text_type(name), json.loads(text_type(arg_json))) def show(self): - with _wait_signal(self.loadFinished, 50): + with _wait_signal(self.loadFinished, 20): self._build() super(HTMLWidget, self).show() # Call the pending JS eval calls after the page has been built. @@ -224,13 +224,19 @@ def __init__(self): '''.format(self._table_id)) self._columns = [('id', (lambda _: _), {})] - def add_column(self, func=None, options=None): - """Add a column function, takes an id as argument, return a value.""" - if func is None: - return lambda f: self.add_column(f, options=options) - + def add_column(self, func, name=None, options=None): + """Add a column function which takes an id as argument and + returns a value.""" + assert func + name = name or func.__name__ options = options or {} - self._columns.append((func.__name__, func, options)) + self._columns.append([name, func, options]) + return func + + def get_column(self, name): + for col in self._columns: + if col[0] == name: + return col @property def column_names(self): From 99b7087202748d9af22e7459c2ab207e179b1c65 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 15 Nov 2015 22:57:57 +0100 Subject: [PATCH 0581/1059] Fix --- phy/cluster/manual/gui_component.py | 2 +- phy/gui/widgets.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 3eb28b85f..d2c000fc4 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -7,7 +7,7 @@ # Imports # ----------------------------------------------------------------------------- -from functools import partial, wraps +from functools import partial import logging import numpy as np diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 10a71934f..dcd4d373d 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -230,14 +230,9 @@ def add_column(self, func, name=None, options=None): assert func name = name or func.__name__ options = options or {} - self._columns.append([name, func, options]) + self._columns.append((name, func, options)) return func - def get_column(self, name): - for col in self._columns: - if col[0] == name: - return col - @property def column_names(self): return [name for (name, func, options) in self._columns From 9fc8d50a00239931217e7872e4960ae4d2abf54d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 12:32:34 +0100 Subject: [PATCH 0582/1059] WIP: refactor table --- phy/gui/tests/test_widgets.py | 2 +- phy/gui/widgets.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 5293092e6..30cf00283 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -23,7 +23,7 @@ def table(qtbot): def count(id): return 10000.5 - 10 * id - table.add_column(count, options={'show': True}) + table.add_column(count, show=True) def skip(id): return id == 4 diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index dcd4d373d..cc83842ec 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -224,12 +224,14 @@ def __init__(self): '''.format(self._table_id)) self._columns = [('id', (lambda _: _), {})] - def add_column(self, func, name=None, options=None): + def add_column(self, func, name=None, show=True, default_sort=False): """Add a column function which takes an id as argument and returns a value.""" assert func name = name or func.__name__ - options = options or {} + options = {'show': show, + 'default_sort': default_sort, + } self._columns.append((name, func, options)) return func From a2a5aee25c32fa04d93fec11512cee159f3827e7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 12:40:27 +0100 Subject: [PATCH 0583/1059] WIP: default sort --- phy/gui/tests/test_widgets.py | 17 +++++++++++++++++ phy/gui/widgets.py | 8 ++++++++ 2 files changed, 25 insertions(+) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 30cf00283..1621e217e 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -88,7 +88,24 @@ def on_test(arg): # Test table #------------------------------------------------------------------------------ +def test_table_default_sort(qtbot): + table = Table() + table.show() + qtbot.waitForWindowShown(table) + + def count(id): + return 10000.5 - 10 * id + table.add_column(count, default_sort=True) + table.set_rows(range(10)) + + assert table.default_sort == 'count' + + table.close() + + def test_table_duplicates(qtbot, table): + assert table.default_sort is None + table.select([1, 1]) assert table.selected == [1] # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index cc83842ec..85d58ee4a 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -269,6 +269,13 @@ def select(self, ids, do_emit=True): do_emit = str(do_emit).lower() self.eval_js('table.select({}, {});'.format(dumps(ids), do_emit)) + @property + def default_sort(self): + """Name of the first column that acts as default sort.""" + for (name, func, options) in self._columns: + if options.get('default_sort', False): + return name + @property def selected(self): """Currently selected rows.""" @@ -276,4 +283,5 @@ def selected(self): @property def current_sort(self): + """Current sort: a tuple `(name, dir)`.""" return tuple(self.eval_js('table.currentSort()')) From 3f3a324f79e6f5ce4c388e804c5247232e60cc67 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 12:52:09 +0100 Subject: [PATCH 0584/1059] WIP: default sort --- phy/gui/tests/test_widgets.py | 13 ++++++++++--- phy/gui/widgets.py | 23 +++++++++++++++++++---- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 1621e217e..09beff0a5 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -95,16 +95,23 @@ def test_table_default_sort(qtbot): def count(id): return 10000.5 - 10 * id - table.add_column(count, default_sort=True) + table.add_column(count, default_sort='asc') table.set_rows(range(10)) - assert table.default_sort == 'count' + assert table.default_sort == ('count', 'asc') + table.next() + assert table.selected == [9] + + table.sort_by('id', 'desc') + table.set_rows(range(11)) + table.next() + assert table.selected == [10] table.close() def test_table_duplicates(qtbot, table): - assert table.default_sort is None + assert table.default_sort == (None, None) table.select([1, 1]) assert table.selected == [1] diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 85d58ee4a..7ffed326b 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -224,7 +224,7 @@ def __init__(self): '''.format(self._table_id)) self._columns = [('id', (lambda _: _), {})] - def add_column(self, func, name=None, show=True, default_sort=False): + def add_column(self, func, name=None, show=True, default_sort=None): """Add a column function which takes an id as argument and returns a value.""" assert func @@ -246,14 +246,28 @@ def _get_row(self, id): def set_rows(self, ids): """Set the rows of the table.""" + # Determine the sort column and dir to set after the rows. + sort_col, sort_dir = self.current_sort + default_sort_col, default_sort_dir = self.default_sort + + sort_col = sort_col or default_sort_col + sort_dir = sort_dir or default_sort_dir or 'desc' + + # Set the rows. + logger.debug("Set %d rows.", len(ids)) items = [self._get_row(id) for id in ids] data = _create_json_dict(items=items, cols=self.column_names, ) self.eval_js('table.setData({});'.format(data)) + # Sort. + if sort_col: + self.sort_by(sort_col, sort_dir) + def sort_by(self, header, dir='asc'): """Sort by a given variable.""" + logger.debug("Sort by `%s` %s.", header, dir) self.eval_js('table.sortBy("{}", "{}");'.format(header, dir)) def next(self): @@ -271,10 +285,11 @@ def select(self, ids, do_emit=True): @property def default_sort(self): - """Name of the first column that acts as default sort.""" + """Default sort as a pair `(name, dir)`.""" for (name, func, options) in self._columns: - if options.get('default_sort', False): - return name + if options.get('default_sort', None): + return name, options['default_sort'] + return None, None @property def selected(self): From 26809d54c79755c5f110192e0947af2861041587 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 12:54:07 +0100 Subject: [PATCH 0585/1059] Comment waitForWindowShown() --- phy/gui/tests/test_widgets.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 09beff0a5..04fc70fa8 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -19,7 +19,7 @@ def table(qtbot): table = Table() table.show() - qtbot.waitForWindowShown(table) + # qtbot.waitForWindowShown(table) def count(id): return 10000.5 - 10 * id @@ -43,7 +43,7 @@ def skip(id): def test_widget_empty(qtbot): widget = HTMLWidget() widget.show() - qtbot.waitForWindowShown(widget) + # qtbot.waitForWindowShown(widget) # qtbot.stop() @@ -53,7 +53,7 @@ def test_widget_html(qtbot): widget.add_header('') widget.set_body('Hello world!') widget.show() - qtbot.waitForWindowShown(widget) + # qtbot.waitForWindowShown(widget) assert 'Hello world!' in widget.html() @@ -61,7 +61,7 @@ def test_widget_javascript_1(qtbot): widget = HTMLWidget() widget.eval_js('number = 1;') widget.show() - qtbot.waitForWindowShown(widget) + # qtbot.waitForWindowShown(widget) assert widget.eval_js('number') == 1 @@ -69,7 +69,7 @@ def test_widget_javascript_1(qtbot): def test_widget_javascript_2(qtbot): widget = HTMLWidget() widget.show() - qtbot.waitForWindowShown(widget) + # qtbot.waitForWindowShown(widget) _out = [] @widget.connect_ @@ -91,7 +91,7 @@ def on_test(arg): def test_table_default_sort(qtbot): table = Table() table.show() - qtbot.waitForWindowShown(table) + # qtbot.waitForWindowShown(table) def count(id): return 10000.5 - 10 * id From 61fc940563714fee8b58e4276c1289d1fd2ae7a8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 12:57:46 +0100 Subject: [PATCH 0586/1059] Update manual clustering component --- phy/cluster/manual/gui_component.py | 30 ++++++----------------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index d2c000fc4..1a971c61a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -224,7 +224,6 @@ def on_cluster(up): self.gui.emit('cluster', up) # Create the cluster views. - self._default_sort = None self._create_cluster_views() # Default columns. @@ -232,12 +231,12 @@ def skip(cluster_id): """Whether to skip that cluster.""" return (self.cluster_meta.get('group', cluster_id) in ('noise', 'mua')) - self.add_column(skip, options={'show': False}) + self.add_column(skip, show=False) def good(cluster_id): """Good column for color.""" return self.cluster_meta.get('group', cluster_id) == 'good' - self.add_column(good, options={'show': False}) + self.add_column(good, show=False) self._best = None @@ -311,30 +310,14 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) - def add_column(self, func=None, name=None, options=None): - options = options or {} - name = name or func.__name__ - assert name - if options.get('is_default_sort', False): - self._default_sort = name - self.cluster_view.add_column(func=func, name=name, options=options) - self.similarity_view.add_column(func=func, name=name, options=options) + def add_column(self, *args, **kwargs): + self.cluster_view.add_column(*args, **kwargs) + self.similarity_view.add_column(*args, **kwargs) def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" - - # Get the current sort of the cluster view. - sort_col, sort_dir = self.cluster_view.current_sort - - # Update the cluster view rows. self.cluster_view.set_rows(self.clustering.cluster_ids) - # Sort with the previous sort or the default one. - sort_col = sort_col or self._default_sort - assert sort_col - sort_dir = sort_dir or 'desc' - self.cluster_view.sort_by(sort_col, sort_dir) - def _update_similarity_view(self): """Update the similarity view with matches for the specified clusters.""" @@ -403,8 +386,7 @@ def _emit_select(self, cluster_ids): # ------------------------------------------------------------------------- def set_quality_func(self, f): - self.add_column(func=f, name='quality', - options={'show': True, 'is_default_sort': True}) + self.add_column(func=f, name='quality', show=True, default_sort='desc') self._update_cluster_view() def set_similarity_func(self, f): From 6ac1b65da2f932ce8621238c1059c87550b46e56 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 14:32:09 +0100 Subject: [PATCH 0587/1059] Load and save data in Context --- phy/io/context.py | 12 +++++++++++- phy/io/tests/test_context.py | 5 +++++ phy/utils/__init__.py | 1 + 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/phy/io/context.py b/phy/io/context.py index 5c6eb1fa5..d412e7351 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -23,7 +23,7 @@ "Install it with `conda install dask`.") from .array import read_array, write_array -from phy.utils import Bunch +from phy.utils import Bunch, _save_json, _load_json logger = logging.getLogger(__name__) @@ -250,6 +250,16 @@ def map(self, f, *args): else: return self._map_serial(f, *args) + def save(self, name, data): + """Save a dictionary in a JSON file within the cache directory.""" + path = op.join(self.cache_dir, name + '.json') + _save_json(path, data) + + def load(self, name): + """Load saved data from the cache directory.""" + path = op.join(self.cache_dir, name + '.json') + return _load_json(path) + def __getstate__(self): """Make sure that this class is picklable.""" state = self.__dict__.copy() diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index d1caf4b42..ab189c85a 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -80,6 +80,11 @@ def test_read_write(tempdir): ae(read_array(op.join(tempdir, 'test.npy')), x) +def test_context_load_save(context): + context.save('hello', {'text': 'world'}) + assert context.load('hello')['text'] == 'world' + + def test_context_cache(context): _res = [] diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 5355a3aad..f64073d02 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -3,6 +3,7 @@ """Utilities.""" +from ._misc import _load_json, _save_json from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, Bunch, _is_list) from .event import EventEmitter, ProgressReporter From 3f237b217662cc1e7887118755bd509ff1af6e54 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 14:45:35 +0100 Subject: [PATCH 0588/1059] WIP --- phy/gui/gui.py | 1 - phy/io/context.py | 3 ++- phy/io/tests/test_context.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 549d0af22..7ab73230f 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -124,7 +124,6 @@ def add_view(self, closable=True, floatable=True, floating=None, - # parent=None, # object to pass in the raised events **kwargs): """Add a widget to the main window.""" diff --git a/phy/io/context.py b/phy/io/context.py index d412e7351..5d0a1c48e 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -23,7 +23,7 @@ "Install it with `conda install dask`.") from .array import read_array, write_array -from phy.utils import Bunch, _save_json, _load_json +from phy.utils import Bunch, _save_json, _load_json, _ensure_dir_exists logger = logging.getLogger(__name__) @@ -253,6 +253,7 @@ def map(self, f, *args): def save(self, name, data): """Save a dictionary in a JSON file within the cache directory.""" path = op.join(self.cache_dir, name + '.json') + _ensure_dir_exists(op.dirname(path)) _save_json(path, data) def load(self, name): diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index ab189c85a..2397716ce 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -81,8 +81,8 @@ def test_read_write(tempdir): def test_context_load_save(context): - context.save('hello', {'text': 'world'}) - assert context.load('hello')['text'] == 'world' + context.save('a/hello', {'text': 'world'}) + assert context.load('a/hello')['text'] == 'world' def test_context_cache(context): From d5610c655634665535359342db0b59f2d16be7de Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 14:48:49 +0100 Subject: [PATCH 0589/1059] Add GUI.name property --- phy/gui/gui.py | 4 ++++ phy/gui/tests/test_gui.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 7ab73230f..ca72b2987 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -192,6 +192,10 @@ def view_count(self): # Status bar # ------------------------------------------------------------------------- + @property + def name(self): + return str(self.windowTitle()) + @property def status_message(self): """The message in the status bar.""" diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 85e6ff06e..7950b7c54 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -45,6 +45,8 @@ def test_gui_1(qtbot): gui = GUI(position=(200, 100), size=(100, 100)) qtbot.addWidget(gui) + assert gui.name == 'GUI' + # Increase coverage. @gui.connect_ def on_show(): From 3f5ebb5ad66593f7a0e075440f7437196a064bd7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 14:52:56 +0100 Subject: [PATCH 0590/1059] Fixes --- phy/gui/gui.py | 2 ++ phy/gui/tests/test_gui.py | 2 ++ phy/io/context.py | 3 +++ 3 files changed, 7 insertions(+) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index ca72b2987..6e916b0bc 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -228,6 +228,8 @@ def restore_geometry_state(self, gs): This function can be called in `on_show()`. """ + if not gs: + return if gs.get('geometry', None): self.restoreGeometry((gs['geometry'])) if gs.get('state', None): diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 7950b7c54..b800092a8 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -132,6 +132,8 @@ def on_close(): def on_show(): gui.restore_geometry_state(_gs[0]) + assert gui.restore_geometry_state(None) is None + qtbot.addWidget(gui) gui.show() diff --git a/phy/io/context.py b/phy/io/context.py index 5d0a1c48e..b4a70aba9 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -259,6 +259,9 @@ def save(self, name, data): def load(self, name): """Load saved data from the cache directory.""" path = op.join(self.cache_dir, name + '.json') + if not op.exists(path): + logger.debug("The file `%s` doesn't exist.", path) + return return _load_json(path) def __getstate__(self): From 0f16eb433fba5a1ce1e118846f91332bb1c74630 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:02:29 +0100 Subject: [PATCH 0591/1059] Ensure that GUI.closeEvent() is called only once --- phy/gui/gui.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 6e916b0bc..d40ec7c6e 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -65,6 +65,9 @@ def __init__(self, size=None, title=None, ): + # HACK to ensure that closeEvent is called only twice (seems like a + # Qt bug). + self._closed = False if not QApplication.instance(): # pragma: no cover raise RuntimeError("A Qt application must be created.") super(GUI, self).__init__() @@ -101,6 +104,9 @@ def unconnect_(self, *args, **kwargs): def closeEvent(self, e): """Qt slot when the window is closed.""" + if self._closed: + return + self._closed = True res = self.emit('close') # Discard the close event if False is returned by one of the callback # functions. From 099fa72edeb69279af0649288c99c91018dc09ae Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:09:29 +0100 Subject: [PATCH 0592/1059] Add load_gui_plugins() function --- phy/gui/gui.py | 28 ++++++++++++++++++++++++++++ phy/gui/tests/test_gui.py | 18 +++++++++++++++++- phy/utils/__init__.py | 3 ++- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index d40ec7c6e..b1f60174d 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -13,6 +13,8 @@ from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) from phy.utils.event import EventEmitter +from phy.utils import load_master_config +from phy.utils.plugin import get_plugin logger = logging.getLogger(__name__) @@ -25,6 +27,32 @@ def _title(widget): return str(widget.windowTitle()).lower() +def load_gui_plugins(gui, plugins=None, session=None): + """Attach a list of plugins to a GUI. + + By default, the list of plugins is taken from the `c.TheGUI.plugins` + parameter, where `TheGUI` is the name of the GUI class. + + """ + session = session or {} + + # GUI name. + name = gui.name + + # If no plugins are specified, load the master config and + # get the list of user plugins to attach to the GUI. + if plugins is None: + config = load_master_config() + plugins = config[name].plugins + if not isinstance(plugins, list): + plugins = [] + + # Attach the plugins to the GUI. + for plugin in plugins: + logger.info("Attach plugin `%s` to %s.", plugin, name) + get_plugin(plugin)().attach_to_gui(gui, session) + + class DockWidget(QDockWidget): """A QDockWidget that can emit events.""" def __init__(self, *args, **kwargs): diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index b800092a8..0466bc78e 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -9,7 +9,8 @@ from pytest import raises from ..qt import Qt, QApplication -from ..gui import GUI +from ..gui import GUI, load_gui_plugins +from phy.utils import IPlugin from phy.utils._color import _random_color @@ -75,6 +76,21 @@ def on_close_widget(): gui.close() +def test_load_gui_plugins(gui, tempdir): + + load_gui_plugins(gui) + + _tmp = [] + + class MyPlugin(IPlugin): + def attach_to_gui(self, gui, session): + _tmp.append(session) + + load_gui_plugins(gui, plugins=['MyPlugin'], session='hello') + + assert _tmp == ['hello'] + + def test_gui_component(gui): class TestComponent(object): diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index f64073d02..187ae593a 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -7,4 +7,5 @@ from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, Bunch, _is_list) from .event import EventEmitter, ProgressReporter -from .config import _ensure_dir_exists +from .plugin import IPlugin, get_plugin, get_all_plugins +from .config import _ensure_dir_exists, load_master_config From caac8729d8cfeb2d98a5a2b0479cbd7da813b324 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:10:13 +0100 Subject: [PATCH 0593/1059] Export load_gui_plugins --- phy/gui/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index e3a05c5f4..d9aa2be96 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -4,5 +4,5 @@ """GUI routines.""" from .qt import require_qt, create_app, run_app -from .gui import GUI +from .gui import GUI, load_gui_plugins from .actions import Actions From 06a541d083b82ef4f7bfb18e883fd39a6d8bce13 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:11:18 +0100 Subject: [PATCH 0594/1059] Update log --- phy/gui/widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 7ffed326b..d3159361a 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -254,7 +254,7 @@ def set_rows(self, ids): sort_dir = sort_dir or default_sort_dir or 'desc' # Set the rows. - logger.debug("Set %d rows.", len(ids)) + logger.debug("Set %d rows in the cluster view.", len(ids)) items = [self._get_row(id) for id in ids] data = _create_json_dict(items=items, cols=self.column_names, From bcec975cdcab1cb130573c4f32856cf8e00ae4ea Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:23:08 +0100 Subject: [PATCH 0595/1059] Enable manual build in HTML widget --- phy/cluster/manual/gui_component.py | 4 ++-- phy/gui/widgets.py | 30 +++++++++++++++-------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 1a971c61a..18226f5d7 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -281,11 +281,11 @@ def _create_actions(self, gui): def _create_cluster_views(self): # Create the cluster view. self.cluster_view = ClusterView() - self.cluster_view.show() + self.cluster_view.build() # Create the similarity view. self.similarity_view = ClusterView() - self.similarity_view.show() + self.similarity_view.build() # Selection in the cluster view. @self.cluster_view.connect_ diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index d3159361a..980b90b6f 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -139,18 +139,21 @@ def html(self): """Return the full HTML source of the widget.""" return self.page().mainFrame().toHtml() - def _build(self): + def build(self): """Build the full HTML source.""" - styles = '\n\n'.join(self._styles) - html = _PAGE_TEMPLATE.format(title=self.title, - styles=styles, - header=self._header, - body=self._body, - ) - logger.log(5, "Set HTML: %s", html) - static_dir = op.join(op.realpath(op.dirname(__file__)), 'static/') - base_url = QUrl().fromLocalFile(static_dir) - self.setHtml(html, base_url) + if self.is_built(): # pragma: no cover + return + with _wait_signal(self.loadFinished, 20): + styles = '\n\n'.join(self._styles) + html = _PAGE_TEMPLATE.format(title=self.title, + styles=styles, + header=self._header, + body=self._body, + ) + logger.log(5, "Set HTML: %s", html) + static_dir = op.join(op.realpath(op.dirname(__file__)), 'static/') + base_url = QUrl().fromLocalFile(static_dir) + self.setHtml(html, base_url) def is_built(self): return self.html() != '' @@ -177,9 +180,8 @@ def _emit_from_js(self, name, arg_json): self.emit(text_type(name), json.loads(text_type(arg_json))) def show(self): - with _wait_signal(self.loadFinished, 20): - self._build() - super(HTMLWidget, self).show() + self.build() + super(HTMLWidget, self).show() # Call the pending JS eval calls after the page has been built. assert self.is_built() for expr in self._pending_js_eval: From 855f010627de41b333ccd6a59e3a5b5814044490 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:49:45 +0100 Subject: [PATCH 0596/1059] Add PanZoom.enable_keyboard_pan property --- phy/plot/panzoom.py | 3 ++- phy/plot/plot.py | 7 +++++++ phy/plot/tests/test_panzoom.py | 5 +++++ phy/plot/tests/test_plot.py | 3 +++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index e863947af..969c96129 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -85,6 +85,7 @@ def __init__(self, self._zoom_coeff = self._default_zoom_coeff self._wheel_coeff = self._default_wheel_coeff + self.enable_keyboard_pan = True self._zoom_to_pointer = True self._canvas_aspect = np.ones(2) @@ -378,7 +379,7 @@ def on_key_press(self, event): return # Pan. - if key in self._arrows: + if self.enable_keyboard_pan and key in self._arrows: self._pan_keyboard(key) # Zoom. diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 66e0f0d80..8fbbb28cd 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -166,6 +166,13 @@ def __init__(self, interacts, **kwargs): self._items = [] # List of view items instances. self._visuals = {} + @property + def panzoom(self): + """PanZoom instance from the interact list, if it exists.""" + for interact in self.interacts: + if isinstance(interact, PanZoom): + return interact + # To override # ------------------------------------------------------------------------- diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 5c464f15d..70d11ea1d 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -189,6 +189,11 @@ def test_panzoom_pan_keyboard(qtbot, canvas_pz, panzoom): c.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) assert pz.pan == [0, 0] + # Disable keyboard pan. + pz.enable_keyboard_pan = False + c.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + def test_panzoom_zoom_mouse(qtbot, canvas_pz, panzoom): c = canvas_pz diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index a04a77f3b..06353335c 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -9,6 +9,7 @@ import numpy as np +from ..panzoom import PanZoom from ..plot import GridView, BoxedView, StackedView from ..utils import _get_linear_x @@ -34,6 +35,8 @@ def test_grid_scatter(qtbot): view = GridView(2, 3) n = 1000 + assert isinstance(view.panzoom, PanZoom) + x = np.random.randn(n) y = np.random.randn(n) From d8d0bd438c42ac59b8d08a61c41e248367177a2e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 15:59:04 +0100 Subject: [PATCH 0597/1059] Add keyboard shortcuts to move in the cluster view --- phy/cluster/manual/gui_component.py | 10 ++++++++++ phy/cluster/manual/tests/test_gui_component.py | 13 +++++++++++++ phy/cluster/manual/views.py | 4 ++++ phy/gui/widgets.py | 2 +- 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 18226f5d7..72fba2667 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -170,6 +170,8 @@ class ManualClustering(object): 'reset': 'ctrl+alt+space', 'next': 'space', 'previous': 'shift+space', + 'next_best': 'down', + 'previous_best': 'up', # Misc. 'save': 'Save', @@ -272,6 +274,8 @@ def _create_actions(self, gui): self.actions.add(self.reset) self.actions.add(self.next) self.actions.add(self.previous) + self.actions.add(self.next_best) + self.actions.add(self.previous_best) # Others. self.actions.add(self.undo) @@ -467,6 +471,12 @@ def reset(self): self._update_cluster_view() self.cluster_view.next() + def next_best(self): + self.cluster_view.next() + + def previous_best(self): + self.cluster_view.previous() + def next(self): if not self.selected: self.cluster_view.next() diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index d8c2a413c..8051076d3 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -245,6 +245,19 @@ def test_manual_clustering_action_reset(qtbot, manual_clustering): assert mc.selected == [30, 20] +def test_manual_clustering_action_nav(qtbot, manual_clustering): + mc = manual_clustering + + mc.actions.reset() + assert mc.selected == [30] + + mc.actions.next_best() + assert mc.selected == [20] + + mc.actions.previous_best() + assert mc.selected == [30] + + def test_manual_clustering_action_move_1(qtbot, manual_clustering): mc = manual_clustering diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 5591743c7..3b2d67922 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -193,6 +193,10 @@ def on_key_press(self, e): def attach(self, gui): """Attach the view to the GUI.""" + # Disable keyboard pan so that we can use arrows as global shortcuts + # in the GUI. + self.panzoom.enable_keyboard_pan = False + gui.add_view(self) gui.connect_(self.on_select) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 980b90b6f..da19dff54 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -256,7 +256,7 @@ def set_rows(self, ids): sort_dir = sort_dir or default_sort_dir or 'desc' # Set the rows. - logger.debug("Set %d rows in the cluster view.", len(ids)) + logger.debug("Set %d rows in the table.", len(ids)) items = [self._get_row(id) for id in ids] data = _create_json_dict(items=items, cols=self.column_names, From 6c8469febea1042c311b04f3fb2e847593c94937 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 16:42:01 +0100 Subject: [PATCH 0598/1059] Increase coverage --- phy/io/tests/test_context.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 2397716ce..5578a3156 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -81,6 +81,8 @@ def test_read_write(tempdir): def test_context_load_save(context): + assert context.load('unexisting') is None + context.save('a/hello', {'text': 'world'}) assert context.load('a/hello')['text'] == 'world' From b096f0a72c5ab44b2363d1d9084bd8ef9e62418f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 18:42:16 +0100 Subject: [PATCH 0599/1059] Updates --- phy/cluster/manual/gui_component.py | 4 ++-- phy/utils/cli.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 72fba2667..835d397bb 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -95,8 +95,8 @@ def mean_masked_features_similarity(c0, c1): n_features_per_channel=nfc, ) d = 1. / max(1e-10, d) # From distance to similarity. - logger.debug("Computed cluster similarity for (%d, %d): %.3f.", - c0, c1, d) + logger.log(5, "Computed cluster similarity for (%d, %d): %.3f.", + c0, c1, d) return d return (max_waveform_amplitude_quality, diff --git a/phy/utils/cli.py b/phy/utils/cli.py index b6a03e84c..bec6dc9c2 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -16,8 +16,8 @@ import click -import phy -from phy import add_default_handler, DEBUG, _Formatter, _logger_fmt +from phy import (add_default_handler, DEBUG, _Formatter, _logger_fmt, + __version_git__) logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def _add_log_file(filename): formatter = _Formatter(fmt=_logger_fmt, datefmt='%Y-%m-%d %H:%M:%S') handler.setFormatter(formatter) - logger.addHandler(handler) + logging.getLogger().addHandler(handler) #------------------------------------------------------------------------------ @@ -57,7 +57,7 @@ def _add_log_file(filename): #------------------------------------------------------------------------------ @click.group() -@click.version_option(version=phy.__version_git__) +@click.version_option(version=__version_git__) @click.help_option('-h', '--help') @click.pass_context def phy(ctx): From d83067975f4ac45bf866798722e62e811578ba17 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 19:07:25 +0100 Subject: [PATCH 0600/1059] Refactor default sort in HTMLTable --- phy/gui/tests/test_widgets.py | 3 ++- phy/gui/widgets.py | 29 ++++++++++++++++------------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 04fc70fa8..c12521cae 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -95,7 +95,8 @@ def test_table_default_sort(qtbot): def count(id): return 10000.5 - 10 * id - table.add_column(count, default_sort='asc') + table.add_column(count) + table.set_default_sort('count', 'asc') table.set_rows(range(10)) assert table.default_sort == ('count', 'asc') diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index da19dff54..413a2e626 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -7,6 +7,7 @@ # Imports # ----------------------------------------------------------------------------- +from collections import OrderedDict import json import logging import os.path as op @@ -224,27 +225,29 @@ def __init__(self): self.add_body(''''''.format(self._table_id)) - self._columns = [('id', (lambda _: _), {})] + self._columns = OrderedDict() + self._default_sort = (None, None) + self.add_column(lambda _: _, name='id') - def add_column(self, func, name=None, show=True, default_sort=None): + def add_column(self, func, name=None, show=True): """Add a column function which takes an id as argument and returns a value.""" assert func name = name or func.__name__ - options = {'show': show, - 'default_sort': default_sort, - } - self._columns.append((name, func, options)) + d = {'func': func, + 'show': show, + } + self._columns[name] = d return func @property def column_names(self): - return [name for (name, func, options) in self._columns - if options.get('show', True)] + return [name for (name, d) in self._columns.items() + if d.get('show', True)] def _get_row(self, id): """Create a row dictionary for a given object id.""" - return {name: func(id) for (name, func, options) in self._columns} + return {name: d['func'](id) for (name, d) in self._columns.items()} def set_rows(self, ids): """Set the rows of the table.""" @@ -288,10 +291,10 @@ def select(self, ids, do_emit=True): @property def default_sort(self): """Default sort as a pair `(name, dir)`.""" - for (name, func, options) in self._columns: - if options.get('default_sort', None): - return name, options['default_sort'] - return None, None + return self._default_sort + + def set_default_sort(self, name, sort_dir='desc'): + self._default_sort = name, sort_dir @property def selected(self): From 220c24c0fed6a90e6495602516bd17b3c069304a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 19:23:14 +0100 Subject: [PATCH 0601/1059] Forbid lambda column functions without explicit names --- phy/gui/tests/test_widgets.py | 5 ++++- phy/gui/widgets.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index c12521cae..6c43b7fc0 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from pytest import yield_fixture +from pytest import yield_fixture, raises from ..widgets import HTMLWidget, Table @@ -93,6 +93,9 @@ def test_table_default_sort(qtbot): table.show() # qtbot.waitForWindowShown(table) + with raises(ValueError): + table.add_column(lambda _: _) + def count(id): return 10000.5 - 10 * id table.add_column(count) diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 413a2e626..afc90e32d 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -234,6 +234,9 @@ def add_column(self, func, name=None, show=True): returns a value.""" assert func name = name or func.__name__ + if name == '': + raise ValueError("Please provide a valid name for " + + name) d = {'func': func, 'show': show, } From 3cd1832d57acceaddbeef3d65dc16264d64f25a3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 19:28:15 +0100 Subject: [PATCH 0602/1059] Update default sort in manual clustering component --- phy/cluster/manual/gui_component.py | 12 +++++++----- phy/cluster/manual/tests/test_gui_component.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 835d397bb..067ed8cdd 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -314,9 +314,11 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) - def add_column(self, *args, **kwargs): - self.cluster_view.add_column(*args, **kwargs) - self.similarity_view.add_column(*args, **kwargs) + def add_column(self, func=None, name=None, show=True): + if func is None: + return lambda f: self.add_column(f, name=name, show=show) + self.cluster_view.add_column(func, name=name, show=show) + self.similarity_view.add_column(func, name=name, show=show) def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" @@ -389,8 +391,8 @@ def _emit_select(self, cluster_ids): # Public methods # ------------------------------------------------------------------------- - def set_quality_func(self, f): - self.add_column(func=f, name='quality', show=True, default_sort='desc') + def set_default_sort(self, name, dir): + self.cluster_view.set_default_sort(name, dir) self._update_cluster_view() def set_similarity_func(self, f): diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 8051076d3..b15163eb0 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -36,7 +36,9 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, shortcuts={'undo': 'ctrl+z'}, ) mc.attach(gui) - mc.set_quality_func(quality) + + mc.add_column(quality, name='quality') + mc.set_default_sort('quality', 'desc') mc.set_similarity_func(similarity) yield mc @@ -113,11 +115,15 @@ def test_manual_clustering_default_metrics(qtbot, gui): n_features_per_channel=npc, spikes_per_cluster=spc, ) - mc.set_quality_func(q) + + @mc.add_column() + def quality(cluster): + return q(cluster) + + mc.set_default_sort('quality', 'desc') mc.set_similarity_func(s) - quality = [(c, q(c)) for c in spc] - best = sorted(quality, key=itemgetter(1))[-1][0] + best = sorted([(c, q(c)) for c in spc], key=itemgetter(1))[-1][0] similarity = [(d, s(best, d)) for d in spc if d != best] match = sorted(similarity, key=itemgetter(1))[-1][0] @@ -182,7 +188,8 @@ def test_manual_clustering_split_2(gui, quality, similarity): mc = ManualClustering(spike_clusters) mc.attach(gui) - mc.set_quality_func(quality) + mc.add_column(quality, name='quality') + mc.set_default_sort('quality', 'desc') mc.set_similarity_func(similarity) mc.split([0]) From 1ff696f0386b7d33fa00c8399088d72ab23f1598 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 19:34:22 +0100 Subject: [PATCH 0603/1059] Fix --- phy/cluster/manual/gui_component.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 067ed8cdd..4f511f497 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -391,8 +391,8 @@ def _emit_select(self, cluster_ids): # Public methods # ------------------------------------------------------------------------- - def set_default_sort(self, name, dir): - self.cluster_view.set_default_sort(name, dir) + def set_default_sort(self, name, sort_dir='desc'): + self.cluster_view.set_default_sort(name, sort_dir) self._update_cluster_view() def set_similarity_func(self, f): From decc6802f00ab94eb7003e9fe25e0c6250e0fae3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 19:43:55 +0100 Subject: [PATCH 0604/1059] Add message --- phy/cluster/manual/gui_component.py | 1 + 1 file changed, 1 insertion(+) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4f511f497..5b7b56721 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -392,6 +392,7 @@ def _emit_select(self, cluster_ids): # ------------------------------------------------------------------------- def set_default_sort(self, name, sort_dir='desc'): + logger.debug("Set default sort `%s` %s.", name, sort_dir) self.cluster_view.set_default_sort(name, sort_dir) self._update_cluster_view() From 8a8ef2d4ffebdff7f0a16278849c362eddf3ae36 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 22:49:48 +0100 Subject: [PATCH 0605/1059] Fix GUI plugins --- phy/gui/gui.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index b1f60174d..2ce871d07 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -35,17 +35,17 @@ def load_gui_plugins(gui, plugins=None, session=None): """ session = session or {} + plugins = plugins or [] # GUI name. name = gui.name # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. - if plugins is None: - config = load_master_config() - plugins = config[name].plugins - if not isinstance(plugins, list): - plugins = [] + config = load_master_config() + plugins_conf = config[name].plugins + plugins_conf = plugins_conf if isinstance(plugins_conf, list) else [] + plugins.extend(plugins_conf) # Attach the plugins to the GUI. for plugin in plugins: From 772dd4f3b69d5a22c03afc01239243aba45fef7d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 23:01:36 +0100 Subject: [PATCH 0606/1059] Bug fixes --- phy/cluster/manual/gui_component.py | 4 ++++ phy/gui/widgets.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 5b7b56721..f0df3f4b3 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -393,8 +393,12 @@ def _emit_select(self, cluster_ids): def set_default_sort(self, name, sort_dir='desc'): logger.debug("Set default sort `%s` %s.", name, sort_dir) + # Set the default sort. self.cluster_view.set_default_sort(name, sort_dir) + # Reset the cluster view. self._update_cluster_view() + # Sort by the default sort. + self.cluster_view.sort_by(name, sort_dir) def set_similarity_func(self, f): """Set the similarity function.""" diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index afc90e32d..440e9022b 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -273,10 +273,10 @@ def set_rows(self, ids): if sort_col: self.sort_by(sort_col, sort_dir) - def sort_by(self, header, dir='asc'): + def sort_by(self, name, sort_dir='asc'): """Sort by a given variable.""" - logger.debug("Sort by `%s` %s.", header, dir) - self.eval_js('table.sortBy("{}", "{}");'.format(header, dir)) + logger.debug("Sort by `%s` %s.", name, sort_dir) + self.eval_js('table.sortBy("{}", "{}");'.format(name, sort_dir)) def next(self): """Select the next non-skip row.""" From deaeb2ff1c7fce1498ab1e63a587d801b4322caf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 18 Nov 2015 23:30:06 +0100 Subject: [PATCH 0607/1059] WIP: show shortcuts --- phy/cluster/manual/gui_component.py | 70 +++++++++++++++-------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index f0df3f4b3..c92d8eb6c 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -175,6 +175,7 @@ class ManualClustering(object): # Misc. 'save': 'Save', + 'show_shortcuts': 'shift+h', 'undo': 'Undo', 'redo': 'Redo', } @@ -278,6 +279,7 @@ def _create_actions(self, gui): self.actions.add(self.previous_best) # Others. + self.actions.add(self.actions.show_shortcuts) self.actions.add(self.undo) self.actions.add(self.redo) self.actions.add(self.save) @@ -314,12 +316,6 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) - def add_column(self, func=None, name=None, show=True): - if func is None: - return lambda f: self.add_column(f, name=name, show=show) - self.cluster_view.add_column(func, name=name, show=show) - self.similarity_view.add_column(func, name=name, show=show) - def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" self.cluster_view.set_rows(self.clustering.cluster_ids) @@ -335,6 +331,40 @@ def _update_similarity_view(self): if c not in selection]) self.similarity_view.sort_by('similarity', 'desc') + def _emit_select(self, cluster_ids): + """Choose spikes from the specified clusters and emit the + `select` event on the GUI.""" + # Choose a spike subset. + spike_ids = select_spikes(np.array(cluster_ids), + self.n_spikes_max_per_cluster, + self.clustering.spikes_per_cluster) + logger.debug("Select clusters: %s (%d spikes).", + ', '.join(map(str, cluster_ids)), len(spike_ids)) + if self.gui: + self.gui.emit('select', cluster_ids, spike_ids) + + # Public methods + # ------------------------------------------------------------------------- + + def add_column(self, func=None, name=None, show=True): + if func is None: + return lambda f: self.add_column(f, name=name, show=show) + self.cluster_view.add_column(func, name=name, show=show) + self.similarity_view.add_column(func, name=name, show=show) + + def set_default_sort(self, name, sort_dir='desc'): + logger.debug("Set default sort `%s` %s.", name, sort_dir) + # Set the default sort. + self.cluster_view.set_default_sort(name, sort_dir) + # Reset the cluster view. + self._update_cluster_view() + # Sort by the default sort. + self.cluster_view.sort_by(name, sort_dir) + + def set_similarity_func(self, f): + """Set the similarity function.""" + self.similarity_func = f + def on_cluster(self, up): """Update the cluster views after clustering actions.""" @@ -376,34 +406,6 @@ def on_cluster(self, up): if similar: self.similarity_view.next() - def _emit_select(self, cluster_ids): - """Choose spikes from the specified clusters and emit the - `select` event on the GUI.""" - # Choose a spike subset. - spike_ids = select_spikes(np.array(cluster_ids), - self.n_spikes_max_per_cluster, - self.clustering.spikes_per_cluster) - logger.debug("Select clusters: %s (%d spikes).", - ', '.join(map(str, cluster_ids)), len(spike_ids)) - if self.gui: - self.gui.emit('select', cluster_ids, spike_ids) - - # Public methods - # ------------------------------------------------------------------------- - - def set_default_sort(self, name, sort_dir='desc'): - logger.debug("Set default sort `%s` %s.", name, sort_dir) - # Set the default sort. - self.cluster_view.set_default_sort(name, sort_dir) - # Reset the cluster view. - self._update_cluster_view() - # Sort by the default sort. - self.cluster_view.sort_by(name, sort_dir) - - def set_similarity_func(self, f): - """Set the similarity function.""" - self.similarity_func = f - def attach(self, gui): self.gui = gui From 41ececd74498cf721eda783a694aa3b72eca5b54 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 19 Nov 2015 22:03:50 +0100 Subject: [PATCH 0608/1059] Remove obsolete pre/post-transforms in Transforms --- phy/plot/tests/test_transform.py | 13 ------------- phy/plot/transform.py | 19 +++++-------------- 2 files changed, 5 insertions(+), 27 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index f26b719d6..e990ee332 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -171,19 +171,6 @@ def test_transform_chain_empty(array): ae(t.apply(array), array) -def test_transform_chain_pre_post(array): - class MyTransform(BaseTransform): - def pre_transforms(self, key=None): - return [MyTransform(key=key - 1)] - - def post_transforms(self, key=None): - return [MyTransform(key=key + 1), MyTransform(key=key + 2)] - - t = TransformChain([Translate(), MyTransform(key=0), Scale()]) - expected = [None, -1, 0, 1, 2, None] - assert [getattr(p, 'key', None) for p in t.cpu_transforms] == expected - - def test_transform_chain_one(array): translate = Translate(translate=[1, 2]) t = TransformChain([translate]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index d97048d06..f80f250c5 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -50,6 +50,11 @@ def wrapped(var, **kwargs): def _wrap(f, **kwargs_init): + """Pass extra keyword arguments to a function. + + Used to pass constructor arguments to class methods in transforms. + + """ def wrapped(*args, **kwargs): # Method kwargs first, then we update with the constructor kwargs. kwargs.update(kwargs_init) @@ -107,14 +112,6 @@ def __init__(self, **kwargs): # Pass the constructor kwargs to the methods. self.apply = _wrap_apply(self.apply, **kwargs) self.glsl = _wrap_glsl(self.glsl, **kwargs) - self.pre_transforms = _wrap(self.pre_transforms, **kwargs) - self.post_transforms = _wrap(self.post_transforms, **kwargs) - - def pre_transforms(self, **kwargs): - return [] - - def post_transforms(self, **kwargs): - return [] def apply(self, arr): raise NotImplementedError() @@ -260,13 +257,7 @@ def gpu_transforms(self): def add(self, transforms): """Add some transforms.""" for t in transforms: - if hasattr(t, 'pre_transforms'): - for p in t.pre_transforms(): - self.transforms.append(p) self.transforms.append(t) - if hasattr(t, 'post_transforms'): - for p in t.post_transforms(): - self.transforms.append(p) def get(self, class_name): """Get a transform in the chain from its name.""" From c86583abd47ec5dceb57a8e0c47e3df80df8635d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 19 Nov 2015 22:04:06 +0100 Subject: [PATCH 0609/1059] Flakify --- phy/plot/tests/test_transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index e990ee332..e54b50b78 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -14,7 +14,6 @@ from pytest import yield_fixture from ..transform import (_glslify, pixels_to_ndc, - BaseTransform, Translate, Scale, Range, Clip, Subplot, GPU, TransformChain, ) From 88681ba72e7a18956b09453950484897f7f37350 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 19 Nov 2015 22:49:45 +0100 Subject: [PATCH 0610/1059] WIP: waveform overlap --- phy/cluster/manual/views.py | 35 +++++++++++++++++++++++++---------- phy/plot/base.py | 6 +++++- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 3b2d67922..46d063ea0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -14,6 +14,7 @@ from phy.io.array import _index_of, _get_padded from phy.electrode.mea import linear_positions +from phy.gui import Actions from phy.plot import (BoxedView, StackedView, GridView, _get_linear_x) from phy.plot.utils import _get_boxes @@ -77,6 +78,11 @@ def _extract_wave(traces, spk, mask, wave_len=None): class WaveformView(BoxedView): normalization_percentile = .95 normalization_n_spikes = 1000 + overlap = True + + default_shortcuts = { + 'toggle_waveform_overlap': 'o', + } def __init__(self, waveforms=None, @@ -92,6 +98,9 @@ def __init__(self, """ + self._cluster_ids = None + self._spike_ids = None + # Initialize the view. if channel_positions is None: channel_positions = linear_positions(self.n_channels) @@ -136,6 +145,9 @@ def on_select(self, cluster_ids, spike_ids): if n_spikes == 0: return + self._cluster_ids = cluster_ids + self._spike_ids = spike_ids + # Relative spike clusters. # NOTE: the order of the clusters in cluster_ids matters. # It will influence the relative index of the clusters, which @@ -148,6 +160,10 @@ def on_select(self, cluster_ids, spike_ids): w = self.waveforms[spike_ids] colors = _selected_clusters_colors(n_clusters) t = _get_linear_x(n_spikes, self.n_samples) + # Overlap. + if self.overlap: + t = t + 2.5 * (spike_clusters_rel[:, np.newaxis] - + (n_clusters - 1) / 2.) # Depth as a function of the cluster index and masks. m = self.masks[spike_ids] @@ -181,15 +197,6 @@ def on_select(self, cluster_ids, spike_ids): self.build() self.update() - def on_cluster(self, up): - pass - - def on_mouse_move(self, e): - pass - - def on_key_press(self, e): - pass - def attach(self, gui): """Attach the view to the GUI.""" @@ -200,7 +207,15 @@ def attach(self, gui): gui.add_view(self) gui.connect_(self.on_select) - gui.connect_(self.on_cluster) + # gui.connect_(self.on_cluster) + + # TODO: customizable shortcut too + self.actions = Actions(gui, default_shortcuts=self.default_shortcuts) + self.actions.add(self.toggle_waveform_overlap, alias='o') + + def toggle_waveform_overlap(self): + self.overlap = not self.overlap + self.on_select(self._cluster_ids, self._spike_ids) class TraceView(StackedView): diff --git a/phy/plot/base.py b/phy/plot/base.py index 8ab993ae7..2154ac87c 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -76,7 +76,11 @@ def get_transforms(self): def get_post_transforms(self): """Return a GLSL snippet to insert after all transforms in the - vertex shader.""" + vertex shader. + + The snippet should modify `gl_Position`. + + """ return '' def set_data(self): From d762fd122087b16bc04526531c66494aa42434f0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 14:05:07 +0100 Subject: [PATCH 0611/1059] Default actions in GUI --- phy/gui/actions.py | 8 -------- phy/gui/gui.py | 8 ++++++++ phy/gui/tests/test_actions.py | 2 -- phy/gui/tests/test_gui.py | 2 +- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 43479d5b6..b87a1ce19 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -15,7 +15,6 @@ from six import string_types, PY3 from .qt import QKeySequence, QAction, require_qt -from .gui import GUI from phy.utils import Bunch logger = logging.getLogger(__name__) @@ -141,14 +140,8 @@ def __init__(self, gui, default_shortcuts=None): self._actions_dict = {} self._aliases = {} self._default_shortcuts = default_shortcuts or {} - assert isinstance(gui, GUI) self.gui = gui - # Default exit action. - @self.add(shortcut='Quit') - def exit(): - gui.close() - # Create and attach snippets. self.snippets = Snippets(gui, self) @@ -264,7 +257,6 @@ class Snippets(object): " ,.;?!_-+~=*/\(){}[]") def __init__(self, gui, actions): - assert isinstance(gui, GUI) self.gui = gui assert isinstance(actions, Actions) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 2ce871d07..956dff335 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -12,6 +12,7 @@ from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) +from .actions import Actions from phy.utils.event import EventEmitter from phy.utils import load_master_config from phy.utils.plugin import get_plugin @@ -118,6 +119,13 @@ def __init__(self, self._status_bar = QStatusBar() self.setStatusBar(self._status_bar) + # Default exit action. + self.default_actions = Actions(self) + + @self.default_actions.add(shortcut='Quit') + def exit(): + self.close() + # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 3ad1c0f76..0dcb98745 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -106,8 +106,6 @@ def press(): actions.press() assert _press == [0] - actions.exit() - def test_snippets_gui(qtbot, gui, actions): qtbot.addWidget(gui) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 0466bc78e..9a13cff2e 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -73,7 +73,7 @@ def on_close_widget(): view.close() assert _close == [0] - gui.close() + gui.default_actions.exit() def test_load_gui_plugins(gui, tempdir): From b188521a83448bc11c2b5896922f14d468463817 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 15:04:23 +0100 Subject: [PATCH 0612/1059] Minor bug fixes in phy.plot --- phy/plot/plot.py | 2 ++ phy/plot/utils.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 8fbbb28cd..6126730ca 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -42,6 +42,8 @@ def __getitem__(self, name): #------------------------------------------------------------------------------ def _prepare_scatter(x, y, color=None, size=None, marker=None): + x = np.asarray(x) + y = np.asarray(y) # Validate x and y. assert x.ndim == y.ndim == 1 assert x.shape == y.shape diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 7cf1c5cdd..deb82b4da 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -101,6 +101,8 @@ def _get_texture(arr, default, n_items, from_bounds): def _get_array(val, shape, default=None): """Ensure an object is an array with the specified shape.""" assert val is not None or default is not None + if hasattr(val, '__len__') and len(val) == 0: + val = None out = np.zeros(shape, dtype=np.float32) # This solves `ValueError: could not broadcast input array from shape (n) # into shape (n, 1)`. From 1aa8e5debd7f8d8025009590b335bbb528f7aa0d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 15:49:26 +0100 Subject: [PATCH 0613/1059] WIP: feature view --- phy/cluster/manual/tests/test_views.py | 48 +++++- phy/cluster/manual/views.py | 222 +++++++++++++++++++++++-- 2 files changed, 256 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index e7691becc..c9c7a3e4e 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -11,13 +11,14 @@ from pytest import raises from phy.io.mock import (artificial_waveforms, + artificial_features, artificial_spike_clusters, artificial_spike_samples, artificial_masks, artificial_traces, ) from phy.electrode.mea import staggered_positions -from ..views import WaveformView, TraceView, _extract_wave +from ..views import WaveformView, FeatureView, TraceView, _extract_wave #------------------------------------------------------------------------------ @@ -61,7 +62,7 @@ def test_extract_wave(): #------------------------------------------------------------------------------ -# Test views +# Test waveform view #------------------------------------------------------------------------------ def test_waveform_view(qtbot): @@ -99,6 +100,10 @@ def test_waveform_view(qtbot): v.close() +#------------------------------------------------------------------------------ +# Test trace view +#------------------------------------------------------------------------------ + def test_trace_view_no_spikes(qtbot): n_samples = 1000 n_channels = 12 @@ -133,3 +138,42 @@ def test_trace_view_spikes(qtbot): n_samples_per_spike=6, ) _show(qtbot, v) + + +#------------------------------------------------------------------------------ +# Test feature view +#------------------------------------------------------------------------------ + +def test_feature_view(qtbot): + n_spikes = 50 + n_channels = 5 + n_clusters = 2 + n_features = 4 + + features = artificial_features(n_spikes, n_channels, n_features) + masks = artificial_masks(n_spikes, n_channels) + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + spike_times = artificial_spike_samples(n_spikes) / 20000. + + # Create the view. + v = FeatureView(features=features, + masks=masks, + spike_times=spike_times, + spike_clusters=spike_clusters, + ) + # Select some spikes. + spike_ids = np.arange(n_spikes) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # Show the view. + v.show() + qtbot.waitForWindowShown(v.native) + + # Select other spikes. + spike_ids = np.arange(2, 10) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # qtbot.stop() + v.close() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 46d063ea0..bf9c52e22 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -71,8 +71,19 @@ def _extract_wave(traces, spk, mask, wave_len=None): return data, channels +def _get_data_bounds(arr, n_spikes=None, percentile=None): + n = arr.shape[0] + k = max(1, n // n_spikes) + w = np.abs(arr[::k]) + n = w.shape[0] + w = w.reshape((n, -1)) + w = w.max(axis=1) + m = np.percentile(w, percentile) + return [-1, -m, +1, +m] + + # ----------------------------------------------------------------------------- -# Views +# Waveform view # ----------------------------------------------------------------------------- class WaveformView(BoxedView): @@ -113,14 +124,9 @@ def __init__(self, self.waveforms = waveforms # Waveform normalization. - n = waveforms.shape[0] - k = max(1, n // self.normalization_n_spikes) - w = np.abs(waveforms[::k]) - n = w.shape[0] - w = w.reshape((n, -1)) - w = w.max(axis=1) - m = np.percentile(w, self.normalization_percentile) - self.data_bounds = [-1, -m, +1, +m] + self.data_bounds = _get_data_bounds(waveforms, + self.normalization_n_spikes, + self.normalization_percentile) # Masks. self.masks = masks @@ -218,6 +224,10 @@ def toggle_waveform_overlap(self): self.on_select(self._cluster_ids, self._spike_ids) +# ----------------------------------------------------------------------------- +# Trace view +# ----------------------------------------------------------------------------- + class TraceView(StackedView): def __init__(self, traces=None, @@ -340,13 +350,201 @@ def set_interval(self, interval): self.update() +# ----------------------------------------------------------------------------- +# Feature view +# ----------------------------------------------------------------------------- + +def _dimensions(x_channels, y_channels): + """Default dimensions matrix.""" + # time, depth time, (x, 0) time, (y, 0) time, (z, 0) + # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) + # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) + # time, (z', 0) (z', 0), (x, 2) (z', 1), (y, 2) (z', 2), (z, 2) + + n = len(x_channels) + assert len(y_channels) == n + y_dim = {} + x_dim = {} + # TODO: depth + x_dim[0, 0] = 'time' + y_dim[0, 0] = 'time' + + # Time in first column and first row. + for i in range(1, n + 1): + x_dim[0, i] = 'time' + y_dim[0, i] = (x_channels[i - 1], 0) + x_dim[i, 0] = 'time' + y_dim[i, 0] = (y_channels[i - 1], 0) + + for i in range(1, n + 1): + for j in range(1, n + 1): + x_dim[i, j] = (x_channels[i - 1], j - 1) + y_dim[i, j] = (y_channels[j - 1], i - 1) + + return x_dim, y_dim + + +def _get_spike_clusters_rel(spike_clusters, spike_ids, cluster_ids): + # Relative spike clusters. + # NOTE: the order of the clusters in cluster_ids matters. + # It will influence the relative index of the clusters, which + # in return influence the depth. + spike_clusters = spike_clusters[spike_ids] + assert np.all(np.in1d(spike_clusters, cluster_ids)) + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + return spike_clusters_rel + + +def _get_depth(masks, + spike_clusters_rel=None, + n_clusters=None, + ): + n_spikes, n_channels = masks.shape + masks = np.atleast_2d(masks) + assert masks.ndim == 2 + depth = (-0.1 - (spike_clusters_rel[:, np.newaxis] + masks) / + float(n_clusters + 10.)) + depth[masks <= 0.25] = 0 + assert depth.shape == (n_spikes, n_channels) + return depth + + +def _get_color(masks, spike_clusters_rel=None, n_clusters=None): + n_spikes = len(masks) + assert masks.shape == (n_spikes,) + assert spike_clusters_rel.shape == (n_spikes,) + + # Fetch the features. + colors = _selected_clusters_colors(n_clusters) + + # Color as a function of the mask. + color = colors[spike_clusters_rel] + hsv = rgb_to_hsv(color[:, :3]) + # Change the saturation and value as a function of the mask. + hsv[:, 1] *= masks + hsv[:, 2] *= .5 * (1. + masks) + color = hsv_to_rgb(hsv) + color = np.c_[color, .5 * np.ones((n_spikes, 1))] + return color + + +def _project_mask_depth(dim, masks, depth): + n_spikes = masks.shape[0] + if dim != 'time': + ch, fet = dim + m = masks[:, ch] + d = depth[:, ch] + else: + m = np.ones(n_spikes) + d = np.zeros(n_spikes) + return m, d + + class FeatureView(GridView): + normalization_percentile = .95 + normalization_n_spikes = 1000 + def __init__(self, features=None, - dimensions=None, - extra_features=None, + masks=None, + spike_times=None, + spike_clusters=None, + keys='interactive', ): - pass + + assert features.ndim == 3 + self.n_spikes, self.n_channels, self.n_features = features.shape + self.n_cols = self.n_features + 1 + self.features = features + + # Initialize the view. + super(FeatureView, self).__init__(self.n_cols, self.n_cols, keys=keys) + + # Feature normalization. + self.data_bounds = _get_data_bounds(features, + self.normalization_n_spikes, + self.normalization_percentile) + + # Masks. + self.masks = masks + + # Spike clusters. + assert spike_clusters.shape == (self.n_spikes,) + self.spike_clusters = spike_clusters + + # Spike times. + assert spike_times.shape == (self.n_spikes,) + self.spike_times = spike_times + + # Initialize the subplots. + self._plots = {(i, j): self[i, j].scatter(x=[], y=[], size=[]) + for i in range(self.n_cols) + for j in range(self.n_cols) + } + self.build() + self.update() + + def _get_feature(self, dim, spike_ids=None): + f = self.features[spike_ids] + assert f.ndim == 3 + + if dim == 'time': + t = self.spike_times[spike_ids] + t0, t1 = self.spike_times[0], self.spike_times[-1] + t = -1 + 2 * (t - t0) / float(t1 - t0) + return .9 * t + else: + assert len(dim) == 2 + ch, fet = dim + # TODO: normalization of features + return f[:, ch, fet] + + def on_select(self, cluster_ids, spike_ids): + n_clusters = len(cluster_ids) + n_spikes = len(spike_ids) + if n_spikes == 0: + return + + spike_clusters_rel = _get_spike_clusters_rel(self.spike_clusters, + spike_ids, + cluster_ids) + + masks = self.masks[spike_ids] + depth = _get_depth(masks, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + + x_dim, y_dim = _dimensions(range(self.n_cols), + range(self.n_cols)) + + # Plot all features. + # TODO: optim: avoid the loop. + for i in range(self.n_cols): + for j in range(self.n_cols): + + x = self._get_feature(x_dim[i, j], spike_ids) + y = self._get_feature(y_dim[i, j], spike_ids) + + mx, dx = _project_mask_depth(x_dim[i, j], masks, depth) + my, dy = _project_mask_depth(y_dim[i, j], masks, depth) + + d = np.maximum(dx, dy) + m = np.maximum(mx, my) + + color = _get_color(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + + self._plots[i, j].set_data(x=x, + y=y, + color=color, + depth=d, + data_bounds=self.data_bounds, + size=5 * np.ones(n_spikes), + ) + + self.build() + self.update() class CorrelogramView(GridView): From abc580d070a1f6a2fa26245d5de583fbea65635a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 16:23:33 +0100 Subject: [PATCH 0614/1059] Refactor waveform view --- phy/cluster/manual/views.py | 125 ++++++++++++++++-------------------- 1 file changed, 55 insertions(+), 70 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index bf9c52e22..84946238f 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -82,6 +82,50 @@ def _get_data_bounds(arr, n_spikes=None, percentile=None): return [-1, -m, +1, +m] +def _get_spike_clusters_rel(spike_clusters, spike_ids, cluster_ids): + # Relative spike clusters. + # NOTE: the order of the clusters in cluster_ids matters. + # It will influence the relative index of the clusters, which + # in return influence the depth. + spike_clusters = spike_clusters[spike_ids] + assert np.all(np.in1d(spike_clusters, cluster_ids)) + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + return spike_clusters_rel + + +def _get_depth(masks, + spike_clusters_rel=None, + n_clusters=None, + ): + n_spikes, n_channels = masks.shape + masks = np.atleast_2d(masks) + assert masks.ndim == 2 + depth = (-0.1 - (spike_clusters_rel[:, np.newaxis] + masks) / + float(n_clusters + 10.)) + depth[masks <= 0.25] = 0 + assert depth.shape == (n_spikes, n_channels) + return depth + + +def _get_color(masks, spike_clusters_rel=None, n_clusters=None): + n_spikes = len(masks) + assert masks.shape == (n_spikes,) + assert spike_clusters_rel.shape == (n_spikes,) + + # Fetch the features. + colors = _selected_clusters_colors(n_clusters) + + # Color as a function of the mask. + color = colors[spike_clusters_rel] + hsv = rgb_to_hsv(color[:, :3]) + # Change the saturation and value as a function of the mask. + hsv[:, 1] *= masks + hsv[:, 2] *= .5 * (1. + masks) + color = hsv_to_rgb(hsv) + color = np.c_[color, .5 * np.ones((n_spikes, 1))] + return color + + # ----------------------------------------------------------------------------- # Waveform view # ----------------------------------------------------------------------------- @@ -155,16 +199,12 @@ def on_select(self, cluster_ids, spike_ids): self._spike_ids = spike_ids # Relative spike clusters. - # NOTE: the order of the clusters in cluster_ids matters. - # It will influence the relative index of the clusters, which - # in return influence the depth. - spike_clusters = self.spike_clusters[spike_ids] - assert np.all(np.in1d(spike_clusters, cluster_ids)) - spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + spike_clusters_rel = _get_spike_clusters_rel(self.spike_clusters, + spike_ids, + cluster_ids) # Fetch the waveforms. w = self.waveforms[spike_ids] - colors = _selected_clusters_colors(n_clusters) t = _get_linear_x(n_spikes, self.n_samples) # Overlap. if self.overlap: @@ -172,28 +212,17 @@ def on_select(self, cluster_ids, spike_ids): (n_clusters - 1) / 2.) # Depth as a function of the cluster index and masks. - m = self.masks[spike_ids] - m = np.atleast_2d(m) - assert m.ndim == 2 - depth = (-0.1 - (spike_clusters_rel[:, np.newaxis] + m) / - float(n_clusters + 10.)) - depth[m <= 0.25] = 0 - assert m.shape == (n_spikes, self.n_channels) - assert depth.shape == (n_spikes, self.n_channels) + masks = self.masks[spike_ids] + depth = _get_depth(masks, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) # Plot all waveforms. - # TODO: optim: avoid the loop. + # OPTIM: avoid the loop. for ch in range(self.n_channels): - - # Color as a function of the mask. - color = colors[spike_clusters_rel] - hsv = rgb_to_hsv(color[:, :3]) - # Change the saturation and value as a function of the mask. - hsv[:, 1] *= m[:, ch] - hsv[:, 2] *= .5 * (1. + m[:, ch]) - color = hsv_to_rgb(hsv) - color = np.c_[color, .5 * np.ones((n_spikes, 1))] - + color = _get_color(masks[:, ch], + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) self._plots[ch].set_data(x=t, y=w[:, :, ch], color=color, depth=depth[:, ch], @@ -384,50 +413,6 @@ def _dimensions(x_channels, y_channels): return x_dim, y_dim -def _get_spike_clusters_rel(spike_clusters, spike_ids, cluster_ids): - # Relative spike clusters. - # NOTE: the order of the clusters in cluster_ids matters. - # It will influence the relative index of the clusters, which - # in return influence the depth. - spike_clusters = spike_clusters[spike_ids] - assert np.all(np.in1d(spike_clusters, cluster_ids)) - spike_clusters_rel = _index_of(spike_clusters, cluster_ids) - return spike_clusters_rel - - -def _get_depth(masks, - spike_clusters_rel=None, - n_clusters=None, - ): - n_spikes, n_channels = masks.shape - masks = np.atleast_2d(masks) - assert masks.ndim == 2 - depth = (-0.1 - (spike_clusters_rel[:, np.newaxis] + masks) / - float(n_clusters + 10.)) - depth[masks <= 0.25] = 0 - assert depth.shape == (n_spikes, n_channels) - return depth - - -def _get_color(masks, spike_clusters_rel=None, n_clusters=None): - n_spikes = len(masks) - assert masks.shape == (n_spikes,) - assert spike_clusters_rel.shape == (n_spikes,) - - # Fetch the features. - colors = _selected_clusters_colors(n_clusters) - - # Color as a function of the mask. - color = colors[spike_clusters_rel] - hsv = rgb_to_hsv(color[:, :3]) - # Change the saturation and value as a function of the mask. - hsv[:, 1] *= masks - hsv[:, 2] *= .5 * (1. + masks) - color = hsv_to_rgb(hsv) - color = np.c_[color, .5 * np.ones((n_spikes, 1))] - return color - - def _project_mask_depth(dim, masks, depth): n_spikes = masks.shape[0] if dim != 'time': From 921680054b27e9190146046c776020735d70374b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 16:35:29 +0100 Subject: [PATCH 0615/1059] WIP: refactor views --- phy/cluster/manual/views.py | 59 +++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 84946238f..96dbff213 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -93,28 +93,26 @@ def _get_spike_clusters_rel(spike_clusters, spike_ids, cluster_ids): return spike_clusters_rel -def _get_depth(masks, - spike_clusters_rel=None, - n_clusters=None, - ): - n_spikes, n_channels = masks.shape - masks = np.atleast_2d(masks) - assert masks.ndim == 2 - depth = (-0.1 - (spike_clusters_rel[:, np.newaxis] + masks) / +def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): + """Return the OpenGL z-depth of vertices as a function of the + mask and cluster index.""" + n_spikes = len(masks) + assert masks.shape == (n_spikes,) + depth = (-0.1 - (spike_clusters_rel + masks) / float(n_clusters + 10.)) depth[masks <= 0.25] = 0 - assert depth.shape == (n_spikes, n_channels) + assert depth.shape == (n_spikes,) return depth def _get_color(masks, spike_clusters_rel=None, n_clusters=None): + """Return the color of vertices as a function of the mask and + cluster index.""" n_spikes = len(masks) assert masks.shape == (n_spikes,) assert spike_clusters_rel.shape == (n_spikes,) - - # Fetch the features. + # Generate the colors. colors = _selected_clusters_colors(n_clusters) - # Color as a function of the mask. color = colors[spike_clusters_rel] hsv = rgb_to_hsv(color[:, :3]) @@ -213,19 +211,20 @@ def on_select(self, cluster_ids, spike_ids): # Depth as a function of the cluster index and masks. masks = self.masks[spike_ids] - depth = _get_depth(masks, - spike_clusters_rel=spike_clusters_rel, - n_clusters=n_clusters) # Plot all waveforms. # OPTIM: avoid the loop. for ch in range(self.n_channels): - color = _get_color(masks[:, ch], + m = masks[:, ch] + depth = _get_depth(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + color = _get_color(m, spike_clusters_rel=spike_clusters_rel, n_clusters=n_clusters) self._plots[ch].set_data(x=t, y=w[:, :, ch], color=color, - depth=depth[:, ch], + depth=depth, data_bounds=self.data_bounds, ) @@ -413,12 +412,14 @@ def _dimensions(x_channels, y_channels): return x_dim, y_dim -def _project_mask_depth(dim, masks, depth): +def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): n_spikes = masks.shape[0] if dim != 'time': ch, fet = dim m = masks[:, ch] - d = depth[:, ch] + d = _get_depth(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) else: m = np.ones(n_spikes) d = np.zeros(n_spikes) @@ -490,14 +491,10 @@ def on_select(self, cluster_ids, spike_ids): if n_spikes == 0: return - spike_clusters_rel = _get_spike_clusters_rel(self.spike_clusters, - spike_ids, - cluster_ids) - masks = self.masks[spike_ids] - depth = _get_depth(masks, - spike_clusters_rel=spike_clusters_rel, - n_clusters=n_clusters) + sc = _get_spike_clusters_rel(self.spike_clusters, + spike_ids, + cluster_ids) x_dim, y_dim = _dimensions(range(self.n_cols), range(self.n_cols)) @@ -510,14 +507,18 @@ def on_select(self, cluster_ids, spike_ids): x = self._get_feature(x_dim[i, j], spike_ids) y = self._get_feature(y_dim[i, j], spike_ids) - mx, dx = _project_mask_depth(x_dim[i, j], masks, depth) - my, dy = _project_mask_depth(y_dim[i, j], masks, depth) + mx, dx = _project_mask_depth(x_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + my, dy = _project_mask_depth(y_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) d = np.maximum(dx, dy) m = np.maximum(mx, my) color = _get_color(m, - spike_clusters_rel=spike_clusters_rel, + spike_clusters_rel=sc, n_clusters=n_clusters) self._plots[i, j].set_data(x=x, From f5cc30c8de042a1a8574ac1ff956025d1742158e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 16:57:42 +0100 Subject: [PATCH 0616/1059] Default show_shortcuts() action in GUI --- phy/cluster/manual/gui_component.py | 4 +--- phy/cluster/manual/views.py | 10 +++++++--- phy/gui/actions.py | 6 +++++- phy/gui/gui.py | 14 ++++++++++++-- phy/gui/tests/test_actions.py | 6 ++++++ 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index c92d8eb6c..99290a63c 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -175,7 +175,6 @@ class ManualClustering(object): # Misc. 'save': 'Save', - 'show_shortcuts': 'shift+h', 'undo': 'Undo', 'redo': 'Redo', } @@ -190,7 +189,7 @@ def __init__(self, self.gui = None self.n_spikes_max_per_cluster = n_spikes_max_per_cluster - # Load default shortcuts, and override any user shortcuts. + # Load default shortcuts, and override with any user shortcuts. self.shortcuts = self.default_shortcuts.copy() self.shortcuts.update(shortcuts or {}) @@ -279,7 +278,6 @@ def _create_actions(self, gui): self.actions.add(self.previous_best) # Others. - self.actions.add(self.actions.show_shortcuts) self.actions.add(self.undo) self.actions.add(self.redo) self.actions.add(self.save) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 96dbff213..0296ae78e 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -142,6 +142,7 @@ def __init__(self, masks=None, spike_clusters=None, channel_positions=None, + shortcuts=None, keys='interactive', ): """ @@ -151,6 +152,10 @@ def __init__(self, """ + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + self._cluster_ids = None self._spike_ids = None @@ -243,9 +248,8 @@ def attach(self, gui): gui.connect_(self.on_select) # gui.connect_(self.on_cluster) - # TODO: customizable shortcut too - self.actions = Actions(gui, default_shortcuts=self.default_shortcuts) - self.actions.add(self.toggle_waveform_overlap, alias='o') + self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions.add(self.toggle_waveform_overlap) def toggle_waveform_overlap(self): self.overlap = not self.overlap diff --git a/phy/gui/actions.py b/phy/gui/actions.py index b87a1ce19..064651670 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -66,7 +66,10 @@ def _get_shortcut_string(shortcut): if isinstance(shortcut, (tuple, list)): return ', '.join([_get_shortcut_string(s) for s in shortcut]) if isinstance(shortcut, string_types): - return shortcut.lower() + if hasattr(QKeySequence, shortcut): + shortcut = QKeySequence(getattr(QKeySequence, shortcut)) + else: + return shortcut.lower() assert isinstance(shortcut, QKeySequence) s = shortcut.toString() or '' return str(s).lower() @@ -141,6 +144,7 @@ def __init__(self, gui, default_shortcuts=None): self._aliases = {} self._default_shortcuts = default_shortcuts or {} self.gui = gui + gui.actions.append(self) # Create and attach snippets. self.snippets = Snippets(gui, self) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 956dff335..962479c6e 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -12,7 +12,7 @@ from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) -from .actions import Actions +from .actions import Actions, _show_shortcuts from phy.utils.event import EventEmitter from phy.utils import load_master_config from phy.utils.plugin import get_plugin @@ -119,9 +119,19 @@ def __init__(self, self._status_bar = QStatusBar() self.setStatusBar(self._status_bar) - # Default exit action. + # List of attached Actions instances. + self.actions = [] + + # Default actions. self.default_actions = Actions(self) + @self.default_actions.add(shortcut='HelpContents') + def show_shortcuts(): + shortcuts = {} + for actions in self.actions: + shortcuts.update(actions.shortcuts) + _show_shortcuts(shortcuts, self.name) + @self.default_actions.add(shortcut='Quit') def exit(): self.close() diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 0dcb98745..cac426772 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -22,6 +22,8 @@ #------------------------------------------------------------------------------ def test_shortcuts(qapp): + assert 'z' in _get_shortcut_string('Undo') + def _assert_shortcut(name, key=None): shortcut = _get_qkeysequence(name) s = _get_shortcut_string(shortcut) @@ -106,6 +108,10 @@ def press(): actions.press() assert _press == [0] + with captured_output() as (stdout, stderr): + gui.default_actions.show_shortcuts() + assert 'g\n' in stdout.getvalue() + def test_snippets_gui(qtbot, gui, actions): qtbot.addWidget(gui) From 5d6349ee3394421c2b4b01b45cd0a14446727957 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 17:18:59 +0100 Subject: [PATCH 0617/1059] WIP: add channel selection in feature view --- phy/cluster/manual/views.py | 70 ++++++++++++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 0296ae78e..111103f39 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -11,6 +11,7 @@ import numpy as np from matplotlib.colors import hsv_to_rgb, rgb_to_hsv +from six import string_types from phy.io.array import _index_of, _get_padded from phy.electrode.mea import linear_positions @@ -18,6 +19,7 @@ from phy.plot import (BoxedView, StackedView, GridView, _get_linear_x) from phy.plot.utils import _get_boxes +from phy.utils._types import _is_integer logger = logging.getLogger(__name__) @@ -386,8 +388,26 @@ def set_interval(self, interval): # Feature view # ----------------------------------------------------------------------------- -def _dimensions(x_channels, y_channels): - """Default dimensions matrix.""" +def _check_dimension(dim, n_channels, n_features): + """Check that a dimension is valid.""" + if _is_integer(dim): + dim = (dim, 0) + if isinstance(dim, tuple): + assert len(dim) == 2 + channel, feature = dim + assert _is_integer(channel) + assert _is_integer(feature) + assert 0 <= channel < n_channels + assert 0 <= feature < n_features + elif isinstance(dim, string_types): + assert dim == 'time' + elif dim: + raise ValueError('{0} should be (channel, feature) '.format(dim) + + 'or one of the extra features.') + + +def _dimensions_matrix(x_channels, y_channels): + """Dimensions matrix.""" # time, depth time, (x, 0) time, (y, 0) time, (z, 0) # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) @@ -397,7 +417,7 @@ def _dimensions(x_channels, y_channels): assert len(y_channels) == n y_dim = {} x_dim = {} - # TODO: depth + # TODO: extra feature like probe depth x_dim[0, 0] = 'time' y_dim[0, 0] = 'time' @@ -416,7 +436,41 @@ def _dimensions(x_channels, y_channels): return x_dim, y_dim +def _dimensions_for_clusters(cluster_ids, n_cols=None, + best_channels_func=None): + """Return the dimension matrix for the selected clusters.""" + n = len(cluster_ids) + if not n: + return {}, {} + best_channels_func = best_channels_func or (lambda _: range(n_cols)) + x_channels = best_channels_func(cluster_ids[min(1, n - 1)]) + y_channels = best_channels_func(cluster_ids[0]) + y_channels = y_channels[:n_cols - 1] + # For the x axis, remove the channels that already are in + # the y axis. + x_channels = [c for c in x_channels if c not in y_channels] + # Now, select the right number of channels in the x axis. + x_channels = x_channels[:n_cols - 1] + if len(x_channels) < n_cols - 1: + x_channels = y_channels + return _dimensions_matrix(x_channels, y_channels) + + +def _smart_dim(dim, n_features=None, prev_dim=None, prev_dim_other=None): + channel, feature = dim + prev_channel, prev_feature = prev_dim + # Scroll the feature if the channel is the same. + if prev_channel == channel: + feature = (prev_feature + 1) % n_features + # Scroll the feature if it is the same than in the other axis. + if (prev_dim_other != 'time' and + prev_dim_other == (channel, feature)): + feature = (feature + 1) % n_features + dim = (channel, feature) + + def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): + """Return the mask and depth vectors for a given dimension.""" n_spikes = masks.shape[0] if dim != 'time': ch, fet = dim @@ -500,8 +554,10 @@ def on_select(self, cluster_ids, spike_ids): spike_ids, cluster_ids) - x_dim, y_dim = _dimensions(range(self.n_cols), - range(self.n_cols)) + x_dim, y_dim = _dimensions_for_clusters(cluster_ids, + n_cols=self.n_cols, + # TODO + best_channels_func=None) # Plot all features. # TODO: optim: avoid the loop. @@ -537,6 +593,10 @@ def on_select(self, cluster_ids, spike_ids): self.update() +# ----------------------------------------------------------------------------- +# Correlogram view +# ----------------------------------------------------------------------------- + class CorrelogramView(GridView): def __init__(self, spike_samples=None, From 12c89f30721ca87885191240cef859b1df750c2a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 19:20:00 +0100 Subject: [PATCH 0618/1059] Improve correlograms() interface --- phy/stats/ccg.py | 106 ++++++++++++++++++------------------ phy/stats/tests/test_ccg.py | 41 +++++++------- 2 files changed, 73 insertions(+), 74 deletions(-) diff --git a/phy/stats/ccg.py b/phy/stats/ccg.py index 448f862af..98281a9f6 100644 --- a/phy/stats/ccg.py +++ b/phy/stats/ccg.py @@ -36,26 +36,49 @@ def _create_correlograms_array(n_clusters, winsize_bins): dtype=np.int32) -def correlograms(spike_samples, spike_clusters, - cluster_order=None, - binsize=None, winsize_bins=None): +def _symmetrize_correlograms(correlograms): + """Return the symmetrized version of the CCG arrays.""" + + n_clusters, _, n_bins = correlograms.shape + assert n_clusters == _ + + # We symmetrize c[i, j, 0]. + # This is necessary because the algorithm in correlograms() + # is sensitive to the order of identical spikes. + correlograms[..., 0] = np.maximum(correlograms[..., 0], + correlograms[..., 0].T) + + sym = correlograms[..., 1:][..., ::-1] + sym = np.transpose(sym, (1, 0, 2)) + + return np.dstack((sym, correlograms)) + + +def correlograms(spike_times, + spike_clusters, + cluster_ids=None, + sample_rate=1., + bin_size=None, + window_size=None, + symmetrize=True, + ): """Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`. Parameters ---------- - spike_samples : array-like - Spike times in samples (integers). + spike_times : array-like + Spike times in seconds. spike_clusters : array-like Spike-cluster mapping. - cluster_order : array-like + cluster_ids : array-like The list of unique clusters, in any order. That order will be used in the output array. - binsize : int - Number of time samples in one bin. - winsize_bins : int (odd number) - Number of bins in the window. + bin_size : float + Size of the bin, in seconds. + window_size : float + Size of the window, in seconds. Returns ------- @@ -64,37 +87,35 @@ def correlograms(spike_samples, spike_clusters, A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise CCGs. - Notes - ----- - - If winsize_samples is the (odd) number of time samples in the window - then: - - winsize_bins = 2 * ((winsize_samples // 2) // binsize) + 1 - assert winsize_bins % 2 == 1 - - For performance reasons, it is recommended to compute the CCGs on a subset - with only a few thousands or tens of thousands of spikes. - """ + assert sample_rate > 0. - spike_clusters = _as_array(spike_clusters) - spike_samples = _as_array(spike_samples) - if spike_samples.dtype in (np.int32, np.int64): - spike_samples = spike_samples.astype(np.uint64) + # Get the spike samples. + spike_times = np.asarray(spike_times, dtype=np.float64) + spike_samples = (spike_times * sample_rate).astype(np.int64) - assert spike_samples.dtype == np.uint64 + spike_clusters = _as_array(spike_clusters) assert spike_samples.ndim == 1 assert spike_samples.shape == spike_clusters.shape + # Find `binsize`. + bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds + binsize = int(sample_rate * bin_size) # in samples + assert binsize >= 1 + + # Find `winsize_bins`. + window_size = np.clip(window_size, 1e-5, 1e5) # in seconds + winsize_bins = 2 * int(.5 * window_size / bin_size) + 1 + + assert winsize_bins >= 1 assert winsize_bins % 2 == 1 # Take the cluster oder into account. - if cluster_order is None: + if cluster_ids is None: clusters = _unique(spike_clusters) else: - clusters = _as_array(cluster_order) + clusters = _as_array(cluster_ids) n_clusters = len(clusters) # Like spike_clusters, but with 0..n_clusters-1 indices. @@ -148,26 +169,7 @@ def correlograms(spike_samples, spike_clusters, np.arange(n_clusters), 0] = 0 - return correlograms - - -#------------------------------------------------------------------------------ -# Helper functions for CCG data structures -#------------------------------------------------------------------------------ - -def _symmetrize_correlograms(correlograms): - """Return the symmetrized version of the CCG arrays.""" - - n_clusters, _, n_bins = correlograms.shape - assert n_clusters == _ - - # We symmetrize c[i, j, 0]. - # This is necessary because the algorithm in correlograms() - # is sensitive to the order of identical spikes. - correlograms[..., 0] = np.maximum(correlograms[..., 0], - correlograms[..., 0].T) - - sym = correlograms[..., 1:][..., ::-1] - sym = np.transpose(sym, (1, 0, 2)) - - return np.dstack((sym, correlograms)) + if symmetrize: + return _symmetrize_correlograms(correlograms) + else: + return correlograms diff --git a/phy/stats/tests/test_ccg.py b/phy/stats/tests/test_ccg.py index e5eb6eeb2..a4dd8ccdd 100644 --- a/phy/stats/tests/test_ccg.py +++ b/phy/stats/tests/test_ccg.py @@ -11,8 +11,8 @@ from ..ccg import (_increment, _diff_shifted, - correlograms, _symmetrize_correlograms, + correlograms, ) @@ -30,15 +30,7 @@ def _random_data(max_cluster): def _ccg_params(): - # window = 50 ms - winsize_samples = 2 * (25 * 20) + 1 - # bin = 1 ms - binsize = 1 * 20 - # 51 bins - winsize_bins = 2 * ((winsize_samples // 2) // binsize) + 1 - assert winsize_bins % 2 == 1 - - return binsize, winsize_bins + return .001, .05 def test_utils(): @@ -78,8 +70,8 @@ def test_ccg_0(): c_expected[0, 1, 0] = 0 # This is a peculiarity of the algorithm. c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins, - cluster_order=[0, 1]) + bin_size=binsize, window_size=winsize_bins, + cluster_ids=[0, 1], symmetrize=False) ae(c, c_expected) @@ -95,7 +87,8 @@ def test_ccg_1(): c_expected[0, 0, 2] = 1 c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + symmetrize=False) ae(c, c_expected) @@ -106,7 +99,8 @@ def test_ccg_2(): binsize, winsize_bins = _ccg_params() c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) assert c.shape == (max_cluster, max_cluster, 26) @@ -118,14 +112,16 @@ def test_ccg_symmetry_time(): binsize, winsize_bins = _ccg_params() c0 = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) spike_samples_1 = np.cumsum(np.r_[np.arange(1), np.diff(spike_samples)[::-1]]) spike_samples_1 = spike_samples_1.astype(np.uint64) spike_clusters_1 = spike_clusters[::-1] c1 = correlograms(spike_samples_1, spike_clusters_1, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) # The ACGs are identical. ae(c0[0, 0], c1[0, 0]) @@ -143,11 +139,13 @@ def test_ccg_symmetry_clusters(): binsize, winsize_bins = _ccg_params() c0 = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) spike_clusters_1 = 1 - spike_clusters c1 = correlograms(spike_samples, spike_clusters_1, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) # The ACGs are identical. ae(c0[0, 0], c1[1, 1]) @@ -162,10 +160,9 @@ def test_symmetrize_correlograms(): spike_samples, spike_clusters = _random_data(3) binsize, winsize_bins = _ccg_params() - c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) - - sym = _symmetrize_correlograms(c) + sym = correlograms(spike_samples, spike_clusters, + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000) assert sym.shape == (3, 3, 51) # The ACG are reversed. From aa12b37be5618c51ea7bcb9c96b97e37e6646e9d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 19:20:16 +0100 Subject: [PATCH 0619/1059] Flakify --- phy/stats/tests/test_ccg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/stats/tests/test_ccg.py b/phy/stats/tests/test_ccg.py index a4dd8ccdd..2c5affe7c 100644 --- a/phy/stats/tests/test_ccg.py +++ b/phy/stats/tests/test_ccg.py @@ -11,7 +11,6 @@ from ..ccg import (_increment, _diff_shifted, - _symmetrize_correlograms, correlograms, ) From 51138774a0f8218444c3cc1f8568ed5b2d555439 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 20:13:58 +0100 Subject: [PATCH 0620/1059] Export functions in stats --- phy/stats/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/stats/__init__.py b/phy/stats/__init__.py index 6c9d646f2..67522338b 100644 --- a/phy/stats/__init__.py +++ b/phy/stats/__init__.py @@ -2,3 +2,5 @@ # flake8: noqa """Statistics functions.""" + +from .ccg import correlograms From f3630aa737439f01aa35dc4bbb696ec2090635ca Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 22:11:37 +0100 Subject: [PATCH 0621/1059] Correlogram view --- phy/cluster/manual/tests/test_views.py | 45 ++++++++++++++++- phy/cluster/manual/views.py | 68 +++++++++++++++++++++++++- phy/plot/glsl/histogram.vert | 4 +- phy/plot/plot.py | 33 ++++++++----- phy/plot/tests/test_visuals.py | 2 +- phy/plot/utils.py | 6 ++- phy/plot/visuals.py | 20 ++++---- 7 files changed, 150 insertions(+), 28 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index c9c7a3e4e..75d175d62 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -18,7 +18,8 @@ artificial_traces, ) from phy.electrode.mea import staggered_positions -from ..views import WaveformView, FeatureView, TraceView, _extract_wave +from ..views import (WaveformView, FeatureView, CorrelogramView, TraceView, + _extract_wave) #------------------------------------------------------------------------------ @@ -177,3 +178,45 @@ def test_feature_view(qtbot): # qtbot.stop() v.close() + + +#------------------------------------------------------------------------------ +# Test correlogram view +#------------------------------------------------------------------------------ + +def test_correlogram_view(qtbot): + n_spikes = 50 + n_clusters = 2 + sample_rate = 20000. + bin_size = 1e-3 + window_size = 50e-3 + + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + spike_times = artificial_spike_samples(n_spikes) / sample_rate + + # Create the view. + v = CorrelogramView(spike_times=spike_times, + spike_clusters=spike_clusters, + sample_rate=sample_rate, + bin_size=bin_size, + window_size=window_size, + excerpt_size=None, + n_excerpts=None, + ) + + # Select some spikes. + spike_ids = np.arange(n_spikes) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # Show the view. + v.show() + qtbot.waitForWindowShown(v.native) + + # Select other spikes. + spike_ids = np.arange(2, 10) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # qtbot.stop() + v.close() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 111103f39..a6f7459f9 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -19,6 +19,7 @@ from phy.plot import (BoxedView, StackedView, GridView, _get_linear_x) from phy.plot.utils import _get_boxes +from phy.stats import correlograms from phy.utils._types import _is_integer logger = logging.getLogger(__name__) @@ -599,11 +600,74 @@ def on_select(self, cluster_ids, spike_ids): class CorrelogramView(GridView): def __init__(self, - spike_samples=None, spike_times=None, + spike_clusters=None, + sample_rate=None, bin_size=None, window_size=None, excerpt_size=None, n_excerpts=None, + keys='interactive', ): - pass + + assert sample_rate > 0 + self.sample_rate = sample_rate + + assert bin_size > 0 + self.bin_size = bin_size + + assert window_size > 0 + self.window_size = window_size + + # TODO: excerpt + + self.spike_times = np.asarray(spike_times) + self.n_spikes, = self.spike_times.shape + + # Initialize the view. + self.n_cols = 2 # TODO: dynamic grid shape in interact + super(CorrelogramView, self).__init__(self.n_cols, self.n_cols, + keys=keys) + + # Spike clusters. + assert spike_clusters.shape == (self.n_spikes,) + self.spike_clusters = spike_clusters + + # Initialize the subplots. + self._plots = {(i, j): self[i, j].hist(hist=[]) + for i in range(self.n_cols) + for j in range(self.n_cols) + } + self.build() + self.update() + + def on_select(self, cluster_ids, spike_ids): + n_clusters = len(cluster_ids) + n_spikes = len(spike_ids) + if n_spikes == 0: + return + + ccg = correlograms(self.spike_times, + self.spike_clusters, + cluster_ids=cluster_ids, + sample_rate=self.sample_rate, + bin_size=self.bin_size, + window_size=self.window_size, + ) + + lim = ccg.max() + + colors = _selected_clusters_colors(n_clusters) + + for i in range(n_clusters): + for j in range(n_clusters): + hist = ccg[i, j, :] + color = colors[i] if i == j else np.ones(3) + color = np.hstack((color, [1])) + self._plots[i, j].set_data(hist=hist, + color=color, + ylim=[lim], + ) + + self.build() + self.update() diff --git a/phy/plot/glsl/histogram.vert b/phy/plot/glsl/histogram.vert index c4055ebbd..790471553 100644 --- a/phy/plot/glsl/histogram.vert +++ b/phy/plot/glsl/histogram.vert @@ -3,7 +3,7 @@ attribute vec2 a_position; attribute float a_hist_index; // 0..n_hists-1 -uniform sampler2D u_hist_colors; +uniform sampler2D u_color; uniform sampler2D u_hist_bounds; uniform float n_hists; @@ -17,6 +17,6 @@ void main() { hist_bounds = hist_bounds * 10.; // NOTE: avoid texture clipping gl_Position = transform(a_position); - v_color = fetch_texture(a_hist_index, u_hist_colors, n_hists); + v_color = fetch_texture(a_hist_index, u_color, n_hists); v_hist_index = a_hist_index; } diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 6126730ca..7c7f128fd 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -70,15 +70,25 @@ def _prepare_plot(x, y, color=None, depth=None, data_bounds=None): return dict(x=x, y=y, color=color, depth=depth, data_bounds=data_bounds) -def _prepare_hist(data, color=None): - # Validate data. - if data.ndim == 1: - data = data[np.newaxis, :] - assert data.ndim == 2 - n_hists, n_samples = data.shape +def _prepare_hist(hist, ylim=None, color=None): + hist = np.asarray(hist) + # Validate hist. + if hist.ndim == 1: + hist = hist[np.newaxis, :] + assert hist.ndim == 2 + n_hists, n_samples = hist.shape + # y-limit + if ylim is None: + ylim = hist.max() if hist.size else 1. + if not hasattr(ylim, '__len__'): + ylim = [ylim] + ylim = np.atleast_2d(ylim) + if len(ylim) == 1: + ylim = np.tile(ylim, (n_hists, 1)) + assert len(ylim) == n_hists # Get the colors. color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) - return dict(data=data, color=color) + return dict(hist=hist, ylim=ylim, color=color) def _prepare_box_index(box_index, n): @@ -131,13 +141,14 @@ def _build_histogram(items): ac = Accumulator() for item in items: - n = item.data.data.size - ac['data'] = item.data.data - ac['hist_colors'] = item.data.color + n = item.data.hist.size + ac['hist'] = item.data.hist + ac['color'] = item.data.color + ac['ylim'] = item.data.ylim # NOTE: the `6 * ` comes from the histogram tesselation. ac['box_index'] = _prepare_box_index(item.box_index, 6 * n) - return (dict(hist=ac['data'], hist_colors=ac['hist_colors']), + return (dict(hist=ac['hist'], ylim=ac['ylim'], color=ac['color']), ac['box_index']) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 38bae2a40..bea0f45ca 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -142,7 +142,7 @@ def test_histogram_2(qtbot, canvas_pz): c[:, 3] = 1 _test_visual(qtbot, canvas_pz, HistogramVisual(), - hist=hist, hist_colors=c, hist_lims=2 * np.ones(n_hists)) + hist=hist, color=c, ylim=2 * np.ones(n_hists)) #------------------------------------------------------------------------------ diff --git a/phy/plot/utils.py b/phy/plot/utils.py index deb82b4da..c5ef290f6 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -72,8 +72,10 @@ def _enable_depth_mask(): def _get_texture(arr, default, n_items, from_bounds): - """Prepare data to be uploaded as a texture, with casting to uint8. + """Prepare data to be uploaded as a texture. + The from_bounds must be specified. + """ if not hasattr(default, '__len__'): # pragma: no cover default = [default] @@ -91,7 +93,7 @@ def _get_texture(arr, default, n_items, from_bounds): m, M = map(float, from_bounds) assert np.all(arr >= m) assert np.all(arr <= M) - arr = 1. * (arr - m) / (M - m) + arr = (arr - m) / (M - m) assert np.all(arr >= 0) assert np.all(arr <= 1.) arr = arr.astype(np.float32) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 4662f47ae..ac2a52157 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -197,8 +197,8 @@ def get_transforms(self): def set_data(self, hist=None, - hist_lims=None, - hist_colors=None, + ylim=None, + color=None, ): hist = _check_pos_2D(hist) n_hists, n_bins = hist.shape @@ -206,7 +206,8 @@ def set_data(self, # Store n_bins for get_transforms(). self.n_bins = n_bins - # Generate hist_max. + # NOTE: this must be set *before* `apply_cpu_transforms` such + # that the histogram is correctly normalized. self.hist_max = _get_hist_max(hist) # Set the transformed position. @@ -220,15 +221,16 @@ def set_data(self, self.program['a_hist_index'] = _get_index(n_hists, n_bins * 6, n) # Hist colors. - self.program['u_hist_colors'] = _get_texture(hist_colors, - self._default_color, - n_hists, [0, 1]) + self.program['u_color'] = _get_texture(color, + self._default_color, + n_hists, [0, 1]) # Hist bounds. + assert ylim is None or len(ylim) == n_hists hist_bounds = np.c_[np.zeros((n_hists, 2)), - np.ones(n_hists), - hist_lims] if hist_lims is not None else None - hist_bounds = _get_texture(hist_bounds, [0, 0, 1, self.hist_max], + np.ones((n_hists, 1)), + ylim / self.hist_max] if ylim is not None else None + hist_bounds = _get_texture(hist_bounds, [0, 0, 1, 1], n_hists, [0, 10]) self.program['u_hist_bounds'] = Texture2D(hist_bounds) self.program['n_hists'] = n_hists From 031f526446d9d8a252cab9455f09212a05635280 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 22:28:34 +0100 Subject: [PATCH 0622/1059] WIP: grid with dynamic shape --- phy/cluster/manual/views.py | 7 ++++--- phy/plot/interact.py | 17 ++++++++--------- phy/plot/plot.py | 6 +++--- phy/plot/tests/test_interact.py | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a6f7459f9..053daed4d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -500,10 +500,11 @@ def __init__(self, assert features.ndim == 3 self.n_spikes, self.n_channels, self.n_features = features.shape self.n_cols = self.n_features + 1 + self.shape = (self.n_cols, self.n_cols) self.features = features # Initialize the view. - super(FeatureView, self).__init__(self.n_cols, self.n_cols, keys=keys) + super(FeatureView, self).__init__(self.shape, keys=keys) # Feature normalization. self.data_bounds = _get_data_bounds(features, @@ -626,8 +627,8 @@ def __init__(self, # Initialize the view. self.n_cols = 2 # TODO: dynamic grid shape in interact - super(CorrelogramView, self).__init__(self.n_cols, self.n_cols, - keys=keys) + self.shape = (self.n_cols, self.n_cols) + super(CorrelogramView, self).__init__(self.shape, keys=keys) # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 7171d7d71..7fcae5c4a 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -30,26 +30,25 @@ class Grid(BaseInteract): Parameters ---------- - n_rows : int - Number of rows in the grid. - n_cols : int - Number of cols in the grid. + shape : tuple or str + Number of rows, cols in the grid. box_var : str Name of the GLSL variable with the box index. """ - def __init__(self, n_rows, n_cols, box_var=None): + def __init__(self, shape, box_var=None): super(Grid, self).__init__() self._zoom = 1. # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' - self.shape = (n_rows, n_cols) - assert len(self.shape) == 2 - assert self.shape[0] >= 1 - assert self.shape[1] >= 1 + self.shape = shape + if isinstance(self.shape, tuple): + assert len(self.shape) == 2 + assert self.shape[0] >= 1 + assert self.shape[1] >= 1 def get_shader_declarations(self): return ('attribute vec2 a_box_index;\n' diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 7c7f128fd..8832a613b 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -324,10 +324,10 @@ def build(self): class GridView(BaseView): """A 2D grid with clipping.""" - def __init__(self, n_rows, n_cols, **kwargs): - self.n_rows, self.n_cols = n_rows, n_cols + def __init__(self, shape, **kwargs): + self.n_rows, self.n_cols = shape pz = PanZoom(aspect=None, constrain_bounds=NDC) - interacts = [Grid(n_rows, n_cols), pz] + interacts = [Grid(shape), pz] super(GridView, self).__init__(interacts, **kwargs) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index e92f984cc..0988eedf9 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -83,7 +83,7 @@ def test_grid_1(qtbot, canvas): box_index = [[i, j] for i, j in product(range(2), range(3))] box_index = np.repeat(box_index, n, axis=0) - grid = Grid(2, 3) + grid = Grid((2, 3)) _create_visual(qtbot, canvas, grid, box_index) # No effect without modifiers. From ae1c3460557d0ca4f38d469ccb81b95568436c2c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 22:35:29 +0100 Subject: [PATCH 0623/1059] Add tests for dynamic grid interact --- phy/plot/tests/test_interact.py | 17 +++++++++++++++++ phy/plot/tests/test_plot.py | 8 ++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 0988eedf9..56400f410 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -109,6 +109,23 @@ def test_grid_1(qtbot, canvas): # qtbot.stop() +def test_grid_2(qtbot, canvas): + + n = 1000 + + box_index = [[i, j] for i, j in product(range(2), range(3))] + box_index = np.repeat(box_index, n, axis=0) + + class MyGrid(Grid): + def get_pre_transforms(self): + return 'vec2 u_shape = vec2(3, 3);' + + grid = MyGrid('u_shape') + _create_visual(qtbot, canvas, grid, box_index) + + # qtbot.stop() + + def test_boxed_1(qtbot, canvas): n = 6 diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 06353335c..1ca1334c3 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -32,7 +32,7 @@ def _show(qtbot, view, stop=False): #------------------------------------------------------------------------------ def test_grid_scatter(qtbot): - view = GridView(2, 3) + view = GridView((2, 3)) n = 1000 assert isinstance(view.panzoom, PanZoom) @@ -58,7 +58,7 @@ def test_grid_scatter(qtbot): def test_grid_plot(qtbot): - view = GridView(1, 2) + view = GridView((1, 2)) n_plots, n_samples = 10, 50 x = _get_linear_x(n_plots, n_samples) @@ -71,7 +71,7 @@ def test_grid_plot(qtbot): def test_grid_hist(qtbot): - view = GridView(3, 3) + view = GridView((3, 3)) hist = np.random.rand(3, 3, 20) @@ -84,7 +84,7 @@ def test_grid_hist(qtbot): def test_grid_complete(qtbot): - view = GridView(2, 2) + view = GridView((2, 2)) t = _get_linear_x(1, 1000).ravel() view[0, 0].scatter(*np.random.randn(2, 100)) From 09ced0c037f7b5ac9696961f6f0bceee59fba4f6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 20 Nov 2015 22:47:57 +0100 Subject: [PATCH 0624/1059] Export objects in phy.plot --- phy/plot/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 10b1e489b..88e2dbe71 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -12,7 +12,10 @@ from vispy import config +from .interact import Grid, Stacked, Boxed from .plot import GridView, BoxedView, StackedView # noqa +from .transform import Translate, Scale, Range, Subplot, NDC +from .panzoom import PanZoom from.visuals import _get_linear_x From aae81a9383ad61cf8c503f5fbed70e08a0e9ded0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 23 Nov 2015 17:56:01 +0100 Subject: [PATCH 0625/1059] WIP: CCG view --- phy/cluster/manual/views.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 053daed4d..77cf14af4 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -626,7 +626,7 @@ def __init__(self, self.n_spikes, = self.spike_times.shape # Initialize the view. - self.n_cols = 2 # TODO: dynamic grid shape in interact + self.n_cols = 1 # TODO: dynamic grid shape in interact self.shape = (self.n_cols, self.n_cols) super(CorrelogramView, self).__init__(self.shape, keys=keys) @@ -648,8 +648,12 @@ def on_select(self, cluster_ids, spike_ids): if n_spikes == 0: return - ccg = correlograms(self.spike_times, - self.spike_clusters, + # TODO: excerpt + ind = np.in1d(self.spike_clusters, cluster_ids) + st = self.spike_times[ind] + sc = self.spike_clusters[ind] + + ccg = correlograms(st, sc, cluster_ids=cluster_ids, sample_rate=self.sample_rate, bin_size=self.bin_size, @@ -672,3 +676,17 @@ def on_select(self, cluster_ids, spike_ids): self.build() self.update() + + def attach(self, gui): + """Attach the view to the GUI.""" + + # Disable keyboard pan so that we can use arrows as global shortcuts + # in the GUI. + self.panzoom.enable_keyboard_pan = False + + gui.add_view(self) + + gui.connect_(self.on_select) + # gui.connect_(self.on_cluster) + + # self.actions = Actions(gui, default_shortcuts=self.shortcuts) From a9786c3cd895ff8e86ae28764b937485fa02c5cb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 23 Nov 2015 17:56:53 +0100 Subject: [PATCH 0626/1059] Fix --- phy/cluster/manual/tests/test_views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 75d175d62..c4e803925 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -186,7 +186,7 @@ def test_feature_view(qtbot): def test_correlogram_view(qtbot): n_spikes = 50 - n_clusters = 2 + n_clusters = 1 sample_rate = 20000. bin_size = 1e-3 window_size = 50e-3 From 2fdeddbcedf2e14497cbf14b854776cfb02cff86 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 15:31:09 +0100 Subject: [PATCH 0627/1059] WIP: simplify TransformChain implementation --- phy/plot/tests/test_transform.py | 14 +++++------ phy/plot/transform.py | 43 ++++++++++---------------------- 2 files changed, 19 insertions(+), 38 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index e54b50b78..7a32cc0a6 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -14,7 +14,7 @@ from pytest import yield_fixture from ..transform import (_glslify, pixels_to_ndc, - Translate, Scale, Range, Clip, Subplot, GPU, + Translate, Scale, Range, Clip, Subplot, TransformChain, ) @@ -161,11 +161,10 @@ def array(): def test_transform_chain_empty(array): - t = TransformChain([]) + t = TransformChain() assert t.cpu_transforms == [] assert t.gpu_transforms == [] - assert t.get('GPU') is None ae(t.apply(array), array) @@ -197,11 +196,10 @@ def test_transform_chain_two(array): def test_transform_chain_complete(array): t = TransformChain([Scale(scale=.5), Scale(scale=2.)]) - t.add([Range(from_bounds=[-3, -3, 1, 1]), - GPU(), - Clip(), - Subplot(shape='u_shape', index='a_box_index'), - ]) + t.add_cpu_transforms([Range(from_bounds=[-3, -3, 1, 1])]) + t.add_gpu_transforms([Clip(), + Subplot(shape='u_shape', index='a_box_index'), + ]) assert len(t.cpu_transforms) == 3 assert len(t.gpu_transforms) == 2 diff --git a/phy/plot/transform.py b/phy/plot/transform.py index f80f250c5..514be39b3 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -225,43 +225,26 @@ def glsl(self, var, shape=None, index=None): # Transform chains #------------------------------------------------------------------------------ -class GPU(object): - """Used to specify that the next transforms in the chain happen on - the GPU.""" - pass - - class TransformChain(object): """A linear sequence of transforms that happen on the CPU and GPU.""" - def __init__(self, transforms=None): + def __init__(self, cpu_transforms=None, gpu_transforms=None): self.transformed_var_name = None - self.transforms = [] - self.add(transforms) - - def _index_of_gpu(self): - classes = [t.__class__.__name__ for t in self.transforms] - return classes.index('GPU') if 'GPU' in classes else None - - @property - def cpu_transforms(self): - """All transforms until `GPU()`.""" - i = self._index_of_gpu() - return self.transforms[:i] if i is not None else self.transforms - - @property - def gpu_transforms(self): - """All transforms after `GPU()`.""" - i = self._index_of_gpu() - return self.transforms[i + 1:] if i is not None else [] - - def add(self, transforms): + self.cpu_transforms = [] + self.gpu_transforms = [] + self.add_cpu_transforms(cpu_transforms or []) + self.add_gpu_transforms(gpu_transforms or []) + + def add_cpu_transforms(self, transforms): + """Add some transforms.""" + self.cpu_transforms.extend(transforms or []) + + def add_gpu_transforms(self, transforms): """Add some transforms.""" - for t in transforms: - self.transforms.append(t) + self.gpu_transforms.extend(transforms or []) def get(self, class_name): """Get a transform in the chain from its name.""" - for transform in self.transforms: + for transform in self.cpu_transforms + self.gpu_transforms: if transform.__class__.__name__ == class_name: return transform From dbf11a804fbcf9b29b0e600eb34a831cc2a726a9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 15:42:48 +0100 Subject: [PATCH 0628/1059] Change name of TransformChain method --- phy/plot/tests/test_transform.py | 4 ++-- phy/plot/transform.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 7a32cc0a6..79eba6e75 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -196,8 +196,8 @@ def test_transform_chain_two(array): def test_transform_chain_complete(array): t = TransformChain([Scale(scale=.5), Scale(scale=2.)]) - t.add_cpu_transforms([Range(from_bounds=[-3, -3, 1, 1])]) - t.add_gpu_transforms([Clip(), + t.add_on_cpu([Range(from_bounds=[-3, -3, 1, 1])]) + t.add_on_gpu([Clip(), Subplot(shape='u_shape', index='a_box_index'), ]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 514be39b3..e5db0f7be 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -231,14 +231,14 @@ def __init__(self, cpu_transforms=None, gpu_transforms=None): self.transformed_var_name = None self.cpu_transforms = [] self.gpu_transforms = [] - self.add_cpu_transforms(cpu_transforms or []) - self.add_gpu_transforms(gpu_transforms or []) + self.add_on_cpu(cpu_transforms or []) + self.add_on_gpu(gpu_transforms or []) - def add_cpu_transforms(self, transforms): + def add_on_cpu(self, transforms): """Add some transforms.""" self.cpu_transforms.extend(transforms or []) - def add_gpu_transforms(self, transforms): + def add_on_gpu(self, transforms): """Add some transforms.""" self.gpu_transforms.extend(transforms or []) From 8990c82b4a4ccea4a34b0541b6f2dfe716adb1f4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 15:51:33 +0100 Subject: [PATCH 0629/1059] WIP: refactor visuals --- phy/plot/base.py | 78 ++++++++++++--------------- phy/plot/tests/test_base.py | 105 +++++++++++++++++------------------- phy/plot/transform.py | 11 ++++ phy/plot/visuals.py | 12 ++--- 4 files changed, 101 insertions(+), 105 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 2154ac87c..aea2131e7 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -7,13 +7,14 @@ # Imports #------------------------------------------------------------------------------ +from collections import defaultdict import logging import re from vispy import gloo from vispy.app import Canvas -from .transform import TransformChain, GPU, Clip +from .transform import TransformChain, Clip from .utils import _load_shader from phy.utils import EventEmitter @@ -38,50 +39,47 @@ class BaseVisual(object): It is rendered with a single pass of a single gloo program with a single type of GL primitive. - Derived classes must implement: - - * `gl_primitive_type`: `lines`, `points`, etc. - * `get_shaders()`: return the vertex and fragment shaders, or just - `shader_name` for built-in shaders - * `get_transforms()`: return a list of `Transform` instances, which - * `get_post_transforms()`: return a GLSL snippet to insert after - all transforms in the vertex shader. - * `set_data()`: has access to `self.program`. Must be called after - `attach()`. - """ - gl_primitive_type = None - shader_name = None - def __init__(self): + self.gl_primitive_type = None + self.transforms = TransformChain() + self._to_insert = defaultdict(list) # This will be set by attach(). self.program = None - # To override + # Visual definition # ------------------------------------------------------------------------- - def get_shaders(self): - """Return the vertex and fragment shader code.""" - assert self.shader_name - return (_load_shader(self.shader_name + '.vert'), - _load_shader(self.shader_name + '.frag')) + def set_shader(self, name): + self.vertex_shader = _load_shader(name + '.vert') + self.fragment_shader = _load_shader(name + '.frag') - def get_transforms(self): - """Return the list of transforms for the visual. + def set_primitive_type(self, primitive_type): + self.gl_primitive_type = primitive_type - There needs to be one and exactly one instance of `GPU()`. + # Shader insertion + # ------------------------------------------------------------------------- - """ - return [GPU()] + def _insert(self, shader_type, glsl, location): + assert location in ( + 'header', + 'before_transforms', + 'transforms', + 'after_transforms', + ) + self._to_insert[shader_type, location].append(glsl) - def get_post_transforms(self): - """Return a GLSL snippet to insert after all transforms in the - vertex shader. + def insert_vert(self, glsl, location): + self._insert('vert', glsl, location) - The snippet should modify `gl_Position`. + def insert_frag(self, glsl, location): + self._insert('frag', glsl, location) - """ - return '' + def get_inserts(self, shader_type, location): + return '\n'.join(self._to_insert[shader_type, location]) + + # To override + # ------------------------------------------------------------------------- def set_data(self): """Set data to the program. @@ -95,9 +93,6 @@ def set_data(self): # Public methods # ------------------------------------------------------------------------- - def apply_cpu_transforms(self, data): - return TransformChain(self.get_transforms()).apply(data) - def attach(self, canvas): """Attach the visual to a canvas. @@ -259,8 +254,6 @@ def insert_glsl(transform_chain, vertex, fragment, pre_transforms='', post_transforms=''): """Generate the GLSL code of the transform chain.""" - # TODO: move this to base.py - # Find the place where to insert the GLSL snippet. # This is "gl_Position = transform(data_var_name);" where # data_var_name is typically an attribute. @@ -332,19 +325,18 @@ def build_program(visual, interacts=()): # Build the transform chain using the visuals transforms first, # then the interact's transforms. - transforms = visual.get_transforms() + transforms = visual.transforms for interact in interacts: - transforms.extend(interact.get_transforms()) - transform_chain = TransformChain(transforms) + transforms += TransformChain(interact.get_transforms()) logger.debug("Build the program of `%s`.", visual.__class__.__name__) # Insert the interact's GLSL into the shaders. - vertex, fragment = visual.get_shaders() + vertex, fragment = visual.vertex_shader, visual.fragment_shader # Get the GLSL snippet to insert before the transformations. pre = '\n'.join(interact.get_pre_transforms() for interact in interacts) # GLSL snippet to insert after all transformations. - post = visual.get_post_transforms() - vertex, fragment = insert_glsl(transform_chain, vertex, fragment, + post = visual.get_inserts('vert', 'after_transforms') + vertex, fragment = insert_glsl(transforms, vertex, fragment, pre, post) # Insert shader declarations using the interacts (if any). diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 6b33b23da..1543ad31e 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -13,7 +13,7 @@ from ..base import BaseVisual, BaseInteract, insert_glsl from ..transform import (subplot_bounds, Translate, Scale, Range, - Clip, Subplot, GPU, TransformChain) + Clip, Subplot, TransformChain) #------------------------------------------------------------------------------ @@ -23,8 +23,11 @@ def test_visual_shader_name(qtbot, canvas): """Test a BaseVisual with a shader name.""" class TestVisual(BaseVisual): - shader_name = 'simple' - gl_primitive_type = 'lines' + + def __init__(self): + super(TestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] @@ -45,21 +48,21 @@ def test_base_visual(qtbot, canvas): """Test a BaseVisual with custom shaders.""" class TestVisual(BaseVisual): - vertex = """ - attribute vec2 a_position; - void main() { - gl_Position = vec4(a_position.xy, 0, 1); - } - """ - fragment = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ - gl_primitive_type = 'lines' - def get_shaders(self): - return self.vertex, self.fragment + def __init__(self): + super(TestVisual, self).__init__() + self.vertex_shader = """ + attribute vec2 a_position; + void main() { + gl_Position = vec4(a_position.xy, 0, 1); + } + """ + self.fragment_shader = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + self.set_primitive_type('lines') def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] @@ -91,11 +94,11 @@ def test_base_interact(): def test_no_interact(qtbot, canvas): """Test a BaseVisual with a CPU transform and no interact.""" class TestVisual(BaseVisual): - shader_name = 'simple' - gl_primitive_type = 'lines' - - def get_transforms(self): - return [Scale(scale=(.5, 1))] + def __init__(self): + super(TestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + self.transforms.add_on_cpu(Scale(scale=(.5, 1))) def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] @@ -121,40 +124,31 @@ def test_interact(qtbot, canvas): """ class TestVisual(BaseVisual): - vertex = """ - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - gl_PointSize = 2.0; - } + def __init__(self): + super(TestVisual, self).__init__() + self.vertex_shader = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + gl_PointSize = 2.0; + } """ - fragment = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ - gl_primitive_type = 'points' - - def get_shaders(self): - return self.vertex, self.fragment - - def get_transforms(self): - return [Scale(scale=(.1, .1)), - Translate(translate=(-1, -1)), - GPU(), - Range(from_bounds=(-1, -1, 1, 1), - to_bounds=(-1.5, -1.5, 1.5, 1.5), - ), - ] - - def get_post_transforms(self): - return """ - gl_Position.y += 1; + self.fragment_shader = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } """ + self.set_primitive_type('points') + self.transforms.add_on_cpu(Scale(scale=(.1, .1))) + self.transforms.add_on_cpu(Translate(translate=(-1, -1))) + self.transforms.add_on_cpu(Range(from_bounds=(-1, -1, 1, 1), + to_bounds=(-1.5, -1.5, 1.5, 1.5), + )) + self.insert_vert("""gl_Position.y += 1;""", 'after_transforms') def set_data(self): data = np.random.uniform(0, 20, (1000, 2)).astype(np.float32) - self.program['a_position'] = self.apply_cpu_transforms(data) + self.program['a_position'] = self.transforms.apply(data) class TestInteract(BaseInteract): def get_transforms(self): @@ -179,11 +173,10 @@ def get_transforms(self): def test_transform_chain_complete(): t = TransformChain([Scale(scale=.5), Scale(scale=2.)]) - t.add([Range(from_bounds=[-3, -3, 1, 1]), - GPU(), - Clip(), - Subplot(shape='u_shape', index='a_box_index'), - ]) + t.add_on_cpu([Range(from_bounds=[-3, -3, 1, 1])]) + t.add_on_gpu([Clip(), + Subplot(shape='u_shape', index='a_box_index'), + ]) vs = dedent(""" attribute vec2 a_position; diff --git a/phy/plot/transform.py b/phy/plot/transform.py index e5db0f7be..22b927f35 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -236,10 +236,14 @@ def __init__(self, cpu_transforms=None, gpu_transforms=None): def add_on_cpu(self, transforms): """Add some transforms.""" + if not isinstance(transforms, list): + transforms = [transforms] self.cpu_transforms.extend(transforms or []) def add_on_gpu(self, transforms): """Add some transforms.""" + if not isinstance(transforms, list): + transforms = [transforms] self.gpu_transforms.extend(transforms or []) def get(self, class_name): @@ -253,3 +257,10 @@ def apply(self, arr): for t in self.cpu_transforms: arr = t.apply(arr) return arr + + def __add__(self, tc): + assert isinstance(tc, TransformChain) + assert tc.transformed_var_name == self.transformed_var_name + self.cpu_transforms.extend(tc.cpu_transforms) + self.gpu_transforms.extend(tc.gpu_transforms) + return self diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index ac2a52157..25a38420c 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -11,7 +11,7 @@ from vispy.gloo import Texture2D from .base import BaseVisual -from .transform import Range, GPU, NDC +from .transform import Range, NDC from .utils import (_enable_depth_mask, _tesselate_histogram, _get_texture, @@ -103,7 +103,7 @@ def set_data(self, # Set the data bounds from the data. self.data_bounds = _get_data_bounds(data_bounds, pos) - pos_tr = self.apply_cpu_transforms(pos) + pos_tr = self.transforms.apply(pos) self.program['a_position'] = _get_pos_depth(pos_tr, depth) self.program['a_size'] = _get_array(size, (n, 1), self._default_marker_size) @@ -158,7 +158,7 @@ def set_data(self, self.data_bounds = _get_data_bounds(data_bounds, pos) # Set the transformed position. - pos_tr = self.apply_cpu_transforms(pos) + pos_tr = self.transforms.apply(pos) # Depth. depth = _get_array(depth, (n_signals,), 0) @@ -212,7 +212,7 @@ def set_data(self, # Set the transformed position. pos = np.vstack(_tesselate_histogram(row) for row in hist) - pos_tr = self.apply_cpu_transforms(pos) + pos_tr = self.transforms.apply(pos) pos_tr = np.asarray(pos_tr, dtype=np.float32) assert pos_tr.shape == (n, 2) self.program['a_position'] = pos_tr @@ -263,7 +263,7 @@ def set_data(self, bounds=NDC, color=None): [x1, y0], [x1, y0], [x0, y0]], dtype=np.float32) - self.program['a_position'] = self.apply_cpu_transforms(arr) + self.program['a_position'] = self.transforms.apply(arr) # Set the color self.program['u_color'] = _get_color(color, self._default_color) @@ -280,7 +280,7 @@ def set_data(self, xs=(), ys=(), bounds=NDC, color=None): arr += [[bounds[0], y, bounds[2], y] for y in ys] arr = np.hstack(arr or [[]]).astype(np.float32) arr = arr.reshape((-1, 2)).astype(np.float32) - position = self.apply_cpu_transforms(arr) + position = self.transforms.apply(arr) self.program['a_position'] = position # Set the color From d3c19fb6086d9ef118296c56d5eb47fd46146de7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 16:22:49 +0100 Subject: [PATCH 0630/1059] Increase transform coverage --- phy/plot/tests/test_transform.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 79eba6e75..ad9e8a1ec 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -196,12 +196,17 @@ def test_transform_chain_two(array): def test_transform_chain_complete(array): t = TransformChain([Scale(scale=.5), Scale(scale=2.)]) - t.add_on_cpu([Range(from_bounds=[-3, -3, 1, 1])]) - t.add_on_gpu([Clip(), - Subplot(shape='u_shape', index='a_box_index'), - ]) + t.add_on_cpu(Range(from_bounds=[-3, -3, 1, 1])) + t.add_on_gpu(Clip()) + t.add_on_gpu([Subplot(shape='u_shape', index='a_box_index')]) assert len(t.cpu_transforms) == 3 assert len(t.gpu_transforms) == 2 ae(t.apply(array), [[0, .5], [1, 1.5]]) + + +def test_transform_chain_add(): + tc = TransformChain([Scale(scale=.5)]) + tc += TransformChain([Scale(scale=2)]) + ae(tc.apply([3]), [[3]]) From f90fbe90883343eabcf7a3d1a47b291ed3861382 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 19:12:53 +0100 Subject: [PATCH 0631/1059] WIP: refactor plot.base, remove interacts --- phy/plot/base.py | 373 ++++++++++++------------------------ phy/plot/tests/test_base.py | 198 +++++++------------ 2 files changed, 187 insertions(+), 384 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index aea2131e7..1038e2c34 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -16,7 +16,6 @@ from .transform import TransformChain, Clip from .utils import _load_shader -from phy.utils import EventEmitter logger = logging.getLogger(__name__) @@ -43,8 +42,9 @@ class BaseVisual(object): def __init__(self): self.gl_primitive_type = None self.transforms = TransformChain() - self._to_insert = defaultdict(list) - # This will be set by attach(). + self.inserter = GLSLInserter() + # The program will be set by the canvas when the visual is + # added to the canvas. self.program = None # Visual definition @@ -57,26 +57,16 @@ def set_shader(self, name): def set_primitive_type(self, primitive_type): self.gl_primitive_type = primitive_type - # Shader insertion - # ------------------------------------------------------------------------- - - def _insert(self, shader_type, glsl, location): - assert location in ( - 'header', - 'before_transforms', - 'transforms', - 'after_transforms', - ) - self._to_insert[shader_type, location].append(glsl) - - def insert_vert(self, glsl, location): - self._insert('vert', glsl, location) - - def insert_frag(self, glsl, location): - self._insert('frag', glsl, location) - - def get_inserts(self, shader_type, location): - return '\n'.join(self._to_insert[shader_type, location]) + def on_draw(self): + """Draw the visual.""" + # Skip the drawing if the program hasn't been built yet. + # The program is built by the interact. + if self.program: + # Draw the program. + self.program.draw(self.gl_primitive_type) + else: # pragma: no cover + logger.debug("Skipping drawing visual `%s` because the program " + "has not been built yet.", self) # To override # ------------------------------------------------------------------------- @@ -90,137 +80,105 @@ def set_data(self): """ raise NotImplementedError() - # Public methods - # ------------------------------------------------------------------------- - - def attach(self, canvas): - """Attach the visual to a canvas. - - After calling this method, the following properties are available: - - * self.program - - """ - logger.debug("Attach `%s` to canvas.", self.__class__.__name__) - - self.program = build_program(self, canvas.interacts) - - # NOTE: this is connect_ and not connect because we're using - # phy's event system, not VisPy's. The reason is that the order - # of the callbacks is not kept by VisPy, whereas we need the order - # to draw visuals in the order they are attached. - @canvas.connect_ - def on_draw(): - self.on_draw() - - @canvas.connect - def on_resize(event): - """Resize the OpenGL context.""" - canvas.context.set_viewport(0, 0, event.size[0], event.size[1]) - - canvas.connect(self.on_mouse_wheel) - canvas.connect(self.on_mouse_move) - canvas.connect(self.on_key_press) - - # NOTE: this might be improved. - canvas.visuals.append(self) - # HACK: allow a visual to update the canvas it is attached to. - self.update = canvas.update - - def on_mouse_move(self, e): - pass - - def on_mouse_wheel(self, e): - pass - - def on_key_press(self, e): - pass - - def on_draw(self): - """Draw the visual.""" - # Skip the drawing if the program hasn't been built yet. - # The program is built by the interact. - if self.program: - # Draw the program. - self.program.draw(self.gl_primitive_type) - #------------------------------------------------------------------------------ -# Base interact +# Build program with interacts #------------------------------------------------------------------------------ -class BaseInteract(object): - """Implement interactions for a set of attached visuals in a canvas. +def _insert_glsl(vertex, fragment, to_insert): + """Insert snippets in a shader. - Derived classes must: + to_insert is a dict `{(shader_type, location): snippet}`. - * Define a list of `transforms` + Snippets can contain `{{ var }}` placeholders for the transformed variable + name. """ - def __init__(self): - self._canvas = None - - # To override - # ------------------------------------------------------------------------- - - def get_shader_declarations(self): - """Return extra declarations for the vertex and fragment shaders.""" - return '', '' - - def get_pre_transforms(self): - """Return an optional GLSL snippet to insert into the vertex shader - before the transforms.""" - return '' + # Find the place where to insert the GLSL snippet. + # This is "gl_Position = transform(data_var_name);" where + # data_var_name is typically an attribute. + vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') + r = vs_regex.search(vertex) + if not r: + logger.debug("The vertex shader doesn't contain the transform " + "placeholder: skipping the transform chain " + "GLSL insertion.") + return vertex, fragment + assert r + logger.log(5, "Found transform placeholder in vertex code: `%s`", + r.group(0)) - def get_transforms(self): - """Return the list of transforms.""" - return [] + # Find the GLSL variable with the data (should be a `vec2`). + var = r.group(1) + assert var and var in vertex - def update_program(self, program): - """Update a program during an interaction event.""" - pass + # Headers. + vertex = to_insert['vert', 'header'] + '\n\n' + vertex + fragment = to_insert['frag', 'header'] + '\n\n' + fragment - # Public methods - # ------------------------------------------------------------------------- + # Get the pre and post transforms. + vs_insert = to_insert['vert', 'before_transforms'] + vs_insert += to_insert['vert', 'transforms'] + vs_insert += to_insert['vert', 'after_transforms'] - @property - def size(self): - return self._canvas.size if self._canvas else None + # Insert the GLSL snippet in the vertex shader. + vertex = vs_regex.sub(indent(vs_insert), vertex) - def attach(self, canvas): - """Attach the interact to a canvas.""" - self._canvas = canvas + # Now, we make the replacements in the fragment shader. + fs_regex = re.compile(r'(void main\(\)\s*\{)') + # NOTE: we add the `void main(){` that was removed by the regex. + fs_insert = '\\1\n' + to_insert['frag', 'before_transforms'] + fragment = fs_regex.sub(indent(fs_insert), fragment) - # NOTE: this might be improved. - canvas.interacts.append(self) + # Replace the transformed variable placeholder by its name. + vertex = vertex.replace('{{ var }}', var) - canvas.connect(self.on_resize) - canvas.connect(self.on_mouse_move) - canvas.connect(self.on_mouse_wheel) - canvas.connect(self.on_key_press) + return vertex, fragment - def is_attached(self): - """Whether the interact is attached to a canvas.""" - return self._canvas is not None - def on_resize(self, event): - pass +class GLSLInserter(object): + def __init__(self): + self._to_insert = defaultdict(list) + self.insert_vert('vec2 temp_pos_tr = {{ var }};', + 'before_transforms') + self.insert_vert('gl_Position = vec4(temp_pos_tr, 0., 1.);', + 'after_transforms') + self.insert_vert('varying vec2 v_temp_pos_tr;\n', 'header') + self.insert_frag('varying vec2 v_temp_pos_tr;\n', 'header') - def on_mouse_move(self, event): - pass + def _insert(self, shader_type, glsl, location): + assert location in ( + 'header', + 'before_transforms', + 'transforms', + 'after_transforms', + ) + self._to_insert[shader_type, location].append(glsl) - def on_mouse_wheel(self, event): - pass + def insert_vert(self, glsl, location='transforms'): + self._insert('vert', glsl, location) - def on_key_press(self, event): - pass + def insert_frag(self, glsl, location=None): + self._insert('frag', glsl, location) - def update(self): - """Update the attached canvas and all attached programs.""" - if self.is_attached(): - for visual in self._canvas.visuals: - self.update_program(visual.program) - self._canvas.update() + def add_transform_chain(self, tc): + # Generate the transforms snippet. + for t in tc.gpu_transforms: + if isinstance(t, Clip): + # Set the varying value in the vertex shader. + self.insert_vert('v_temp_pos_tr = temp_pos_tr;') + continue + self.insert_vert(t.glsl('temp_pos_tr')) + # Clipping. + clip = tc.get('Clip') + if clip: + self.insert_frag(clip.glsl('v_temp_pos_tr'), 'before_transforms') + + def insert_into_shaders(self, vertex, fragment): + to_insert = defaultdict(str) + to_insert.update({key: '\n'.join(self._to_insert[key]) + for key in self._to_insert}) + return _insert_glsl(vertex, fragment, to_insert) #------------------------------------------------------------------------------ @@ -231,129 +189,38 @@ class BaseCanvas(Canvas): """A blank VisPy canvas with a custom event system that keeps the order.""" def __init__(self, *args, **kwargs): super(BaseCanvas, self).__init__(*args, **kwargs) - self._events = EventEmitter() - self.interacts = [] + self.transforms = TransformChain() self.visuals = [] - def connect_(self, *args, **kwargs): - return self._events.connect(*args, **kwargs) + def add_visual(self, visual): + """Add a visual to the canvas, and build its program by the same + occasion. - def emit_(self, *args, **kwargs): # pragma: no cover - return self._events.emit(*args, **kwargs) + We can't build the visual's program before, because we need the canvas' + transforms first. - def on_draw(self, e): - gloo.clear() - self._events.emit('draw') - - -#------------------------------------------------------------------------------ -# Build program with interacts -#------------------------------------------------------------------------------ - -def insert_glsl(transform_chain, vertex, fragment, - pre_transforms='', post_transforms=''): - """Generate the GLSL code of the transform chain.""" - - # Find the place where to insert the GLSL snippet. - # This is "gl_Position = transform(data_var_name);" where - # data_var_name is typically an attribute. - vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') - r = vs_regex.search(vertex) - if not r: - logger.debug("The vertex shader doesn't contain the transform " - "placeholder: skipping the transform chain " - "GLSL insertion.") - return vertex, fragment - assert r - logger.log(5, "Found transform placeholder in vertex code: `%s`", - r.group(0)) - - # Find the GLSL variable with the data (should be a `vec2`). - var = r.group(1) - transform_chain.transformed_var_name = var - assert var and var in vertex - - # Generate the snippet to insert in the shaders. - temp_var = 'temp_pos_tr' - # Name for the (eventual) varying. - fvar = 'v_{}'.format(temp_var) - vs_insert = '' - # Insert the pre-transforms. - vs_insert += pre_transforms + '\n' - vs_insert += "vec2 {} = {};\n".format(temp_var, var) - for t in transform_chain.gpu_transforms: - if isinstance(t, Clip): - # Set the varying value in the vertex shader. - vs_insert += '{} = {};\n'.format(fvar, temp_var) - continue - vs_insert += t.glsl(temp_var) + '\n' - vs_insert += 'gl_Position = vec4({}, 0., 1.);\n'.format(temp_var) - vs_insert += post_transforms + '\n' - - # Clipping. - clip = transform_chain.get('Clip') - if clip: - # Varying name. - glsl_clip = clip.glsl(fvar) - - # Prepare the fragment regex. - fs_regex = re.compile(r'(void main\(\)\s*\{)') - fs_insert = '\\1\n{}'.format(glsl_clip) - - # Add the varying declaration for clipping. - varying_decl = 'varying vec2 {};\n'.format(fvar) - vertex = varying_decl + vertex - fragment = varying_decl + fragment - - # Make the replacement in the fragment shader for clipping. - fragment = fs_regex.sub(indent(fs_insert), fragment) - - # Insert the GLSL snippet of the transform chain in the vertex shader. - vertex = vs_regex.sub(indent(vs_insert), vertex) - - return vertex, fragment - - -def build_program(visual, interacts=()): - """Create the gloo program of a visual using the interacts - transforms. + """ + # Retrieve the visual's GLSL inserter. + inserter = visual.inserter + # Add the visual's transforms. + inserter.add_transform_chain(visual.transforms) + # Then, add the canvas' transforms. + inserter.add_transform_chain(self.transforms) + # Now, we insert the transforms GLSL into the shaders. + vs, fs = visual.vertex_shader, visual.fragment_shader + vs, fs = inserter.insert_into_shaders(vs, fs) + # Finally, we create the visual's program. + visual.program = gloo.Program(vs, fs) + logger.log(5, "Vertex shader: %s", vs) + logger.log(5, "Fragment shader: %s", fs) + # Register the visual in the list of visuals in the canvas. + self.visuals.append(visual) - This method is called when a visual is attached to the canvas. + def on_resize(self, event): + """Resize the OpenGL context.""" + self.context.set_viewport(0, 0, event.size[0], event.size[1]) - """ - assert visual.program is None, "The program has already been built." - - # Build the transform chain using the visuals transforms first, - # then the interact's transforms. - transforms = visual.transforms - for interact in interacts: - transforms += TransformChain(interact.get_transforms()) - - logger.debug("Build the program of `%s`.", visual.__class__.__name__) - # Insert the interact's GLSL into the shaders. - vertex, fragment = visual.vertex_shader, visual.fragment_shader - # Get the GLSL snippet to insert before the transformations. - pre = '\n'.join(interact.get_pre_transforms() for interact in interacts) - # GLSL snippet to insert after all transformations. - post = visual.get_inserts('vert', 'after_transforms') - vertex, fragment = insert_glsl(transforms, vertex, fragment, - pre, post) - - # Insert shader declarations using the interacts (if any). - if interacts: - vertex_decls, frag_decls = zip(*(interact.get_shader_declarations() - for interact in interacts)) - - vertex = '\n'.join(vertex_decls) + '\n' + vertex - fragment = '\n'.join(frag_decls) + '\n' + fragment - - logger.log(5, "Vertex shader: \n%s", vertex) - logger.log(5, "Fragment shader: \n%s", fragment) - - program = gloo.Program(vertex, fragment) - - # Update the program with all interacts. - for interact in interacts: - interact.update_program(program) - - return program + def on_draw(self, e): + gloo.clear() + for visual in self.visuals: + visual.on_draw() diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 1543ad31e..7c29ca27b 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -7,194 +7,130 @@ # Imports #------------------------------------------------------------------------------ -from textwrap import dedent - import numpy as np +from pytest import yield_fixture -from ..base import BaseVisual, BaseInteract, insert_glsl +from ..base import BaseVisual, GLSLInserter from ..transform import (subplot_bounds, Translate, Scale, Range, Clip, Subplot, TransformChain) #------------------------------------------------------------------------------ -# Test base +# Fixtures #------------------------------------------------------------------------------ -def test_visual_shader_name(qtbot, canvas): - """Test a BaseVisual with a shader name.""" - class TestVisual(BaseVisual): - - def __init__(self): - super(TestVisual, self).__init__() - self.set_shader('simple') - self.set_primitive_type('lines') - - def set_data(self): - self.program['a_position'] = [[-1, 0], [1, 0]] - self.program['u_color'] = [1, 1, 1, 1] - - v = TestVisual() - # We need to build the program explicitly when there is no interact. - v.attach(canvas) - # Must be called *after* attach(). - v.set_data() +@yield_fixture +def vertex_shader_nohook(): + yield """ + attribute vec2 a_position; + void main() { + gl_Position = vec4(a_position.xy, 0, 1); + } + """ - canvas.show() - qtbot.waitForWindowShown(canvas.native) - # qtbot.stop() +@yield_fixture +def vertex_shader(): + yield """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position.xy); + gl_PointSize = 2.0; + } + """ -def test_base_visual(qtbot, canvas): - """Test a BaseVisual with custom shaders.""" - class TestVisual(BaseVisual): +@yield_fixture +def fragment_shader(): + yield """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ - def __init__(self): - super(TestVisual, self).__init__() - self.vertex_shader = """ - attribute vec2 a_position; - void main() { - gl_Position = vec4(a_position.xy, 0, 1); - } - """ - self.fragment_shader = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ - self.set_primitive_type('lines') - - def set_data(self): - self.program['a_position'] = [[-1, 0], [1, 0]] - - v = TestVisual() - # We need to build the program explicitly when there is no interact. - v.attach(canvas) - v.set_data() - - canvas.show() - qtbot.waitForWindowShown(canvas.native) - # qtbot.stop() - - # Simulate a mouse move. - canvas.events.mouse_move(pos=(0., 0.)) - canvas.events.key_press(text='a') - - v.update() - - -def test_base_interact(): - interact = BaseInteract() - assert interact.get_shader_declarations() == ('', '') - assert interact.get_pre_transforms() == '' - assert interact.get_transforms() == [] - interact.update_program(None) +#------------------------------------------------------------------------------ +# Test base +#------------------------------------------------------------------------------ -def test_no_interact(qtbot, canvas): - """Test a BaseVisual with a CPU transform and no interact.""" +def test_glsl_inserter_nohook(vertex_shader_nohook, fragment_shader): + vertex_shader = vertex_shader_nohook + inserter = GLSLInserter() + inserter.insert_vert('uniform float boo;', 'header') + inserter.insert_frag('// In fragment shader.', 'before_transforms') + vs, fs = inserter.insert_into_shaders(vertex_shader, fragment_shader) + assert vs == vertex_shader + assert fs == fragment_shader + + +def test_glsl_inserter_hook(vertex_shader, fragment_shader): + inserter = GLSLInserter() + inserter.insert_vert('uniform float boo;', 'header') + inserter.insert_frag('// In fragment shader.', 'before_transforms') + tc = TransformChain(gpu_transforms=[Scale(scale=.5)]) + inserter.add_transform_chain(tc) + vs, fs = inserter.insert_into_shaders(vertex_shader, fragment_shader) + assert 'temp_pos_tr = temp_pos_tr * 0.5;' in vs + assert 'uniform float boo;' in vs + assert '// In fragment shader.' in fs + + +def test_visual_1(qtbot, canvas): class TestVisual(BaseVisual): def __init__(self): super(TestVisual, self).__init__() self.set_shader('simple') self.set_primitive_type('lines') - self.transforms.add_on_cpu(Scale(scale=(.5, 1))) def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] self.program['u_color'] = [1, 1, 1, 1] - # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() - v.attach(canvas) + canvas.add_visual(v) + # Must be called *after* add_visual(). v.set_data() canvas.show() - assert not canvas.interacts qtbot.waitForWindowShown(canvas.native) # qtbot.stop() -def test_interact(qtbot, canvas): - """Test a BaseVisual with multiple CPU and GPU transforms and a - non-blank interact. +def test_visual_2(qtbot, canvas, vertex_shader, fragment_shader): + """Test a BaseVisual with multiple CPU and GPU transforms. - There should be points filling the entire lower (2, 3) subplot. + There should be points filling the entire right upper (2, 3) subplot. """ class TestVisual(BaseVisual): def __init__(self): super(TestVisual, self).__init__() - self.vertex_shader = """ - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - gl_PointSize = 2.0; - } - """ - self.fragment_shader = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ + self.vertex_shader = vertex_shader + self.fragment_shader = fragment_shader self.set_primitive_type('points') self.transforms.add_on_cpu(Scale(scale=(.1, .1))) self.transforms.add_on_cpu(Translate(translate=(-1, -1))) self.transforms.add_on_cpu(Range(from_bounds=(-1, -1, 1, 1), to_bounds=(-1.5, -1.5, 1.5, 1.5), )) - self.insert_vert("""gl_Position.y += 1;""", 'after_transforms') + self.inserter.insert_vert('gl_Position.y += 1;', + 'after_transforms') def set_data(self): data = np.random.uniform(0, 20, (1000, 2)).astype(np.float32) self.program['a_position'] = self.transforms.apply(data) - class TestInteract(BaseInteract): - def get_transforms(self): - bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) - return [Subplot(shape=(2, 3), index=(1, 2)), - Clip(bounds=bounds), - ] - - TestInteract().attach(canvas) + bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) + canvas.transforms.add_on_gpu([Subplot(shape=(2, 3), index=(1, 2)), + Clip(bounds=bounds), + ]) # We attach the visual to the canvas. By default, a BaseInteract is used. v = TestVisual() - v.attach(canvas) + canvas.add_visual(v) v.set_data() canvas.show() - assert len(canvas.interacts) == 1 qtbot.waitForWindowShown(canvas.native) # qtbot.stop() - - -def test_transform_chain_complete(): - t = TransformChain([Scale(scale=.5), - Scale(scale=2.)]) - t.add_on_cpu([Range(from_bounds=[-3, -3, 1, 1])]) - t.add_on_gpu([Clip(), - Subplot(shape='u_shape', index='a_box_index'), - ]) - - vs = dedent(""" - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - } - """).strip() - - fs = dedent(""" - void main() { - gl_FragColor = vec4(1., 1., 1., 1.); - } - """).strip() - vs, fs = insert_glsl(t, vs, fs) - assert 'a_box_index' in vs - assert 'v_' in vs - assert 'v_' in fs - assert 'discard' in fs - - # Increase coverage. - insert_glsl(t, vs.replace('transform', ''), fs) From 02e491368232d24e36cfe5b2d8d9b544b14d2bf0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 21:23:20 +0100 Subject: [PATCH 0632/1059] WIP: update visuals --- phy/plot/__init__.py | 4 +-- phy/plot/base.py | 7 ++++++ phy/plot/panzoom.py | 46 +++++++++++++++++++++------------- phy/plot/tests/test_panzoom.py | 12 +++++---- phy/plot/tests/test_visuals.py | 4 +-- phy/plot/visuals.py | 16 ++++-------- 6 files changed, 52 insertions(+), 37 deletions(-) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 88e2dbe71..0029ba8df 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -12,8 +12,8 @@ from vispy import config -from .interact import Grid, Stacked, Boxed -from .plot import GridView, BoxedView, StackedView # noqa +# from .interact import Grid, Stacked, Boxed +# from .plot import GridView, BoxedView, StackedView # noqa from .transform import Translate, Scale, Range, Subplot, NDC from .panzoom import PanZoom from.visuals import _get_linear_x diff --git a/phy/plot/base.py b/phy/plot/base.py index 1038e2c34..099adc41d 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -180,6 +180,10 @@ def insert_into_shaders(self, vertex, fragment): for key in self._to_insert}) return _insert_glsl(vertex, fragment, to_insert) + def __add__(self, inserter): + self._to_insert.update(inserter._to_insert) + return self + #------------------------------------------------------------------------------ # Base canvas @@ -190,6 +194,7 @@ class BaseCanvas(Canvas): def __init__(self, *args, **kwargs): super(BaseCanvas, self).__init__(*args, **kwargs) self.transforms = TransformChain() + self.inserter = GLSLInserter() self.visuals = [] def add_visual(self, visual): @@ -206,6 +211,8 @@ def add_visual(self, visual): inserter.add_transform_chain(visual.transforms) # Then, add the canvas' transforms. inserter.add_transform_chain(self.transforms) + # Also, add the canvas' inserter. + inserter += self.inserter # Now, we insert the transforms GLSL into the shaders. vs, fs = visual.vertex_shader, visual.fragment_shader vs, fs = inserter.insert_into_shaders(vs, fs) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 969c96129..c82ef835f 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -11,7 +11,6 @@ import numpy as np -from .base import BaseInteract from .transform import Translate, Scale, pixels_to_ndc from phy.utils._types import _as_array @@ -20,7 +19,7 @@ # PanZoom class #------------------------------------------------------------------------------ -class PanZoom(BaseInteract): +class PanZoom(object): """Pan and zoom interact. To use it: @@ -58,8 +57,6 @@ def __init__(self, pan_var_name='u_pan', zoom_var_name='u_zoom', ): - super(PanZoom, self).__init__() - if constrain_bounds: assert xmin is None assert ymin is None @@ -89,13 +86,8 @@ def __init__(self, self._zoom_to_pointer = True self._canvas_aspect = np.ones(2) - def get_shader_declarations(self): - return ('uniform vec2 {};\n'.format(self.pan_var_name) + - 'uniform vec2 {};\n'.format(self.zoom_var_name)), '' - - def get_transforms(self): - return [Translate(translate=self.pan_var_name), - Scale(scale=self.zoom_var_name)] + # Will be set when attached to a canvas. + self.canvas = None def update_program(self, program): zoom = self._zoom_aspect() @@ -338,14 +330,12 @@ def reset(self): def on_resize(self, event): """Resize event.""" - super(PanZoom, self).on_resize(event) self._set_canvas_aspect() # Update zoom level self.zoom = self._zoom def on_mouse_move(self, event): """Pan and zoom with the mouse.""" - super(PanZoom, self).on_mouse_move(event) if event.modifiers: return if event.is_dragging: @@ -361,7 +351,6 @@ def on_mouse_move(self, event): def on_mouse_wheel(self, event): """Zoom with the mouse wheel.""" - super(PanZoom, self).on_mouse_wheel(event) if event.modifiers: return dx = np.sign(event.delta[1]) * self._wheel_coeff @@ -371,8 +360,6 @@ def on_mouse_wheel(self, event): def on_key_press(self, event): """Pan and zoom with the keyboard.""" - super(PanZoom, self).on_key_press(event) - # Zooming with the keyboard. key = event.key if event.modifiers: @@ -393,7 +380,32 @@ def on_key_press(self, event): # Canvas methods # ------------------------------------------------------------------------- + @property + def size(self): + if self.canvas: + return self.canvas.size + def attach(self, canvas): """Attach this interact to a canvas.""" - super(PanZoom, self).attach(canvas) + self.canvas = canvas + canvas.panzoom = self + + canvas.transforms.add_on_gpu([Translate(translate=self.pan_var_name), + Scale(scale=self.zoom_var_name)]) + # Add the variable declarations. + vs = ('uniform vec2 {};\n'.format(self.pan_var_name) + + 'uniform vec2 {};\n'.format(self.zoom_var_name)) + canvas.inserter.insert_vert(vs, 'header') + + canvas.connect(self.on_resize) + canvas.connect(self.on_mouse_move) + canvas.connect(self.on_mouse_wheel) + canvas.connect(self.on_key_press) + self._set_canvas_aspect() + + def update(self): + if not self.canvas: + return + for visual in self.canvas.visuals: + self.update_program(visual.program) diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 70d11ea1d..913cc9534 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -21,8 +21,10 @@ #------------------------------------------------------------------------------ class MyTestVisual(BaseVisual): - shader_name = 'simple' - gl_primitive_type = 'lines' + def __init__(self): + super(MyTestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') def set_data(self): self.program['a_position'] = [[-1, 0], [1, 0]] @@ -33,13 +35,13 @@ def set_data(self): def panzoom(qtbot, canvas_pz): c = canvas_pz visual = MyTestVisual() - visual.attach(c) + c.add_visual(visual) visual.set_data() c.show() qtbot.waitForWindowShown(c.native) - yield c.interacts[0] + yield c.panzoom #------------------------------------------------------------------------------ @@ -49,7 +51,7 @@ def panzoom(qtbot, canvas_pz): def test_panzoom_basic_attrs(): pz = PanZoom() - assert not pz.is_attached() + # assert not pz.is_attached() # Aspect. assert pz.aspect == 1. diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index bea0f45ca..0d31f783c 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -19,7 +19,7 @@ #------------------------------------------------------------------------------ def _test_visual(qtbot, c, v, stop=False, **kwargs): - v.attach(c) + c.add_visual(v) v.set_data(**kwargs) c.show() qtbot.waitForWindowShown(c.native) @@ -43,7 +43,7 @@ def test_scatter_markers(qtbot, canvas_pz): pos = .2 * np.random.randn(n, 2) v = ScatterVisual(marker='vbar') - v.attach(c) + c.add_visual(v) v.set_data(pos=pos) c.show() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 25a38420c..7f3f3a2f6 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -38,9 +38,6 @@ #------------------------------------------------------------------------------ class ScatterVisual(BaseVisual): - shader_name = 'scatter' - gl_primitive_type = 'points' - _default_marker_size = 10. _default_marker = 'disc' _default_color = DEFAULT_COLOR @@ -79,14 +76,11 @@ def __init__(self, marker=None): # Enable transparency. _enable_depth_mask() - def get_shaders(self): - v, f = super(ScatterVisual, self).get_shaders() - # Replace the marker type in the shader. - f = f.replace('%MARKER', self.marker) - return v, f - - def get_transforms(self): - return [Range(from_bounds=self.data_bounds), GPU()] + self.set_shader('scatter') + self.fragment_shader = self.fragment_shader.replace('%MARKER', + self.marker) + self.set_primitive_type('points') + self.transforms.add_on_cpu(Range(from_bounds=self.data_bounds)) def set_data(self, pos=None, From 64aabb3a8f277636346bbb5ed0f15a6b9d6b96b9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 25 Nov 2015 22:09:09 +0100 Subject: [PATCH 0633/1059] WIP: update visuals --- phy/plot/interact.py | 8 ++------ phy/plot/panzoom.py | 1 + phy/plot/visuals.py | 49 ++++++++++++++++++++++---------------------- 3 files changed, 28 insertions(+), 30 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 7fcae5c4a..c1069a29a 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -12,7 +12,6 @@ import numpy as np from vispy.gloo import Texture2D -from .base import BaseInteract from .transform import Scale, Range, Subplot, Clip, NDC from .utils import _get_texture, _get_boxes, _get_box_pos_size @@ -21,7 +20,7 @@ # Grid interact #------------------------------------------------------------------------------ -class Grid(BaseInteract): +class Grid(object): """Grid interact. NOTE: to be used in a grid, a visual must define `a_box_index` @@ -38,7 +37,6 @@ class Grid(BaseInteract): """ def __init__(self, shape, box_var=None): - super(Grid, self).__init__() self._zoom = 1. # Name of the variable with the box index. @@ -84,7 +82,6 @@ def zoom(self, value): def on_key_press(self, event): """Pan and zoom with the keyboard.""" - super(Grid, self).on_key_press(event) key = event.key # Zoom. @@ -103,7 +100,7 @@ def on_key_press(self, event): # Boxed interact #------------------------------------------------------------------------------ -class Boxed(BaseInteract): +class Boxed(object): """Boxed interact. NOTE: to be used in a boxed, a visual must define `a_box_index` @@ -129,7 +126,6 @@ def __init__(self, box_pos=None, box_size=None, box_var=None): - super(Boxed, self).__init__() self._key_pressed = None # Name of the variable with the box index. diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index c82ef835f..fdf927db5 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -409,3 +409,4 @@ def update(self): return for visual in self.canvas.visuals: self.update_program(visual.program) + self.canvas.update() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 7f3f3a2f6..531dee884 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -106,8 +106,6 @@ def set_data(self, class PlotVisual(BaseVisual): - shader_name = 'plot' - gl_primitive_type = 'line_strip' _default_color = DEFAULT_COLOR def __init__(self, n_samples=None): @@ -116,10 +114,9 @@ def __init__(self, n_samples=None): self.n_samples = n_samples _enable_depth_mask() - def get_transforms(self): - return [Range(from_bounds=self.data_bounds), - GPU(), - ] + self.set_shader('plot') + self.set_primitive_type('line_strip') + self.transforms.add_on_cpu(Range(from_bounds=self.data_bounds)) def set_data(self, x=None, @@ -172,8 +169,6 @@ def set_data(self, class HistogramVisual(BaseVisual): - shader_name = 'histogram' - gl_primitive_type = 'triangles' _default_color = DEFAULT_COLOR def __init__(self): @@ -181,13 +176,14 @@ def __init__(self): self.n_bins = 0 self.hist_max = 1 - def get_transforms(self): - return [Range(from_bounds=[0, 0, self.n_bins, self.hist_max], - to_bounds=[0, 0, 1, 1]), - GPU(), - Range(from_bounds='hist_bounds', # (0, 0, 1, v) - to_bounds=NDC), - ] + self.set_shader('histogram') + self.set_primitive_type('triangles') + self.transforms.add_on_cpu(Range(from_bounds=[0, 0, self.n_bins, + self.hist_max], + to_bounds=[0, 0, 1, 1])) + # (0, 0, 1, v) + self.transforms.add_on_gpu(Range(from_bounds='hist_bounds', + to_bounds=NDC)) def set_data(self, hist=None, @@ -231,21 +227,23 @@ def set_data(self, class TextVisual(BaseVisual): - shader_name = 'text' - gl_primitive_type = 'points' - - def get_transforms(self): - pass + def __init__(self): + super(TextVisual, self).__init__() + self.set_shader('text') + self.set_primitive_type('points') def set_data(self): pass class BoxVisual(BaseVisual): - shader_name = 'simple' - gl_primitive_type = 'lines' _default_color = (.35, .35, .35, 1.) + def __init__(self): + super(BoxVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + def set_data(self, bounds=NDC, color=None): # Set the position. x0, y0, x1, y1 = bounds @@ -264,10 +262,13 @@ def set_data(self, bounds=NDC, color=None): class AxesVisual(BaseVisual): - shader_name = 'simple' - gl_primitive_type = 'lines' _default_color = (.2, .2, .2, 1.) + def __init__(self): + super(AxesVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + def set_data(self, xs=(), ys=(), bounds=NDC, color=None): # Set the position. arr = [[x, bounds[1], x, bounds[3]] for x in xs] From a1d3799f0048af28604a9fbf0ea6d089d7a30811 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 26 Nov 2015 10:36:37 +0100 Subject: [PATCH 0634/1059] WIP: refactor transform --- phy/plot/tests/test_transform.py | 94 +++++++++++++------------ phy/plot/transform.py | 116 +++++++++++++------------------ 2 files changed, 97 insertions(+), 113 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index ad9e8a1ec..771db03cd 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -23,8 +23,8 @@ # Fixtures #------------------------------------------------------------------------------ -def _check(transform, array, expected, **kwargs): - transformed = transform.apply(array, **kwargs) +def _check(transform, array, expected): + transformed = transform.apply(array) if array is None or not len(array): assert transformed == array return @@ -57,58 +57,54 @@ def test_pixels_to_ndc(): #------------------------------------------------------------------------------ def test_types(): - t = Translate() - _check(t, [], [], translate=[1, 2]) + _check(Translate([1, 2]), [], []) for ab in [[3, 4], [3., 4.]]: for arr in [ab, [ab], np.array(ab), np.array([ab]), np.array([ab, ab, ab])]: - _check(t, arr, [[4, 6]], translate=[1, 2]) + _check(Translate([1, 2]), arr, [[4, 6]]) def test_translate_cpu(): - _check(Translate(translate=[1, 2]), [3, 4], [[4, 6]]) + _check(Translate([1, 2]), [3, 4], [[4, 6]]) def test_scale_cpu(): - _check(Scale(), [3, 4], [[-3, 8]], scale=[-1, 2]) + _check(Scale([-1, 2]), [3, 4], [[-3, 8]]) def test_range_cpu(): - kwargs = dict(from_bounds=[0, 0, 1, 1], to_bounds=[-1, -1, 1, 1]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [-1, -1], [[-3, -3]]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [0, 0], [[-1, -1]]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [0.5, 0.5], [[0, 0]]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [1, 1], [[1, 1]]) - _check(Range(), [-1, -1], [[-3, -3]], **kwargs) - _check(Range(), [0, 0], [[-1, -1]], **kwargs) - _check(Range(), [0.5, 0.5], [[0, 0]], **kwargs) - _check(Range(), [1, 1], [[1, 1]], **kwargs) - - _check(Range(), [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]], **kwargs) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), + [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]]) def test_clip_cpu(): - kwargs = dict(bounds=[0, 1, 2, 3]) - _check(Clip(), [0, 0], [0, 0]) # Default bounds. - _check(Clip(), [0, 1], [0, 1], **kwargs) - _check(Clip(), [1, 2], [1, 2], **kwargs) - _check(Clip(), [2, 3], [2, 3], **kwargs) + _check(Clip([0, 1, 2, 3]), [0, 1], [0, 1]) + _check(Clip([0, 1, 2, 3]), [1, 2], [1, 2]) + _check(Clip([0, 1, 2, 3]), [2, 3], [2, 3]) - _check(Clip(), [-1, -1], [], **kwargs) - _check(Clip(), [3, 4], [], **kwargs) - _check(Clip(), [[-1, 0], [3, 4]], [], **kwargs) + _check(Clip([0, 1, 2, 3]), [-1, -1], []) + _check(Clip([0, 1, 2, 3]), [3, 4], []) + _check(Clip([0, 1, 2, 3]), [[-1, 0], [3, 4]], []) def test_subplot_cpu(): shape = (2, 3) - _check(Subplot(), [-1, -1], [-1, +0], index=(0, 0), shape=shape) - _check(Subplot(), [+0, +0], [-2. / 3., .5], index=(0, 0), shape=shape) + _check(Subplot(shape, (0, 0)), [-1, -1], [-1, +0]) + _check(Subplot(shape, (0, 0)), [+0, +0], [-2. / 3., .5]) - _check(Subplot(), [-1, -1], [-1, -1], index=(1, 0), shape=shape) - _check(Subplot(), [+1, +1], [-1. / 3, 0], index=(1, 0), shape=shape) + _check(Subplot(shape, (1, 0)), [-1, -1], [-1, -1]) + _check(Subplot(shape, (1, 0)), [+1, +1], [-1. / 3, 0]) - _check(Subplot(), [0, 1], [0, 0], index=(1, 1), shape=shape) + _check(Subplot(shape, (1, 1)), [0, 1], [0, 0]) #------------------------------------------------------------------------------ @@ -116,22 +112,22 @@ def test_subplot_cpu(): #------------------------------------------------------------------------------ def test_translate_glsl(): - t = Translate(translate='u_translate').glsl('x') + t = Translate('u_translate').glsl('x') assert 'x = x + u_translate' in t def test_scale_glsl(): - assert 'x = x * u_scale' in Scale().glsl('x', scale='u_scale') + assert 'x = x * u_scale' in Scale('u_scale').glsl('x') def test_range_glsl(): - assert Range(from_bounds=[-1, -1, 1, 1]).glsl('x') + assert Range([-1, -1, 1, 1]).glsl('x') expected = ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' '(u_from.zw - u_from.xy)') - r = Range(to_bounds='u_to') - assert expected in r.glsl('x', from_bounds='u_from') + r = Range('u_from', 'u_to') + assert expected in r.glsl('x') def test_clip_glsl(): @@ -143,11 +139,11 @@ def test_clip_glsl(): discard; } """).strip() - assert expected in Clip().glsl('x', bounds='b') + assert expected in Clip('b').glsl('x') def test_subplot_glsl(): - glsl = Subplot().glsl('x', shape='u_shape', index='a_index') + glsl = Subplot('u_shape', 'a_index').glsl('x') assert 'x = ' in glsl @@ -170,8 +166,9 @@ def test_transform_chain_empty(array): def test_transform_chain_one(array): - translate = Translate(translate=[1, 2]) - t = TransformChain([translate]) + translate = Translate([1, 2]) + t = TransformChain() + t.add_on_cpu([translate]) assert t.cpu_transforms == [translate] assert t.gpu_transforms == [] @@ -180,9 +177,10 @@ def test_transform_chain_one(array): def test_transform_chain_two(array): - translate = Translate(translate=[1, 2]) - scale = Scale(scale=[.5, .5]) - t = TransformChain([translate, scale]) + translate = Translate([1, 2]) + scale = Scale([.5, .5]) + t = TransformChain() + t.add_on_cpu([translate, scale]) assert t.cpu_transforms == [translate, scale] assert t.gpu_transforms == [] @@ -194,11 +192,11 @@ def test_transform_chain_two(array): def test_transform_chain_complete(array): - t = TransformChain([Scale(scale=.5), - Scale(scale=2.)]) - t.add_on_cpu(Range(from_bounds=[-3, -3, 1, 1])) + t = TransformChain() + t.add_on_cpu([Scale(.5), Scale(2.)]) + t.add_on_cpu(Range([-3, -3, 1, 1])) t.add_on_gpu(Clip()) - t.add_on_gpu([Subplot(shape='u_shape', index='a_box_index')]) + t.add_on_gpu([Subplot('u_shape', 'a_box_index')]) assert len(t.cpu_transforms) == 3 assert len(t.gpu_transforms) == 2 @@ -207,6 +205,10 @@ def test_transform_chain_complete(array): def test_transform_chain_add(): - tc = TransformChain([Scale(scale=.5)]) - tc += TransformChain([Scale(scale=2)]) - ae(tc.apply([3]), [[3]]) + tc = TransformChain() + tc.add_on_cpu([Scale(.5)]) + + tc_2 = TransformChain() + tc_2.add_on_cpu([Scale(2)]) + + ae((tc + tc_2).apply([3]), [[3]]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 22b927f35..cbc1859e3 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -21,12 +21,10 @@ # Utils #------------------------------------------------------------------------------ -def _wrap_apply(f, **kwargs_init): +def _wrap_apply(f): def wrapped(arr, **kwargs): if arr is None or not len(arr): return arr - # Method kwargs first, then we update with the constructor kwargs. - kwargs.update(kwargs_init) arr = np.atleast_2d(arr) arr = arr.astype(np.float32) assert arr.ndim == 2 @@ -39,29 +37,14 @@ def wrapped(arr, **kwargs): return wrapped -def _wrap_glsl(f, **kwargs_init): +def _wrap_glsl(f): def wrapped(var, **kwargs): - # Method kwargs first, then we update with the constructor kwargs. - kwargs.update(kwargs_init) out = f(var, **kwargs) out = dedent(out).strip() return out return wrapped -def _wrap(f, **kwargs_init): - """Pass extra keyword arguments to a function. - - Used to pass constructor arguments to class methods in transforms. - - """ - def wrapped(*args, **kwargs): - # Method kwargs first, then we update with the constructor kwargs. - kwargs.update(kwargs_init) - return f(*args, **kwargs) - return wrapped - - def _glslify(r): """Transform a string or a n-tuple to a valid GLSL expression.""" if isinstance(r, string_types): @@ -107,11 +90,10 @@ def pixels_to_ndc(pos, size=None): #------------------------------------------------------------------------------ class BaseTransform(object): - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - # Pass the constructor kwargs to the methods. - self.apply = _wrap_apply(self.apply, **kwargs) - self.glsl = _wrap_glsl(self.glsl, **kwargs) + def __init__(self, value=None): + self.value = value + self.apply = _wrap_apply(self.apply) + self.glsl = _wrap_glsl(self.glsl) def apply(self, arr): raise NotImplementedError() @@ -121,39 +103,44 @@ def glsl(self, var): class Translate(BaseTransform): - def apply(self, arr, translate=None): + def apply(self, arr): assert isinstance(arr, np.ndarray) - return arr + np.asarray(translate) + return arr + np.asarray(self.value) - def glsl(self, var, translate=None): + def glsl(self, var): assert var return """{var} = {var} + {translate};""".format(var=var, - translate=translate) + translate=self.value) class Scale(BaseTransform): - def apply(self, arr, scale=None): - return arr * np.asarray(scale) + def apply(self, arr): + return arr * np.asarray(self.value) - def glsl(self, var, scale=None): + def glsl(self, var): assert var - return """{var} = {var} * {scale};""".format(var=var, scale=scale) + return """{var} = {var} * {scale};""".format(var=var, scale=self.value) class Range(BaseTransform): - def apply(self, arr, from_bounds=None, to_bounds=NDC): - f0 = np.asarray(from_bounds[:2]) - f1 = np.asarray(from_bounds[2:]) - t0 = np.asarray(to_bounds[:2]) - t1 = np.asarray(to_bounds[2:]) + def __init__(self, from_bounds=None, to_bounds=None): + super(Range, self).__init__() + self.from_bounds = from_bounds or NDC + self.to_bounds = to_bounds or NDC + + def apply(self, arr): + f0 = np.asarray(self.from_bounds[:2]) + f1 = np.asarray(self.from_bounds[2:]) + t0 = np.asarray(self.to_bounds[:2]) + t1 = np.asarray(self.to_bounds[2:]) return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) - def glsl(self, var, from_bounds=None, to_bounds=NDC): + def glsl(self, var): assert var - from_bounds = _glslify(from_bounds) - to_bounds = _glslify(to_bounds) + from_bounds = _glslify(self.from_bounds) + to_bounds = _glslify(self.to_bounds) return ("{var} = {t}.xy + ({t}.zw - {t}.xy) * " "({var} - {f}.xy) / ({f}.zw - {f}.xy);" @@ -161,17 +148,20 @@ def glsl(self, var, from_bounds=None, to_bounds=NDC): class Clip(BaseTransform): - def apply(self, arr, bounds=NDC): - index = ((arr[:, 0] >= bounds[0]) & - (arr[:, 1] >= bounds[1]) & - (arr[:, 0] <= bounds[2]) & - (arr[:, 1] <= bounds[3])) + def __init__(self, bounds=None): + super(Clip, self).__init__() + self.bounds = bounds or NDC + + def apply(self, arr): + index = ((arr[:, 0] >= self.bounds[0]) & + (arr[:, 1] >= self.bounds[1]) & + (arr[:, 0] <= self.bounds[2]) & + (arr[:, 1] <= self.bounds[3])) return arr[index, ...] - def glsl(self, var, bounds=NDC): + def glsl(self, var): assert var - - bounds = _glslify(bounds) + bounds = _glslify(self.bounds) return """ if (({var}.x < {bounds}.x) || @@ -186,25 +176,19 @@ def glsl(self, var, bounds=NDC): class Subplot(Range): """Assume that the from_bounds is [-1, -1, 1, 1].""" - def __init__(self, **kwargs): - super(Subplot, self).__init__(**kwargs) - self.get_bounds = _wrap(self.get_bounds) - - def get_bounds(self, shape=None, index=None): - return subplot_bounds(shape=shape, index=index) + def __init__(self, shape, index=None): + super(Subplot, self).__init__() + self.shape = shape + self.index = index + self.from_bounds = NDC + if isinstance(self.shape, tuple): + self.to_bounds = subplot_bounds(shape=self.shape, index=self.index) - def apply(self, arr, shape=None, index=None): - from_bounds = NDC - to_bounds = self.get_bounds(shape=shape, index=index) - return super(Subplot, self).apply(arr, - from_bounds=from_bounds, - to_bounds=to_bounds) - - def glsl(self, var, shape=None, index=None): + def glsl(self, var): assert var - index = _glslify(index) - shape = _glslify(shape) + index = _glslify(self.index) + shape = _glslify(self.shape) snippet = """ float subplot_width = 2. / {shape}.y; @@ -227,12 +211,10 @@ def glsl(self, var, shape=None, index=None): class TransformChain(object): """A linear sequence of transforms that happen on the CPU and GPU.""" - def __init__(self, cpu_transforms=None, gpu_transforms=None): + def __init__(self): self.transformed_var_name = None self.cpu_transforms = [] self.gpu_transforms = [] - self.add_on_cpu(cpu_transforms or []) - self.add_on_gpu(gpu_transforms or []) def add_on_cpu(self, transforms): """Add some transforms.""" From aade0f96e6d6e0af05feecfa572efe35ee2f574f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 26 Nov 2015 10:38:24 +0100 Subject: [PATCH 0635/1059] WIP: update plot.base --- phy/plot/tests/test_base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 7c29ca27b..2d5831392 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -67,7 +67,8 @@ def test_glsl_inserter_hook(vertex_shader, fragment_shader): inserter = GLSLInserter() inserter.insert_vert('uniform float boo;', 'header') inserter.insert_frag('// In fragment shader.', 'before_transforms') - tc = TransformChain(gpu_transforms=[Scale(scale=.5)]) + tc = TransformChain() + tc.add_on_gpu([Scale(.5)]) inserter.add_transform_chain(tc) vs, fs = inserter.insert_into_shaders(vertex_shader, fragment_shader) assert 'temp_pos_tr = temp_pos_tr * 0.5;' in vs @@ -109,10 +110,10 @@ def __init__(self): self.vertex_shader = vertex_shader self.fragment_shader = fragment_shader self.set_primitive_type('points') - self.transforms.add_on_cpu(Scale(scale=(.1, .1))) - self.transforms.add_on_cpu(Translate(translate=(-1, -1))) - self.transforms.add_on_cpu(Range(from_bounds=(-1, -1, 1, 1), - to_bounds=(-1.5, -1.5, 1.5, 1.5), + self.transforms.add_on_cpu(Scale((.1, .1))) + self.transforms.add_on_cpu(Translate((-1, -1))) + self.transforms.add_on_cpu(Range((-1, -1, 1, 1), + (-1.5, -1.5, 1.5, 1.5), )) self.inserter.insert_vert('gl_Position.y += 1;', 'after_transforms') @@ -122,8 +123,8 @@ def set_data(self): self.program['a_position'] = self.transforms.apply(data) bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) - canvas.transforms.add_on_gpu([Subplot(shape=(2, 3), index=(1, 2)), - Clip(bounds=bounds), + canvas.transforms.add_on_gpu([Subplot((2, 3), (1, 2)), + Clip(bounds), ]) # We attach the visual to the canvas. By default, a BaseInteract is used. From f5d905a4aae9c064c6032ec5bb6ec389c35d23e6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 26 Nov 2015 10:43:20 +0100 Subject: [PATCH 0636/1059] Update panzoom --- phy/plot/panzoom.py | 4 ++-- phy/plot/tests/test_panzoom.py | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index fdf927db5..bb5a95485 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -390,8 +390,8 @@ def attach(self, canvas): self.canvas = canvas canvas.panzoom = self - canvas.transforms.add_on_gpu([Translate(translate=self.pan_var_name), - Scale(scale=self.zoom_var_name)]) + canvas.transforms.add_on_gpu([Translate(self.pan_var_name), + Scale(self.zoom_var_name)]) # Add the variable declarations. vs = ('uniform vec2 {};\n'.format(self.pan_var_name) + 'uniform vec2 {};\n'.format(self.zoom_var_name)) diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 913cc9534..902fb5a6e 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -51,8 +51,6 @@ def panzoom(qtbot, canvas_pz): def test_panzoom_basic_attrs(): pz = PanZoom() - # assert not pz.is_attached() - # Aspect. assert pz.aspect == 1. pz.aspect = 2. @@ -70,6 +68,19 @@ def test_panzoom_basic_attrs(): assert getattr(pz, name) == v * 2 +def test_panzoom_basic_constrain(): + pz = PanZoom(constrain_bounds=(-1, -1, 1, 1)) + + # Aspect. + assert pz.aspect == 1. + pz.aspect = 2. + assert pz.aspect == 2. + + # Constraints. + assert pz.xmin == pz.ymin == -1 + assert pz.xmax == pz.ymax == +1 + + def test_panzoom_basic_pan_zoom(): pz = PanZoom() From 79e1f246b549924f14323622494e1caf58cc3e4b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 26 Nov 2015 10:51:05 +0100 Subject: [PATCH 0637/1059] Update visuals --- phy/plot/tests/test_visuals.py | 5 +++-- phy/plot/visuals.py | 30 +++++++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 0d31f783c..306fa30f6 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -95,7 +95,8 @@ def test_plot_1(qtbot, canvas_pz): def test_plot_2(qtbot, canvas_pz): n_signals = 50 - y = 20 * np.random.randn(n_signals, 10) + n_samples = 10 + y = 20 * np.random.randn(n_signals, n_samples) # Signal colors. c = np.random.uniform(.5, 1, size=(n_signals, 4)) @@ -104,7 +105,7 @@ def test_plot_2(qtbot, canvas_pz): # Depth. depth = np.linspace(0., -1., n_signals) - _test_visual(qtbot, canvas_pz, PlotVisual(), + _test_visual(qtbot, canvas_pz, PlotVisual(n_samples=n_samples), y=y, depth=depth, data_bounds=[-1, -50, 1, 50], plot_colors=c) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 531dee884..627fa8664 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -80,7 +80,7 @@ def __init__(self, marker=None): self.fragment_shader = self.fragment_shader.replace('%MARKER', self.marker) self.set_primitive_type('points') - self.transforms.add_on_cpu(Range(from_bounds=self.data_bounds)) + self.transforms.add_on_cpu(Range(self.data_bounds)) def set_data(self, pos=None, @@ -110,13 +110,14 @@ class PlotVisual(BaseVisual): def __init__(self, n_samples=None): super(PlotVisual, self).__init__() - self.data_bounds = NDC self.n_samples = n_samples _enable_depth_mask() self.set_shader('plot') self.set_primitive_type('line_strip') - self.transforms.add_on_cpu(Range(from_bounds=self.data_bounds)) + + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) def set_data(self, x=None, @@ -146,7 +147,9 @@ def set_data(self, pos[:, 1] = y.ravel() pos = _check_pos_2D(pos) - self.data_bounds = _get_data_bounds(data_bounds, pos) + # Update the data range using the specified data_bounds and the data. + # NOTE: this must be called *before* transforms.apply(). + self.data_range.from_bounds = _get_data_bounds(data_bounds, pos) # Set the transformed position. pos_tr = self.transforms.apply(pos) @@ -174,16 +177,15 @@ class HistogramVisual(BaseVisual): def __init__(self): super(HistogramVisual, self).__init__() self.n_bins = 0 - self.hist_max = 1 self.set_shader('histogram') self.set_primitive_type('triangles') - self.transforms.add_on_cpu(Range(from_bounds=[0, 0, self.n_bins, - self.hist_max], - to_bounds=[0, 0, 1, 1])) + + self.data_range = Range([0, 0, self.n_bins, 1], + [0, 0, 1, 1]) + self.transforms.add_on_cpu(self.data_range) # (0, 0, 1, v) - self.transforms.add_on_gpu(Range(from_bounds='hist_bounds', - to_bounds=NDC)) + self.transforms.add_on_gpu(Range('hist_bounds', NDC)) def set_data(self, hist=None, @@ -198,7 +200,8 @@ def set_data(self, # NOTE: this must be set *before* `apply_cpu_transforms` such # that the histogram is correctly normalized. - self.hist_max = _get_hist_max(hist) + hist_max = _get_hist_max(hist) + self.data_range.from_bounds = [0, 0, self.n_bins, hist_max] # Set the transformed position. pos = np.vstack(_tesselate_histogram(row) for row in hist) @@ -219,7 +222,7 @@ def set_data(self, assert ylim is None or len(ylim) == n_hists hist_bounds = np.c_[np.zeros((n_hists, 2)), np.ones((n_hists, 1)), - ylim / self.hist_max] if ylim is not None else None + ylim / hist_max] if ylim is not None else None hist_bounds = _get_texture(hist_bounds, [0, 0, 1, 1], n_hists, [0, 10]) self.program['u_hist_bounds'] = Texture2D(hist_bounds) @@ -227,7 +230,8 @@ def set_data(self, class TextVisual(BaseVisual): - def __init__(self): + def __init__(self): # pragma: no cover + # TODO: this text visual super(TextVisual, self).__init__() self.set_shader('text') self.set_primitive_type('points') From e15445596c664fa74aa46cfa9481fd56183a0047 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 27 Nov 2015 01:01:00 +0100 Subject: [PATCH 0638/1059] WIP --- phy/plot/panzoom.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index bb5a95485..54addb7c2 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -89,11 +89,6 @@ def __init__(self, # Will be set when attached to a canvas. self.canvas = None - def update_program(self, program): - zoom = self._zoom_aspect() - program[self.pan_var_name] = self._pan - program[self.zoom_var_name] = zoom - # Various properties # ------------------------------------------------------------------------- @@ -404,6 +399,11 @@ def attach(self, canvas): self._set_canvas_aspect() + def update_program(self, program): + zoom = self._zoom_aspect() + program[self.pan_var_name] = self._pan + program[self.zoom_var_name] = zoom + def update(self): if not self.canvas: return From 2f518674dd8d677460351fb9a50d82f07a9d5d4b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 27 Nov 2015 13:35:11 +0100 Subject: [PATCH 0639/1059] WIP: refactor interact --- phy/plot/base.py | 34 ++++++++++++++++++++ phy/plot/interact.py | 56 +++++++++++++++++---------------- phy/plot/panzoom.py | 16 +++------- phy/plot/tests/test_interact.py | 37 +++++++++++----------- phy/plot/transform.py | 2 +- 5 files changed, 87 insertions(+), 58 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 099adc41d..3e82f2cca 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -13,6 +13,7 @@ from vispy import gloo from vispy.app import Canvas +from vispy.util.event import Event from .transform import TransformChain, Clip from .utils import _load_shader @@ -189,6 +190,12 @@ def __add__(self, inserter): # Base canvas #------------------------------------------------------------------------------ +class VisualEvent(Event): + def __init__(self, type, visual=None): + super(VisualEvent, self).__init__(type) + self.visual = visual + + class BaseCanvas(Canvas): """A blank VisPy canvas with a custom event system that keeps the order.""" def __init__(self, *args, **kwargs): @@ -196,6 +203,7 @@ def __init__(self, *args, **kwargs): self.transforms = TransformChain() self.inserter = GLSLInserter() self.visuals = [] + self.events.add(visual_added=VisualEvent) def add_visual(self, visual): """Add a visual to the canvas, and build its program by the same @@ -222,6 +230,7 @@ def add_visual(self, visual): logger.log(5, "Fragment shader: %s", fs) # Register the visual in the list of visuals in the canvas. self.visuals.append(visual) + self.events.visual_added(visual=visual) def on_resize(self, event): """Resize the OpenGL context.""" @@ -231,3 +240,28 @@ def on_draw(self, e): gloo.clear() for visual in self.visuals: visual.on_draw() + + +#------------------------------------------------------------------------------ +# Base interact +#------------------------------------------------------------------------------ + +class BaseInteract(object): + def attach(self, canvas): + """Attach this interact to a canvas.""" + self.canvas = canvas + canvas.panzoom = self + + @canvas.connect + def on_visual_added(e): + self.update_program(e.visual.program) + + def update_program(self, program): + pass + + def update(self): + if not self.canvas: + return + for visual in self.canvas.visuals: + self.update_program(visual.program) + self.canvas.update() diff --git a/phy/plot/interact.py b/phy/plot/interact.py index c1069a29a..e82dc456f 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -12,6 +12,7 @@ import numpy as np from vispy.gloo import Texture2D +from .base import BaseInteract from .transform import Scale, Range, Subplot, Clip, NDC from .utils import _get_texture, _get_boxes, _get_box_pos_size @@ -20,7 +21,7 @@ # Grid interact #------------------------------------------------------------------------------ -class Grid(object): +class Grid(BaseInteract): """Grid interact. NOTE: to be used in a grid, a visual must define `a_box_index` @@ -48,20 +49,23 @@ def __init__(self, shape, box_var=None): assert self.shape[0] >= 1 assert self.shape[1] >= 1 - def get_shader_declarations(self): - return ('attribute vec2 a_box_index;\n' - 'uniform float u_grid_zoom;\n', '') - - def get_transforms(self): - # Define the grid transform and clipping. + def attach(self, canvas): + super(Grid, self).attach(canvas) m = 1. - .05 # Margin. - return [Scale(scale='u_grid_zoom'), - Scale(scale=(m, m)), - Clip(bounds=[-m, -m, m, m]), - Subplot(shape=self.shape, index='a_box_index'), - ] + canvas.transforms.add_on_gpu([Scale('u_grid_zoom'), + Scale((m, m)), + Clip([-m, -m, m, m]), + Subplot(self.shape, 'a_box_index'), + ]) + canvas.inserter.insert_vert(""" + attribute vec2 a_box_index; + //uniform float u_grid_shape; + uniform float u_grid_zoom; + """, 'header') + canvas.connect(self.on_key_press) def update_program(self, program): + # program['u_grid_shape'] = self.shape program['u_grid_zoom'] = self._zoom # Only set the default box index if necessary. try: @@ -100,7 +104,7 @@ def on_key_press(self, event): # Boxed interact #------------------------------------------------------------------------------ -class Boxed(object): +class Boxed(BaseInteract): """Boxed interact. NOTE: to be used in a boxed, a visual must define `a_box_index` @@ -143,25 +147,23 @@ def __init__(self, self.n_boxes = len(self._box_bounds) - def get_shader_declarations(self): - return ('#include "utils.glsl"\n\n' - 'attribute float {};\n'.format(self.box_var) + - 'uniform sampler2D u_box_bounds;\n' - 'uniform float n_boxes;', '') - - def get_pre_transforms(self): - return """ + def attach(self, canvas): + super(Boxed, self).attach(canvas) + canvas.transforms.add_on_gpu([Range(NDC, 'box_bounds')]) + canvas.inserter.insert_vert(""" + #include "utils.glsl" + attribute float {}; + uniform sampler2D u_box_bounds; + uniform float n_boxes;""".format(self.box_var), 'header') + canvas.inserter.insert_vert(""" // Fetch the box bounds for the current box (`box_var`). vec4 box_bounds = fetch_texture({}, u_box_bounds, n_boxes); box_bounds = (2 * box_bounds - 1); // See hack in Python. - """.format(self.box_var) - - def get_transforms(self): - return [Range(from_bounds=NDC, - to_bounds='box_bounds'), - ] + """.format(self.box_var), 'before_transforms') + canvas.connect(self.on_key_press) + canvas.connect(self.on_key_release) def update_program(self, program): # Signal bounds (positions). diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 54addb7c2..3b51dbdb3 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -11,6 +11,7 @@ import numpy as np +from .base import BaseInteract from .transform import Translate, Scale, pixels_to_ndc from phy.utils._types import _as_array @@ -19,7 +20,7 @@ # PanZoom class #------------------------------------------------------------------------------ -class PanZoom(object): +class PanZoom(BaseInteract): """Pan and zoom interact. To use it: @@ -382,8 +383,7 @@ def size(self): def attach(self, canvas): """Attach this interact to a canvas.""" - self.canvas = canvas - canvas.panzoom = self + super(PanZoom, self).attach(canvas) canvas.transforms.add_on_gpu([Translate(self.pan_var_name), Scale(self.zoom_var_name)]) @@ -400,13 +400,5 @@ def attach(self, canvas): self._set_canvas_aspect() def update_program(self, program): - zoom = self._zoom_aspect() program[self.pan_var_name] = self._pan - program[self.zoom_var_name] = zoom - - def update(self): - if not self.canvas: - return - for visual in self.canvas.visuals: - self.update_program(visual.program) - self.canvas.update() + program[self.zoom_var_name] = self._zoom_aspect() diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 56400f410..546410a03 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -25,22 +25,21 @@ #------------------------------------------------------------------------------ class MyTestVisual(BaseVisual): - vertex = """ - attribute vec2 a_position; - void main() { - gl_Position = transform(a_position); - gl_PointSize = 2.; - } + def __init__(self): + super(MyTestVisual, self).__init__() + self.vertex_shader = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + gl_PointSize = 2.; + } """ - fragment = """ - void main() { - gl_FragColor = vec4(1, 1, 1, 1); - } - """ - gl_primitive_type = 'points' - - def get_shaders(self): - return self.vertex, self.fragment + self.fragment_shader = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + self.set_primitive_type('points') def set_data(self): n = 1000 @@ -62,7 +61,7 @@ def _create_visual(qtbot, canvas, interact, box_index): PanZoom(aspect=None, constrain_bounds=NDC).attach(c) visual = MyTestVisual() - visual.attach(c) + c.add_visual(visual) visual.set_data() visual.program['a_box_index'] = box_index.astype(np.float32) @@ -117,8 +116,10 @@ def test_grid_2(qtbot, canvas): box_index = np.repeat(box_index, n, axis=0) class MyGrid(Grid): - def get_pre_transforms(self): - return 'vec2 u_shape = vec2(3, 3);' + def attach(self, canvas): + super(MyGrid, self).attach(canvas) + canvas.inserter.insert_vert('vec2 u_shape = vec2(3, 3);', + 'before_transforms') grid = MyGrid('u_shape') _create_visual(qtbot, canvas, grid, box_index) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index cbc1859e3..1b4648113 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -181,7 +181,7 @@ def __init__(self, shape, index=None): self.shape = shape self.index = index self.from_bounds = NDC - if isinstance(self.shape, tuple): + if isinstance(self.shape, tuple) and isinstance(self.index, tuple): self.to_bounds = subplot_bounds(shape=self.shape, index=self.index) def glsl(self, var): From 9a0f0363d4ce849dbb2d4432e85e0942526b0e0a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 27 Nov 2015 17:03:17 +0100 Subject: [PATCH 0640/1059] WIP --- phy/plot/interact.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index e82dc456f..7aa847501 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -59,13 +59,11 @@ def attach(self, canvas): ]) canvas.inserter.insert_vert(""" attribute vec2 a_box_index; - //uniform float u_grid_shape; uniform float u_grid_zoom; """, 'header') canvas.connect(self.on_key_press) def update_program(self, program): - # program['u_grid_shape'] = self.shape program['u_grid_zoom'] = self._zoom # Only set the default box index if necessary. try: From f14c2fe13412bb141b662458c7e1baf4be7a42c1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 27 Nov 2015 17:27:47 +0100 Subject: [PATCH 0641/1059] Test BaseInteract --- phy/plot/base.py | 2 ++ phy/plot/tests/test_base.py | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 3e82f2cca..b62172da7 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -247,6 +247,8 @@ def on_draw(self, e): #------------------------------------------------------------------------------ class BaseInteract(object): + canvas = None + def attach(self, canvas): """Attach this interact to a canvas.""" self.canvas = canvas diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 2d5831392..ffa1fa1d3 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -10,7 +10,7 @@ import numpy as np from pytest import yield_fixture -from ..base import BaseVisual, GLSLInserter +from ..base import BaseVisual, BaseInteract, GLSLInserter from ..transform import (subplot_bounds, Translate, Scale, Range, Clip, Subplot, TransformChain) @@ -135,3 +135,27 @@ def set_data(self): canvas.show() qtbot.waitForWindowShown(canvas.native) # qtbot.stop() + + +def test_interact_1(qtbot, canvas): + interact = BaseInteract() + interact.update() + + class TestVisual(BaseVisual): + def __init__(self): + super(TestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + + def set_data(self): + self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] + + interact.attach(canvas) + v = TestVisual() + canvas.add_visual(v) + v.set_data() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + interact.update() From 96c4773029aa391179fbb94673c967d2a5b146f0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 27 Nov 2015 17:33:23 +0100 Subject: [PATCH 0642/1059] Reorganize plot.utils --- phy/plot/utils.py | 198 ++++++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 95 deletions(-) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index c5ef290f6..089cf23d1 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -19,103 +19,9 @@ #------------------------------------------------------------------------------ -# Misc +# Box positioning #------------------------------------------------------------------------------ -def _load_shader(filename): - """Load a shader file.""" - curdir = op.dirname(op.realpath(__file__)) - glsl_path = op.join(curdir, 'glsl') - path = op.join(glsl_path, filename) - with open(path, 'r') as f: - return f.read() - - -def _tesselate_histogram(hist): - """ - - 2/4 3 - ____ - |\ | - | \ | - | \ | - |___\| - - 0 1/5 - - """ - assert hist.ndim == 1 - nsamples = len(hist) - - x0 = np.arange(nsamples) - - x = np.zeros(6 * nsamples) - y = np.zeros(6 * nsamples) - - x[0::2] = np.repeat(x0, 3) - x[1::2] = x[0::2] + 1 - - y[2::6] = y[3::6] = y[4::6] = hist - - return np.c_[x, y] - - -def _enable_depth_mask(): - gloo.set_state(clear_color='black', - depth_test=True, - depth_range=(0., 1.), - # depth_mask='true', - depth_func='lequal', - blend=True, - blend_func=('src_alpha', 'one_minus_src_alpha')) - gloo.set_clear_depth(1.0) - - -def _get_texture(arr, default, n_items, from_bounds): - """Prepare data to be uploaded as a texture. - - The from_bounds must be specified. - - """ - if not hasattr(default, '__len__'): # pragma: no cover - default = [default] - n_cols = len(default) - if arr is None: - arr = np.tile(default, (n_items, 1)) - assert arr.shape == (n_items, n_cols) - # Convert to 3D texture. - arr = arr[np.newaxis, ...].astype(np.float32) - assert arr.shape == (1, n_items, n_cols) - # NOTE: we need to cast the texture to [0., 1.] (float texture). - # This is easy as soon as we assume that the signal bounds are in - # [-1, 1]. - assert len(from_bounds) == 2 - m, M = map(float, from_bounds) - assert np.all(arr >= m) - assert np.all(arr <= M) - arr = (arr - m) / (M - m) - assert np.all(arr >= 0) - assert np.all(arr <= 1.) - arr = arr.astype(np.float32) - return arr - - -def _get_array(val, shape, default=None): - """Ensure an object is an array with the specified shape.""" - assert val is not None or default is not None - if hasattr(val, '__len__') and len(val) == 0: - val = None - out = np.zeros(shape, dtype=np.float32) - # This solves `ValueError: could not broadcast input array from shape (n) - # into shape (n, 1)`. - if val is not None and isinstance(val, np.ndarray): - if val.size == out.size: - val = val.reshape(out.shape) - out[...] = val if val is not None else default - assert out.shape == shape - return out - - def _boxes_overlap(x0, y0, x1, y1): n = len(x0) overlap_matrix = ((x0 < x1.T) & (x1 > x0.T) & (y0 < y1.T) & (y1 > y0.T)) @@ -203,6 +109,55 @@ def _get_box_pos_size(box_bounds): return np.c_[x, y], (w.mean(), h.mean()) +#------------------------------------------------------------------------------ +# Data validation +#------------------------------------------------------------------------------ + +def _get_texture(arr, default, n_items, from_bounds): + """Prepare data to be uploaded as a texture. + + The from_bounds must be specified. + + """ + if not hasattr(default, '__len__'): # pragma: no cover + default = [default] + n_cols = len(default) + if arr is None: + arr = np.tile(default, (n_items, 1)) + assert arr.shape == (n_items, n_cols) + # Convert to 3D texture. + arr = arr[np.newaxis, ...].astype(np.float32) + assert arr.shape == (1, n_items, n_cols) + # NOTE: we need to cast the texture to [0., 1.] (float texture). + # This is easy as soon as we assume that the signal bounds are in + # [-1, 1]. + assert len(from_bounds) == 2 + m, M = map(float, from_bounds) + assert np.all(arr >= m) + assert np.all(arr <= M) + arr = (arr - m) / (M - m) + assert np.all(arr >= 0) + assert np.all(arr <= 1.) + arr = arr.astype(np.float32) + return arr + + +def _get_array(val, shape, default=None): + """Ensure an object is an array with the specified shape.""" + assert val is not None or default is not None + if hasattr(val, '__len__') and len(val) == 0: + val = None + out = np.zeros(shape, dtype=np.float32) + # This solves `ValueError: could not broadcast input array from shape (n) + # into shape (n, 1)`. + if val is not None and isinstance(val, np.ndarray): + if val.size == out.size: + val = val.reshape(out.shape) + out[...] = val if val is not None else default + assert out.shape == shape + return out + + def _check_data_bounds(data_bounds): assert len(data_bounds) == 4 assert data_bounds[0] < data_bounds[2] @@ -269,3 +224,56 @@ def _get_color(color, default): def _get_linear_x(n_signals, n_samples): return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) + + +#------------------------------------------------------------------------------ +# Misc +#------------------------------------------------------------------------------ + +def _load_shader(filename): + """Load a shader file.""" + curdir = op.dirname(op.realpath(__file__)) + glsl_path = op.join(curdir, 'glsl') + path = op.join(glsl_path, filename) + with open(path, 'r') as f: + return f.read() + + +def _tesselate_histogram(hist): + """ + + 2/4 3 + ____ + |\ | + | \ | + | \ | + |___\| + + 0 1/5 + + """ + assert hist.ndim == 1 + nsamples = len(hist) + + x0 = np.arange(nsamples) + + x = np.zeros(6 * nsamples) + y = np.zeros(6 * nsamples) + + x[0::2] = np.repeat(x0, 3) + x[1::2] = x[0::2] + 1 + + y[2::6] = y[3::6] = y[4::6] = hist + + return np.c_[x, y] + + +def _enable_depth_mask(): + gloo.set_state(clear_color='black', + depth_test=True, + depth_range=(0., 1.), + # depth_mask='true', + depth_func='lequal', + blend=True, + blend_func=('src_alpha', 'one_minus_src_alpha')) + gloo.set_clear_depth(1.0) From af532b4510f20dec3d3a1f65f794f19c9623a721 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 27 Nov 2015 19:27:18 +0100 Subject: [PATCH 0643/1059] Add validation logic in visuals --- phy/plot/tests/test_visuals.py | 21 +++-- phy/plot/utils.py | 22 +++-- phy/plot/visuals.py | 154 +++++++++++++++++++++------------ 3 files changed, 120 insertions(+), 77 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 306fa30f6..060aebb17 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -32,19 +32,19 @@ def _test_visual(qtbot, c, v, stop=False, **kwargs): #------------------------------------------------------------------------------ def test_scatter_empty(qtbot, canvas): - pos = np.zeros((0, 2)) - _test_visual(qtbot, canvas, ScatterVisual(), pos=pos) + _test_visual(qtbot, canvas, ScatterVisual(), x=np.zeros(0), y=np.zeros(0)) def test_scatter_markers(qtbot, canvas_pz): c = canvas_pz n = 100 - pos = .2 * np.random.randn(n, 2) + x = .2 * np.random.randn(n) + y = .2 * np.random.randn(n) v = ScatterVisual(marker='vbar') c.add_visual(v) - v.set_data(pos=pos) + v.set_data(x=x, y=y) c.show() qtbot.waitForWindowShown(c.native) @@ -57,7 +57,8 @@ def test_scatter_custom(qtbot, canvas_pz): n = 100 # Random position. - pos = .2 * np.random.randn(n, 2) + x = .2 * np.random.randn(n) + y = .2 * np.random.randn(n) # Random colors. c = np.random.uniform(.4, .7, size=(n, 4)) @@ -67,7 +68,9 @@ def test_scatter_custom(qtbot, canvas_pz): s = 5 + 20 * np.random.rand(n) _test_visual(qtbot, canvas_pz, ScatterVisual(), - pos=pos, color=c, size=s) + x=x, y=y, color=c, size=s) + + # qtbot.stop() #------------------------------------------------------------------------------ @@ -105,10 +108,10 @@ def test_plot_2(qtbot, canvas_pz): # Depth. depth = np.linspace(0., -1., n_signals) - _test_visual(qtbot, canvas_pz, PlotVisual(n_samples=n_samples), + _test_visual(qtbot, canvas_pz, PlotVisual(), y=y, depth=depth, data_bounds=[-1, -50, 1, 50], - plot_colors=c) + color=c) #------------------------------------------------------------------------------ @@ -122,7 +125,7 @@ def test_histogram_empty(qtbot, canvas): def test_histogram_0(qtbot, canvas_pz): - hist = np.zeros((1, 10)) + hist = np.zeros((10,)) _test_visual(qtbot, canvas_pz, HistogramVisual(), hist=hist) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 089cf23d1..b5fc069bc 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -182,20 +182,18 @@ def _get_data_bounds(data_bounds, pos): return data_bounds -def _check_pos_2D(pos): - """Check position data before GPU uploading.""" - assert pos is not None - pos = np.asarray(pos, dtype=np.float32) - assert pos.ndim == 2 - return pos +def _get_pos(x, y): + assert x is not None + assert y is not None + + x = np.asarray(x, dtype=np.float32) + y = np.asarray(y, dtype=np.float32) + # Validate the position. + assert x.ndim == y.ndim == 1 + assert x.shape == y.shape -def _get_pos_depth(pos_tr, depth): - """Prepare a (N, 3) position-depth array for GPU uploading.""" - n = pos_tr.shape[0] - pos_tr = _get_array(pos_tr, (n, 2)) - depth = _get_array(depth, (n, 1), 0) - return np.c_[pos_tr, depth] + return x, y def _get_hist_max(hist): diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 627fa8664..7255bac51 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -17,13 +17,13 @@ _get_texture, _get_array, _get_data_bounds, - _get_pos_depth, - _check_pos_2D, + _get_pos, _get_index, _get_linear_x, _get_hist_max, _get_color, ) +from phy.utils import Bunch #------------------------------------------------------------------------------ @@ -65,8 +65,6 @@ class ScatterVisual(BaseVisual): def __init__(self, marker=None): super(ScatterVisual, self).__init__() - # Default bounds. - self.data_bounds = NDC self.n_points = None # Set the marker type. @@ -80,37 +78,52 @@ def __init__(self, marker=None): self.fragment_shader = self.fragment_shader.replace('%MARKER', self.marker) self.set_primitive_type('points') - self.transforms.add_on_cpu(Range(self.data_bounds)) + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) - def set_data(self, - pos=None, - depth=None, + @staticmethod + def validate(x=None, + y=None, color=None, - marker=None, size=None, + depth=None, data_bounds=None, ): - pos = _check_pos_2D(pos) + x, y = _get_pos(x, y) + pos = np.c_[x, y] n = pos.shape[0] - assert pos.shape == (n, 2) - # Set the data bounds from the data. - self.data_bounds = _get_data_bounds(data_bounds, pos) + # Validate the data. + color = _get_array(color, (n, 4), ScatterVisual._default_color) + size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) + depth = _get_array(depth, (n, 1), 0) + data_bounds = _get_data_bounds(data_bounds, pos) - pos_tr = self.transforms.apply(pos) - self.program['a_position'] = _get_pos_depth(pos_tr, depth) - self.program['a_size'] = _get_array(size, (n, 1), - self._default_marker_size) - self.program['a_color'] = _get_array(color, (n, 4), - self._default_color) + return Bunch(pos=pos, color=color, size=size, + depth=depth, data_bounds=data_bounds) + + def set_data(self, + x=None, + y=None, + color=None, + size=None, + depth=None, + data_bounds=None, + ): + data = self.validate(x=x, y=y, color=color, size=size, depth=depth, + data_bounds=data_bounds) + self.data_range.from_bounds = data.data_bounds + pos_tr = self.transforms.apply(data.pos) + self.program['a_position'] = np.c_[pos_tr, data.depth] + self.program['a_size'] = data.size + self.program['a_color'] = data.color class PlotVisual(BaseVisual): _default_color = DEFAULT_COLOR - def __init__(self, n_samples=None): + def __init__(self): super(PlotVisual, self).__init__() - self.n_samples = n_samples _enable_depth_mask() self.set_shader('plot') @@ -119,55 +132,64 @@ def __init__(self, n_samples=None): self.data_range = Range(NDC) self.transforms.add_on_cpu(self.data_range) - def set_data(self, - x=None, + @staticmethod + def validate(x=None, y=None, + color=None, depth=None, data_bounds=None, - plot_colors=None, ): - # Default x coordinates. if x is None: assert y is not None x = _get_linear_x(*y.shape) + # Default x coordinates. assert x is not None assert y is not None + x = np.asarray(x, np.float32) + y = np.asarray(y, np.float32) assert x.ndim == 2 assert x.shape == y.shape n_signals, n_samples = x.shape - if self.n_samples: - assert n_samples == self.n_samples + + # Validate the data. + color = _get_array(color, (n_signals, 4), PlotVisual._default_color) + depth = _get_array(depth, (n_signals,), 0) + + return Bunch(x=x, y=y, + color=color, depth=depth, + data_bounds=data_bounds) + + def set_data(self, + x=None, + y=None, + color=None, + depth=None, + data_bounds=None, + ): + data = self.validate(x=x, y=y, color=color, depth=depth, + data_bounds=data_bounds) + x, y = data.x, data.y + + n_signals, n_samples = x.shape n = n_signals * n_samples - # Generate the (n, 2) pos array. + # Generate the position array. pos = np.empty((n, 2), dtype=np.float32) pos[:, 0] = x.ravel() pos[:, 1] = y.ravel() - pos = _check_pos_2D(pos) - - # Update the data range using the specified data_bounds and the data. - # NOTE: this must be called *before* transforms.apply(). - self.data_range.from_bounds = _get_data_bounds(data_bounds, pos) - - # Set the transformed position. + self.data_range.from_bounds = _get_data_bounds(data.data_bounds, pos) pos_tr = self.transforms.apply(pos) - # Depth. - depth = _get_array(depth, (n_signals,), 0) - depth = np.repeat(depth, n_samples) - self.program['a_position'] = _get_pos_depth(pos_tr, depth) + depth = np.repeat(data.depth, n_samples) - # Generate the signal index. + self.program['a_position'] = np.c_[pos_tr, depth] self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) - - # Signal colors. - plot_colors = _get_texture(plot_colors, self._default_color, - n_signals, [0, 1]) - self.program['u_plot_colors'] = Texture2D(plot_colors) - - # Number of signals. + self.program['u_plot_colors'] = Texture2D(_get_texture(data.color, + PlotVisual._default_color, + n_signals, + [0, 1])) self.program['n_signals'] = n_signals @@ -176,37 +198,57 @@ class HistogramVisual(BaseVisual): def __init__(self): super(HistogramVisual, self).__init__() - self.n_bins = 0 self.set_shader('histogram') self.set_primitive_type('triangles') - self.data_range = Range([0, 0, self.n_bins, 1], + self.data_range = Range([0, 0, 1, 1], [0, 0, 1, 1]) self.transforms.add_on_cpu(self.data_range) # (0, 0, 1, v) self.transforms.add_on_gpu(Range('hist_bounds', NDC)) + @staticmethod + def validate(hist=None, + color=None, + ylim=None, + ): + assert hist is not None + hist = np.asarray(hist, np.float32) + if hist.ndim == 1: + hist = hist[None, :] + assert hist.ndim == 2 + n_hists, n_bins = hist.shape + + # Validate the data. + color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) + + return Bunch(hist=hist, + ylim=ylim, + color=color, + ) + def set_data(self, hist=None, - ylim=None, color=None, + ylim=None, ): - hist = _check_pos_2D(hist) + + data = self.validate(hist=hist, color=color, ylim=ylim) + hist = data.hist + n_hists, n_bins = hist.shape n = 6 * n_hists * n_bins - # Store n_bins for get_transforms(). - self.n_bins = n_bins # NOTE: this must be set *before* `apply_cpu_transforms` such # that the histogram is correctly normalized. hist_max = _get_hist_max(hist) - self.data_range.from_bounds = [0, 0, self.n_bins, hist_max] + self.data_range.from_bounds = [0, 0, n_bins, hist_max] # Set the transformed position. pos = np.vstack(_tesselate_histogram(row) for row in hist) + pos = pos.astype(np.float32) pos_tr = self.transforms.apply(pos) - pos_tr = np.asarray(pos_tr, dtype=np.float32) assert pos_tr.shape == (n, 2) self.program['a_position'] = pos_tr @@ -214,7 +256,7 @@ def set_data(self, self.program['a_hist_index'] = _get_index(n_hists, n_bins * 6, n) # Hist colors. - self.program['u_color'] = _get_texture(color, + self.program['u_color'] = _get_texture(data.color, self._default_color, n_hists, [0, 1]) From 92519258bb75ebfce8abcaab016aa11492179bbd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 11:50:00 +0100 Subject: [PATCH 0644/1059] Support vector bounds in Range transform --- phy/plot/tests/test_transform.py | 13 +++++++++++++ phy/plot/transform.py | 15 +++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 771db03cd..9df0eec94 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -83,6 +83,19 @@ def test_range_cpu(): [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]]) +def test_range_cpu_vectorized(): + arr = np.arange(6).reshape((3, 2)) * 1. + arr_tr = arr / 5. + arr_tr[2, :] /= 10 + + f = np.tile([0, 0, 5, 5], (3, 1)) + f[2, :] *= 10 + + t = np.tile([0, 0, 1, 1], (3, 1)) + + _check(Range(f, t), arr, arr_tr) + + def test_clip_cpu(): _check(Clip(), [0, 0], [0, 0]) # Default bounds. diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 1b4648113..e236da310 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -125,14 +125,17 @@ def glsl(self, var): class Range(BaseTransform): def __init__(self, from_bounds=None, to_bounds=None): super(Range, self).__init__() - self.from_bounds = from_bounds or NDC - self.to_bounds = to_bounds or NDC + self.from_bounds = from_bounds if from_bounds is not None else NDC + self.to_bounds = to_bounds if to_bounds is not None else NDC def apply(self, arr): - f0 = np.asarray(self.from_bounds[:2]) - f1 = np.asarray(self.from_bounds[2:]) - t0 = np.asarray(self.to_bounds[:2]) - t1 = np.asarray(self.to_bounds[2:]) + self.from_bounds = np.asarray(self.from_bounds) + self.to_bounds = np.asarray(self.to_bounds) + + f0 = np.asarray(self.from_bounds[..., :2]) + f1 = np.asarray(self.from_bounds[..., 2:]) + t0 = np.asarray(self.to_bounds[..., :2]) + t1 = np.asarray(self.to_bounds[..., 2:]) return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) From 36ea41f989ff335f2d50531da1fbc9afa244e07f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 12:09:36 +0100 Subject: [PATCH 0645/1059] Update and test _get_data_bounds() --- phy/plot/tests/test_utils.py | 22 ++++++++++++++++++++++ phy/plot/utils.py | 25 +++++++++++++++---------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index fb0498df5..0933deee8 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -19,6 +19,7 @@ from ..utils import (_load_shader, _tesselate_histogram, _enable_depth_mask, + _get_data_bounds, _boxes_overlap, _binary_search, _get_boxes, @@ -58,6 +59,27 @@ def on_draw(e): qtbot.waitForWindowShown(canvas.native) +def test_get_data_bounds(): + db0 = np.array([[0, 1, 4, 5], + [0, 1, 4, 5], + [0, 1, 4, 5]]) + arr = np.arange(6).reshape((3, 2)) + assert np.all(_get_data_bounds(None, arr) == [[0, 1, 4, 5]]) + + db = db0.copy() + assert np.all(_get_data_bounds(db, arr) == [[0, 1, 4, 5]]) + + db = db0.copy() + db[2, :] = [1, 1, 1, 1] + assert np.all(_get_data_bounds(db, arr)[:2, :] == [[0, 1, 4, 5]]) + assert np.all(_get_data_bounds(db, arr)[2, :] == [0, 0, 2, 2]) + + db = db0.copy() + db[:2, :] = [1, 1, 1, 1] + assert np.all(_get_data_bounds(db, arr)[:2, :] == [[0, 0, 2, 2]]) + assert np.all(_get_data_bounds(db, arr)[2, :] == [0, 1, 4, 5]) + + def test_boxes_overlap(): def _get_args(boxes): diff --git a/phy/plot/utils.py b/phy/plot/utils.py index b5fc069bc..94dc47e97 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -159,9 +159,10 @@ def _get_array(val, shape, default=None): def _check_data_bounds(data_bounds): - assert len(data_bounds) == 4 - assert data_bounds[0] < data_bounds[2] - assert data_bounds[1] < data_bounds[3] + assert data_bounds.ndim == 2 + assert data_bounds.shape[1] == 4 + assert np.all(data_bounds[:, 0] < data_bounds[:, 2]) + assert np.all(data_bounds[:, 1] < data_bounds[:, 3]) def _get_data_bounds(data_bounds, pos): @@ -171,13 +172,17 @@ def _get_data_bounds(data_bounds, pos): if data_bounds is None: m, M = pos.min(axis=0), pos.max(axis=0) data_bounds = [m[0], m[1], M[0], M[1]] - data_bounds = list(data_bounds) - if data_bounds[0] == data_bounds[2]: # pragma: no cover - data_bounds[0] -= 1 - data_bounds[2] += 1 - if data_bounds[1] == data_bounds[3]: - data_bounds[1] -= 1 - data_bounds[3] += 1 + data_bounds = np.atleast_2d(data_bounds) + + ind_x = data_bounds[:, 0] == data_bounds[:, 2] + ind_y = data_bounds[:, 1] == data_bounds[:, 3] + if np.sum(ind_x): + data_bounds[ind_x, 0] -= 1 + data_bounds[ind_x, 2] += 1 + if np.sum(ind_y): + data_bounds[ind_y, 1] -= 1 + data_bounds[ind_y, 3] += 1 + _check_data_bounds(data_bounds) return data_bounds From 5bc487f66daaf2b1d50d61aaec04ffb420300373 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 12:31:57 +0100 Subject: [PATCH 0646/1059] WIP: variable data_bounds in visuals --- phy/plot/tests/test_visuals.py | 2 -- phy/plot/utils.py | 20 +++++++++++++++----- phy/plot/visuals.py | 20 +++++++++++++++++++- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 060aebb17..547e19a43 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -70,8 +70,6 @@ def test_scatter_custom(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, ScatterVisual(), x=x, y=y, color=c, size=s) - # qtbot.stop() - #------------------------------------------------------------------------------ # Test plot visual diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 94dc47e97..75c1cab7c 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -165,13 +165,14 @@ def _check_data_bounds(data_bounds): assert np.all(data_bounds[:, 1] < data_bounds[:, 3]) -def _get_data_bounds(data_bounds, pos): +def _get_data_bounds(data_bounds, pos=None, length=None): """"Prepare data bounds, possibly using min/max of the data.""" - if not len(pos): - return data_bounds or NDC if data_bounds is None: - m, M = pos.min(axis=0), pos.max(axis=0) - data_bounds = [m[0], m[1], M[0], M[1]] + if pos is not None and len(pos): + m, M = pos.min(axis=0), pos.max(axis=0) + data_bounds = [m[0], m[1], M[0], M[1]] + else: + data_bounds = NDC data_bounds = np.atleast_2d(data_bounds) ind_x = data_bounds[:, 0] == data_bounds[:, 2] @@ -183,6 +184,15 @@ def _get_data_bounds(data_bounds, pos): data_bounds[ind_y, 1] -= 1 data_bounds[ind_y, 3] += 1 + # Extend the data_bounds if needed. + if length is None: + length = pos.shape[0] if pos is not None else 1 + if data_bounds.shape[0] == 1: + data_bounds = np.tile(data_bounds, (length, 1)) + + # Check the shape of data_bounds. + assert data_bounds.shape == (length, 4) + _check_data_bounds(data_bounds) return data_bounds diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 7255bac51..b2b34fb48 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -98,6 +98,7 @@ def validate(x=None, size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) depth = _get_array(depth, (n, 1), 0) data_bounds = _get_data_bounds(data_bounds, pos) + assert data_bounds.shape[0] == n return Bunch(pos=pos, color=color, size=size, depth=depth, data_bounds=data_bounds) @@ -157,6 +158,17 @@ def validate(x=None, color = _get_array(color, (n_signals, 4), PlotVisual._default_color) depth = _get_array(depth, (n_signals,), 0) + # Validate data bounds. + if data_bounds is None: + if n_samples > 0: + # NOTE: by default, per-signal normalization. + data_bounds = np.c_[x.min(axis=1), y.min(axis=1), + x.max(axis=1), y.max(axis=1)] + else: + data_bounds = NDC + data_bounds = _get_data_bounds(data_bounds, length=n_signals) + assert data_bounds.shape == (n_signals, 4) + return Bunch(x=x, y=y, color=color, depth=depth, data_bounds=data_bounds) @@ -179,9 +191,15 @@ def set_data(self, pos = np.empty((n, 2), dtype=np.float32) pos[:, 0] = x.ravel() pos[:, 1] = y.ravel() - self.data_range.from_bounds = _get_data_bounds(data.data_bounds, pos) + + # Repeat the data bounds for every vertex in the signals. + data_bounds = np.repeat(data.data_bounds, n_samples, axis=0) + self.data_range.from_bounds = data_bounds + + # Transform the positions. pos_tr = self.transforms.apply(pos) + # Repeat the depth. depth = np.repeat(data.depth, n_samples) self.program['a_position'] = np.c_[pos_tr, depth] From f3317d5398fa7c72c7d212c9ff4a48b5aecadabd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 13:36:30 +0100 Subject: [PATCH 0647/1059] Update histogram normalization --- phy/plot/glsl/histogram.vert | 5 ----- phy/plot/visuals.py | 30 ++++++++++++++---------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/phy/plot/glsl/histogram.vert b/phy/plot/glsl/histogram.vert index 790471553..1889ca2f4 100644 --- a/phy/plot/glsl/histogram.vert +++ b/phy/plot/glsl/histogram.vert @@ -4,17 +4,12 @@ attribute vec2 a_position; attribute float a_hist_index; // 0..n_hists-1 uniform sampler2D u_color; -uniform sampler2D u_hist_bounds; uniform float n_hists; varying vec4 v_color; varying float v_hist_index; void main() { - vec4 hist_bounds = fetch_texture(a_hist_index, - u_hist_bounds, - n_hists); - hist_bounds = hist_bounds * 10.; // NOTE: avoid texture clipping gl_Position = transform(a_position); v_color = fetch_texture(a_hist_index, u_color, n_hists); diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index b2b34fb48..0f9e0b164 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -20,7 +20,6 @@ _get_pos, _get_index, _get_linear_x, - _get_hist_max, _get_color, ) from phy.utils import Bunch @@ -220,11 +219,8 @@ def __init__(self): self.set_shader('histogram') self.set_primitive_type('triangles') - self.data_range = Range([0, 0, 1, 1], - [0, 0, 1, 1]) + self.data_range = Range([0, 0, 1, 1]) self.transforms.add_on_cpu(self.data_range) - # (0, 0, 1, v) - self.transforms.add_on_gpu(Range('hist_bounds', NDC)) @staticmethod def validate(hist=None, @@ -241,6 +237,14 @@ def validate(hist=None, # Validate the data. color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) + # Validate ylim. + if ylim is None: + ylim = hist.max() if hist.size > 0 else 1. + ylim = np.atleast_1d(ylim) + if len(ylim) == 1: + ylim = np.tile(ylim, len(ylim)) + assert ylim.shape == (n_hists,) + return Bunch(hist=hist, ylim=ylim, color=color, @@ -260,8 +264,11 @@ def set_data(self, # NOTE: this must be set *before* `apply_cpu_transforms` such # that the histogram is correctly normalized. - hist_max = _get_hist_max(hist) - self.data_range.from_bounds = [0, 0, n_bins, hist_max] + data_bounds = np.c_[np.zeros((n_hists, 2)), + n_bins * np.ones((n_hists, 1)), + data.ylim] + data_bounds = np.repeat(data_bounds, 6 * n_bins, axis=0) + self.data_range.from_bounds = data_bounds # Set the transformed position. pos = np.vstack(_tesselate_histogram(row) for row in hist) @@ -277,15 +284,6 @@ def set_data(self, self.program['u_color'] = _get_texture(data.color, self._default_color, n_hists, [0, 1]) - - # Hist bounds. - assert ylim is None or len(ylim) == n_hists - hist_bounds = np.c_[np.zeros((n_hists, 2)), - np.ones((n_hists, 1)), - ylim / hist_max] if ylim is not None else None - hist_bounds = _get_texture(hist_bounds, [0, 0, 1, 1], - n_hists, [0, 10]) - self.program['u_hist_bounds'] = Texture2D(hist_bounds) self.program['n_hists'] = n_hists From 04841f0a0fc63c1af85981bcfb084d65558f46ec Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 13:58:27 +0100 Subject: [PATCH 0648/1059] Minor fix in interact --- phy/plot/interact.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 7aa847501..a1855d57c 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -55,21 +55,21 @@ def attach(self, canvas): canvas.transforms.add_on_gpu([Scale('u_grid_zoom'), Scale((m, m)), Clip([-m, -m, m, m]), - Subplot(self.shape, 'a_box_index'), + Subplot(self.shape, self.box_var), ]) canvas.inserter.insert_vert(""" - attribute vec2 a_box_index; + attribute vec2 {}; uniform float u_grid_zoom; - """, 'header') + """.format(self.box_var), 'header') canvas.connect(self.on_key_press) def update_program(self, program): program['u_grid_zoom'] = self._zoom # Only set the default box index if necessary. try: - program['a_box_index'] + program[self.box_var] except KeyError: - program['a_box_index'] = (0, 0) + program[self.box_var] = (0, 0) @property def zoom(self): From a7dd7bce3188ffa9aa3beb18e1e16a14172d1f67 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 15:17:58 +0100 Subject: [PATCH 0649/1059] Minor update in visuals --- phy/plot/visuals.py | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 0f9e0b164..62d9d384e 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -87,7 +87,7 @@ def validate(x=None, size=None, depth=None, data_bounds=None, - ): + **kwargs): x, y = _get_pos(x, y) pos = np.c_[x, y] n = pos.shape[0] @@ -102,16 +102,8 @@ def validate(x=None, return Bunch(pos=pos, color=color, size=size, depth=depth, data_bounds=data_bounds) - def set_data(self, - x=None, - y=None, - color=None, - size=None, - depth=None, - data_bounds=None, - ): - data = self.validate(x=x, y=y, color=color, size=size, depth=depth, - data_bounds=data_bounds) + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) self.data_range.from_bounds = data.data_bounds pos_tr = self.transforms.apply(data.pos) self.program['a_position'] = np.c_[pos_tr, data.depth] @@ -138,7 +130,7 @@ def validate(x=None, color=None, depth=None, data_bounds=None, - ): + **kwargs): if x is None: assert y is not None @@ -172,15 +164,8 @@ def validate(x=None, color=color, depth=depth, data_bounds=data_bounds) - def set_data(self, - x=None, - y=None, - color=None, - depth=None, - data_bounds=None, - ): - data = self.validate(x=x, y=y, color=color, depth=depth, - data_bounds=data_bounds) + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) x, y = data.x, data.y n_signals, n_samples = x.shape @@ -226,7 +211,7 @@ def __init__(self): def validate(hist=None, color=None, ylim=None, - ): + **kwargs): assert hist is not None hist = np.asarray(hist, np.float32) if hist.ndim == 1: @@ -250,13 +235,8 @@ def validate(hist=None, color=color, ) - def set_data(self, - hist=None, - color=None, - ylim=None, - ): - - data = self.validate(hist=hist, color=color, ylim=ylim) + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) hist = data.hist n_hists, n_bins = hist.shape From adbadfde14a18088656e92cf2f5fcaef685694ec Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 15:22:42 +0100 Subject: [PATCH 0650/1059] WIP: refactor plot interface --- phy/plot/plot.py | 379 ++++++++---------------------------- phy/plot/tests/test_plot.py | 13 +- phy/plot/visuals.py | 16 +- 3 files changed, 104 insertions(+), 304 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 8832a613b..f2d5206f5 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -7,8 +7,9 @@ # Imports #------------------------------------------------------------------------------ -from itertools import groupby from collections import defaultdict +from contextlib import contextmanager +from itertools import groupby import numpy as np @@ -30,321 +31,107 @@ class Accumulator(object): def __init__(self): self._data = defaultdict(list) - def __setitem__(self, name, val): + def add(self, name, val): self._data[name].append(val) def __getitem__(self, name): return np.vstack(self._data[name]).astype(np.float32) -#------------------------------------------------------------------------------ -# Base plotting interface -#------------------------------------------------------------------------------ +def _accumulate(data_list): + acc = Accumulator() + names = set() + for data in data_list: + for name, val in data.items(): + names.add(name) + acc.add(name, val) + return {name: acc[name] for name in names} -def _prepare_scatter(x, y, color=None, size=None, marker=None): - x = np.asarray(x) - y = np.asarray(y) - # Validate x and y. - assert x.ndim == y.ndim == 1 - assert x.shape == y.shape - n = x.shape[0] - # Set the color and size. - color = _get_array(color, (n, 4), ScatterVisual._default_color) - size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) - # Default marker. - marker = marker or ScatterVisual._default_marker - return dict(x=x, y=y, color=color, size=size, marker=marker) - - -def _prepare_plot(x, y, color=None, depth=None, data_bounds=None): - x = np.atleast_2d(x) - y = np.atleast_2d(y) - # Validate x and y. - assert x.ndim == y.ndim == 2 - assert x.shape == y.shape - n_plots, n_samples = x.shape - # Get the colors. - color = _get_array(color, (n_plots, 4), PlotVisual._default_color) - # Get the depth. - depth = _get_array(depth, (n_plots, 1), 0) - return dict(x=x, y=y, color=color, depth=depth, data_bounds=data_bounds) - - -def _prepare_hist(hist, ylim=None, color=None): - hist = np.asarray(hist) - # Validate hist. - if hist.ndim == 1: - hist = hist[np.newaxis, :] - assert hist.ndim == 2 - n_hists, n_samples = hist.shape - # y-limit - if ylim is None: - ylim = hist.max() if hist.size else 1. - if not hasattr(ylim, '__len__'): - ylim = [ylim] - ylim = np.atleast_2d(ylim) - if len(ylim) == 1: - ylim = np.tile(ylim, (n_hists, 1)) - assert len(ylim) == n_hists - # Get the colors. - color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) - return dict(hist=hist, ylim=ylim, color=color) - - -def _prepare_box_index(box_index, n): - if not _is_array_like(box_index): - box_index = np.tile(box_index, (n, 1)) - box_index = np.asarray(box_index, dtype=np.int32) - assert box_index.ndim == 2 - assert box_index.shape[0] == n - return box_index - - -def _build_scatter(items): - """Build scatter items and return parameters for `set_data()`.""" - - ac = Accumulator() - for item in items: - # The item data has already been prepared. - n = len(item.data.x) - ac['pos'] = np.c_[item.data.x, item.data.y] - ac['color'] = item.data.color - ac['size'] = item.data.size - ac['box_index'] = _prepare_box_index(item.box_index, n) - - return (dict(pos=ac['pos'], color=ac['color'], size=ac['size']), - ac['box_index']) - - -def _build_plot(items): - """Build all plot items and return parameters for `set_data()`.""" - - ac = Accumulator() - for item in items: - n = item.data.x.size - ac['x'] = item.data.x - ac['y'] = item.data.y - ac['depth'] = item.data.depth - ac['plot_colors'] = item.data.color - ac['box_index'] = _prepare_box_index(item.box_index, n) - - return (dict(x=ac['x'], y=ac['y'], - plot_colors=ac['plot_colors'], - depth=ac['depth'], - data_bounds=item.data.data_bounds, - ), - ac['box_index']) - - -def _build_histogram(items): - """Build all histogram items and return parameters for `set_data()`.""" - - ac = Accumulator() - for item in items: - n = item.data.hist.size - ac['hist'] = item.data.hist - ac['color'] = item.data.color - ac['ylim'] = item.data.ylim - # NOTE: the `6 * ` comes from the histogram tesselation. - ac['box_index'] = _prepare_box_index(item.box_index, 6 * n) - - return (dict(hist=ac['hist'], ylim=ac['ylim'], color=ac['color']), - ac['box_index']) - - -class ViewItem(Bunch): - """A visual item that will be rendered in batch with other view items - of the same type.""" - def __init__(self, base, visual_class=None, data=None, box_index=None): - super(ViewItem, self).__init__(visual_class=visual_class, - data=Bunch(data), - box_index=box_index, - to_build=True, - ) - self._base = base - - def set_data(self, **kwargs): - self.data.update(kwargs) - self.to_build = True +def _make_scatter_class(marker): + return type('ScatterVisual' + marker.title(), + (ScatterVisual,), {'_default_marker': marker}) + + +#------------------------------------------------------------------------------ +# Plotting interface +#------------------------------------------------------------------------------ class BaseView(BaseCanvas): """High-level plotting canvas.""" - def __init__(self, interacts, **kwargs): + def __init__(self, **kwargs): super(BaseView, self).__init__(**kwargs) - # Attach the passed interacts to the current canvas. - for interact in interacts: - interact.attach(self) - self._items = [] # List of view items instances. - self._visuals = {} - - @property - def panzoom(self): - """PanZoom instance from the interact list, if it exists.""" - for interact in self.interacts: - if isinstance(interact, PanZoom): - return interact - - # To override - # ------------------------------------------------------------------------- - - def __getitem__(self, idx): - class _Proxy(object): - def scatter(s, *args, **kwargs): - kwargs['box_index'] = idx - return self.scatter(*args, **kwargs) - - def plot(s, *args, **kwargs): - kwargs['box_index'] = idx - return self.plot(*args, **kwargs) - - def hist(s, *args, **kwargs): - kwargs['box_index'] = idx - return self.hist(*args, **kwargs) - - return _Proxy() - - def _iter_items(self): - """Iterate over all view items.""" - for item in self._items: - yield item - - def _visuals_to_build(self): - """Return the set of visual classes that need to be rebuilt.""" - visual_classes = set() - for item in self._items: - if item.to_build: - visual_classes.add(item.visual_class) - return visual_classes - - def _get_visual(self, key): - """Create or return a visual from its class or tuple (class, param).""" - if key not in self._visuals: - # Create the visual. - if isinstance(key, tuple): - # Case of the scatter plot, where the visual depends on the - # marker. - v = key[0](key[1]) - else: - v = key() - # Attach the visual to the view. - v.attach(self) - # Store the visual for reuse. - self._visuals[key] = v - return self._visuals[key] - - # Public methods - # ------------------------------------------------------------------------- + self._default_box_index = None + self.clear() + + def clear(self): + self._items = defaultdict(list) + + def _add_item(self, cls, *args, **kwargs): + data = cls.validate(*args, **kwargs) + data['box_index'] = kwargs.get('box_index', self._default_box_index) + self._items[cls].append(data) def plot(self, *args, **kwargs): - """Add a line plot.""" - box_index = kwargs.pop('box_index', None) - data = _prepare_plot(*args, **kwargs) - item = ViewItem(self, visual_class=PlotVisual, - data=data, box_index=box_index) - self._items.append(item) - return item + self._add_item(PlotVisual, *args, **kwargs) def scatter(self, *args, **kwargs): - """Add a scatter plot.""" - box_index = kwargs.pop('box_index', None) - data = _prepare_scatter(*args, **kwargs) - item = ViewItem(self, visual_class=ScatterVisual, - data=data, box_index=box_index) - self._items.append(item) - return item + cls = _make_scatter_class(kwargs.get('marker', + ScatterVisual._default_marker)) + self._add_item(cls, *args, **kwargs) def hist(self, *args, **kwargs): - """Add a histogram plot.""" - box_index = kwargs.pop('box_index', None) - data = _prepare_hist(*args, **kwargs) - item = ViewItem(self, visual_class=HistogramVisual, - data=data, box_index=box_index) - self._items.append(item) - return item + self._add_item(HistogramVisual, *args, **kwargs) - def build(self): - """Build all visuals.""" - visuals_to_build = self._visuals_to_build() - - for visual_class, items in groupby(self._iter_items(), - lambda item: item.visual_class): - items = list(items) - - # Skip visuals that do not need to be built. - if visual_class not in visuals_to_build: - continue - - # Histogram. - # TODO: refactor this (DRY). - if visual_class == HistogramVisual: - data, box_index = _build_histogram(items) - v = self._get_visual(HistogramVisual) - v.set_data(**data) - v.program['a_box_index'] = box_index - for item in items: - item.to_build = False - - # Scatter. - if visual_class == ScatterVisual: - items_grouped = groupby(items, lambda item: item.data.marker) - # One visual per marker type. - for marker, items_scatter in items_grouped: - items_scatter = list(items_scatter) - data, box_index = _build_scatter(items_scatter) - v = self._get_visual((ScatterVisual, marker)) - v.set_data(**data) - v.program['a_box_index'] = box_index - for item in items_scatter: - item.to_build = False - - # Plot. - if visual_class == PlotVisual: - items_grouped = groupby(items, - lambda item: item.data.x.shape[1]) - # HACK: one visual per number of samples, because currently - # a PlotVisual only accepts a regular (n_plots, n_samples) - # array as input. - for n_samples, items_plot in items_grouped: - items_plot = list(items_plot) - data, box_index = _build_plot(items_plot) - v = self._get_visual((PlotVisual, n_samples)) - v.set_data(**data) - v.program['a_box_index'] = box_index - for item in items_plot: - item.to_build = False - - self.update() + def __getitem__(self, box_index): + @contextmanager + def box_index_ctx(): + self._default_box_index = box_index + yield + self._default_box_index = None -#------------------------------------------------------------------------------ -# Plotting interface -#------------------------------------------------------------------------------ + with box_index_ctx(): + return self -class GridView(BaseView): - """A 2D grid with clipping.""" - def __init__(self, shape, **kwargs): - self.n_rows, self.n_cols = shape - pz = PanZoom(aspect=None, constrain_bounds=NDC) - interacts = [Grid(shape), pz] - super(GridView, self).__init__(interacts, **kwargs) - - -class BoxedView(BaseView): - """Subplots at arbitrary positions""" - def __init__(self, box_bounds, **kwargs): - self.n_plots = len(box_bounds) - self._boxed = Boxed(box_bounds) - self._pz = PanZoom(aspect=None, constrain_bounds=NDC) - interacts = [self._boxed, self._pz] - super(BoxedView, self).__init__(interacts, **kwargs) - - -class StackedView(BaseView): - """Stacked subplots""" - def __init__(self, n_plots, **kwargs): - self.n_plots = n_plots - pz = PanZoom(aspect=None, constrain_bounds=NDC) - interacts = [Stacked(n_plots, margin=.1), pz] - super(StackedView, self).__init__(interacts, **kwargs) + def build(self): + for cls, data_list in self._items.items(): + data = _accumulate(data_list) + box_index = data.pop('box_index') + visual = cls() + self.add_visual(visual) + visual.set_data(**data) + try: + visual.program['a_box_index'] + visual.program['a_box_index'] = box_index + except KeyError: + pass + + +# class GridView(BaseView): +# """A 2D grid with clipping.""" +# def __init__(self, shape, **kwargs): +# self.n_rows, self.n_cols = shape +# pz = PanZoom(aspect=None, constrain_bounds=NDC) +# interacts = [Grid(shape), pz] +# super(GridView, self).__init__(interacts, **kwargs) + + +# class BoxedView(BaseView): +# """Subplots at arbitrary positions""" +# def __init__(self, box_bounds, **kwargs): +# self.n_plots = len(box_bounds) +# self._boxed = Boxed(box_bounds) +# self._pz = PanZoom(aspect=None, constrain_bounds=NDC) +# interacts = [self._boxed, self._pz] +# super(BoxedView, self).__init__(interacts, **kwargs) + + +# class StackedView(BaseView): +# """Stacked subplots""" +# def __init__(self, n_plots, **kwargs): +# self.n_plots = n_plots +# pz = PanZoom(aspect=None, constrain_bounds=NDC) +# interacts = [Stacked(n_plots, margin=.1), pz] +# super(StackedView, self).__init__(interacts, **kwargs) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 1ca1334c3..0d390d57f 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -10,7 +10,7 @@ import numpy as np from ..panzoom import PanZoom -from ..plot import GridView, BoxedView, StackedView +from ..plot import BaseView #, GridView, BoxedView, StackedView from ..utils import _get_linear_x @@ -31,6 +31,17 @@ def _show(qtbot, view, stop=False): # Test plotting interface #------------------------------------------------------------------------------ +def test_base_view(qtbot): + view = BaseView(keys='interactive') + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + + view.scatter(x, y) + _show(qtbot, view, stop=True) + + def test_grid_scatter(qtbot): view = GridView((2, 3)) n = 1000 diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 62d9d384e..d4269283c 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -83,13 +83,17 @@ def __init__(self, marker=None): @staticmethod def validate(x=None, y=None, + pos=None, color=None, size=None, depth=None, data_bounds=None, - **kwargs): - x, y = _get_pos(x, y) - pos = np.c_[x, y] + ): + if pos is None: + x, y = _get_pos(x, y) + pos = np.c_[x, y] + assert pos.ndim == 2 + assert pos.shape[1] == 2 n = pos.shape[0] # Validate the data. @@ -129,8 +133,7 @@ def validate(x=None, y=None, color=None, depth=None, - data_bounds=None, - **kwargs): + data_bounds=None): if x is None: assert y is not None @@ -210,8 +213,7 @@ def __init__(self): @staticmethod def validate(hist=None, color=None, - ylim=None, - **kwargs): + ylim=None): assert hist is not None hist = np.asarray(hist, np.float32) if hist.ndim == 1: From bca7dffd7887cd1eaebfb0e0555ed9e5fef729ca Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 17:47:00 +0100 Subject: [PATCH 0651/1059] WIP: grid with variable shape --- phy/plot/interact.py | 22 ++++++++++++++-------- phy/plot/tests/test_interact.py | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index a1855d57c..682534edd 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -37,17 +37,13 @@ class Grid(BaseInteract): """ - def __init__(self, shape, box_var=None): + def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): self._zoom = 1. # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' - - self.shape = shape - if isinstance(self.shape, tuple): - assert len(self.shape) == 2 - assert self.shape[0] >= 1 - assert self.shape[1] >= 1 + self.shape_var = shape_var + self._shape = shape def attach(self, canvas): super(Grid, self).attach(canvas) @@ -55,7 +51,7 @@ def attach(self, canvas): canvas.transforms.add_on_gpu([Scale('u_grid_zoom'), Scale((m, m)), Clip([-m, -m, m, m]), - Subplot(self.shape, self.box_var), + Subplot(self.shape_var, self.box_var), ]) canvas.inserter.insert_vert(""" attribute vec2 {}; @@ -65,6 +61,7 @@ def attach(self, canvas): def update_program(self, program): program['u_grid_zoom'] = self._zoom + program[self.shape_var] = self._shape # Only set the default box index if necessary. try: program[self.box_var] @@ -82,6 +79,15 @@ def zoom(self, value): self._zoom = value self.update() + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, value): + self._shape = value + self.update() + def on_key_press(self, event): """Pan and zoom with the keyboard.""" key = event.key diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 546410a03..6cb102277 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -121,7 +121,7 @@ def attach(self, canvas): canvas.inserter.insert_vert('vec2 u_shape = vec2(3, 3);', 'before_transforms') - grid = MyGrid('u_shape') + grid = MyGrid(shape_var='u_shape') _create_visual(qtbot, canvas, grid, box_index) # qtbot.stop() From 78c572d646765f90dba61e37fc911d83f8172ba4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 28 Nov 2015 23:26:21 +0100 Subject: [PATCH 0652/1059] All tests pass --- phy/plot/base.py | 1 - phy/plot/interact.py | 4 +- phy/plot/panzoom.py | 1 + phy/plot/plot.py | 113 +++++++++++++++++--------------- phy/plot/tests/test_interact.py | 9 +-- phy/plot/tests/test_plot.py | 31 +++------ phy/plot/visuals.py | 27 ++++++-- 7 files changed, 98 insertions(+), 88 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index b62172da7..fbcdb6779 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -252,7 +252,6 @@ class BaseInteract(object): def attach(self, canvas): """Attach this interact to a canvas.""" self.canvas = canvas - canvas.panzoom = self @canvas.connect def on_visual_added(e): diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 682534edd..6b0d948ab 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -55,8 +55,10 @@ def attach(self, canvas): ]) canvas.inserter.insert_vert(""" attribute vec2 {}; + uniform vec2 {}; uniform float u_grid_zoom; - """.format(self.box_var), 'header') + """.format(self.box_var, self.shape_var), + 'header') canvas.connect(self.on_key_press) def update_program(self, program): diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 3b51dbdb3..9d7cf282f 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -384,6 +384,7 @@ def size(self): def attach(self, canvas): """Attach this interact to a canvas.""" super(PanZoom, self).attach(canvas) + canvas.panzoom = self canvas.transforms.add_on_gpu([Translate(self.pan_var_name), Scale(self.zoom_var_name)]) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index f2d5206f5..103001de3 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -7,13 +7,10 @@ # Imports #------------------------------------------------------------------------------ -from collections import defaultdict -from contextlib import contextmanager -from itertools import groupby +from collections import defaultdict, OrderedDict import numpy as np -from phy.utils import Bunch, _is_array_like from .base import BaseCanvas from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom @@ -59,41 +56,43 @@ def _make_scatter_class(marker): class BaseView(BaseCanvas): """High-level plotting canvas.""" + _default_box_index = (0,) def __init__(self, **kwargs): + if not kwargs.get('keys', None): + kwargs['keys'] = 'interactive' super(BaseView, self).__init__(**kwargs) - self._default_box_index = None self.clear() def clear(self): - self._items = defaultdict(list) + self._items = OrderedDict() def _add_item(self, cls, *args, **kwargs): data = cls.validate(*args, **kwargs) - data['box_index'] = kwargs.get('box_index', self._default_box_index) + n = cls.vertex_count(**data) + box_index = kwargs.get('box_index', self._default_box_index) + k = len(box_index) if hasattr(box_index, '__len__') else 1 + box_index = _get_array(box_index, (n, k)) + data['box_index'] = box_index + if cls not in self._items: + self._items[cls] = [] self._items[cls].append(data) + return data def plot(self, *args, **kwargs): - self._add_item(PlotVisual, *args, **kwargs) + return self._add_item(PlotVisual, *args, **kwargs) def scatter(self, *args, **kwargs): - cls = _make_scatter_class(kwargs.get('marker', + cls = _make_scatter_class(kwargs.pop('marker', ScatterVisual._default_marker)) - self._add_item(cls, *args, **kwargs) + return self._add_item(cls, *args, **kwargs) def hist(self, *args, **kwargs): - self._add_item(HistogramVisual, *args, **kwargs) + return self._add_item(HistogramVisual, *args, **kwargs) def __getitem__(self, box_index): - - @contextmanager - def box_index_ctx(): - self._default_box_index = box_index - yield - self._default_box_index = None - - with box_index_ctx(): - return self + self._default_box_index = box_index + return self def build(self): for cls, data_list in self._items.items(): @@ -102,36 +101,44 @@ def build(self): visual = cls() self.add_visual(visual) visual.set_data(**data) - try: - visual.program['a_box_index'] - visual.program['a_box_index'] = box_index - except KeyError: - pass - - -# class GridView(BaseView): -# """A 2D grid with clipping.""" -# def __init__(self, shape, **kwargs): -# self.n_rows, self.n_cols = shape -# pz = PanZoom(aspect=None, constrain_bounds=NDC) -# interacts = [Grid(shape), pz] -# super(GridView, self).__init__(interacts, **kwargs) - - -# class BoxedView(BaseView): -# """Subplots at arbitrary positions""" -# def __init__(self, box_bounds, **kwargs): -# self.n_plots = len(box_bounds) -# self._boxed = Boxed(box_bounds) -# self._pz = PanZoom(aspect=None, constrain_bounds=NDC) -# interacts = [self._boxed, self._pz] -# super(BoxedView, self).__init__(interacts, **kwargs) - - -# class StackedView(BaseView): -# """Stacked subplots""" -# def __init__(self, n_plots, **kwargs): -# self.n_plots = n_plots -# pz = PanZoom(aspect=None, constrain_bounds=NDC) -# interacts = [Stacked(n_plots, margin=.1), pz] -# super(StackedView, self).__init__(interacts, **kwargs) + visual.program['a_box_index'] = box_index + + +class GridView(BaseView): + """A 2D grid with clipping.""" + _default_box_index = (0, 0) + + def __init__(self, shape=None, **kwargs): + super(GridView, self).__init__(**kwargs) + + self.grid = Grid(shape) + self.grid.attach(self) + + self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) + self.panzoom.attach(self) + + +class BoxedView(BaseView): + """Subplots at arbitrary positions""" + def __init__(self, box_bounds, **kwargs): + super(BoxedView, self).__init__(**kwargs) + self.n_plots = len(box_bounds) + + self.boxed = Boxed(box_bounds) + self.boxed.attach(self) + + self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) + self.panzoom.attach(self) + + +class StackedView(BaseView): + """Stacked subplots""" + def __init__(self, n_plots, **kwargs): + super(StackedView, self).__init__(**kwargs) + self.n_plots = n_plots + + self.stacked = Stacked(n_plots, margin=.1) + self.stacked.attach(self) + + self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) + self.panzoom.attach(self) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 6cb102277..54e307743 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -115,14 +115,9 @@ def test_grid_2(qtbot, canvas): box_index = [[i, j] for i, j in product(range(2), range(3))] box_index = np.repeat(box_index, n, axis=0) - class MyGrid(Grid): - def attach(self, canvas): - super(MyGrid, self).attach(canvas) - canvas.inserter.insert_vert('vec2 u_shape = vec2(3, 3);', - 'before_transforms') - - grid = MyGrid(shape_var='u_shape') + grid = Grid() _create_visual(qtbot, canvas, grid, box_index) + grid.shape = (3, 3) # qtbot.stop() diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 0d390d57f..fd6073882 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -10,7 +10,7 @@ import numpy as np from ..panzoom import PanZoom -from ..plot import BaseView #, GridView, BoxedView, StackedView +from ..plot import BaseView, GridView, BoxedView, StackedView from ..utils import _get_linear_x @@ -39,12 +39,12 @@ def test_base_view(qtbot): y = np.random.randn(n) view.scatter(x, y) - _show(qtbot, view, stop=True) + _show(qtbot, view) def test_grid_scatter(qtbot): view = GridView((2, 3)) - n = 1000 + n = 100 assert isinstance(view.panzoom, PanZoom) @@ -108,18 +108,19 @@ def test_grid_complete(qtbot): def test_stacked_complete(qtbot): - view = StackedView(4) + view = StackedView(3) t = _get_linear_x(1, 1000).ravel() view[0].scatter(*np.random.randn(2, 100)) # Different types of visuals in the same subplot. - view[2].hist(np.random.rand(5, 10), + view[1].hist(np.random.rand(5, 10), color=np.random.uniform(.4, .9, size=(5, 4))) - view[2].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) - v = view[1].plot(t[::2], np.sin(20 * t[::2]), color=(1, 0, 0, 1)) - v.set_data(color=(0, 1, 0, 1)) + # TODO + # v = view[2].plot(t[::2], np.sin(20 * t[::2]), color=(1, 0, 0, 1)) + # v.update(color=(0, 1, 0, 1)) _show(qtbot, view) @@ -137,16 +138,4 @@ def test_boxed_complete(qtbot): view[2].hist(np.random.rand(5, 10), color=np.random.uniform(.4, .9, size=(5, 4))) - # Build and show. - view.build() - view.show() - - # Change a subplot. - view[2].hist(np.random.rand(5, 10), - color=np.random.uniform(.4, .9, size=(5, 4))) - - # Rebuild and show. - view.build() - qtbot.waitForWindowShown(view.native) - - view.close() + _show(qtbot, view) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index d4269283c..9cca04b09 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -80,6 +80,12 @@ def __init__(self, marker=None): self.data_range = Range(NDC) self.transforms.add_on_cpu(self.data_range) + @staticmethod + def vertex_count(x=None, y=None, pos=None, **kwargs): + if pos is not None: + return len(pos) + return x.size if x is not None else y.size + @staticmethod def validate(x=None, y=None, @@ -142,8 +148,8 @@ def validate(x=None, # Default x coordinates. assert x is not None assert y is not None - x = np.asarray(x, np.float32) - y = np.asarray(y, np.float32) + x = np.atleast_2d(x).astype(np.float32) + y = np.atleast_2d(y).astype(np.float32) assert x.ndim == 2 assert x.shape == y.shape n_signals, n_samples = x.shape @@ -167,6 +173,10 @@ def validate(x=None, color=color, depth=depth, data_bounds=data_bounds) + @staticmethod + def vertex_count(x=None, y=None, **kwargs): + return x.size if x is not None else y.size + def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) x, y = data.x, data.y @@ -229,20 +239,27 @@ def validate(hist=None, ylim = hist.max() if hist.size > 0 else 1. ylim = np.atleast_1d(ylim) if len(ylim) == 1: - ylim = np.tile(ylim, len(ylim)) - assert ylim.shape == (n_hists,) + ylim = np.tile(ylim, n_hists) + if ylim.ndim == 1: + ylim = ylim[:, np.newaxis] + assert ylim.shape == (n_hists, 1) return Bunch(hist=hist, ylim=ylim, color=color, ) + @staticmethod + def vertex_count(hist, **kwargs): + n_hists, n_bins = hist.shape + return 6 * n_hists * n_bins + def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) hist = data.hist n_hists, n_bins = hist.shape - n = 6 * n_hists * n_bins + n = self.vertex_count(hist) # NOTE: this must be set *before* `apply_cpu_transforms` such # that the histogram is correctly normalized. From 72c3712b2455226f9a9c2917bb20e1cde4f0419b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 29 Nov 2015 07:36:43 +0100 Subject: [PATCH 0653/1059] Increase coverage --- phy/plot/base.py | 10 ++++++++++ phy/plot/tests/test_interact.py | 1 + phy/plot/tests/test_visuals.py | 7 ++++--- phy/plot/utils.py | 12 ++---------- phy/plot/visuals.py | 1 + 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index fbcdb6779..bcaf9e774 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -72,6 +72,16 @@ def on_draw(self): # To override # ------------------------------------------------------------------------- + @staticmethod + def validate(**kwargs): + """Make consistent the input data for the visual.""" + return kwargs + + @staticmethod + def vertex_count(**kwargs): + """Return the number of vertices as a function of the input data.""" + return 0 + def set_data(self): """Set data to the program. diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 54e307743..36014b0c0 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -118,6 +118,7 @@ def test_grid_2(qtbot, canvas): grid = Grid() _create_visual(qtbot, canvas, grid, box_index) grid.shape = (3, 3) + assert grid.shape == (3, 3) # qtbot.stop() diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 547e19a43..576054373 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -20,6 +20,8 @@ def _test_visual(qtbot, c, v, stop=False, **kwargs): c.add_visual(v) + v.validate(**kwargs) + assert v.vertex_count(**kwargs) >= 0 v.set_data(**kwargs) c.show() qtbot.waitForWindowShown(c.native) @@ -57,8 +59,7 @@ def test_scatter_custom(qtbot, canvas_pz): n = 100 # Random position. - x = .2 * np.random.randn(n) - y = .2 * np.random.randn(n) + pos = .2 * np.random.randn(n, 2) # Random colors. c = np.random.uniform(.4, .7, size=(n, 4)) @@ -68,7 +69,7 @@ def test_scatter_custom(qtbot, canvas_pz): s = 5 + 20 * np.random.rand(n) _test_visual(qtbot, canvas_pz, ScatterVisual(), - x=x, y=y, color=c, size=s) + pos=pos, color=c, size=s) #------------------------------------------------------------------------------ diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 75c1cab7c..5ac9bf169 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -122,7 +122,7 @@ def _get_texture(arr, default, n_items, from_bounds): if not hasattr(default, '__len__'): # pragma: no cover default = [default] n_cols = len(default) - if arr is None: + if arr is None: # pragma: no cover arr = np.tile(default, (n_items, 1)) assert arr.shape == (n_items, n_cols) # Convert to 3D texture. @@ -145,7 +145,7 @@ def _get_texture(arr, default, n_items, from_bounds): def _get_array(val, shape, default=None): """Ensure an object is an array with the specified shape.""" assert val is not None or default is not None - if hasattr(val, '__len__') and len(val) == 0: + if hasattr(val, '__len__') and len(val) == 0: # pragma: no cover val = None out = np.zeros(shape, dtype=np.float32) # This solves `ValueError: could not broadcast input array from shape (n) @@ -211,14 +211,6 @@ def _get_pos(x, y): return x, y -def _get_hist_max(hist): - hist_max = hist.max() if hist.size else 1. - hist_max = float(hist_max) - hist_max = hist_max if hist_max > 0 else 1. - assert hist_max > 0 - return hist_max - - def _get_index(n_items, item_size, n): """Prepare an index attribute for GPU uploading.""" index = np.arange(n_items) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 9cca04b09..4452ebb9d 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -251,6 +251,7 @@ def validate(hist=None, @staticmethod def vertex_count(hist, **kwargs): + hist = np.atleast_2d(hist) n_hists, n_bins = hist.shape return 6 * n_hists * n_bins From 6b01de7ac49cf0761dbef8c146cb5fc4d3e4b6e6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 29 Nov 2015 08:09:58 +0100 Subject: [PATCH 0654/1059] Minor improvements to phy.plot --- phy/plot/__init__.py | 3 +-- phy/plot/plot.py | 14 ++++++++++++-- phy/plot/tests/test_plot.py | 15 +++++++++++++++ phy/plot/visuals.py | 2 +- 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 0029ba8df..b4e1d3453 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -12,8 +12,7 @@ from vispy import config -# from .interact import Grid, Stacked, Boxed -# from .plot import GridView, BoxedView, StackedView # noqa +from .plot import GridView, BoxedView, StackedView # noqa from .transform import Translate, Scale, Range, Subplot, NDC from .panzoom import PanZoom from.visuals import _get_linear_x diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 103001de3..6cc6e13ff 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -8,6 +8,7 @@ #------------------------------------------------------------------------------ from collections import defaultdict, OrderedDict +from contextlib import contextmanager import numpy as np @@ -68,12 +69,14 @@ def clear(self): self._items = OrderedDict() def _add_item(self, cls, *args, **kwargs): + box_index = kwargs.pop('box_index', self._default_box_index) + k = len(box_index) if hasattr(box_index, '__len__') else 1 + data = cls.validate(*args, **kwargs) n = cls.vertex_count(**data) - box_index = kwargs.get('box_index', self._default_box_index) - k = len(box_index) if hasattr(box_index, '__len__') else 1 box_index = _get_array(box_index, (n, k)) data['box_index'] = box_index + if cls not in self._items: self._items[cls] = [] self._items[cls].append(data) @@ -102,6 +105,13 @@ def build(self): self.add_visual(visual) visual.set_data(**data) visual.program['a_box_index'] = box_index + self.update() + + @contextmanager + def building(self): + self.clear() + yield + self.build() class GridView(BaseView): diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index fd6073882..085614372 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -31,6 +31,21 @@ def _show(qtbot, view, stop=False): # Test plotting interface #------------------------------------------------------------------------------ +def test_building(qtbot): + view = BaseView(keys='interactive') + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + + with view.building(): + view.scatter(x, y) + + view.show() + qtbot.waitForWindowShown(view.native) + view.close() + + def test_base_view(qtbot): view = BaseView(keys='interactive') n = 1000 diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 4452ebb9d..aaf10d8ff 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -156,7 +156,7 @@ def validate(x=None, # Validate the data. color = _get_array(color, (n_signals, 4), PlotVisual._default_color) - depth = _get_array(depth, (n_signals,), 0) + depth = _get_array(depth, (n_signals, 1), 0) # Validate data bounds. if data_bounds is None: From 32f34deef1c463da170865b327465e008df6f79b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 29 Nov 2015 08:19:59 +0100 Subject: [PATCH 0655/1059] WIP: update manual clustering views --- phy/cluster/manual/tests/conftest.py | 2 +- phy/cluster/manual/tests/test_views.py | 3 +- phy/cluster/manual/views.py | 117 ++++++++++++------------- 3 files changed, 57 insertions(+), 65 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 7a48df7a6..94a618c97 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -"""Test wizard.""" +"""Test fixtures.""" #------------------------------------------------------------------------------ # Imports diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index c4e803925..61dff7c01 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -117,7 +117,7 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def test_trace_view_spikes(qtbot): +def SKIPtest_trace_view_spikes(qtbot): n_samples = 1000 n_channels = 12 sample_rate = 2000. @@ -131,6 +131,7 @@ def test_trace_view_spikes(qtbot): masks = artificial_masks(n_spikes, n_channels) # Create the view. + # TODO: make this work = plots with variable n_samples v = TraceView(traces=traces, sample_rate=sample_rate, spike_times=spike_times, diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 77cf14af4..71c2425b4 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -189,12 +189,6 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - # Initialize the subplots. - self._plots = {ch: self[ch].plot(x=[], y=[]) - for ch in range(self.n_channels)} - self.build() - self.update() - def on_select(self, cluster_ids, spike_ids): n_clusters = len(cluster_ids) n_spikes = len(spike_ids) @@ -222,22 +216,20 @@ def on_select(self, cluster_ids, spike_ids): # Plot all waveforms. # OPTIM: avoid the loop. - for ch in range(self.n_channels): - m = masks[:, ch] - depth = _get_depth(m, - spike_clusters_rel=spike_clusters_rel, - n_clusters=n_clusters) - color = _get_color(m, - spike_clusters_rel=spike_clusters_rel, - n_clusters=n_clusters) - self._plots[ch].set_data(x=t, y=w[:, :, ch], - color=color, - depth=depth, - data_bounds=self.data_bounds, - ) - - self.build() - self.update() + with self.building(): + for ch in range(self.n_channels): + m = masks[:, ch] + depth = _get_depth(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + color = _get_color(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + self[ch].plot(x=t, y=w[:, :, ch], + color=color, + depth=depth, + data_bounds=self.data_bounds, + ) def attach(self, gui): """Attach the view to the GUI.""" @@ -333,7 +325,7 @@ def _load_spikes(self, interval): return self.spike_times[a:b], self.spike_clusters[a:b], self.masks[a:b] def set_interval(self, interval): - + self.clear() start, end = interval color = (.5, .5, .5, 1) @@ -345,7 +337,7 @@ def set_interval(self, interval): assert traces.shape[1] == self.n_channels m, M = traces.min(), traces.max() - data_bounds = [start, m, end, M] + data_bounds = np.array([start, m, end, M]) # Generate the trace plots. # TODO OPTIM: avoid the loop and generate all channel traces in @@ -365,16 +357,16 @@ def set_interval(self, interval): dur_spike = wave_len * dt trace_start = int(self.sample_rate * start) - # ac = Accumulator() for i in range(n_spikes): sample_rel = (int(spike_times[i] * self.sample_rate) - trace_start) mask = self.masks[i] + # TODO # clu = spike_clusters[i] w, ch = _extract_wave(traces, sample_rel, mask, wave_len) n_ch = len(ch) t0 = spike_times[i] - dur_spike / 2. - color = (1, 0, 0, 1) + color = np.array([1, 0, 0, 1]) box_index = np.repeat(ch[:, np.newaxis], wave_len, axis=0) t = t0 + dt * np.arange(wave_len) t = np.tile(t, (n_ch, 1)) @@ -384,6 +376,12 @@ def set_interval(self, interval): self.build() self.update() + def on_select(self, cluster_ids, spike_ids): + pass + + def attach(self, gui): + pass + # ----------------------------------------------------------------------------- # Feature view @@ -522,14 +520,6 @@ def __init__(self, assert spike_times.shape == (self.n_spikes,) self.spike_times = spike_times - # Initialize the subplots. - self._plots = {(i, j): self[i, j].scatter(x=[], y=[], size=[]) - for i in range(self.n_cols) - for j in range(self.n_cols) - } - self.build() - self.update() - def _get_feature(self, dim, spike_ids=None): f = self.features[spike_ids] assert f.ndim == 3 @@ -563,36 +553,37 @@ def on_select(self, cluster_ids, spike_ids): # Plot all features. # TODO: optim: avoid the loop. - for i in range(self.n_cols): - for j in range(self.n_cols): + with self.building(): + for i in range(self.n_cols): + for j in range(self.n_cols): + + x = self._get_feature(x_dim[i, j], spike_ids) + y = self._get_feature(y_dim[i, j], spike_ids) + + mx, dx = _project_mask_depth(x_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + my, dy = _project_mask_depth(y_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + + d = np.maximum(dx, dy) + m = np.maximum(mx, my) + + color = _get_color(m, + spike_clusters_rel=sc, + n_clusters=n_clusters) + + self[i, j].scatter(x=x, + y=y, + color=color, + depth=d, + data_bounds=self.data_bounds, + size=5 * np.ones(n_spikes), + ) - x = self._get_feature(x_dim[i, j], spike_ids) - y = self._get_feature(y_dim[i, j], spike_ids) - - mx, dx = _project_mask_depth(x_dim[i, j], masks, - spike_clusters_rel=sc, - n_clusters=n_clusters) - my, dy = _project_mask_depth(y_dim[i, j], masks, - spike_clusters_rel=sc, - n_clusters=n_clusters) - - d = np.maximum(dx, dy) - m = np.maximum(mx, my) - - color = _get_color(m, - spike_clusters_rel=sc, - n_clusters=n_clusters) - - self._plots[i, j].set_data(x=x, - y=y, - color=color, - depth=d, - data_bounds=self.data_bounds, - size=5 * np.ones(n_spikes), - ) - - self.build() - self.update() + def attach(self, gui): + pass # ----------------------------------------------------------------------------- From 61200380c73c6897679202ac2342d9b41aa941fc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 29 Nov 2015 08:23:14 +0100 Subject: [PATCH 0656/1059] Update CCG view --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 35 +++++++++----------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 61dff7c01..b68810cda 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -187,7 +187,7 @@ def test_feature_view(qtbot): def test_correlogram_view(qtbot): n_spikes = 50 - n_clusters = 1 + n_clusters = 5 sample_rate = 20000. bin_size = 1e-3 window_size = 50e-3 diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 71c2425b4..d5b6b83b8 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -617,22 +617,12 @@ def __init__(self, self.n_spikes, = self.spike_times.shape # Initialize the view. - self.n_cols = 1 # TODO: dynamic grid shape in interact - self.shape = (self.n_cols, self.n_cols) - super(CorrelogramView, self).__init__(self.shape, keys=keys) + super(CorrelogramView, self).__init__(keys=keys) # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters - # Initialize the subplots. - self._plots = {(i, j): self[i, j].hist(hist=[]) - for i in range(self.n_cols) - for j in range(self.n_cols) - } - self.build() - self.update() - def on_select(self, cluster_ids, spike_ids): n_clusters = len(cluster_ids) n_spikes = len(spike_ids) @@ -655,18 +645,17 @@ def on_select(self, cluster_ids, spike_ids): colors = _selected_clusters_colors(n_clusters) - for i in range(n_clusters): - for j in range(n_clusters): - hist = ccg[i, j, :] - color = colors[i] if i == j else np.ones(3) - color = np.hstack((color, [1])) - self._plots[i, j].set_data(hist=hist, - color=color, - ylim=[lim], - ) - - self.build() - self.update() + self.grid.shape = (n_clusters, n_clusters) + with self.building(): + for i in range(n_clusters): + for j in range(n_clusters): + hist = ccg[i, j, :] + color = colors[i] if i == j else np.ones(3) + color = np.hstack((color, [1])) + self[i, j].hist(hist, + color=color, + ylim=[lim], + ) def attach(self, gui): """Attach the view to the GUI.""" From 0cbad5e3cadba8b076159a9b288ca2e01d6e749c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 15:01:26 +0100 Subject: [PATCH 0657/1059] Plot visual with arbitrary number of samples --- phy/plot/tests/test_visuals.py | 12 +++- phy/plot/visuals.py | 113 ++++++++++++++++++++------------- 2 files changed, 80 insertions(+), 45 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 576054373..3e8218e88 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -31,8 +31,8 @@ def _test_visual(qtbot, c, v, stop=False, **kwargs): #------------------------------------------------------------------------------ # Test scatter visual - #------------------------------------------------------------------------------ + def test_scatter_empty(qtbot, canvas): _test_visual(qtbot, canvas, ScatterVisual(), x=np.zeros(0), y=np.zeros(0)) @@ -113,6 +113,16 @@ def test_plot_2(qtbot, canvas_pz): color=c) +def test_plot_list(qtbot, canvas_pz): + y = [np.random.randn(i) for i in (5, 20)] + + c = np.random.uniform(.5, 1, size=(2, 4)) + c[:, 3] = .5 + + _test_visual(qtbot, canvas_pz, PlotVisual(), + y=y, color=c) + + #------------------------------------------------------------------------------ # Test histogram visual #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index aaf10d8ff..fc9b9caa1 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -139,73 +139,98 @@ def validate(x=None, y=None, color=None, depth=None, - data_bounds=None): - - if x is None: - assert y is not None - x = _get_linear_x(*y.shape) + data_bounds=None, + ): # Default x coordinates. - assert x is not None assert y is not None - x = np.atleast_2d(x).astype(np.float32) - y = np.atleast_2d(y).astype(np.float32) - assert x.ndim == 2 - assert x.shape == y.shape - n_signals, n_samples = x.shape - # Validate the data. + if isinstance(y, np.ndarray) and y.ndim == 2: + if x is None: + x = _get_linear_x(*y.shape) + assert x.ndim == 2 + assert x.shape == y.shape + n_signals, n_samples = y.shape + n = y.size + # Data bounds. + if data_bounds is None: + if n_samples > 0: + # NOTE: by default, per-signal normalization. + data_bounds = np.c_[x.min(axis=1), y.min(axis=1), + x.max(axis=1), y.max(axis=1)] + else: + data_bounds = NDC + x = x.ravel() + y = y.ravel() + elif isinstance(y, list): + if x is None: + x = [np.linspace(-1., 1., len(_)) for _ in y] + assert isinstance(x, list) + # Remove empty elements. + x = [_ for _ in x if len(_)] + y = [_ for _ in y if len(_)] + assert len(x) == len(y) + n_signals = len(x) + n_samples = [len(_) for _ in y] + if data_bounds is None: + xmin = [_.min() for _ in x] + ymin = [_.min() for _ in y] + xmax = [_.max() for _ in x] + ymax = [_.max() for _ in y] + data_bounds = np.c_[xmin, ymin, xmax, ymax] + n = sum(n_samples) + x = np.concatenate(x) + y = np.concatenate(y) + assert x.shape == y.shape == (n,) + # NOTE: n_samples may be an int or a list of ints. + + # Generate the position array. + pos = np.empty((n, 2), dtype=np.float32) + pos[:, 0] = x.ravel() + pos[:, 1] = y.ravel() + assert pos.shape == (n, 2) + + # Generate signal index. + signal_index = np.repeat(np.arange(n_signals), n_samples) + signal_index = _get_array(signal_index, (n, 1)).astype(np.float32) + assert signal_index.shape == (n, 1) + color = _get_array(color, (n_signals, 4), PlotVisual._default_color) + # color = np.repeat(color, n_samples, axis=0).astype(np.float32) + assert color.shape == (n_signals, 4) + depth = _get_array(depth, (n_signals, 1), 0) + depth = np.repeat(depth, n_samples, axis=0).astype(np.float32) + assert depth.shape == (n, 1) - # Validate data bounds. - if data_bounds is None: - if n_samples > 0: - # NOTE: by default, per-signal normalization. - data_bounds = np.c_[x.min(axis=1), y.min(axis=1), - x.max(axis=1), y.max(axis=1)] - else: - data_bounds = NDC data_bounds = _get_data_bounds(data_bounds, length=n_signals) - assert data_bounds.shape == (n_signals, 4) + data_bounds = np.repeat(data_bounds, n_samples, axis=0) + data_bounds = data_bounds.astype(np.float32) + assert data_bounds.shape == (n, 4) - return Bunch(x=x, y=y, + return Bunch(pos=pos, n_signals=n_signals, + signal_index=signal_index, color=color, depth=depth, data_bounds=data_bounds) @staticmethod def vertex_count(x=None, y=None, **kwargs): - return x.size if x is not None else y.size + return y.size if isinstance(y, np.ndarray) else sum(len(_) for _ in y) def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) - x, y = data.x, data.y - - n_signals, n_samples = x.shape - n = n_signals * n_samples - - # Generate the position array. - pos = np.empty((n, 2), dtype=np.float32) - pos[:, 0] = x.ravel() - pos[:, 1] = y.ravel() - - # Repeat the data bounds for every vertex in the signals. - data_bounds = np.repeat(data.data_bounds, n_samples, axis=0) - self.data_range.from_bounds = data_bounds # Transform the positions. - pos_tr = self.transforms.apply(pos) - - # Repeat the depth. - depth = np.repeat(data.depth, n_samples) + self.data_range.from_bounds = data.data_bounds + pos_tr = self.transforms.apply(data.pos) - self.program['a_position'] = np.c_[pos_tr, depth] - self.program['a_signal_index'] = _get_index(n_signals, n_samples, n) + self.program['a_position'] = np.c_[pos_tr, data.depth] + self.program['a_signal_index'] = data.signal_index self.program['u_plot_colors'] = Texture2D(_get_texture(data.color, PlotVisual._default_color, - n_signals, + data.n_signals, [0, 1])) - self.program['n_signals'] = n_signals + self.program['n_signals'] = data.n_signals class HistogramVisual(BaseVisual): From 0cfee0f90caa5f5ad7cd6cf8d97d7f0a3bcd0bb8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 15:28:33 +0100 Subject: [PATCH 0658/1059] WIP: fix plot visual --- phy/plot/tests/test_visuals.py | 4 +-- phy/plot/visuals.py | 63 +++++++++++++++++++--------------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 3e8218e88..422b3ffaf 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -20,8 +20,8 @@ def _test_visual(qtbot, c, v, stop=False, **kwargs): c.add_visual(v) - v.validate(**kwargs) - assert v.vertex_count(**kwargs) >= 0 + data = v.validate(**kwargs) + assert v.vertex_count(**data) >= 0 v.set_data(**kwargs) c.show() qtbot.waitForWindowShown(c.native) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index fc9b9caa1..368e3e84c 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -82,9 +82,7 @@ def __init__(self, marker=None): @staticmethod def vertex_count(x=None, y=None, pos=None, **kwargs): - if pos is not None: - return len(pos) - return x.size if x is not None else y.size + return y.size if y is not None else len(pos) @staticmethod def validate(x=None, @@ -160,8 +158,8 @@ def validate(x=None, x.max(axis=1), y.max(axis=1)] else: data_bounds = NDC - x = x.ravel() - y = y.ravel() + # x = x.ravel() + # y = y.ravel() elif isinstance(y, list): if x is None: x = [np.linspace(-1., 1., len(_)) for _ in y] @@ -179,24 +177,13 @@ def validate(x=None, ymax = [_.max() for _ in y] data_bounds = np.c_[xmin, ymin, xmax, ymax] n = sum(n_samples) - x = np.concatenate(x) - y = np.concatenate(y) - assert x.shape == y.shape == (n,) - # NOTE: n_samples may be an int or a list of ints. - - # Generate the position array. - pos = np.empty((n, 2), dtype=np.float32) - pos[:, 0] = x.ravel() - pos[:, 1] = y.ravel() - assert pos.shape == (n, 2) + # x = np.concatenate(x) + # y = np.concatenate(y) - # Generate signal index. - signal_index = np.repeat(np.arange(n_signals), n_samples) - signal_index = _get_array(signal_index, (n, 1)).astype(np.float32) - assert signal_index.shape == (n, 1) + # assert x.shape == y.shape == (n,) + # NOTE: n_samples may be an int or a list of ints. color = _get_array(color, (n_signals, 4), PlotVisual._default_color) - # color = np.repeat(color, n_samples, axis=0).astype(np.float32) assert color.shape == (n_signals, 4) depth = _get_array(depth, (n_signals, 1), 0) @@ -208,29 +195,51 @@ def validate(x=None, data_bounds = data_bounds.astype(np.float32) assert data_bounds.shape == (n, 4) - return Bunch(pos=pos, n_signals=n_signals, - signal_index=signal_index, + return Bunch(x=x, y=y, color=color, depth=depth, data_bounds=data_bounds) @staticmethod - def vertex_count(x=None, y=None, **kwargs): + def vertex_count(y=None, **kwargs): + """Take the output of validate() as input.""" return y.size if isinstance(y, np.ndarray) else sum(len(_) for _ in y) def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) + if isinstance(data.y, np.ndarray): + n_signals, n_samples = data.y.shape + n = data.y.size + x, y = data.x, data.y + elif isinstance(data.y, list): + n_signals = len(data.y) + n_samples = [len(_) for _ in data.y] + n = sum(n_samples) + x = np.concatenate(data.x) + y = np.concatenate(data.y) + + # Generate the position array. + pos = np.empty((n, 2), dtype=np.float32) + pos[:, 0] = x.ravel() + pos[:, 1] = y.ravel() + assert pos.shape == (n, 2) + + # Generate signal index. + signal_index = np.repeat(np.arange(n_signals), n_samples) + signal_index = _get_array(signal_index, (n, 1)).astype(np.float32) + assert signal_index.shape == (n, 1) + # Transform the positions. self.data_range.from_bounds = data.data_bounds - pos_tr = self.transforms.apply(data.pos) + pos_tr = self.transforms.apply(pos) self.program['a_position'] = np.c_[pos_tr, data.depth] - self.program['a_signal_index'] = data.signal_index + self.program['a_signal_index'] = signal_index self.program['u_plot_colors'] = Texture2D(_get_texture(data.color, PlotVisual._default_color, - data.n_signals, + n_signals, [0, 1])) - self.program['n_signals'] = data.n_signals + self.program['n_signals'] = n_signals class HistogramVisual(BaseVisual): From 8bc5d7ed4b975249b2544ac13042184c2fb9da88 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 16:26:26 +0100 Subject: [PATCH 0659/1059] WIP --- phy/plot/base.py | 4 ++ phy/plot/plot.py | 24 ++++++-- phy/plot/tests/test_plot.py | 4 +- phy/plot/tests/test_visuals.py | 2 +- phy/plot/visuals.py | 101 +++++++++++++++------------------ 5 files changed, 71 insertions(+), 64 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index bcaf9e774..5cd5e55e2 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -40,6 +40,10 @@ class BaseVisual(object): type of GL primitive. """ + + """Data variables that can be lists of arrays.""" + allow_list = () + def __init__(self): self.gl_primitive_type = None self.transforms = TransformChain() diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 6cc6e13ff..47f244281 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -32,18 +32,29 @@ def __init__(self): def add(self, name, val): self._data[name].append(val) + def get(self, name): + return self._data[name] + + @property + def names(self): + return set(self._data) + def __getitem__(self, name): return np.vstack(self._data[name]).astype(np.float32) -def _accumulate(data_list): +def _accumulate(data_list, no_concat=()): acc = Accumulator() - names = set() for data in data_list: for name, val in data.items(): - names.add(name) acc.add(name, val) - return {name: acc[name] for name in names} + out = {name: acc[name] for name in acc.names if name not in no_concat} + + # Some variables should not be concatenated but should be kept as lists. + # This is when there can be several arrays of variable length (NumPy + # doesn't support ragged arrays). + out.update({name: acc.get(name) for name in no_concat}) + return out def _make_scatter_class(marker): @@ -79,6 +90,7 @@ def _add_item(self, cls, *args, **kwargs): if cls not in self._items: self._items[cls] = [] + print(data['y']) self._items[cls].append(data) return data @@ -99,7 +111,9 @@ def __getitem__(self, box_index): def build(self): for cls, data_list in self._items.items(): - data = _accumulate(data_list) + # Some variables are not concatenated. They are specified + # in `allow_list`. + data = _accumulate(data_list, cls.allow_list) box_index = data.pop('box_index') visual = cls() self.add_visual(visual) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 085614372..a8df5092a 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -85,13 +85,13 @@ def test_grid_scatter(qtbot): def test_grid_plot(qtbot): view = GridView((1, 2)) - n_plots, n_samples = 10, 50 + n_plots, n_samples = 2, 5 x = _get_linear_x(n_plots, n_samples) y = np.random.randn(n_plots, n_samples) view[0, 0].plot(x, y) - view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) + # view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) _show(qtbot, view) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 422b3ffaf..06ed2960d 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -89,7 +89,7 @@ def test_plot_0(qtbot, canvas_pz): def test_plot_1(qtbot, canvas_pz): - y = .2 * np.random.randn(1, 10) + y = .2 * np.random.randn(10) _test_visual(qtbot, canvas_pz, PlotVisual(), y=y) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 368e3e84c..8d9547a98 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -119,8 +119,19 @@ def set_data(self, *args, **kwargs): self.program['a_color'] = data.color +def _as_list(arr): + if isinstance(arr, np.ndarray): + if arr.ndim == 1: + return [arr] + elif arr.ndim == 2: + return list(arr) + assert isinstance(arr, list) + return arr + + class PlotVisual(BaseVisual): _default_color = DEFAULT_COLOR + allow_list = ('x', 'y') def __init__(self): super(PlotVisual, self).__init__() @@ -140,60 +151,38 @@ def validate(x=None, data_bounds=None, ): - # Default x coordinates. assert y is not None + y = _as_list(y) + + if x is None: + x = [np.linspace(-1., 1., len(_)) for _ in y] + x = _as_list(x) + + # Remove empty elements. + x = [_ for _ in x if len(_)] + y = [_ for _ in y if len(_)] + assert len(x) == len(y) + + assert [len(_) for _ in x] == [len(_) for _ in y] - if isinstance(y, np.ndarray) and y.ndim == 2: - if x is None: - x = _get_linear_x(*y.shape) - assert x.ndim == 2 - assert x.shape == y.shape - n_signals, n_samples = y.shape - n = y.size - # Data bounds. - if data_bounds is None: - if n_samples > 0: - # NOTE: by default, per-signal normalization. - data_bounds = np.c_[x.min(axis=1), y.min(axis=1), - x.max(axis=1), y.max(axis=1)] - else: - data_bounds = NDC - # x = x.ravel() - # y = y.ravel() - elif isinstance(y, list): - if x is None: - x = [np.linspace(-1., 1., len(_)) for _ in y] - assert isinstance(x, list) - # Remove empty elements. - x = [_ for _ in x if len(_)] - y = [_ for _ in y if len(_)] - assert len(x) == len(y) - n_signals = len(x) - n_samples = [len(_) for _ in y] - if data_bounds is None: - xmin = [_.min() for _ in x] - ymin = [_.min() for _ in y] - xmax = [_.max() for _ in x] - ymax = [_.max() for _ in y] - data_bounds = np.c_[xmin, ymin, xmax, ymax] - n = sum(n_samples) - # x = np.concatenate(x) - # y = np.concatenate(y) - - # assert x.shape == y.shape == (n,) - # NOTE: n_samples may be an int or a list of ints. + n_signals = len(x) + + if data_bounds is None: + xmin = [_.min() for _ in x] + ymin = [_.min() for _ in y] + xmax = [_.max() for _ in x] + ymax = [_.max() for _ in y] + data_bounds = np.c_[xmin, ymin, xmax, ymax] color = _get_array(color, (n_signals, 4), PlotVisual._default_color) assert color.shape == (n_signals, 4) depth = _get_array(depth, (n_signals, 1), 0) - depth = np.repeat(depth, n_samples, axis=0).astype(np.float32) - assert depth.shape == (n, 1) + assert depth.shape == (n_signals, 1) data_bounds = _get_data_bounds(data_bounds, length=n_signals) - data_bounds = np.repeat(data_bounds, n_samples, axis=0) data_bounds = data_bounds.astype(np.float32) - assert data_bounds.shape == (n, 4) + assert data_bounds.shape == (n_signals, 4) return Bunch(x=x, y=y, color=color, depth=depth, @@ -207,16 +196,12 @@ def vertex_count(y=None, **kwargs): def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) - if isinstance(data.y, np.ndarray): - n_signals, n_samples = data.y.shape - n = data.y.size - x, y = data.x, data.y - elif isinstance(data.y, list): - n_signals = len(data.y) - n_samples = [len(_) for _ in data.y] - n = sum(n_samples) - x = np.concatenate(data.x) - y = np.concatenate(data.y) + assert isinstance(data.y, list) + n_signals = len(data.y) + n_samples = [len(_) for _ in data.y] + n = sum(n_samples) + x = np.concatenate(data.x) if len(data.x) else np.array([]) + y = np.concatenate(data.y) if len(data.y) else np.array([]) # Generate the position array. pos = np.empty((n, 2), dtype=np.float32) @@ -230,10 +215,14 @@ def set_data(self, *args, **kwargs): assert signal_index.shape == (n, 1) # Transform the positions. - self.data_range.from_bounds = data.data_bounds + data_bounds = np.repeat(data.data_bounds, n_samples, axis=0) + self.data_range.from_bounds = data_bounds pos_tr = self.transforms.apply(pos) - self.program['a_position'] = np.c_[pos_tr, data.depth] + # Position and depth. + depth = np.repeat(data.depth, n_samples, axis=0) + self.program['a_position'] = np.c_[pos_tr, depth] + self.program['a_signal_index'] = signal_index self.program['u_plot_colors'] = Texture2D(_get_texture(data.color, PlotVisual._default_color, From 09cbf38b843eb6cf7d84a42d2d4b34a89f5d8422 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 17:12:17 +0100 Subject: [PATCH 0660/1059] Fix visuals --- phy/plot/tests/test_visuals.py | 4 +--- phy/plot/visuals.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 06ed2960d..348d2374a 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -18,15 +18,13 @@ # Fixtures #------------------------------------------------------------------------------ -def _test_visual(qtbot, c, v, stop=False, **kwargs): +def _test_visual(qtbot, c, v, **kwargs): c.add_visual(v) data = v.validate(**kwargs) assert v.vertex_count(**data) >= 0 v.set_data(**kwargs) c.show() qtbot.waitForWindowShown(c.native) - if stop: # pragma: no cover - qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 8d9547a98..da5265e34 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -129,6 +129,14 @@ def _as_list(arr): return arr +def _min(arr): + return arr.min() if len(arr) else 0 + + +def _max(arr): + return arr.max() if len(arr) else 1 + + class PlotVisual(BaseVisual): _default_color = DEFAULT_COLOR allow_list = ('x', 'y') @@ -159,8 +167,6 @@ def validate(x=None, x = _as_list(x) # Remove empty elements. - x = [_ for _ in x if len(_)] - y = [_ for _ in y if len(_)] assert len(x) == len(y) assert [len(_) for _ in x] == [len(_) for _ in y] @@ -168,10 +174,10 @@ def validate(x=None, n_signals = len(x) if data_bounds is None: - xmin = [_.min() for _ in x] - ymin = [_.min() for _ in y] - xmax = [_.max() for _ in x] - ymax = [_.max() for _ in y] + xmin = [_min(_) for _ in x] + ymin = [_min(_) for _ in y] + xmax = [_max(_) for _ in x] + ymax = [_max(_) for _ in y] data_bounds = np.c_[xmin, ymin, xmax, ymax] color = _get_array(color, (n_signals, 4), PlotVisual._default_color) From ccd9bc2e605cb49c0c9c610e1b1088e06fd79073 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 17:19:48 +0100 Subject: [PATCH 0661/1059] Plot tests pass --- phy/plot/plot.py | 7 +++++-- phy/plot/tests/test_plot.py | 4 ++-- phy/plot/tests/test_visuals.py | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 47f244281..343587dc6 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -24,6 +24,10 @@ # Utils #------------------------------------------------------------------------------ +def _flatten(l): + return [item for sublist in l for item in sublist] + + class Accumulator(object): """Accumulate arrays for concatenation.""" def __init__(self): @@ -33,7 +37,7 @@ def add(self, name, val): self._data[name].append(val) def get(self, name): - return self._data[name] + return _flatten(self._data[name]) @property def names(self): @@ -90,7 +94,6 @@ def _add_item(self, cls, *args, **kwargs): if cls not in self._items: self._items[cls] = [] - print(data['y']) self._items[cls].append(data) return data diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index a8df5092a..48943556e 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -85,13 +85,13 @@ def test_grid_scatter(qtbot): def test_grid_plot(qtbot): view = GridView((1, 2)) - n_plots, n_samples = 2, 5 + n_plots, n_samples = 5, 50 x = _get_linear_x(n_plots, n_samples) y = np.random.randn(n_plots, n_samples) view[0, 0].plot(x, y) - # view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) + view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) _show(qtbot, view) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 348d2374a..9510897d1 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -18,13 +18,16 @@ # Fixtures #------------------------------------------------------------------------------ -def _test_visual(qtbot, c, v, **kwargs): +def _test_visual(qtbot, c, v, stop=False, **kwargs): c.add_visual(v) data = v.validate(**kwargs) assert v.vertex_count(**data) >= 0 v.set_data(**kwargs) c.show() qtbot.waitForWindowShown(c.native) + if stop: # pragma: no cover + qtbot.stop() + c.close() #------------------------------------------------------------------------------ From f2f57335713bb63462a470dffb5b96e615b1fb4a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 17:20:47 +0100 Subject: [PATCH 0662/1059] Flakify --- phy/plot/visuals.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index da5265e34..579f0076d 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -19,7 +19,6 @@ _get_data_bounds, _get_pos, _get_index, - _get_linear_x, _get_color, ) from phy.utils import Bunch From 64ef58e4c036ef2d89078425b39193cc35479f42 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 17:26:46 +0100 Subject: [PATCH 0663/1059] Fix import --- phy/plot/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index b4e1d3453..2b5f48742 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -15,7 +15,7 @@ from .plot import GridView, BoxedView, StackedView # noqa from .transform import Translate, Scale, Range, Subplot, NDC from .panzoom import PanZoom -from.visuals import _get_linear_x +from .utils import _get_linear_x #------------------------------------------------------------------------------ From e21fadec0237a69c54327df7e209efee9e9d49da Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 18:27:39 +0100 Subject: [PATCH 0664/1059] Make trace spike test pass --- phy/cluster/manual/tests/test_views.py | 3 +-- phy/cluster/manual/views.py | 2 +- phy/plot/plot.py | 6 ++++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index b68810cda..a3bb700b0 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -117,7 +117,7 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def SKIPtest_trace_view_spikes(qtbot): +def test_trace_view_spikes(qtbot): n_samples = 1000 n_channels = 12 sample_rate = 2000. @@ -131,7 +131,6 @@ def SKIPtest_trace_view_spikes(qtbot): masks = artificial_masks(n_spikes, n_channels) # Create the view. - # TODO: make this work = plots with variable n_samples v = TraceView(traces=traces, sample_rate=sample_rate, spike_times=spike_times, diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d5b6b83b8..a10edf434 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -361,7 +361,7 @@ def set_interval(self, interval): sample_rel = (int(spike_times[i] * self.sample_rate) - trace_start) mask = self.masks[i] - # TODO + # TODO: color of spike = white or color if selected cluster # clu = spike_clusters[i] w, ch = _extract_wave(traces, sample_rel, mask, wave_len) n_ch = len(ch) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 343587dc6..9a01962aa 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -85,11 +85,13 @@ def clear(self): def _add_item(self, cls, *args, **kwargs): box_index = kwargs.pop('box_index', self._default_box_index) - k = len(box_index) if hasattr(box_index, '__len__') else 1 data = cls.validate(*args, **kwargs) n = cls.vertex_count(**data) - box_index = _get_array(box_index, (n, k)) + + if not isinstance(box_index, np.ndarray): + k = len(box_index) if hasattr(box_index, '__len__') else 1 + box_index = _get_array(box_index, (n, k)) data['box_index'] = box_index if cls not in self._items: From b3fbcfe3b22437641a8f3ba16d9f9d416442d4c1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 18:36:03 +0100 Subject: [PATCH 0665/1059] Fixes --- phy/plot/base.py | 5 ++++- phy/plot/plot.py | 1 + phy/plot/visuals.py | 7 +------ 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 5cd5e55e2..b084317a0 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -16,7 +16,7 @@ from vispy.util.event import Event from .transform import TransformChain, Clip -from .utils import _load_shader +from .utils import _load_shader, _enable_depth_mask logger = logging.getLogger(__name__) @@ -219,6 +219,9 @@ def __init__(self, *args, **kwargs): self.visuals = [] self.events.add(visual_added=VisualEvent) + # Enable transparency. + _enable_depth_mask() + def add_visual(self, visual): """Add a visual to the canvas, and build its program by the same occasion. diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 9a01962aa..236e5ba62 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -82,6 +82,7 @@ def __init__(self, **kwargs): def clear(self): self._items = OrderedDict() + self.visuals = [] def _add_item(self, cls, *args, **kwargs): box_index = kwargs.pop('box_index', self._default_box_index) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 579f0076d..7d15ff4aa 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -12,8 +12,7 @@ from .base import BaseVisual from .transform import Range, NDC -from .utils import (_enable_depth_mask, - _tesselate_histogram, +from .utils import (_tesselate_histogram, _get_texture, _get_array, _get_data_bounds, @@ -69,9 +68,6 @@ def __init__(self, marker=None): self.marker = marker or self._default_marker assert self.marker in self._supported_markers - # Enable transparency. - _enable_depth_mask() - self.set_shader('scatter') self.fragment_shader = self.fragment_shader.replace('%MARKER', self.marker) @@ -142,7 +138,6 @@ class PlotVisual(BaseVisual): def __init__(self): super(PlotVisual, self).__init__() - _enable_depth_mask() self.set_shader('plot') self.set_primitive_type('line_strip') From be39507ab4e733fb564923defcaf2a36e882c21c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 30 Nov 2015 18:51:14 +0100 Subject: [PATCH 0666/1059] Add spike increase test in CCG computation --- phy/stats/ccg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/stats/ccg.py b/phy/stats/ccg.py index 98281a9f6..c6a038c01 100644 --- a/phy/stats/ccg.py +++ b/phy/stats/ccg.py @@ -89,6 +89,8 @@ def correlograms(spike_times, """ assert sample_rate > 0. + assert np.all(np.diff(spike_times) >= 0), ("The spike times must be " + "increasing.") # Get the spike samples. spike_times = np.asarray(spike_times, dtype=np.float64) From 81bde9a55e2e1a9c62bf2659bd2abdf1cbc6fc7b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 09:42:48 +0100 Subject: [PATCH 0667/1059] Compute CCGs on excerpts --- phy/cluster/manual/tests/test_views.py | 4 ++-- phy/cluster/manual/views.py | 26 +++++++++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index a3bb700b0..92ccaf8c3 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -200,8 +200,8 @@ def test_correlogram_view(qtbot): sample_rate=sample_rate, bin_size=bin_size, window_size=window_size, - excerpt_size=None, - n_excerpts=None, + excerpt_size=8, + n_excerpts=5, ) # Select some spikes. diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a10edf434..c0a47537e 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -13,7 +13,7 @@ from matplotlib.colors import hsv_to_rgb, rgb_to_hsv from six import string_types -from phy.io.array import _index_of, _get_padded +from phy.io.array import _index_of, _get_padded, get_excerpts from phy.electrode.mea import linear_positions from phy.gui import Actions from phy.plot import (BoxedView, StackedView, GridView, @@ -132,6 +132,7 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): # ----------------------------------------------------------------------------- class WaveformView(BoxedView): + # TODO: make this configurable normalization_percentile = .95 normalization_n_spikes = 1000 overlap = True @@ -484,6 +485,7 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(GridView): + # TODO: make this configurable normalization_percentile = .95 normalization_n_spikes = 1000 @@ -591,6 +593,10 @@ def attach(self, gui): # ----------------------------------------------------------------------------- class CorrelogramView(GridView): + # TODO: make this configurable + excerpt_size = 10000 + n_excerpts = 100 + def __init__(self, spike_times=None, spike_clusters=None, @@ -611,7 +617,8 @@ def __init__(self, assert window_size > 0 self.window_size = window_size - # TODO: excerpt + self.excerpt_size = excerpt_size or self.excerpt_size + self.n_excerpts = n_excerpts or self.n_excerpts self.spike_times = np.asarray(spike_times) self.n_spikes, = self.spike_times.shape @@ -629,11 +636,24 @@ def on_select(self, cluster_ids, spike_ids): if n_spikes == 0: return - # TODO: excerpt + # Keep spikes belonging to the selected clusters. ind = np.in1d(self.spike_clusters, cluster_ids) st = self.spike_times[ind] sc = self.spike_clusters[ind] + # Take excerpts of the spikes. + n_spikes_total = len(st) + st = get_excerpts(st, excerpt_size=self.excerpt_size, + n_excerpts=self.n_excerpts) + sc = get_excerpts(sc, excerpt_size=self.excerpt_size, + n_excerpts=self.n_excerpts) + n_spikes_exerpts = len(st) + logger.debug("Computing correlograms for clusters %s (%d/%d spikes).", + ', '.join(map(str, cluster_ids)), + n_spikes_exerpts, n_spikes_total, + ) + + # Compute all pairwise correlograms. ccg = correlograms(st, sc, cluster_ids=cluster_ids, sample_rate=self.sample_rate, From df1887e841d909c26b9abfee538908ab2be6a783 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 11:27:41 +0100 Subject: [PATCH 0668/1059] Improve snippets and actions --- phy/gui/actions.py | 88 ++++++++++++++++++++++------------- phy/gui/gui.py | 9 ++-- phy/gui/tests/conftest.py | 4 +- phy/gui/tests/test_actions.py | 41 ++++++++++++---- 4 files changed, 95 insertions(+), 47 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 064651670..f52d45918 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -124,8 +124,7 @@ def wrapped(checked, *args, **kwargs): # pragma: no cover sequence = _get_qkeysequence(shortcut) if not isinstance(sequence, (tuple, list)): sequence = [sequence] - for s in sequence: - action.setShortcut(s) + action.setShortcuts(sequence) return action @@ -146,12 +145,6 @@ def __init__(self, gui, default_shortcuts=None): self.gui = gui gui.actions.append(self) - # Create and attach snippets. - self.snippets = Snippets(gui, self) - - def backup(self): - return list(self._actions_dict.values()) - def add(self, callback=None, name=None, shortcut=None, alias=None, verbose=True): """Add an action with a keyboard shortcut.""" @@ -187,6 +180,26 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, if callback: setattr(self, name, callback) + def disable(self, name=None): + """Disable one or all actions.""" + if name is None: + for name in self._actions_dict: + self.disable(name) + return + self._actions_dict[name].qaction.setEnabled(False) + + def enable(self, name=None): + """Enable one or all actions.""" + if name is None: + for name in self._actions_dict: + self.enable(name) + return + self._actions_dict[name].qaction.setEnabled(True) + + def get(self, name): + """Get a QAction instance from its name.""" + return self._actions_dict[name].qaction + def run(self, name, *args): """Run an action as specified by its name.""" assert isinstance(name, string_types) @@ -222,6 +235,12 @@ def show_shortcuts(self): """Print all shortcuts.""" _show_shortcuts(self.shortcuts, self.gui.windowTitle()) + def __contains__(self, name): + return name in self._actions_dict + + def __repr__(self): + return ''.format(sorted(self._actions_dict)) + # ----------------------------------------------------------------------------- # Snippets @@ -260,21 +279,18 @@ class Snippets(object): _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" " ,.;?!_-+~=*/\(){}[]") - def __init__(self, gui, actions): + def __init__(self, gui): self.gui = gui - assert isinstance(actions, Actions) - self.actions = actions - - # We will keep a backup of all actions so that we can switch - # safely to the set of shortcut actions when snippet mode is on. - self._actions_backup = [] + self.actions = Actions(gui) # Register snippet mode shortcut. - @actions.add(shortcut=':') + @self.actions.add(shortcut=':') def enable_snippet_mode(): self.mode_on() + self._create_snippet_actions() + @property def command(self): """This is used to write a snippet message in the status bar. @@ -350,7 +366,16 @@ def run(self, snippet): logger.info("Processing snippet `%s`.", snippet) try: - self.actions.run(name, *snippet_args[1:]) + # Try to run the snippet on all attached Actions instances. + for actions in self.gui.actions: + try: + actions.run(name, *snippet_args[1:]) + return + except ValueError: + # This Actions instance doesn't contain the requested + # snippet, trying the next attached Actions instance. + pass + logger.warn("Couldn't find action `%s`.", name) except Exception as e: logger.warn("Error when executing snippet: \"%s\".", str(e)) logger.debug(''.join(traceback.format_exception(*sys.exc_info()))) @@ -360,23 +385,22 @@ def is_mode_on(self): def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") - self._actions_backup = self.actions.backup() - # Remove all existing actions. - self.actions.remove_all() - # Add snippet keystroke actions. - self._create_snippet_actions() + + # Silent all actions except the Snippets actions. + for actions in self.gui.actions: + if actions != self.actions: + actions.disable() + self.actions.enable() + self.command = ':' def mode_off(self): self.gui.status_message = '' - # Remove all existing actions. - self.actions.remove_all() + + # Re-enable all actions except the Snippets actions. + self.actions.disable() + for actions in self.gui.actions: + if actions != self.actions: + actions.enable() + logger.info("Snippet mode disabled.") - # Reestablishes the shortcuts. - for action_obj in self._actions_backup: - self.actions.add(callback=action_obj.callback, - name=action_obj.name, - shortcut=action_obj.shortcut, - alias=action_obj.alias, - verbose=False, - ) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 962479c6e..3c987a443 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -12,7 +12,7 @@ from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) -from .actions import Actions, _show_shortcuts +from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter from phy.utils import load_master_config from phy.utils.plugin import get_plugin @@ -125,9 +125,9 @@ def __init__(self, # Default actions. self.default_actions = Actions(self) - @self.default_actions.add(shortcut='HelpContents') + @self.default_actions.add(shortcut=('HelpContents', 'h')) def show_shortcuts(): - shortcuts = {} + shortcuts = self.default_actions.shortcuts for actions in self.actions: shortcuts.update(actions.shortcuts) _show_shortcuts(shortcuts, self.name) @@ -136,6 +136,9 @@ def show_shortcuts(): def exit(): self.close() + # Create and attach snippets. + self.snippets = Snippets(self) + # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/conftest.py b/phy/gui/tests/conftest.py index d2cad174b..e2a2e6bdf 100644 --- a/phy/gui/tests/conftest.py +++ b/phy/gui/tests/conftest.py @@ -29,5 +29,5 @@ def actions(gui): @yield_fixture -def snippets(gui, actions): - yield Snippets(gui, actions) +def snippets(gui): + yield Snippets(gui) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index cac426772..55594f084 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -87,8 +87,13 @@ def show_my_shortcuts(): actions.run('t', 1) assert _res == [(1,)] + assert 'show_my_shortcuts' in actions + assert 'unknown' not in actions + actions.remove_all() + assert '' snippets.mode_on() # ':' - actions._snippet_backspace() + snippets.actions._snippet_backspace() _run('t3 hello') - actions._snippet_activate() # 'Enter' + snippets.actions._snippet_activate() # 'Enter' assert _actions[-1] == (3, ('hello',)) snippets.mode_off() From 35f8b13435803ed376690b460eee63edd0dab1a2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 11:28:00 +0100 Subject: [PATCH 0669/1059] Minor updates --- phy/cluster/manual/gui_component.py | 3 ++- phy/cluster/manual/tests/test_gui_component.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 99290a63c..a38dd4825 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -175,8 +175,9 @@ class ManualClustering(object): # Misc. 'save': 'Save', + 'show_shortcuts': 'Save', 'undo': 'Undo', - 'redo': 'Redo', + 'redo': ('Redo', 'ctrl+y'), } def __init__(self, diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index b15163eb0..c6a696ab0 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -51,6 +51,7 @@ def gui(qtbot): gui.show() qtbot.waitForWindowShown(gui) yield gui + qtbot.wait(5) gui.close() del gui qtbot.wait(5) From 7f9055e3b75d0f65bd1deb5e04d2dc70ab57dd01 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 11:32:41 +0100 Subject: [PATCH 0670/1059] Enable snippet mode action always enabled --- phy/gui/actions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index f52d45918..c3d4975da 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -290,6 +290,7 @@ def enable_snippet_mode(): self.mode_on() self._create_snippet_actions() + self.mode_off() @property def command(self): @@ -402,5 +403,7 @@ def mode_off(self): for actions in self.gui.actions: if actions != self.actions: actions.enable() + # The `:` shortcut should always be enabled. + self.actions.enable('enable_snippet_mode') logger.info("Snippet mode disabled.") From cd8c1f31a8dc18d938ec134cfad1088a8963c15e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 11:43:17 +0100 Subject: [PATCH 0671/1059] WIP: save --- phy/cluster/manual/gui_component.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index a38dd4825..c66fe6f59 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -143,7 +143,7 @@ class ManualClustering(object): when clusters are selected cluster(up) when a merge or split happens - save_requested(spike_clusters, cluster_groups) + request_save(spike_clusters, cluster_groups) when a save is requested by the user """ @@ -505,6 +505,6 @@ def redo(self): def save(self): spike_clusters = self.clustering.spike_clusters - groups = {c: self.cluster_meta.get('group', c) + groups = {c: self.cluster_meta.get('group', c) or 'unsorted' for c in self.clustering.cluster_ids} - self.gui.emit('save_requested', spike_clusters, groups) + self.gui.emit('request_save', spike_clusters, groups) From ec321457a2e84c9c218fb04c73dbc89b54ec07d9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 13:26:19 +0100 Subject: [PATCH 0672/1059] Add some phy.plot docstrings --- phy/plot/base.py | 22 ++++++++++++++++++++++ phy/plot/plot.py | 23 +++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/phy/plot/base.py b/phy/plot/base.py index b084317a0..696a39e80 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -152,6 +152,8 @@ def _insert_glsl(vertex, fragment, to_insert): class GLSLInserter(object): + """Insert GLSL snippets into shader codes.""" + def __init__(self): self._to_insert = defaultdict(list) self.insert_vert('vec2 temp_pos_tr = {{ var }};', @@ -171,12 +173,25 @@ def _insert(self, shader_type, glsl, location): self._to_insert[shader_type, location].append(glsl) def insert_vert(self, glsl, location='transforms'): + """Insert a GLSL snippet into the vertex shader. + + The location can be: + + * `header`: declaration of GLSL variables + * `before_transforms`: just before the transforms in the vertex shader + * `transforms`: where the GPU transforms are applied in the vertex + shader + * `after_transforms`: just after the GPU transforms + + """ self._insert('vert', glsl, location) def insert_frag(self, glsl, location=None): + """Insert a GLSL snippet into the fragment shader.""" self._insert('frag', glsl, location) def add_transform_chain(self, tc): + """Insert the GLSL snippets of a transform chain.""" # Generate the transforms snippet. for t in tc.gpu_transforms: if isinstance(t, Clip): @@ -190,12 +205,14 @@ def add_transform_chain(self, tc): self.insert_frag(clip.glsl('v_temp_pos_tr'), 'before_transforms') def insert_into_shaders(self, vertex, fragment): + """Apply the insertions to shader code.""" to_insert = defaultdict(str) to_insert.update({key: '\n'.join(self._to_insert[key]) for key in self._to_insert}) return _insert_glsl(vertex, fragment, to_insert) def __add__(self, inserter): + """Concatenate two inserters.""" self._to_insert.update(inserter._to_insert) return self @@ -254,6 +271,7 @@ def on_resize(self, event): self.context.set_viewport(0, 0, event.size[0], event.size[1]) def on_draw(self, e): + """Draw all visuals.""" gloo.clear() for visual in self.visuals: visual.on_draw() @@ -264,6 +282,7 @@ def on_draw(self, e): #------------------------------------------------------------------------------ class BaseInteract(object): + """Implement dynamic transforms on a canvas.""" canvas = None def attach(self, canvas): @@ -275,9 +294,12 @@ def on_visual_added(e): self.update_program(e.visual.program) def update_program(self, program): + """Override this method to update programs when `self.update()` + is called.""" pass def update(self): + """Update all visuals in the attached canvas.""" if not self.canvas: return for visual in self.canvas.visuals: diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 236e5ba62..96535ff07 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -34,20 +34,30 @@ def __init__(self): self._data = defaultdict(list) def add(self, name, val): + """Add an array.""" self._data[name].append(val) def get(self, name): + """Return the list of arrays for a given name.""" return _flatten(self._data[name]) @property def names(self): + """List of names.""" return set(self._data) def __getitem__(self, name): + """Concatenate all arrays with a given name.""" return np.vstack(self._data[name]).astype(np.float32) def _accumulate(data_list, no_concat=()): + """Concatenate a list of dicts `(name, array)`. + + You can specify some names which arrays should not be concatenated. + This is necessary with lists of plots with different sizes. + + """ acc = Accumulator() for data in data_list: for name, val in data.items(): @@ -62,6 +72,7 @@ def _accumulate(data_list, no_concat=()): def _make_scatter_class(marker): + """Return a temporary ScatterVisual class with a given marker.""" return type('ScatterVisual' + marker.title(), (ScatterVisual,), {'_default_marker': marker}) @@ -81,10 +92,12 @@ def __init__(self, **kwargs): self.clear() def clear(self): + """Reset the view.""" self._items = OrderedDict() self.visuals = [] def _add_item(self, cls, *args, **kwargs): + """Add a plot item.""" box_index = kwargs.pop('box_index', self._default_box_index) data = cls.validate(*args, **kwargs) @@ -101,14 +114,17 @@ def _add_item(self, cls, *args, **kwargs): return data def plot(self, *args, **kwargs): + """Add a line plot.""" return self._add_item(PlotVisual, *args, **kwargs) def scatter(self, *args, **kwargs): + """Add a scatter plot.""" cls = _make_scatter_class(kwargs.pop('marker', ScatterVisual._default_marker)) return self._add_item(cls, *args, **kwargs) def hist(self, *args, **kwargs): + """Add some histograms.""" return self._add_item(HistogramVisual, *args, **kwargs) def __getitem__(self, box_index): @@ -116,6 +132,12 @@ def __getitem__(self, box_index): return self def build(self): + """Build all added items. + + Visuals are created, added, and built. The `set_data()` methods can + be called afterwards. + + """ for cls, data_list in self._items.items(): # Some variables are not concatenated. They are specified # in `allow_list`. @@ -129,6 +151,7 @@ def build(self): @contextmanager def building(self): + """Context manager to specify the plots.""" self.clear() yield self.build() From 0f1958531804fc58240859a93574ad78f86c8b65 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 14:52:43 +0100 Subject: [PATCH 0673/1059] Fix path bug on Windows --- phy/utils/tests/test_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 96abdfd1c..8d81dd9f6 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -52,7 +52,7 @@ def _write_my_plugins_dir_in_config(temp_user_dir): # Now, we specify the path to the plugin in the phy config file. config_contents = """ c = get_config() - c.Plugins.dirs = ['%s'] + c.Plugins.dirs = [r'%s'] """ _write_text(op.join(temp_user_dir, 'phy_config.py'), config_contents % op.join(temp_user_dir, 'my_plugins/')) From 0a5745636fd76b7057e91835f610d75d6dc8fbe9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Dec 2015 15:36:42 +0100 Subject: [PATCH 0674/1059] Minor fixes --- phy/cluster/manual/gui_component.py | 2 +- phy/gui/actions.py | 3 ++- phy/gui/gui.py | 2 +- phy/gui/widgets.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index c66fe6f59..15b34aaf9 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -177,7 +177,7 @@ class ManualClustering(object): 'save': 'Save', 'show_shortcuts': 'Save', 'undo': 'Undo', - 'redo': ('Redo', 'ctrl+y'), + 'redo': ('ctrl+shift+z', 'ctrl+y'), } def __init__(self, diff --git a/phy/gui/actions.py b/phy/gui/actions.py index c3d4975da..cc5e47e75 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -98,7 +98,8 @@ def _show_shortcuts(shortcuts, name=None): print('Keyboard shortcuts' + name) for name in sorted(shortcuts): shortcut = _get_shortcut_string(shortcuts[name]) - print('{0:<40}: {1:s}'.format(name, shortcut)) + if not name.startswith('_'): + print('{0:<40}: {1:s}'.format(name, shortcut)) print() diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 3c987a443..327b7a848 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -132,7 +132,7 @@ def show_shortcuts(): shortcuts.update(actions.shortcuts) _show_shortcuts(shortcuts, self.name) - @self.default_actions.add(shortcut='Quit') + @self.default_actions.add(shortcut='ctrl+q') def exit(): self.close() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 440e9022b..fab72cfd9 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -262,7 +262,7 @@ def set_rows(self, ids): sort_dir = sort_dir or default_sort_dir or 'desc' # Set the rows. - logger.debug("Set %d rows in the table.", len(ids)) + logger.log(5, "Set %d rows in the table.", len(ids)) items = [self._get_row(id) for id in ids] data = _create_json_dict(items=items, cols=self.column_names, @@ -275,7 +275,7 @@ def set_rows(self, ids): def sort_by(self, name, sort_dir='asc'): """Sort by a given variable.""" - logger.debug("Sort by `%s` %s.", name, sort_dir) + logger.log(5, "Sort by `%s` %s.", name, sort_dir) self.eval_js('table.sortBy("{}", "{}");'.format(name, sort_dir)) def next(self): From e8bf2bc9aa3a163220fb982538165ed6311317be Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 9 Dec 2015 22:31:14 +0100 Subject: [PATCH 0675/1059] WIP: add docs --- docs/index.md | 18 ++++++++++++++++++ docs/install.md | 4 ++++ docs/overview.md | 13 +++++++++++++ mkdocs.yml | 13 +++++++++++++ 4 files changed, 48 insertions(+) create mode 100644 docs/index.md create mode 100644 docs/install.md create mode 100644 docs/overview.md create mode 100644 mkdocs.yml diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..94645ed60 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,18 @@ +# phy documentation + +phy is an ephys data analysis library. It provides everything you need to spikesort and analyze extracellular multielectrode recordings, including components to build analysis command-line and graphical applications. + + +## Frequently Asked Questions + +### What file formats does phy support? + +None. + +### Are there ready-to-use scripts and GUIs for spike sorting? + +No. + +### Can you add feature X? + +No. diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 000000000..d748597d9 --- /dev/null +++ b/docs/install.md @@ -0,0 +1,4 @@ +# Installation + +* Install Anaconda +* `conda install -c kwikteam/phy` (TODO) diff --git a/docs/overview.md b/docs/overview.md new file mode 100644 index 000000000..41c2c903d --- /dev/null +++ b/docs/overview.md @@ -0,0 +1,13 @@ +# Overview + +phy provides a set of generic tools for building data-intensive command-line and graphical applications with fast visualization capabilities. In addition, it implements ephys data analysis functions and manual clustering routines. Overall, these tools allow you to build specific data analysis applications, for example, a manual clustering GUI. + +phy is entirely agnostic to file formats and processing workflows. As such, it can be easily integrated with any existing system. + +## List of components + +* GUI +* Plotting +* Manual clustering routines +* Analysis functions +* Utilities diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000..e1d42d93a --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,13 @@ +site_name: phy +pages: +- Home: 'index.md' +- Installation: 'install.md' +- Overview: 'overview.md' +- GUI: 'gui.md' +- Plotting: 'plot.md' +- Command-line interface: 'cli.md' +- Configuration and plugin system: 'config.md' +- Manual clustering: 'cluster-manual.md' +- Analysis functions: 'analysis.md' +- API reference: 'api.md' +theme: readthedocs From f761f9b98f9b658914caaa8a450eef79c92d6e0d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 9 Dec 2015 22:37:23 +0100 Subject: [PATCH 0676/1059] Add empty docs --- .gitignore | 1 + docs/analysis.md | 0 docs/api.md | 0 docs/cluster-manual.md | 0 docs/config.md | 0 docs/gui.md | 0 docs/plot.md | 0 7 files changed, 1 insertion(+) create mode 100644 docs/analysis.md create mode 100644 docs/api.md create mode 100644 docs/cluster-manual.md create mode 100644 docs/config.md create mode 100644 docs/gui.md create mode 100644 docs/plot.md diff --git a/.gitignore b/.gitignore index 7f2da8ac9..075cbd602 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ contrib data doc +phy-doc docker experimental htmlcov diff --git a/docs/analysis.md b/docs/analysis.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/gui.md b/docs/gui.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/plot.md b/docs/plot.md new file mode 100644 index 000000000..e69de29bb From b6f285a9911e1608276a2eeb299c35d7164b3c5e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 9 Dec 2015 22:38:36 +0100 Subject: [PATCH 0677/1059] Add empty docs --- docs/cli.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/cli.md diff --git a/docs/cli.md b/docs/cli.md new file mode 100644 index 000000000..e69de29bb From b99d543ab0417b175a12862c95b2ea846ff0e68d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 9 Dec 2015 23:05:11 +0100 Subject: [PATCH 0678/1059] Add FAQ --- docs/faq.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ docs/index.md | 15 --------------- mkdocs.yml | 1 + 3 files changed, 47 insertions(+), 15 deletions(-) create mode 100644 docs/faq.md diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 000000000..8b72fdadc --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,46 @@ +# Frequently Asked Questions + +Read this section to understand what phy is and is not. + +## What file formats does phy support? + +None, but [phycontrib](https://github.com/kwikteam/phy-contrib) contains a set of plugins supporting file formats like the Kwik format. + +If a file format becomes a widely-used standard in the future, we might support it directly in phy. + +## Are there ready-to-use scripts and GUIs for spike sorting in phy? + +No, but there are some in [phycontrib](https://github.com/kwikteam/phy-contrib). + +## Can you add feature X? + +No. + +In principle, you should be able to implement it by writing a plugin, and if not, we'd be happy to help. + +If the majority of our users are desperately asking one particular feature, we might consider implementing it. + +Here is a nice explanation of why *No* should be the default answer here (from the *Getting Real* book by 37signals): + +> **Each time you say yes to a feature, you're adopting a child**. You have to take your baby through a whole chain of events (e.g. design, implementation, testing, etc.). And once that feature's out there, you're stuck with it. Just try to take a released feature away from customers and see how pissed off they get. + +> Make each feature work hard to be implemented. Make each feature prove itself and show that it's a survivor. **It's like "Fight Club." You should only consider features if they're willing to stand on the porch for three days waiting to be let in.** + +> That's why you start with no. Every new feature request that comes to us – or from us – meets a no. We listen but don't act. **The initial response is "not now."** If a request for a feature keeps coming back, that's when we know it's time to take a deeper look. Then, and only then, do we start considering the feature for real. + +> And what do you say to people who complain when you won't adopt their feature idea? Remind them why they like the app in the first place. "You like it because we say no. You like it because it doesn't do 100 other things. You like it because it doesn't try to please everyone all the time." + +References: + +* [Why you should write buggy software with as few features as possible](https://astrocompute.wordpress.com/2013/07/11/why-you-should-write-buggy-software-with-as-few-features-as-possible-no-really/) +* [Getting Real](https://basecamp.com/books/Getting%20Real.pdf) + +## Where can I get some help? + +[On the mailing list](https://groups.google.com/forum/#!forum/phy-users). + +## Who is developing phy? + +Cyrille Rossant, Max Hunter, aided by Shabnam Kadir, Nick Steinmetz, from the Cortexlab (University College London), led by Kenneth Harris and Matteo Carandini. + +## How do I cite phy? diff --git a/docs/index.md b/docs/index.md index 94645ed60..9793c3f2e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,18 +1,3 @@ # phy documentation phy is an ephys data analysis library. It provides everything you need to spikesort and analyze extracellular multielectrode recordings, including components to build analysis command-line and graphical applications. - - -## Frequently Asked Questions - -### What file formats does phy support? - -None. - -### Are there ready-to-use scripts and GUIs for spike sorting? - -No. - -### Can you add feature X? - -No. diff --git a/mkdocs.yml b/mkdocs.yml index e1d42d93a..64ecb72a7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,7 @@ site_name: phy pages: - Home: 'index.md' +- FAQ: 'faq.md' - Installation: 'install.md' - Overview: 'overview.md' - GUI: 'gui.md' From f4958e2e9fd3ccc8d4463b5914132d715895cf08 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 00:42:12 +0100 Subject: [PATCH 0679/1059] WIP: GUI doc --- docs/gui.md | 162 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/docs/gui.md b/docs/gui.md index e69de29bb..99faeae5c 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -0,0 +1,162 @@ +# GUI + +`phy.gui` provides generic Qt-based GUI components. You don't need to know Qt to use `phy.gui`, although it might help. + +## Creating a Qt application + +You need to create a Qt application before creating and using GUIs. There is a single Qt application object in a Python interpreter. + +In IPython (console or notebook), you can just use the following magic command before doing anything Qt-related: + +```python +>>> %gui qt +``` + +In other situations, like in regular Python scripts, you need to: + +* Call `phy.gui.create_app()` once, before you create a GUI. +* Call `phy.gui.run_app()` to launch your application. This blocks the Python interpreter and runs the Qt event loop. Generally, when this call returns, the application exits. + +For interactive use and explorative work, it is highly recommended to use IPython, for example with a Jupyter Notebook. + +## Creating a GUI + +phy provides a **GUI**, a main window with dockable widgets (`QMainWindow`). By default, a GUI is empty, but you can add views. A view is any Qt widget or a VisPy canvas. + +Let's create an empty GUI: + +```python +>>> from phy.gui import GUI +>>> gui = GUI(position=(400, 200), size=(600, 400)) +>>> gui.show() +``` + +## Adding a visualization + +We can add any Qt widget with `gui.add_view(widget)`, as well as visualizations with VisPy or matplotlib (which are fully compatible with Qt and phy). + +### With VisPy + +The `gui.add_view()` method accepts any VisPy canvas. For example, here we add an empty VisPy window: + +```python +>>> from vispy.app import Canvas +>>> from vispy import gloo +... +>>> c = Canvas() +... +>>> @c.connect +... def on_draw(e): +... gloo.clear('purple') +... +>>> gui.add_view(c) + +``` + +We can now dock and undock our widget from the GUI. This is particularly convenient when there are many widgets. + +### With matplotlib + +Here we add a matplotlib figure to our GUI: + +```python +>>> import numpy as np +>>> import matplotlib.pyplot as plt +... +>>> f = plt.figure() +>>> ax = f.add_subplot(111) +>>> t = np.linspace(-10., 10., 1000) +>>> ax.plot(t, np.sin(t)) +... +>>> # TODO: implement this directly in phy +... from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas +>>> gui.add_view(FigureCanvas(f)) + +``` + +## Adding an HTML widget + +phy provides an `HTMLWidget` component which allows you to create widgets in HTML. This is just a `QWebView` with some user-friendly facilities. + +First, let's create a standalone HTML widget: + +```python +>>> from phy.gui.widgets import HTMLWidget +>>> widget = HTMLWidget() +>>> widget.set_body("Hello world!") +>>> widget.show() +``` + +Now that our widget is created, let's add it to the GUI: + +```python +>>> gui.add_view(widget) +``` + +You'll find in the API reference other methods to edit the styles, scripts, header, and body of the HTML widget. + +### Interactivity with Javascript + +We can use Javascript in an HTML widget, and we can make Python and Javascript communicate. + +```python +>>> from phy.gui.widgets import HTMLWidget +>>> widget = HTMLWidget() +>>> widget.set_body('
') +>>> # We can execute Javascript code from Python. +... widget.eval_js("document.getElementById('mydiv').innerHTML='hello'") +>>> widget.show() +>>> gui.add_view(widget) + +``` + +You can use `widget.eval_js()` to evaluate Javascript code from Python. Conversely, you can use `widget.some_method()` from Javascript, where `some_method()` is a method implemented in a subclass of `HTMLWidget`. + +## Other GUI methods + +Let's display the list of views in the GUI: + +```python +>>> gui.list_views() +[, + , + , + ] +``` + +The following method allows you to check how many views of each class there are: + +```python +>>> gui.view_count() +{'canvas': 1, 'figurecanvasqtagg': 1, 'htmlwidget': 2} +``` + +Use the following property to change the status bar: + +```python +>>> gui.status_message = "Hello world" +``` + +Finally, the following methods allow you to save/restore the state of the GUI and the widgets: + +```python +>>> gs = gui.save_geometry_state() +``` + +```python +>>> gui.restore_geometry_state(gs) +``` + +The object `gs` is a JSON-serializable Python dictionary. + +## Adding actions + +```python + +``` + +## Snippets + +```python + +``` From c668ab4b4c7a3a225e82e9626e7d27cd64d99432 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 11:14:16 +0100 Subject: [PATCH 0680/1059] Update GUI doc --- docs/gui.md | 58 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/docs/gui.md b/docs/gui.md index 99faeae5c..47181ba1a 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -29,6 +29,7 @@ Let's create an empty GUI: >>> from phy.gui import GUI >>> gui = GUI(position=(400, 200), size=(600, 400)) >>> gui.show() +INFO:phy.gui.actions:Snippet mode disabled. ``` ## Adding a visualization @@ -50,7 +51,7 @@ The `gui.add_view()` method accepts any VisPy canvas. For example, here we add a ... gloo.clear('purple') ... >>> gui.add_view(c) - + ``` We can now dock and undock our widget from the GUI. This is particularly convenient when there are many widgets. @@ -71,7 +72,7 @@ Here we add a matplotlib figure to our GUI: >>> # TODO: implement this directly in phy ... from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas >>> gui.add_view(FigureCanvas(f)) - + ``` ## Adding an HTML widget @@ -91,6 +92,7 @@ Now that our widget is created, let's add it to the GUI: ```python >>> gui.add_view(widget) + ``` You'll find in the API reference other methods to edit the styles, scripts, header, and body of the HTML widget. @@ -107,7 +109,7 @@ We can use Javascript in an HTML widget, and we can make Python and Javascript c ... widget.eval_js("document.getElementById('mydiv').innerHTML='hello'") >>> widget.show() >>> gui.add_view(widget) - + ``` You can use `widget.eval_js()` to evaluate Javascript code from Python. Conversely, you can use `widget.some_method()` from Javascript, where `some_method()` is a method implemented in a subclass of `HTMLWidget`. @@ -118,10 +120,10 @@ Let's display the list of views in the GUI: ```python >>> gui.list_views() -[, - , - , - ] +[, + , + , + ] ``` The following method allows you to check how many views of each class there are: @@ -151,12 +153,52 @@ The object `gs` is a JSON-serializable Python dictionary. ## Adding actions +An **action** is a Python function that can be run by the user by clicking on a button or pressing a keyboard shortcut. You can create an `Actions` object to specify a list of actions attached to a GUI. + ```python +>>> from phy.gui import Actions +>>> actions = Actions(gui) +... +>>> @actions.add(shortcut='ctrl+h') +... def hello(): +... print("Hello world!") +``` + +Now, if you press *Ctrl+H* in the GUI, you'll see ̀`Hello world!` printed in the console. + +Once an action is added, you can call it with `actions.hello()` where `hello` is the name of the action. By default, this is the name of the associated function, but you can also specify the name explicitly with the `name=...` keyword argument in `actions.add()`. +You'll find more details about `actions.add()` in the API reference. + +Every GUI comes with a `default_actions` property which implements actions always available in GUIs: + +```python +>>> gui.default_actions + ``` -## Snippets +For example, the following action shows the shortcuts of all actions attached to the GUI: ```python +>>> gui.default_actions.show_shortcuts() +Keyboard shortcuts for GUI +enable_snippet_mode : : +exit : ctrl+q +hello : ctrl+h +show_shortcuts : f1, h ``` + +## Snippets + +The GUI provides a convenient system to quickly execute actions without leaving one's keyboard. Inspired by console-based text editors like *vim*, it is enabled by pressing `:` on the keyboard. Once this mode is enabled, what you type is displayed in the status bar. Then, you can call a function by typing its name or its alias. You can also use arguments to the actions, using a special syntax. Here is an example. + +```python +>>> @actions.add(alias='c') +... def select(ids, obj): +... print("Select %s with %s" % (ids, obj)) +``` + +Now, pressing `:c 3-6 hello` followed by the `Enter` keystrokes displays `Select [3, 4, 5, 6] with hello` in the console. + +By convention, multiple arguments are separated by spaces, sequences of numbers are given either with `2,3,5,7` or `3-6` for consecutive numbers. If an alias is not specified when adding the action, you can always use the full action's name. From bbb36bf3889210284ecd39ee2e0489e4c128c29b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 11:21:10 +0100 Subject: [PATCH 0681/1059] Support matplotlib figures in GUI.add_view() --- phy/gui/gui.py | 33 +++++++++++++++++++++++++-------- phy/gui/tests/test_gui.py | 23 ++++++++++++++++++++--- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 327b7a848..9cd2f6590 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -28,6 +28,29 @@ def _title(widget): return str(widget.windowTitle()).lower() +def _try_get_vispy_canvas(view): + # Get the Qt widget from a VisPy canvas. + try: + from vispy.app import Canvas + if isinstance(view, Canvas): + view = view.native + except ImportError: # pragma: no cover + pass + return view + + +def _try_get_matplotlib_canvas(view): + # Get the Qt widget from a matplotlib figure. + try: + from matplotlib.pyplot import Figure + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg + if isinstance(view, Figure): + view = FigureCanvasQTAgg(view) + except ImportError: # pragma: no cover + pass + return view + + def load_gui_plugins(gui, plugins=None, session=None): """Attach a list of plugins to a GUI. @@ -182,15 +205,9 @@ def add_view(self, **kwargs): """Add a widget to the main window.""" - try: - from vispy.app import Canvas - if isinstance(view, Canvas): - title = title or view.__class__.__name__ - view = view.native - except ImportError: # pragma: no cover - pass - title = title or view.__class__.__name__ + view = _try_get_vispy_canvas(view) + view = _try_get_matplotlib_canvas(view) # Create the gui widget. dockwidget = DockWidget(self) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 9a13cff2e..839d62cc9 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -8,8 +8,11 @@ from pytest import raises -from ..qt import Qt, QApplication -from ..gui import GUI, load_gui_plugins +from ..qt import Qt, QApplication, QWidget +from ..gui import (GUI, load_gui_plugins, + _try_get_matplotlib_canvas, + _try_get_vispy_canvas, + ) from phy.utils import IPlugin from phy.utils._color import _random_color @@ -32,7 +35,21 @@ def on_draw(e): # pragma: no cover #------------------------------------------------------------------------------ -# Test gui +# Test views +#------------------------------------------------------------------------------ + +def test_vispy_view(): + from vispy.app import Canvas + assert isinstance(_try_get_vispy_canvas(Canvas()), QWidget) + + +def test_matplotlib_view(): + from matplotlib.pyplot import Figure + assert isinstance(_try_get_matplotlib_canvas(Figure()), QWidget) + + +#------------------------------------------------------------------------------ +# Test GUI #------------------------------------------------------------------------------ def test_gui_noapp(): From 5cd0c20190917fee63755f3e6daf744239fa807d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 11:27:39 +0100 Subject: [PATCH 0682/1059] Save the GUI status message when the snippet mode is enabled --- docs/gui.md | 6 +----- phy/gui/actions.py | 9 +++++++-- phy/gui/tests/test_actions.py | 7 +++++++ 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/docs/gui.md b/docs/gui.md index 47181ba1a..473949acc 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -29,7 +29,6 @@ Let's create an empty GUI: >>> from phy.gui import GUI >>> gui = GUI(position=(400, 200), size=(600, 400)) >>> gui.show() -INFO:phy.gui.actions:Snippet mode disabled. ``` ## Adding a visualization @@ -68,10 +67,7 @@ Here we add a matplotlib figure to our GUI: >>> ax = f.add_subplot(111) >>> t = np.linspace(-10., 10., 1000) >>> ax.plot(t, np.sin(t)) -... ->>> # TODO: implement this directly in phy -... from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas ->>> gui.add_view(FigureCanvas(f)) +>>> gui.add_view(f) ``` diff --git a/phy/gui/actions.py b/phy/gui/actions.py index cc5e47e75..e59ddfb69 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -282,6 +282,7 @@ class Snippets(object): def __init__(self, gui): self.gui = gui + self._status_message = gui.status_message self.actions = Actions(gui) @@ -387,6 +388,8 @@ def is_mode_on(self): def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") + # Save the current status message. + self._status_message = self.gui.status_message # Silent all actions except the Snippets actions. for actions in self.gui.actions: @@ -397,7 +400,9 @@ def mode_on(self): self.command = ':' def mode_off(self): - self.gui.status_message = '' + # Reset the GUI status message that was set before the mode was + # activated. + self.gui.status_message = self._status_message # Re-enable all actions except the Snippets actions. self.actions.disable() @@ -407,4 +412,4 @@ def mode_off(self): # The `:` shortcut should always be enabled. self.actions.enable('enable_snippet_mode') - logger.info("Snippet mode disabled.") + logger.debug("Snippet mode disabled.") diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index 55594f084..c029cb726 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -134,6 +134,13 @@ def press(): assert actions.get('press').isEnabled() +def test_snippets_message(qtbot, gui): + gui.status_message = 'Hello world!' + gui.snippets.mode_on() + gui.snippets.mode_off() + assert gui.status_message == 'Hello world!' + + def test_snippets_gui(qtbot, gui, actions): qtbot.addWidget(gui) gui.show() From 38df646915b66e3459eb7757a6cd3d33cb154804 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 11:49:56 +0100 Subject: [PATCH 0683/1059] Table documentation --- docs/gui.md | 47 ++++++++++++++++++++++++++++++++--- phy/gui/__init__.py | 1 + phy/gui/actions.py | 2 -- phy/gui/tests/test_widgets.py | 4 +++ phy/gui/widgets.py | 8 +++--- 5 files changed, 53 insertions(+), 9 deletions(-) diff --git a/docs/gui.md b/docs/gui.md index 473949acc..b807af675 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -50,7 +50,7 @@ The `gui.add_view()` method accepts any VisPy canvas. For example, here we add a ... gloo.clear('purple') ... >>> gui.add_view(c) - + ``` We can now dock and undock our widget from the GUI. This is particularly convenient when there are many widgets. @@ -68,7 +68,7 @@ Here we add a matplotlib figure to our GUI: >>> t = np.linspace(-10., 10., 1000) >>> ax.plot(t, np.sin(t)) >>> gui.add_view(f) - + ``` ## Adding an HTML widget @@ -78,7 +78,7 @@ phy provides an `HTMLWidget` component which allows you to create widgets in HTM First, let's create a standalone HTML widget: ```python ->>> from phy.gui.widgets import HTMLWidget +>>> from phy.gui import HTMLWidget >>> widget = HTMLWidget() >>> widget.set_body("Hello world!") >>> widget.show() @@ -93,12 +93,51 @@ Now that our widget is created, let's add it to the GUI: You'll find in the API reference other methods to edit the styles, scripts, header, and body of the HTML widget. +### Table + +phy also provides a `Table` widget written in HTML and Javascript (using the [tablesort](https://github.com/tristen/tablesort) Javascript library). This widget shows a table of items, where every item (row) has an id, and every column is defined as a function `id => string`, the string being the contents of a row's cell in the table. The table can be sorted by every column. + +One or several items can be selected by the user. The `select` event is raised when rows are selected. Here is a complete example: + +--- +scrolled: true +... + +```python +>>> from phy.gui import Table +>>> table = Table() +... +>>> # We add a column in the table. +... @table.add_column +... def name(id): +... # This function takes an id as input and returns a string. +... return "My id is %d" % id +... +>>> # Now we add some rows. +... table.set_rows([2, 3, 5, 7]) +... +>>> # We print something when items are selected. +... @table.connect_ +... def on_select(ids): +... # NOTE: we use `connect_` and not `connect`, because `connect` is +... # a Qt method associated to every Qt widget, and `Table` is a subclass +... # of `QWidget`. Using `connect_` ensures that we're using phy's event +... # system, not Qt's. +... print("The items %s have been selected." % ids) +... +>>> table.show() +The items [3] have been selected. +The items [3, 5] have been selected. +``` + + + ### Interactivity with Javascript We can use Javascript in an HTML widget, and we can make Python and Javascript communicate. ```python ->>> from phy.gui.widgets import HTMLWidget +>>> from phy.gui import HTMLWidget >>> widget = HTMLWidget() >>> widget.set_body('
') >>> # We can execute Javascript code from Python. diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index d9aa2be96..1e7576a2e 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -6,3 +6,4 @@ from .qt import require_qt, create_app, run_app from .gui import GUI, load_gui_plugins from .actions import Actions +from .widgets import HTMLWidget, Table diff --git a/phy/gui/actions.py b/phy/gui/actions.py index e59ddfb69..2fde837f9 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -411,5 +411,3 @@ def mode_off(self): actions.enable() # The `:` shortcut should always be enabled. self.actions.enable('enable_snippet_mode') - - logger.debug("Snippet mode disabled.") diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 6c43b7fc0..caeb8c2b5 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -88,6 +88,10 @@ def on_test(arg): # Test table #------------------------------------------------------------------------------ +def test_table_current_sort(): + assert Table().current_sort == (None, None) + + def test_table_default_sort(qtbot): table = Table() table.show() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index fab72cfd9..dba4e52f9 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -245,6 +245,7 @@ def add_column(self, func, name=None, show=True): @property def column_names(self): + """List of column names.""" return [name for (name, d) in self._columns.items() if d.get('show', True)] @@ -279,11 +280,11 @@ def sort_by(self, name, sort_dir='asc'): self.eval_js('table.sortBy("{}", "{}");'.format(name, sort_dir)) def next(self): - """Select the next non-skip row.""" + """Select the next non-skipped row.""" self.eval_js('table.next();') def previous(self): - """Select the previous non-skip row.""" + """Select the previous non-skipped row.""" self.eval_js('table.previous();') def select(self, ids, do_emit=True): @@ -297,6 +298,7 @@ def default_sort(self): return self._default_sort def set_default_sort(self, name, sort_dir='desc'): + """Set the default sort column.""" self._default_sort = name, sort_dir @property @@ -307,4 +309,4 @@ def selected(self): @property def current_sort(self): """Current sort: a tuple `(name, dir)`.""" - return tuple(self.eval_js('table.currentSort()')) + return tuple(self.eval_js('table.currentSort()') or (None, None)) From d3e8930e3c539e105f0f31173148e26e811ad083 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 11:51:23 +0100 Subject: [PATCH 0684/1059] Fix --- docs/gui.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/gui.md b/docs/gui.md index b807af675..cd5c208e5 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -99,10 +99,6 @@ phy also provides a `Table` widget written in HTML and Javascript (using the [ta One or several items can be selected by the user. The `select` event is raised when rows are selected. Here is a complete example: ---- -scrolled: true -... - ```python >>> from phy.gui import Table >>> table = Table() @@ -131,7 +127,6 @@ The items [3, 5] have been selected. ``` - ### Interactivity with Javascript We can use Javascript in an HTML widget, and we can make Python and Javascript communicate. From cd007985e90a64ab00e75ff52481d06a9aaa31bb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 11:56:04 +0100 Subject: [PATCH 0685/1059] Minor updates in GUI doc --- docs/gui.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/gui.md b/docs/gui.md index cd5c208e5..d9b957540 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -21,7 +21,7 @@ For interactive use and explorative work, it is highly recommended to use IPytho ## Creating a GUI -phy provides a **GUI**, a main window with dockable widgets (`QMainWindow`). By default, a GUI is empty, but you can add views. A view is any Qt widget or a VisPy canvas. +phy provides a **GUI**, a main window with dockable widgets (`QMainWindow`). By default, a GUI is empty, but you can add views. A view is any Qt widget or a matplotlib or VisPy canvas. Let's create an empty GUI: @@ -134,7 +134,7 @@ We can use Javascript in an HTML widget, and we can make Python and Javascript c ```python >>> from phy.gui import HTMLWidget >>> widget = HTMLWidget() ->>> widget.set_body('
') +>>> widget.set_body('
') >>> # We can execute Javascript code from Python. ... widget.eval_js("document.getElementById('mydiv').innerHTML='hello'") >>> widget.show() @@ -142,7 +142,7 @@ We can use Javascript in an HTML widget, and we can make Python and Javascript c ``` -You can use `widget.eval_js()` to evaluate Javascript code from Python. Conversely, you can use `widget.some_method()` from Javascript, where `some_method()` is a method implemented in a subclass of `HTMLWidget`. +You can use `widget.eval_js()` to evaluate Javascript code from Python. Conversely, you can use `widget.some_method()` from Javascript, where `some_method()` is a method implemented in your widget (which should be a subclass of `HTMLWidget`). ## Other GUI methods @@ -194,7 +194,7 @@ An **action** is a Python function that can be run by the user by clicking on a ... print("Hello world!") ``` -Now, if you press *Ctrl+H* in the GUI, you'll see ̀`Hello world!` printed in the console. +Now, if you press *Ctrl+H* in the GUI, you'll see `Hello world!` printed in the console. Once an action is added, you can call it with `actions.hello()` where `hello` is the name of the action. By default, this is the name of the associated function, but you can also specify the name explicitly with the `name=...` keyword argument in `actions.add()`. @@ -219,6 +219,8 @@ hello : ctrl+h show_shortcuts : f1, h ``` +You can create multiple `Actions` instance for a single GUI, which allows you to separate between different sets of actions. + ## Snippets The GUI provides a convenient system to quickly execute actions without leaving one's keyboard. Inspired by console-based text editors like *vim*, it is enabled by pressing `:` on the keyboard. Once this mode is enabled, what you type is displayed in the status bar. Then, you can call a function by typing its name or its alias. You can also use arguments to the actions, using a special syntax. Here is an example. @@ -229,6 +231,6 @@ The GUI provides a convenient system to quickly execute actions without leaving ... print("Select %s with %s" % (ids, obj)) ``` -Now, pressing `:c 3-6 hello` followed by the `Enter` keystrokes displays `Select [3, 4, 5, 6] with hello` in the console. +Now, pressing `:c 3-6 hello` followed by the `Enter` keystroke displays `Select [3, 4, 5, 6] with hello` in the console. By convention, multiple arguments are separated by spaces, sequences of numbers are given either with `2,3,5,7` or `3-6` for consecutive numbers. If an alias is not specified when adding the action, you can always use the full action's name. From 88eeebe4f31e9dae2d36455e178b1c3503c11a46 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 12:07:55 +0100 Subject: [PATCH 0686/1059] Add SimpleView --- phy/plot/__init__.py | 2 +- phy/plot/plot.py | 15 ++++++++++++++- phy/plot/tests/test_plot.py | 6 +++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 2b5f48742..8554d8e0f 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -12,7 +12,7 @@ from vispy import config -from .plot import GridView, BoxedView, StackedView # noqa +from .plot import SimpleView, GridView, BoxedView, StackedView # noqa from .transform import Translate, Scale, Range, Subplot, NDC from .panzoom import PanZoom from .utils import _get_linear_x diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 96535ff07..8bf1dd844 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -146,7 +146,11 @@ def build(self): visual = cls() self.add_visual(visual) visual.set_data(**data) - visual.program['a_box_index'] = box_index + try: + visual.program['a_box_index'] + visual.program['a_box_index'] = box_index + except KeyError: + pass self.update() @contextmanager @@ -157,6 +161,15 @@ def building(self): self.build() +class SimpleView(BaseView): + """A simple view.""" + def __init__(self, shape=None, **kwargs): + super(SimpleView, self).__init__(**kwargs) + + self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) + self.panzoom.attach(self) + + class GridView(BaseView): """A 2D grid with clipping.""" _default_box_index = (0, 0) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 48943556e..eddc4f677 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -10,7 +10,7 @@ import numpy as np from ..panzoom import PanZoom -from ..plot import BaseView, GridView, BoxedView, StackedView +from ..plot import BaseView, SimpleView, GridView, BoxedView, StackedView from ..utils import _get_linear_x @@ -46,8 +46,8 @@ def test_building(qtbot): view.close() -def test_base_view(qtbot): - view = BaseView(keys='interactive') +def test_simple_view(qtbot): + view = SimpleView(keys='interactive') n = 1000 x = np.random.randn(n) From 3581c76d449ff23e9f9277edde0921f4d74a16bf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 12:07:59 +0100 Subject: [PATCH 0687/1059] Start plot doc --- docs/overview.md | 2 ++ docs/plot.md | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/docs/overview.md b/docs/overview.md index 41c2c903d..5654e8505 100644 --- a/docs/overview.md +++ b/docs/overview.md @@ -8,6 +8,8 @@ phy is entirely agnostic to file formats and processing workflows. As such, it c * GUI * Plotting +* Command-line interface +* Configuration and plugin system * Manual clustering routines * Analysis functions * Utilities diff --git a/docs/plot.md b/docs/plot.md index e69de29bb..01d6bc799 100644 --- a/docs/plot.md +++ b/docs/plot.md @@ -0,0 +1,30 @@ +# Plotting with VisPy + +phy provides a simple and fast plotting system based on VisPy's low-level **gloo** interface. This plotting system is entirely generic. + +```python +>>> %gui qt +``` + +```python +>>> import numpy as np +>>> from phy.plot import SimpleView +``` + +```python +>>> view = SimpleView() +... +>>> n = 1000 +>>> x, y = np.random.randn(2, n) +>>> c = np.random.uniform(.3, .7, (n, 4)) +>>> s = np.random.uniform(5, 30, n) +... +>>> with view.building(): +... view.scatter(x, y, color=c, size=s, marker='disc') +... +>>> view.show() +``` + +```python + +``` From 72f4fa220127718c88bee6cee4c6c0e26b23e8c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 12:28:55 +0100 Subject: [PATCH 0688/1059] WIP: plot doc --- docs/plot.md | 52 +++++++++++++++++++++++++++++++++++++++++++++--- phy/plot/plot.py | 6 +----- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/docs/plot.md b/docs/plot.md index 01d6bc799..1adb8d93f 100644 --- a/docs/plot.md +++ b/docs/plot.md @@ -1,14 +1,20 @@ # Plotting with VisPy -phy provides a simple and fast plotting system based on VisPy's low-level **gloo** interface. This plotting system is entirely generic. +phy provides a simple and fast plotting system based on VisPy's low-level **gloo** interface. This plotting system is entirely generic. Currently, it privileges speed and scalability over quality. In other words, you can display millions of points at very high speed, but the plotting quality is not as good as matplotlib, for example. + +First, we need to activate the Qt event loop in IPython, or create and run the Qt application in a script. ```python >>> %gui qt ``` +## Simple view + +Let's create a simple view with a scatter plot. + ```python >>> import numpy as np ->>> from phy.plot import SimpleView +>>> from phy.plot import SimpleView, GridView, BoxedView, StackedView ``` ```python @@ -19,12 +25,52 @@ phy provides a simple and fast plotting system based on VisPy's low-level **gloo >>> c = np.random.uniform(.3, .7, (n, 4)) >>> s = np.random.uniform(5, 30, n) ... ->>> with view.building(): +>>> # NOTE: currently, the building process needs to be explicit. +... # All commands that construct the view should be enclosed in this +... # context manager, or at least one should ensure that +... # `view.clear()` and `view.build()` are called before and after +... # the building commands. +... with view.building(): ... view.scatter(x, y, color=c, size=s, marker='disc') ... >>> view.show() ``` +Note that you can pan and zoom with the mouse and keyboard. + +The other plotting commands currently supported are `plot()` and `hist()`. We're planning to add support for text in the near future. + +## Grid view + +The `GridView` lets you create multiple subplots arranged in a grid. Subplots are all individually clipped, which means that their viewports never overlap across the grid boundaries. Here is an example: + ```python +>>> view = GridView((1, 2)) # the shape is `(n_rows, n_cols)` +... +>>> x = np.linspace(-10., 10., 1000) +... +>>> with view.building(): +... view[0, 0].plot(x, np.sin(x)) +... view[0, 1].plot(x, np.cos(x), color=(1, 0, 0, 1)) +... +>>> view.show() +``` + +Subplots are created with the `view[i, j]` syntax. The indexing scheme works like mathematical matrices (origin at the upper left). + +Note that there are no axes at this point, but we'll be working on it. Also, independent per-subplot panning and zooming is not supported and this is unlikely to change in the foreseable future. + +## Stacked view + +The stacked view lets you stack several subplots vertically with no clipping. An example is a trace view showing a multichannel time-dependent signal. +```python +>>> view = StackedView(50) +... +>>> with view.building(): +... for i in range(view.n_plots): +... view[i].plot(y=np.random.randn(2000), +... color=np.random.uniform(.5, .9, 4)) +... +>>> view.show() ``` diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 8bf1dd844..14826ad37 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -146,11 +146,7 @@ def build(self): visual = cls() self.add_visual(visual) visual.set_data(**data) - try: - visual.program['a_box_index'] - visual.program['a_box_index'] = box_index - except KeyError: - pass + visual.program['a_box_index'] = box_index self.update() @contextmanager From 46c1716c0e9d84762a6a68cb9ba55f1bc159bcbc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 12:45:55 +0100 Subject: [PATCH 0689/1059] Add BoxedView doc --- docs/plot.md | 27 +++++++++++++++++++++++++++ phy/plot/plot.py | 12 ++++++++---- phy/plot/tests/test_plot.py | 2 +- 3 files changed, 36 insertions(+), 5 deletions(-) diff --git a/docs/plot.md b/docs/plot.md index 1adb8d93f..ac8b5eaaa 100644 --- a/docs/plot.md +++ b/docs/plot.md @@ -74,3 +74,30 @@ The stacked view lets you stack several subplots vertically with no clipping. An ... >>> view.show() ``` + +## Boxed view + +The boxed view lets you put subplots at arbitrary locations. You can dynamically change the positions and the sizes of the boxes. An example is the waveform view, where line plots are positioned at the recording sites on a multielectrode array. + +```python +>>> # Generate box positions along a circle. +... dt = np.pi / 10 +>>> t = np.arange(0, 2 * np.pi, dt) +>>> x = np.cos(t) +>>> y = np.sin(t) +>>> box_pos = np.c_[x, y] +... +>>> view = BoxedView(box_pos=box_pos) +... +>>> with view.building(): +... for i in range(view.n_plots): +... # Create the subplots. +... view[i].plot(y=np.random.randn(10, 100), +... color=np.random.uniform(.5, .9, 4)) +... +>>> view.show() +``` + +You can use `ctrl+arrows` and `shift+arrows` to change the scaling of the positions and boxes. + +## Data normalization diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 14826ad37..33caf10b4 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -146,7 +146,8 @@ def build(self): visual = cls() self.add_visual(visual) visual.set_data(**data) - visual.program['a_box_index'] = box_index + if 'a_box_index' in visual.program: + visual.program['a_box_index'] = box_index self.update() @contextmanager @@ -182,11 +183,14 @@ def __init__(self, shape=None, **kwargs): class BoxedView(BaseView): """Subplots at arbitrary positions""" - def __init__(self, box_bounds, **kwargs): + def __init__(self, box_bounds=None, box_pos=None, box_size=None, **kwargs): super(BoxedView, self).__init__(**kwargs) - self.n_plots = len(box_bounds) + self.n_plots = (len(box_bounds) + if box_bounds is not None else len(box_pos)) - self.boxed = Boxed(box_bounds) + self.boxed = Boxed(box_bounds=box_bounds, + box_pos=box_pos, + box_size=box_size) self.boxed.attach(self) self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index eddc4f677..a99d92d37 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -47,7 +47,7 @@ def test_building(qtbot): def test_simple_view(qtbot): - view = SimpleView(keys='interactive') + view = SimpleView() n = 1000 x = np.random.randn(n) From 385edf737950b4b6f9de94088bd33e4b4f5d64ed Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 13:07:26 +0100 Subject: [PATCH 0690/1059] Fix bug with latest VisPy release --- phy/plot/plot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 33caf10b4..15eab0d84 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -146,7 +146,10 @@ def build(self): visual = cls() self.add_visual(visual) visual.set_data(**data) - if 'a_box_index' in visual.program: + # NOTE: visual.program.__contains__ is implemented in vispy master + # so we can replace this with `if 'a_box_index' in visual.program` + # after the next VisPy release. + if 'a_box_index' in visual.program._code_variables: visual.program['a_box_index'] = box_index self.update() From fbfe6659f2a911fd6377acee98202641bcbf1c91 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 13:16:44 +0100 Subject: [PATCH 0691/1059] Add data normalization doc --- docs/plot.md | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/plot.md b/docs/plot.md index ac8b5eaaa..71e6de8cc 100644 --- a/docs/plot.md +++ b/docs/plot.md @@ -1,6 +1,6 @@ # Plotting with VisPy -phy provides a simple and fast plotting system based on VisPy's low-level **gloo** interface. This plotting system is entirely generic. Currently, it privileges speed and scalability over quality. In other words, you can display millions of points at very high speed, but the plotting quality is not as good as matplotlib, for example. +phy provides a simple and fast plotting system based on VisPy's low-level **gloo** interface. This plotting system is entirely generic. Currently, it privileges speed and scalability over quality. In other words, you can display millions of points at very high speed, but the plotting quality is not as good as matplotlib, for example. While this sytem uses the GPU extensively, knowledge of GPU or OpenGL is not required for most purposes. First, we need to activate the Qt event loop in IPython, or create and run the Qt application in a script. @@ -101,3 +101,19 @@ The boxed view lets you put subplots at arbitrary locations. You can dynamically You can use `ctrl+arrows` and `shift+arrows` to change the scaling of the positions and boxes. ## Data normalization + +Data normalization is supported via the `data_bounds` keyword. This is a 4-tuple `(xmin, ymin, xmax, ymax)` with the coordinates of the viewport in the data coordinate system. By default, this is obtained with the min and max of the data. Here is an example: + +```python +>>> view = StackedView(2) +... +>>> n = 100 +>>> x = np.linspace(0., 1., n) +>>> y = np.random.rand(n) +... +>>> with view.building(): +... view[0].plot(x, y, data_bounds=(0, 0, 1, 1)) +... view[1].plot(x, y, data_bounds=(0, -10, 1, 10)) +... +>>> view.show() +``` From 1ef9798cecb4f5308e461731fa594c268450baa7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 13:26:54 +0100 Subject: [PATCH 0692/1059] Update --- docs/index.md | 14 +++++++++++++- docs/overview.md | 15 --------------- mkdocs.yml | 1 - 3 files changed, 13 insertions(+), 17 deletions(-) delete mode 100644 docs/overview.md diff --git a/docs/index.md b/docs/index.md index 9793c3f2e..517838cd9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,3 +1,15 @@ # phy documentation -phy is an ephys data analysis library. It provides everything you need to spikesort and analyze extracellular multielectrode recordings, including components to build analysis command-line and graphical applications. +phy is an **electrophysiology data analysis library**. It provides spike sorting and analysis routines for extracellular multielectrode recordings, as well as generic components to build command-line and graphical applications. Overall, phy lets you build custom data analysis applications, for example, a manual clustering GUI. + +phy is entirely agnostic to file formats and processing workflows. As such, it can be easily integrated with any existing system. + +## List of components + +* GUI +* Plotting +* Command-line interface +* Configuration and plugin system +* Manual clustering routines +* Analysis functions +* Utilities diff --git a/docs/overview.md b/docs/overview.md deleted file mode 100644 index 5654e8505..000000000 --- a/docs/overview.md +++ /dev/null @@ -1,15 +0,0 @@ -# Overview - -phy provides a set of generic tools for building data-intensive command-line and graphical applications with fast visualization capabilities. In addition, it implements ephys data analysis functions and manual clustering routines. Overall, these tools allow you to build specific data analysis applications, for example, a manual clustering GUI. - -phy is entirely agnostic to file formats and processing workflows. As such, it can be easily integrated with any existing system. - -## List of components - -* GUI -* Plotting -* Command-line interface -* Configuration and plugin system -* Manual clustering routines -* Analysis functions -* Utilities diff --git a/mkdocs.yml b/mkdocs.yml index 64ecb72a7..1bcd90f11 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -3,7 +3,6 @@ pages: - Home: 'index.md' - FAQ: 'faq.md' - Installation: 'install.md' -- Overview: 'overview.md' - GUI: 'gui.md' - Plotting: 'plot.md' - Command-line interface: 'cli.md' From a75a050435ddaae0b6961afcaf8708c493ea5ee1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 13:32:17 +0100 Subject: [PATCH 0693/1059] Update install instructions --- docs/install.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/install.md b/docs/install.md index d748597d9..25b1adb1a 100644 --- a/docs/install.md +++ b/docs/install.md @@ -1,4 +1,5 @@ # Installation * Install Anaconda -* `conda install -c kwikteam/phy` (TODO) +* To install the development version: `pip install git+https://github.com/kwikteam/phy.git@kill-kwik` +* (not done yet) to install the stable version: `conda install -c kwikteam/phy` From e27513bcb4a5b82ed69fc2189b9c1173a3e2805c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 13:45:53 +0100 Subject: [PATCH 0694/1059] Add config/plugin doc --- docs/cli.md | 1 + docs/config.md | 31 +++++++++++++++++++++++++++++++ docs/index.md | 2 +- mkdocs.yml | 2 +- 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index e69de29bb..3f213d436 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -0,0 +1 @@ +# CLI diff --git a/docs/config.md b/docs/config.md index e69de29bb..09ed3b89e 100644 --- a/docs/config.md +++ b/docs/config.md @@ -0,0 +1,31 @@ +# Configuration and plugin system + +phy uses part of the **traitlets** package for its config system. **This is still a work in progress**. + +## Configuration file + +It is in `~/.phy/phy_config.py`. It is a regular Python file. It should begin with `c = get_config()` with no import (this function is injected in the namespace automatically by the config system). + +Then, you can set configuration options as follows: + +`c.SomeClass.some_param = some_value` + +## Plugin system + +A plugin is a Python class deriving from `phy.IPlugin`. To ensure that phy knows about your plugin, just make sure that your class is imported in the Python namespace. + +Here are two common methods: + +* Implement your plugin in a Python file and put this file in `~/.phy/plugins/`: it will be automatically discovered by phy. +* Edit `c.Plugins.dirs = ['/path/to/folder']` in your `phy_config.py` file: all Python scripts there will be automatically imported. + +Here is a minimal plugin template: + +```python +from phy import IPlugin + +class MyPlugin(IPlugin): + pass +``` + +We'll see in the next sections what methods you can implement in your plugins, and how to use them in your applications. diff --git a/docs/index.md b/docs/index.md index 517838cd9..b8a67892d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,8 +8,8 @@ phy is entirely agnostic to file formats and processing workflows. As such, it c * GUI * Plotting -* Command-line interface * Configuration and plugin system +* Command-line interface * Manual clustering routines * Analysis functions * Utilities diff --git a/mkdocs.yml b/mkdocs.yml index 1bcd90f11..e23fd38e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,8 +5,8 @@ pages: - Installation: 'install.md' - GUI: 'gui.md' - Plotting: 'plot.md' -- Command-line interface: 'cli.md' - Configuration and plugin system: 'config.md' +- Command-line interface: 'cli.md' - Manual clustering: 'cluster-manual.md' - Analysis functions: 'analysis.md' - API reference: 'api.md' From dd561719d62bbe562c36e7ecc3c3f4c2653b757e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 13:55:42 +0100 Subject: [PATCH 0695/1059] WIP: CLI doc --- docs/cli.md | 62 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/docs/cli.md b/docs/cli.md index 3f213d436..2b42a372e 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -1 +1,63 @@ # CLI + +When you install phy, a command-line tool named `phy` is installed: + +```bash +$ phy +Usage: phy [OPTIONS] COMMAND [ARGS]... + + By default, the `phy` command does nothing. Add subcommands with plugins + using `attach_to_cli()` and the `click` library. + +Options: + --version Show the version and exit. + -h, --help Show this message and exit. +``` + +This command doesn't do anything by default, but it serves as an entry-point for your applications. + +## Adding a subcommand + +A subcommand is called with `phy subcommand ...` from the command-line. To create a subcommand, create a new plugin, and implement the `attach_to_cli(cli)` method. This uses the [click](http://click.pocoo.org/5/) library. + +Here is an example. Create a file in `~/.phy/plugins/hello.py` and write the following: + +``` +from phy import IPlugin +import click + + +class MyPlugin(IPlugin): + def attach_to_cli(self, cli): + @cli.command('hello') + @click.argument('name') + def hello(name): + print("Hello %s!" % name) +``` + +Then, type the following in a system shell: + +```bash +$ phy +Usage: phy [OPTIONS] COMMAND [ARGS]... + + By default, the `phy` command does nothing. Add subcommands with plugins + using `attach_to_cli()` and the `click` library. + +Options: + --version Show the version and exit. + -h, --help Show this message and exit. + +Commands: + hello + +$ phy hello +Usage: phy hello [OPTIONS] NAME + +Error: Missing argument "name". + +$ phy hello world +Hello world! +``` + +When the `phy` CLI is created, the `attach_to_cli(cli)` method of all discovered plugins are called. Refer to the click documentation to create subcommands with phy. From 1e5c06e4bd6125d1ecc00a9c767d27e90dd97ad1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 14:16:09 +0100 Subject: [PATCH 0696/1059] Add GUI CLI/plugin doc --- docs/cli.md | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/docs/cli.md b/docs/cli.md index 2b42a372e..a4e8cdebc 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -61,3 +61,95 @@ Hello world! ``` When the `phy` CLI is created, the `attach_to_cli(cli)` method of all discovered plugins are called. Refer to the click documentation to create subcommands with phy. + +## Creating a graphical application + +You can use this system to create a graphical application that is launched with `phy some_subcommand`. Moreover, your graphical application can itself accept user-defined plugins. + +Here is a complete example. Write the following in `~/.phy/plugins/mygui.py`: + +``` +import click +from phy import IPlugin +from phy.gui import GUI, HTMLWidget, create_app, run_app, load_gui_plugins +from phy.utils import Bunch + + +class MyGUI(GUI): + def __init__(self, name, plugins=None): + super(MyGUI, self).__init__() + + # We create a widget. + view = HTMLWidget() + view.set_body("Hello %s!" % name) + view.show() + self.add_view(view) + + # We load all plugins attached to that GUI. + session = Bunch(name=name) + load_gui_plugins(self, plugins, session) + + +class MyGUIPlugin(IPlugin): + def attach_to_cli(self, cli): + + @cli.command(name='mygui') + @click.argument('name') + def mygui(name): + + # Create the Qt application. + create_app() + + # Show the GUI. + gui = MyGUI(name) + gui.show() + + # Start the Qt event loop. + run_app() + + # Close the GUI. + gui.close() + del gui +``` + +Now, you can call `phy mygui world` to open a GUI showing `Hello world!`. + +## GUI plugins + +Your users can now create plugins for your graphical application, by creating a plugin with the `attach_to_gui(gui, session)` method. In this method, you can add actions, add views, and do anything provided by the GUI API. + +The `session` object is any Python object passed to the plugins by the GUI. Generally, it is a `Bunch` instance (just a Python dictionary with the additional `bunch.name` syntax) containing any data that you want to pass to the plugins. + +Here is a complete example. There are three steps. + +### Creating the plugin + +First, create a file in `~/.phy/plugins/mygui_plugin.py` with the following: + +``` +from phy import IPlugin +from phy.gui import Actions + + +class MyGUIPlugin(IPlugin): + def attach_to_gui(self, gui, session): + actions = Actions(gui) + + @actions.add(shortcut='a') + def myaction(): + print("Hello %s!" % session.name) +``` + +### Activating the plugin + +Next, add the following line in `~/.phy/phy_config.py`: + +``` +c.MyGUI.plugins = ['MyGUIPlugin'] +``` + +This is the list of the plugin names to activate automatically when creating a `MyGUI` instance. When you create a GUI from Python, you can also pass the list of plugins to activate as follows: `gui = MyGUI(name, plugins=[...])`. + +### Testing the plugin + +Finally, launch the GUI with `phy mygui world` and press `a` in the GUI. It should print `Hello world!` in the console. From 9e7540cee1d540d249605fb005355802af61e397 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 14:21:42 +0100 Subject: [PATCH 0697/1059] WIP: manual clustering doc --- docs/api.md | 3 +++ docs/cluster-manual.md | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/docs/api.md b/docs/api.md index e69de29bb..0de7ede1d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -0,0 +1,3 @@ +# API reference + +TODO. In the meantime, see the code directly on GitHub. diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index e69de29bb..87e0938cd 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -0,0 +1,17 @@ +# Manual clustering + +## History + +## Clustering + +## Cluster metadata + +## Cluster view + +## Waveform view + +## Feature view + +## Trace view + +## Manual clustering GUI component From cfd25dd4c98ce8a2da54cf61aa1044a19c287cf5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 17:35:24 +0100 Subject: [PATCH 0698/1059] WIP: manual clustering doc --- docs/cluster-manual.md | 134 ++++++++++++++++++++++++++++++++++- phy/cluster/manual/_utils.py | 7 ++ 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index 87e0938cd..a448789a0 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -1,11 +1,143 @@ # Manual clustering -## History +The `phy.cluster.manual` package provides manual clustering routines. The components can be used independently in a modular way. ## Clustering +The `Clustering` class implements the logic of assigning clusters to spikes as a succession of undoable merge and split operations. Also, it provides efficient methods to retrieve the set of spikes belonging to one or several clusters. + +Create an instance with `clustering = Clustering(spike_clusters)` where `spike_clusters` is an `n_spikes`-long array containing the cluster number of every spike. + +Notable properties are: + +* `clustering.spikes_per_cluster`: a dictionary `{cluster_id: spike_ids}`. +* `clustering.cluster_ids`: array of all non-empty clusters +* `clustering.cluster_counts`: dictionary with the number of spikes in each cluster + +Notable methods are: + +* `clustering.new_cluster_id()`: generate a new unique cluster id +* `clustering.spikes_in_clusters(cluster_ids)`: return the array of spike ids belonging to a set of clusters. +* `clustering.merge(cluster_ids)`: merge some clusters. +* `clustering.split(spike_ids)`: create a new cluster from a set of spikes. **Note**: this will change the cluster ids of all affected clusters. For example, splitting a single spike belonging to cluster 10 containing 100 spikes leads to the deletion of cluster 10, and the creation of clusters N (99 spikes) and N+1 (1 spike). This is to ensure that cluster ids are always unique. +* `clustering.undo()`: undo the last operation. +* `clustering.redo()`: redo the last operation. + +### UpdateInfo + +Every clustering action returns an `UpdateInfo` object and emits a `cluster` event. + +An `UpdateInfo` object is a `Bunch` instance (dictionary with dotted attribute access) with several keys, including: + +* `description`: can be `merge` or `assign` +* `history`: can be `None` (default), `'undo'`, or `'redo'` +* `added`: list of new clusters +* `deleted`: list of removed clusters + +A `Clustering` object emits the `cluster` event after every clustering action (including undo and redo). To register a callback function to this event, use the `connect()` method. + +Here is a complete example: + +```python +>>> import numpy as np +>>> from phy.cluster.manual import Clustering +``` + +```python +>>> clustering = Clustering(np.arange(5)) +``` + +```python +>>> @clustering.connect +... def on_cluster(up): +... print("A %s just occurred." % up.description) +``` + +```python +>>> up = clustering.merge([0, 1, 2]) +A merge just occurred. +``` + +```python +>>> clustering.cluster_ids +array([3, 4, 5]) +``` + +```python +>>> clustering.cluster_counts +{3: 1, 4: 1, 5: 3} +``` + +```python +>>> for key, val in up.items(): +... print(key, "=", val) +undo_state = None +description = merge +metadata_value = None +spike_ids = [0 1 2] +old_spikes_per_cluster = {0: array([0]), 1: array([1]), 2: array([2])} +descendants = [(0, 5), (1, 5), (2, 5)] +added = [5] +new_spikes_per_cluster = {5: array([0, 1, 2])} +history = None +metadata_changed = [] +deleted = [0, 1, 2] +``` + + + ## Cluster metadata +The `ClusterMeta` class implement the logic of assigning metadata to every cluster (for example, a cluster group) as a succession of undoable operations. + +Here is an example. + +```python +>>> from phy.cluster.manual import ClusterMeta +>>> cm = ClusterMeta() +``` + +```python +>>> cm.add_field('group', default_value='unsorted') +``` + +```python +>>> cm.get('group', 3) +'unsorted' +``` + +```python +>>> cm.set('group', 3, 'good') + good> +``` + +```python +>>> cm.set('group', 3, 'bad') + bad> +``` + +```python +>>> cm.get('group', 3) +'bad' +``` + +```python +>>> cm.undo() + bad> +``` + +```python +>>> cm.get('group', 3) +'good' +``` + +You can import and export data from a dictionary using the `to_dict()` and `from_dict()` methods. + +```python +>>> cm.to_dict('group') +{3: 'good'} +``` + ## Cluster view ## Waveform view diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 98ec04a1d..fd6107f8c 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -112,9 +112,11 @@ def _reset_data(self): @property def fields(self): + """List of fields.""" return sorted(self._fields.keys()) def add_field(self, name, default_value=None): + """Add a field with an optional default value.""" self._fields[name] = default_value def func(cluster): @@ -123,6 +125,7 @@ def func(cluster): setattr(self, name, func) def from_dict(self, dic): + """Import data from a {cluster: {field: value}} dictionary.""" self._reset_data() for cluster, vals in dic.items(): for field, value in vals.items(): @@ -130,11 +133,14 @@ def from_dict(self, dic): self._data_base = deepcopy(self._data) def to_dict(self, field): + """Export data to a {cluster: value} dictionary, for a particular + field.""" assert field in self._fields, "This field doesn't exist" return {cluster: self.get(field, cluster) for cluster in self._data.keys()} def set(self, field, clusters, value, add_to_stack=True): + """Set the value of one of several clusters.""" assert field in self._fields clusters = _as_list(clusters) @@ -156,6 +162,7 @@ def set(self, field, clusters, value, add_to_stack=True): return up def get(self, field, cluster): + """Retrieve the value of one cluster.""" if _is_list(cluster): return [self.get(field, c) for c in cluster] assert field in self._fields From fd66471da68427af372368b01465c86f6770840f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 17:49:00 +0100 Subject: [PATCH 0699/1059] WIP: manual clustering doc --- docs/cluster-manual.md | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index a448789a0..825ec6af4 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -84,8 +84,6 @@ metadata_changed = [] deleted = [0, 1, 2] ``` - - ## Cluster metadata The `ClusterMeta` class implement the logic of assigning metadata to every cluster (for example, a cluster group) as a succession of undoable operations. @@ -138,12 +136,35 @@ You can import and export data from a dictionary using the `to_dict()` and `from {3: 'good'} ``` -## Cluster view +## Views + +There are several views typically associated with manual clustering operations. + +### Waveform view + +The waveform view displays action potentials across all channels, following the probe geometry. + +### Feature view -## Waveform view +The feature view shows the principal components of spikes across multiple dimensions. -## Feature view +### Trace view -## Trace view +### Correlogram view + +The correlogram view computes and shows all pairwise correlograms of a set of clusters. ## Manual clustering GUI component + +The `ManualClustering` component encapsulates all the logic for a manual clustering GUI: + +* cluster views +* selection of clusters +* navigation with a wizard +* clustering actions: merge, split, undo stack +* moving clusters to groups + +Create an object with `mc = ManualClustering(spike_clusters)`. Then you can attach it to a GUI to bring manual clustering facilities to the GUI: `mc.attach(gui)`. This adds the manual clustering actions and the two cluster views. + +### Cluster view + From 6722ba58e5743db5f60d5ec09743727feda3640a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 19:25:42 +0100 Subject: [PATCH 0700/1059] WIP: manual clustering doc --- docs/cluster-manual.md | 42 +++++++++++++++++++++++++++-- phy/cluster/manual/gui_component.py | 2 ++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index 825ec6af4..e526ea379 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -164,7 +164,45 @@ The `ManualClustering` component encapsulates all the logic for a manual cluster * clustering actions: merge, split, undo stack * moving clusters to groups -Create an object with `mc = ManualClustering(spike_clusters)`. Then you can attach it to a GUI to bring manual clustering facilities to the GUI: `mc.attach(gui)`. This adds the manual clustering actions and the two cluster views. +Create an object with `mc = ManualClustering(spike_clusters)`. Then you can attach it to a GUI to bring manual clustering facilities to the GUI: `mc.attach(gui)`. This adds the manual clustering actions and the two tables to the GUI: the cluster view and the similarity view. -### Cluster view +The main objects are the following: + +`mc.clustering`: a `Clustering` instance +`mc.cluster_meta`: a `ClusterMeta` instance +`mc.cluster_view`: the cluster view (derives from `Table`) +`mc.similarity_view`: the similarity view (derives from `Table`) +`mc.actions`: the clustering actions (instance of `Actions`) + +In practice, you generally access this object from a GUI plugin, available in `session.manual_clustering`. + +### Cluster and similarity view + +The cluster view shows the list of all clusters with their ids, while the similarity view shows the list of all clusters sorted by decreasing similarity wrt the currently-selected clusters in the cluster view. + +You can add a new column in both views as follows: + +```python +>>> @mc.add_column +... def n_spikes(cluster_id): +... return mc.clustering.cluster_counts[cluster_id] +``` + +The similarity view has an additional column compared to the cluster view: `similarity` with respect to the currently-selected clusters in the cluster view. + +See also the following methods: + +* `mc.set_default_sort(name)`: set a column as default sort in the quality cluster view +* `mc.set_similarity_func(func)`: set a similarity function for the similarity view + +### Cluster selection + +The `ManualClustering` instance is responsible for the selection of the clusters. + +* `mc.select(cluster_ids)`: select some clusters +* `mc.selected`: list of currently-selected clusters + +When the selection changes, the attached GUI raises the `select(cluster_ids, spike_ids)` event. + +Other events are `cluster(up)` when a clustering action occurs, and `request_save(spike_clusters, cluster_groups)` when the user wants to save the results of the manual clustering session. diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 15b34aaf9..938813776 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -315,6 +315,8 @@ def on_request_undo_state(up): self.clustering.connect(on_request_undo_state) self.cluster_meta.connect(on_request_undo_state) + self._update_cluster_view() + def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" self.cluster_view.set_rows(self.clustering.cluster_ids) From db24a85409c7034a2e48ed31f75ae440eb17f361 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Dec 2015 19:28:59 +0100 Subject: [PATCH 0701/1059] Clearer color of good clusters --- phy/cluster/manual/gui_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 938813776..165e0d7cf 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -112,7 +112,7 @@ def __init__(self): super(ClusterView, self).__init__() self.add_styles(''' table tr[data-good='true'] { - color: #B4DEA6; + color: #86D16D; } ''') From 5fe50c16168a950face254c76510678ad9451fcd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 08:37:08 +0100 Subject: [PATCH 0702/1059] Update CONTRIBUTE.md --- CONTRIBUTE.md | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/CONTRIBUTE.md b/CONTRIBUTE.md index d717d3c79..c4446d498 100644 --- a/CONTRIBUTE.md +++ b/CONTRIBUTE.md @@ -7,37 +7,35 @@ Please read this entire document before contributing any code. On your development computer: * Use the very latest Anaconda release (`conda update conda anaconda`). -* Use a special `phy` conda environment based on the latest Python 3.x (3.4 at the time of writing). +* Use a special `phy` conda environment based on Python 3.5. * Have another `phy2` clone environment based on Python 2.7. -* phy only supports Python 2.7, and Python 3.4+. -* Use `six` for writing compatible code (see [the documentation here](http://pythonhosted.org/six/)) -* You need the following dependencies for development (not required for using phy): pytest, pip, flake8, coverage, coveralls. -* For IPython, use the IPython git `master` branch (or version 3.0 when it will be released in early 2015). +* Install the dev dependencies: `pip install -r requirements-dev.txt` +* phy only supports Python 2.7 and Python 3.4+. +* Use the `six` library for writing compatible code (see [the documentation here](http://pythonhosted.org/six/)) A few rules: * Every module `phy/mypackage/mymodule.py` must come with a `phy/mypackage/tests/test_mymodule.py` test module that contains a bunch of `test_*()` functions. * Never import test modules in the main code. -* Do not import packages from `phy/__init__.py`. Every subpackage `phy.stuff` will need to be imported explicitly by the user. Dependencies required by this subpackage will only be loaded when the subpackage is loaded. This ensures that users can use `phy.subpackageA` without having to install the dependencies required by `phy.subpackageB`. -* phy's required dependencies are: numpy. Every subpackage can come with further dependencies. For example, `phy.io.kwik` depends on h5py. +* In general, do not import packages from `phy/__init__.py`. Every subpackage `phy.stuff` will need to be imported explicitly by the user. Dependencies required by this subpackage will only be loaded when the subpackage is loaded. This ensures that users can use `phy.subpackageA` without having to install the dependencies required by `phy.subpackageB`. +* phy's required dependencies are: pip, traitlets, click, numpy. Every subpackage may come with further dependencies. * You can experiment with ideas and prototypes in the `kwikteam/experimental` repo. Use a different folder for every experiment. -* `kwikteam/phy` will only contain a `master` branch and release tags. There should be no experimental/debugging code in the entire repository. +* `kwikteam/phy` will only contain a `master` branch, release tags, and possibly one refactoring branch. There should be no experimental/debugging code in the entire repository. ### GitHub flow -* Work through PRs from `yourfork/specialbranch` against `phy/master` exclusively. +* Work through PRs from `yourfork/specialbranch` against `phy/master` or `phy/kill-kwik` exclusively. * Set `upstream` to `kwikteam/phy` and `origin` to your fork. * When master and your PR's branch are out of sync, [rebase your branch in your fork](https://groups.google.com/forum/#!msg/vispy-dev/q-UNjxburGA/wYNkZRXiySwJ). * Two-pairs-of-eyes rule: every line of code needs to be reviewed by 2 people, including the author. * Never merge your own PR to the main phy repository, unless in exceptional circumstances. * A PR is assumed to be **not ready for merge** unless explicitly stated otherwise. * Always run `make test` before stating that a PR is ready to merge (and ideally before pushing on your PR's branch). -* We try to have a code coverage close to 100%: always test all features you implement, and verify through code coverage that all lines are covered by your tests. +* We try to have a code coverage close to 100%: always test all features you implement, and verify through code coverage that all lines are covered by your tests. Use `#pragma: no cover` comments for lines that don't absolutely need to be covered (for example, rare exception-raising code). * Always wait for Travis to be green before merging. * `phy/master` should always be stable and deployable. * Use [semantic versioning](http://www.semver.org) for stable releases. -* Do not make too many releases until the software is mature enough. Early adopters can work directly off `master`. * We follow almost all [PEP8 rules](https://www.python.org/dev/peps/pep-0008/), except [for a few exceptions](https://github.com/kwikteam/phy/blob/master/Makefile#L24). @@ -55,10 +53,10 @@ A few rules: Make sure your text editor is configured to: -* automatically clean blank lines (i.e. no lines containing only whitespace) -* use four spaces per indent level, and **never** tab indents -* enforce an empty blank line at the end of every text file -* display a vertical ruler at 79 characters (length limit of every line) +* Automatically clean blank lines (i.e. no lines containing only whitespace) +* Use four spaces per indent level, and **never** tab indents +* Enforce an empty blank line at the end of every text file +* Display a vertical ruler at 79 characters (length limit of every line) Below is a settings script for the popular text editor [Sublime](http://www.sublimetext.com) which you can put into your ```Preferences -> Settings (User)```: From e67d3798194d2d67bcf24b45e6f903feef0cda76 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 08:56:50 +0100 Subject: [PATCH 0703/1059] Add environment.yml for conda dependencies --- .travis.yml | 9 +++------ environment.yml | 3 +-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6a39c9c27..fbe3bec3e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,12 +19,9 @@ install: - conda update -q conda - conda info -a # Create the environment. - - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION - - source activate testenv - - conda install pip numpy=1.9 vispy matplotlib scipy h5py pyqt ipython requests six dill ipyparallel joblib dask click - # NOTE: cython is only temporarily needed for building KK2. - # Once we have KK2 binary builds on binstar we can remove this dependency. - - conda install cython && pip install klustakwik2 + - conda env create + - source activate phy + - conda install python=$TRAVIS_PYTHON_VERSION # Dev requirements - pip install -r requirements-dev.txt - pip install -e . diff --git a/environment.yml b/environment.yml index 0b61d5aae..6948e83eb 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ channels: - kwikteam dependencies: - python - - numpy=1.9 + - numpy - vispy - matplotlib - scipy @@ -11,7 +11,6 @@ dependencies: - pyqt - ipython - requests - - traitlets - six - ipyparallel - joblib From f7238aa79097e6033adba48dcba64d6260d28489 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 09:01:04 +0100 Subject: [PATCH 0704/1059] WIP: update Travis --- .travis.yml | 3 +-- environment.yml | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index fbe3bec3e..4844c62a8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,9 +19,8 @@ install: - conda update -q conda - conda info -a # Create the environment. - - conda env create + - conda env create python=$TRAVIS_PYTHON_VERSION - source activate phy - - conda install python=$TRAVIS_PYTHON_VERSION # Dev requirements - pip install -r requirements-dev.txt - pip install -e . diff --git a/environment.yml b/environment.yml index 6948e83eb..07b23df0e 100644 --- a/environment.yml +++ b/environment.yml @@ -11,6 +11,7 @@ dependencies: - pyqt - ipython - requests + - traitlets - six - ipyparallel - joblib From f27847391e01de08bd4a7e65370c4b1ab0469d33 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 09:06:07 +0100 Subject: [PATCH 0705/1059] WIP: fix Travis --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 07b23df0e..0b61d5aae 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ channels: - kwikteam dependencies: - python - - numpy + - numpy=1.9 - vispy - matplotlib - scipy From 1c9445c9f23b64dcfc348a4d1fd60ced540027a0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 09:15:38 +0100 Subject: [PATCH 0706/1059] Rename cluster_counts to spike_counts --- docs/cluster-manual.md | 6 +++--- phy/cluster/manual/clustering.py | 2 +- phy/cluster/manual/tests/test_clustering.py | 14 +++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index e526ea379..1cac448af 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -12,7 +12,7 @@ Notable properties are: * `clustering.spikes_per_cluster`: a dictionary `{cluster_id: spike_ids}`. * `clustering.cluster_ids`: array of all non-empty clusters -* `clustering.cluster_counts`: dictionary with the number of spikes in each cluster +* `clustering.spike_counts`: dictionary with the number of spikes in each cluster Notable methods are: @@ -64,7 +64,7 @@ array([3, 4, 5]) ``` ```python ->>> clustering.cluster_counts +>>> clustering.spike_counts {3: 1, 4: 1, 5: 3} ``` @@ -185,7 +185,7 @@ You can add a new column in both views as follows: ```python >>> @mc.add_column ... def n_spikes(cluster_id): -... return mc.clustering.cluster_counts[cluster_id] +... return mc.clustering.spike_counts[cluster_id] ``` The similarity view has an additional column compared to the cluster view: `similarity` with respect to the currently-selected clusters in the cluster view. diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 267bcc4a9..d5e01efc4 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -195,7 +195,7 @@ def cluster_ids(self): return np.array(sorted(self._spikes_per_cluster)) @property - def cluster_counts(self): + def spike_counts(self): """Dictionary with the number of spikes in each cluster.""" return {cluster: len(self._spikes_per_cluster[cluster]) for cluster in self.cluster_ids} diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index c81c88607..792fdecbe 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -456,8 +456,8 @@ def test_clustering_long(): assert clustering.new_cluster_id() == n_clusters assert clustering.n_clusters == n_clusters - assert len(clustering.cluster_counts) == n_clusters - assert sum(itervalues(clustering.cluster_counts)) == n_spikes + assert len(clustering.spike_counts) == n_clusters + assert sum(itervalues(clustering.spike_counts)) == n_spikes _check_spikes_per_cluster(clustering) # Updating a cluster, method 1. @@ -479,18 +479,18 @@ def test_clustering_long(): new_cluster = 101 clustering.assign(np.arange(0, 10), new_cluster) assert new_cluster in clustering.cluster_ids - assert clustering.cluster_counts[new_cluster] == 10 + assert clustering.spike_counts[new_cluster] == 10 assert np.all(clustering.spike_clusters[:10] == new_cluster) _check_spikes_per_cluster(clustering) # Merge. - count = clustering.cluster_counts.copy() + count = clustering.spike_counts.copy() my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [2, 3]))[0] info = clustering.merge([2, 3]) my_spikes = info.spike_ids ae(my_spikes, my_spikes_0) assert (new_cluster + 1) in clustering.cluster_ids - assert clustering.cluster_counts[new_cluster + 1] == count[2] + count[3] + assert clustering.spike_counts[new_cluster + 1] == count[2] + count[3] assert np.all(clustering.spike_clusters[my_spikes] == (new_cluster + 1)) _check_spikes_per_cluster(clustering) @@ -498,13 +498,13 @@ def test_clustering_long(): clustering.spike_clusters[:] = spike_clusters_base[:] clustering._update_all_spikes_per_cluster() my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [4, 6]))[0] - count = clustering.cluster_counts + count = clustering.spike_counts count4, count6 = count[4], count[6] info = clustering.merge([4, 6], 11) my_spikes = info.spike_ids ae(my_spikes, my_spikes_0) assert 11 in clustering.cluster_ids - assert clustering.cluster_counts[11] == count4 + count6 + assert clustering.spike_counts[11] == count4 + count6 assert np.all(clustering.spike_clusters[my_spikes] == 11) _check_spikes_per_cluster(clustering) From e55b94fca07891fc29fa64fcd3829443f9b0e3a4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 10:15:42 +0100 Subject: [PATCH 0707/1059] WIP: refactor navigation in HTML table --- phy/gui/static/table.js | 51 ++++++++++++++++++++++++++--------- phy/gui/tests/test_widgets.py | 2 +- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 6f0b9e6c3..faaae6069 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -159,25 +159,52 @@ Table.prototype.clear = function() { this.selected = []; }; +Table.prototype.firstRow = function() { + return this.el.rows[1]; +}; + +Table.prototype.nextRow = function(id) { + // TODO: what to do when doing next() while several items are selected. + var i0 = 1; + if (id !== undefined) { + i0 = this.rows[id].rowIndex; + } + var that = this; + return { + i: i0, + increment: function () { + if (this.i < this.n - 1) { + this.i++; + return true; + } + return false; + }, + n: that.el.rows.length, + row: function () { return that.el.rows[this.i]; }, + next: function () { + this.increment(); + return this.row(); + } + }; +}; + Table.prototype.next = function() { // TODO: what to do when doing next() while several items are selected. var id = this.selected[0]; - if (id === undefined) { - var row = null; - var i0 = 1; // 1, not 0, because we skip the header. + if (id == undefined) { + var row = this.firstRow(); } else { - var row = this.rows[id]; - var i0 = row.rowIndex + 1; - } - for (var i = i0; i < this.el.rows.length; i++) { - row = this.el.rows[i]; - if (row.dataset.skip != 'true') { - this.select([row.dataset.id]); - row.scrollIntoView(false); - return; + // Select the next non-skip. + var iterator = this.nextRow(id); + var row = iterator.next(); + while (row.dataset.skip == 'true') { + row = iterator.next(); } } + this.select([row.dataset.id]); + row.scrollIntoView(false); + return; }; Table.prototype.previous = function() { diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index caeb8c2b5..4b6940f16 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -131,7 +131,7 @@ def test_table_nav_first(qtbot, table): assert table.selected == [0] -def test_table_nav(qtbot, table): +def test_table_nav_0(qtbot, table): table.select([4]) table.next() From 8f064e1184724ea48525452e2e91866a8ceb0001 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 10:23:14 +0100 Subject: [PATCH 0708/1059] Refactor previous --- phy/gui/static/table.js | 51 ++++++++++++++++++----------------- phy/gui/tests/test_widgets.py | 5 ++++ 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index faaae6069..df5572a5c 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -163,7 +163,11 @@ Table.prototype.firstRow = function() { return this.el.rows[1]; }; -Table.prototype.nextRow = function(id) { +Table.prototype.lastRow = function() { + return this.el.rows[this.el.rows.length - 1]; +}; + +Table.prototype.rowIterator = function(id) { // TODO: what to do when doing next() while several items are selected. var i0 = 1; if (id !== undefined) { @@ -172,17 +176,18 @@ Table.prototype.nextRow = function(id) { var that = this; return { i: i0, - increment: function () { - if (this.i < this.n - 1) { - this.i++; - return true; - } - return false; - }, n: that.el.rows.length, row: function () { return that.el.rows[this.i]; }, + previous: function () { + if (this.i > 1) { + this.i--; + } + return this.row(); + }, next: function () { - this.increment(); + if (this.i < this.n - 1) { + this.i++; + } return this.row(); } }; @@ -196,7 +201,7 @@ Table.prototype.next = function() { } else { // Select the next non-skip. - var iterator = this.nextRow(id); + var iterator = this.rowIterator(id); var row = iterator.next(); while (row.dataset.skip == 'true') { row = iterator.next(); @@ -208,24 +213,20 @@ Table.prototype.next = function() { }; Table.prototype.previous = function() { - // TODO: what to do when doing next() while several items are selected. + // TODO: what to do when doing previous() while several items are selected. var id = this.selected[0]; - if (id === undefined) { - var row = null; - var i0 = this.rows.length - 1; + if (id == undefined) { + var row = this.lastRow(); } else { - var row = this.rows[id]; - var i0 = row.rowIndex - 1; - } - - // NOTE: i >= 1 because we skip the header column. - for (var i = i0; i >= 1; i--) { - row = this.el.rows[i]; - if (row.dataset.skip != 'true') { - this.select([row.dataset.id]); - // row.scrollIntoView(false); - return; + // Select the previous non-skip. + var iterator = this.rowIterator(id); + var row = iterator.previous(); + while (row.dataset.skip == 'true') { + row = iterator.previous(); } } + this.select([row.dataset.id]); + row.scrollIntoView(false); + return; }; diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index 4b6940f16..a9c1ff22b 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -131,6 +131,11 @@ def test_table_nav_first(qtbot, table): assert table.selected == [0] +def test_table_nav_last(qtbot, table): + table.previous() + assert table.selected == [9] + + def test_table_nav_0(qtbot, table): table.select([4]) From 7502e0e5759b3202979a635596ff5fbce43f7697 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 10:43:29 +0100 Subject: [PATCH 0709/1059] Implement shift selection in table --- phy/gui/static/table.css | 7 +++++++ phy/gui/static/table.js | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/phy/gui/static/table.css b/phy/gui/static/table.css index 001a20e88..7e15b1c4d 100644 --- a/phy/gui/static/table.css +++ b/phy/gui/static/table.css @@ -57,3 +57,10 @@ table tr.pinned { table tr[data-skip='true'] { color: #888; } + +table tr, table td { + user-select: none; /* CSS3 (little to no support) */ + -ms-user-select: none; /* IE 10+ */ + -moz-user-select: none; /* Gecko (Firefox) */ + -webkit-user-select: none; /* Webkit (Safari, Chrome) */ +} diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index df5572a5c..8d99e5d2c 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -92,6 +92,18 @@ Table.prototype.setData = function(data) { that.select(that.selected.concat([id])); } } + else if (evt.shiftKey) { + var clicked_idx = that.rows[id].rowIndex; + var sel_idx = that.rows[that.selected[0]].rowIndex; + if (sel_idx == undefined) return; + var i0 = Math.min(clicked_idx, sel_idx); + var i1 = Math.max(clicked_idx, sel_idx); + var sel = []; + for (var i = i0; i <= i1; i++) { + sel.push(that.el.rows[i].dataset.id); + } + that.select(sel); + } // Otherwise, select just that item. else { that.select([id]); From f2b3355006cf95f6ff091976922ae1f1a3d14197 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 13:57:37 +0100 Subject: [PATCH 0710/1059] WIP: bug fixes --- phy/gui/static/table.js | 60 +++++++++++++++++------------------ phy/gui/tests/test_widgets.py | 2 +- phy/gui/widgets.py | 2 +- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 8d99e5d2c..dbf5e5fcf 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -54,6 +54,7 @@ Table.prototype.setData = function(data) { this.headers[key] = th; } thead.appendChild(tr); + this.nrows = data.items.length; // Data rows. for (var i = 0; i < data.items.length; i++) { @@ -121,6 +122,14 @@ Table.prototype.setData = function(data) { this.tablesort = new Tablesort(this.el); }; +Table.prototype.rowId = function(i) { + return this.el.rows[i].dataset.id; +}; + +Table.prototype.isRowSkipped = function(i) { + return this.el.rows[i].dataset.skip == 'true'; +}; + Table.prototype.sortBy = function(header, dir) { dir = typeof dir !== 'undefined' ? dir : 'asc'; if (this.headers[header] == undefined) @@ -179,9 +188,10 @@ Table.prototype.lastRow = function() { return this.el.rows[this.el.rows.length - 1]; }; -Table.prototype.rowIterator = function(id) { +Table.prototype.rowIterator = function(id, doSkip) { + doSkip = typeof doSkip !== 'undefined' ? doSkip : true; // TODO: what to do when doing next() while several items are selected. - var i0 = 1; + var i0 = undefined; if (id !== undefined) { i0 = this.rows[id].rowIndex; } @@ -191,16 +201,24 @@ Table.prototype.rowIterator = function(id) { n: that.el.rows.length, row: function () { return that.el.rows[this.i]; }, previous: function () { - if (this.i > 1) { - this.i--; + if (this.i == undefined) this.i = 1; + for (var i = this.i - 1; i >= 1; i--) { + if (!doSkip || !that.isRowSkipped(i)) { + this.i = i; + return this.row(); + } } - return this.row(); + return that.firstRow(); }, next: function () { - if (this.i < this.n - 1) { - this.i++; + if (this.i == undefined) this.i = this.n - 1; + for (var i = this.i + 1; i < this.n; i++) { + if (!doSkip || !that.isRowSkipped(i)) { + this.i = i; + return this.row(); + } } - return this.row(); + return that.firstRow(); } }; }; @@ -208,17 +226,8 @@ Table.prototype.rowIterator = function(id) { Table.prototype.next = function() { // TODO: what to do when doing next() while several items are selected. var id = this.selected[0]; - if (id == undefined) { - var row = this.firstRow(); - } - else { - // Select the next non-skip. - var iterator = this.rowIterator(id); - var row = iterator.next(); - while (row.dataset.skip == 'true') { - row = iterator.next(); - } - } + var iterator = this.rowIterator(id); + var row = iterator.next(); this.select([row.dataset.id]); row.scrollIntoView(false); return; @@ -227,17 +236,8 @@ Table.prototype.next = function() { Table.prototype.previous = function() { // TODO: what to do when doing previous() while several items are selected. var id = this.selected[0]; - if (id == undefined) { - var row = this.lastRow(); - } - else { - // Select the previous non-skip. - var iterator = this.rowIterator(id); - var row = iterator.previous(); - while (row.dataset.skip == 'true') { - row = iterator.previous(); - } - } + var iterator = this.rowIterator(id); + var row = iterator.previous(); this.select([row.dataset.id]); row.scrollIntoView(false); return; diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index a9c1ff22b..d0d6cfd85 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -133,7 +133,7 @@ def test_table_nav_first(qtbot, table): def test_table_nav_last(qtbot, table): table.previous() - assert table.selected == [9] + assert table.selected == [0] def test_table_nav_0(qtbot, table): diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index dba4e52f9..2191d6e76 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -59,7 +59,7 @@ class WebPage(QWebPage): def javaScriptConsoleMessage(self, msg, line, source): - logger.debug(msg) # pragma: no cover + logger.debug("Error line %d: %s", line, msg) # pragma: no cover class HTMLWidget(QWebView): From b182defc5290307c40aee8752eae7a35e2357d44 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 14:23:48 +0100 Subject: [PATCH 0711/1059] Fix bugs --- phy/gui/static/table.js | 8 ++++---- phy/gui/tests/test_widgets.py | 16 +++++++++++++++- phy/gui/widgets.py | 2 +- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index dbf5e5fcf..8d6767baa 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -201,24 +201,24 @@ Table.prototype.rowIterator = function(id, doSkip) { n: that.el.rows.length, row: function () { return that.el.rows[this.i]; }, previous: function () { - if (this.i == undefined) this.i = 1; + if (this.i == undefined) this.i = this.n; for (var i = this.i - 1; i >= 1; i--) { if (!doSkip || !that.isRowSkipped(i)) { this.i = i; return this.row(); } } - return that.firstRow(); + return this.row(); }, next: function () { - if (this.i == undefined) this.i = this.n - 1; + if (this.i == undefined) this.i = 0; for (var i = this.i + 1; i < this.n; i++) { if (!doSkip || !that.isRowSkipped(i)) { this.i = i; return this.row(); } } - return that.firstRow(); + return this.row(); } }; }; diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py index d0d6cfd85..fc3601241 100644 --- a/phy/gui/tests/test_widgets.py +++ b/phy/gui/tests/test_widgets.py @@ -133,7 +133,21 @@ def test_table_nav_first(qtbot, table): def test_table_nav_last(qtbot, table): table.previous() - assert table.selected == [0] + assert table.selected == [9] + + +def test_table_nav_edge_0(qtbot, table): + # The first item is skipped. + table.set_rows([4, 5]) + table.next() + assert table.selected == [5] + + +def test_table_nav_edge_1(qtbot, table): + # The last item is skipped. + table.set_rows([3, 4]) + table.previous() + assert table.selected == [3] def test_table_nav_0(qtbot, table): diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 2191d6e76..15bb16097 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -59,7 +59,7 @@ class WebPage(QWebPage): def javaScriptConsoleMessage(self, msg, line, source): - logger.debug("Error line %d: %s", line, msg) # pragma: no cover + logger.debug("[%d] %s", line, msg) # pragma: no cover class HTMLWidget(QWebView): From ae40170f2bfe53985411d76834f6f6e2d1f86b69 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 14:49:31 +0100 Subject: [PATCH 0712/1059] Minor updates --- phy/gui/static/table.js | 8 ++++++++ phy/gui/widgets.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 8d6767baa..87c1e32bc 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -134,6 +134,14 @@ Table.prototype.sortBy = function(header, dir) { dir = typeof dir !== 'undefined' ? dir : 'asc'; if (this.headers[header] == undefined) throw "The column `" + header + "` doesn't exist." + + // Remove all sort classes. + for (var i = 0; i < this.headers.length; i++) { + this.headers[i].classList.remove('sort-up'); + this.headers[i].classList.remove('sort-down'); + } + + // Add sort. this.tablesort.sortTable(this.headers[header]); if (dir == 'desc') { this.tablesort.sortTable(this.headers[header]); diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 15bb16097..947d83a8f 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -33,7 +33,7 @@ background-color: black; color: white; font-family: sans-serif; - font-size: 14pt; + font-size: 12pt; margin: 5px 10px; } """ From c0f0e8057be7fa2739c0d89a94897474d9224c74 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 12 Dec 2015 19:07:47 +0100 Subject: [PATCH 0713/1059] Add transforms argument in BaseCanvas.add_visual() --- phy/plot/base.py | 12 +++++++++--- phy/plot/tests/test_base.py | 4 ++++ phy/plot/transform.py | 2 ++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 696a39e80..9ce87f905 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -239,7 +239,7 @@ def __init__(self, *args, **kwargs): # Enable transparency. _enable_depth_mask() - def add_visual(self, visual): + def add_visual(self, visual, transforms=None): """Add a visual to the canvas, and build its program by the same occasion. @@ -251,8 +251,14 @@ def add_visual(self, visual): inserter = visual.inserter # Add the visual's transforms. inserter.add_transform_chain(visual.transforms) - # Then, add the canvas' transforms. - inserter.add_transform_chain(self.transforms) + # Then, add the canvas' transforms... + if transforms is None: + inserter.add_transform_chain(self.transforms) + # or user-specified transforms. + else: + tc = TransformChain() + tc.add_on_gpu(transforms) + inserter.add_transform_chain(tc) # Also, add the canvas' inserter. inserter += self.inserter # Now, we insert the transforms GLSL into the shaders. diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index ffa1fa1d3..71066057b 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -132,6 +132,10 @@ def set_data(self): canvas.add_visual(v) v.set_data() + v = TestVisual() + canvas.add_visual(v, transforms=[Subplot((10, 10), (0, 0))]) + v.set_data() + canvas.show() qtbot.waitForWindowShown(canvas.native) # qtbot.stop() diff --git a/phy/plot/transform.py b/phy/plot/transform.py index e236da310..0e8213690 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -224,12 +224,14 @@ def add_on_cpu(self, transforms): if not isinstance(transforms, list): transforms = [transforms] self.cpu_transforms.extend(transforms or []) + return self def add_on_gpu(self, transforms): """Add some transforms.""" if not isinstance(transforms, list): transforms = [transforms] self.gpu_transforms.extend(transforms or []) + return self def get(self, class_name): """Get a transform in the chain from its name.""" From 9351e3572de4eba6235c4376694f6568eb9c6998 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 12:23:37 +0100 Subject: [PATCH 0714/1059] WIP: line visual --- phy/plot/tests/test_visuals.py | 15 +++++- phy/plot/visuals.py | 98 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 9510897d1..b446e1405 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,7 @@ import numpy as np from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, - BoxVisual, AxesVisual, + BoxVisual, AxesVisual, LineVisual, ) @@ -159,6 +159,19 @@ def test_histogram_2(qtbot, canvas_pz): hist=hist, color=c, ylim=2 * np.ones(n_hists)) +#------------------------------------------------------------------------------ +# Test line visual +#------------------------------------------------------------------------------ + +def test_line_empty(qtbot, canvas): + _test_visual(qtbot, canvas, LineVisual()) + + +def test_line_0(qtbot, canvas_pz): + _test_visual(qtbot, canvas_pz, LineVisual(), + color=(1., 0., 0., .5)) + + #------------------------------------------------------------------------------ # Test box visual #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 7d15ff4aa..4b8a0b38c 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -132,6 +132,24 @@ def _max(arr): return arr.max() if len(arr) else 1 +def _validate_line_coord(x, n, default): + assert n >= 0 + if x is None: + x = default + if not hasattr(x, '__len__'): + x = x * np.ones(n) + assert isinstance(x, np.ndarray) + assert x.shape == (n,) + return x.astype(np.float32) + + +def _get_length(*args): + for arg in args: + if hasattr(arg, '__len__'): + return len(arg) + return 1 + + class PlotVisual(BaseVisual): _default_color = DEFAULT_COLOR allow_list = ('x', 'y') @@ -321,6 +339,86 @@ def set_data(self): pass +class LineVisual(BaseVisual): + _default_color = (.35, .35, .35, 1.) + + def __init__(self): + super(LineVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def validate(x0=None, + y0=None, + x1=None, + y1=None, + color=None, + data_bounds=None, + ): + + # Get the number of lines. + n_lines = _get_length(x0, y0, x1, y1) + x0 = _validate_line_coord(x0, n_lines, -1) + y0 = _validate_line_coord(y0, n_lines, -1) + x1 = _validate_line_coord(x1, n_lines, +1) + y1 = _validate_line_coord(y1, n_lines, +1) + + assert x0.shape == y0.shape == x1.shape == y1.shape == (n_lines,) + + if data_bounds is None: + xmin = min(x0.min(), x1.min()) + ymin = min(y0.min(), y1.min()) + xmax = max(x0.max(), x1.max()) + ymax = max(y0.max(), y1.max()) + data_bounds = np.c_[xmin, ymin, xmax, ymax] + + color = _get_array(color, (4,), LineVisual._default_color) + # assert color.shape == (n_lines, 4) + assert len(color) == 4 + + data_bounds = _get_data_bounds(data_bounds, length=n_lines) + data_bounds = data_bounds.astype(np.float32) + assert data_bounds.shape == (n_lines, 4) + + return Bunch(x0=x0, + y0=y0, + x1=x1, + y1=y1, + color=color, + data_bounds=data_bounds, + ) + + @staticmethod + def vertex_count(x0=None, y0=None, x1=None, y1=None, **kwargs): + """Take the output of validate() as input.""" + return 2 * _get_length(x0, y0, x1, y1) + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + pos = np.c_[data.x0, data.y0, data.x1, data.y1] + n_lines = pos.shape[0] + n_vertices = 2 * n_lines + pos = pos.reshape((-1, 2)) + + # Transform the positions. + data_bounds = np.repeat(data.data_bounds, n_vertices, axis=0) + self.data_range.from_bounds = data_bounds + pos_tr = self.transforms.apply(pos) + + # Position. + assert pos_tr.shape == (n_vertices, 2) + self.program['a_position'] = pos_tr + + # Color. + # a_color = np.repeat(data.color, 2, axis=0).astype(np.float32) + # assert a_color.shape == (n_vertices, 4) + # self.program['a_color'] = a_color + self.program['u_color'] = data.color + + class BoxVisual(BaseVisual): _default_color = (.35, .35, .35, 1.) From 25c6c793123847cef288d0f357dfad4097f8dd0b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 13:25:32 +0100 Subject: [PATCH 0715/1059] Remove box and axis visuals --- phy/plot/base.py | 4 +-- phy/plot/tests/test_visuals.py | 41 +++--------------------- phy/plot/utils.py | 7 ---- phy/plot/visuals.py | 58 +++------------------------------- 4 files changed, 11 insertions(+), 99 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 9ce87f905..c61352178 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -79,12 +79,12 @@ def on_draw(self): @staticmethod def validate(**kwargs): """Make consistent the input data for the visual.""" - return kwargs + return kwargs # pragma: no cover @staticmethod def vertex_count(**kwargs): """Return the number of vertices as a function of the input data.""" - return 0 + return 0 # pragma: no cover def set_data(self): """Set data to the program. diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index b446e1405..89c91ef50 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,7 @@ import numpy as np from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, - BoxVisual, AxesVisual, LineVisual, + LineVisual, ) @@ -168,40 +168,7 @@ def test_line_empty(qtbot, canvas): def test_line_0(qtbot, canvas_pz): + y = np.linspace(-.5, .5, 10) _test_visual(qtbot, canvas_pz, LineVisual(), - color=(1., 0., 0., .5)) - - -#------------------------------------------------------------------------------ -# Test box visual -#------------------------------------------------------------------------------ - -def test_box_empty(qtbot, canvas): - _test_visual(qtbot, canvas, BoxVisual()) - - -def test_box_0(qtbot, canvas_pz): - _test_visual(qtbot, canvas_pz, BoxVisual(), - bounds=(-.5, -.5, 0., 0.), - color=(1., 0., 0., .5)) - - -#------------------------------------------------------------------------------ -# Test axes visual -#------------------------------------------------------------------------------ - -def test_axes_empty(qtbot, canvas): - _test_visual(qtbot, canvas, AxesVisual()) - - -def test_axes_0(qtbot, canvas_pz): - _test_visual(qtbot, canvas_pz, AxesVisual(), - xs=[0]) - - -def test_axes_1(qtbot, canvas_pz): - _test_visual(qtbot, canvas_pz, AxesVisual(), - xs=[-.25, -.1], - ys=[-.15], - bounds=(-.5, -.5, 0., 0.), - color=(0., 1., 0., .5)) + y0=y, y1=y, data_bounds=[-1, -1, 1, 1], + color=(1., 1., 0., .5)) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 5ac9bf169..d34c0d055 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -220,13 +220,6 @@ def _get_index(n_items, item_size, n): return index -def _get_color(color, default): - if color is None: - color = default - assert len(color) == 4 - return color - - def _get_linear_x(n_signals, n_samples): return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 4b8a0b38c..dca49dbe4 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -18,7 +18,6 @@ _get_data_bounds, _get_pos, _get_index, - _get_color, ) from phy.utils import Bunch @@ -376,7 +375,6 @@ def validate(x0=None, data_bounds = np.c_[xmin, ymin, xmax, ymax] color = _get_array(color, (4,), LineVisual._default_color) - # assert color.shape == (n_lines, 4) assert len(color) == 4 data_bounds = _get_data_bounds(data_bounds, length=n_lines) @@ -398,13 +396,16 @@ def vertex_count(x0=None, y0=None, x1=None, y1=None, **kwargs): def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) - pos = np.c_[data.x0, data.y0, data.x1, data.y1] + pos = np.c_[data.x0, data.y0, data.x1, data.y1].astype(np.float32) + assert pos.ndim == 2 + assert pos.shape[1] == 4 + assert pos.dtype == np.float32 n_lines = pos.shape[0] n_vertices = 2 * n_lines pos = pos.reshape((-1, 2)) # Transform the positions. - data_bounds = np.repeat(data.data_bounds, n_vertices, axis=0) + data_bounds = np.repeat(data.data_bounds, 2, axis=0) self.data_range.from_bounds = data_bounds pos_tr = self.transforms.apply(pos) @@ -413,53 +414,4 @@ def set_data(self, *args, **kwargs): self.program['a_position'] = pos_tr # Color. - # a_color = np.repeat(data.color, 2, axis=0).astype(np.float32) - # assert a_color.shape == (n_vertices, 4) - # self.program['a_color'] = a_color self.program['u_color'] = data.color - - -class BoxVisual(BaseVisual): - _default_color = (.35, .35, .35, 1.) - - def __init__(self): - super(BoxVisual, self).__init__() - self.set_shader('simple') - self.set_primitive_type('lines') - - def set_data(self, bounds=NDC, color=None): - # Set the position. - x0, y0, x1, y1 = bounds - arr = np.array([[x0, y0], - [x0, y1], - [x0, y1], - [x1, y1], - [x1, y1], - [x1, y0], - [x1, y0], - [x0, y0]], dtype=np.float32) - self.program['a_position'] = self.transforms.apply(arr) - - # Set the color - self.program['u_color'] = _get_color(color, self._default_color) - - -class AxesVisual(BaseVisual): - _default_color = (.2, .2, .2, 1.) - - def __init__(self): - super(AxesVisual, self).__init__() - self.set_shader('simple') - self.set_primitive_type('lines') - - def set_data(self, xs=(), ys=(), bounds=NDC, color=None): - # Set the position. - arr = [[x, bounds[1], x, bounds[3]] for x in xs] - arr += [[bounds[0], y, bounds[2], y] for y in ys] - arr = np.hstack(arr or [[]]).astype(np.float32) - arr = arr.reshape((-1, 2)).astype(np.float32) - position = self.transforms.apply(arr) - self.program['a_position'] = position - - # Set the color - self.program['u_color'] = _get_color(color, self._default_color) From 83f612ec1296ed238d443cc5ad85d8d44c8c825d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 13:40:16 +0100 Subject: [PATCH 0716/1059] Fixes --- phy/plot/plot.py | 6 +++++- phy/plot/tests/test_plot.py | 9 +++++++++ phy/plot/tests/test_visuals.py | 3 +-- phy/plot/visuals.py | 36 +++++++++++++++++++++------------- 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 15eab0d84..e2443aa66 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -17,7 +17,7 @@ from .panzoom import PanZoom from .transform import NDC from .utils import _get_array -from .visuals import ScatterVisual, PlotVisual, HistogramVisual +from .visuals import ScatterVisual, PlotVisual, HistogramVisual, LineVisual #------------------------------------------------------------------------------ @@ -127,6 +127,10 @@ def hist(self, *args, **kwargs): """Add some histograms.""" return self._add_item(HistogramVisual, *args, **kwargs) + def lines(self, *args, **kwargs): + """Add some lines.""" + return self._add_item(LineVisual, *args, **kwargs) + def __getitem__(self, box_index): self._default_box_index = box_index return self diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index a99d92d37..125ae6fbe 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -109,6 +109,15 @@ def test_grid_hist(qtbot): _show(qtbot, view) +def test_grid_lines(qtbot): + view = GridView((1, 2)) + + view[0, 0].lines(y0=-.5, y1=-.5) + view[0, 1].lines(y0=+.5, y1=+.5) + + _show(qtbot, view) + + def test_grid_complete(qtbot): view = GridView((2, 2)) t = _get_linear_x(1, 1000).ravel() diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 89c91ef50..6448b82a0 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -170,5 +170,4 @@ def test_line_empty(qtbot, canvas): def test_line_0(qtbot, canvas_pz): y = np.linspace(-.5, .5, 10) _test_visual(qtbot, canvas_pz, LineVisual(), - y0=y, y1=y, data_bounds=[-1, -1, 1, 1], - color=(1., 1., 0., .5)) + y0=y, y1=y, data_bounds=[-1, -1, 1, 1]) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index dca49dbe4..5bbdfef89 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -138,7 +138,9 @@ def _validate_line_coord(x, n, default): if not hasattr(x, '__len__'): x = x * np.ones(n) assert isinstance(x, np.ndarray) - assert x.shape == (n,) + if x.ndim == 1: + x = x[:, None] + assert x.shape == (n, 1) return x.astype(np.float32) @@ -339,12 +341,18 @@ def set_data(self): class LineVisual(BaseVisual): - _default_color = (.35, .35, .35, 1.) + """Lines. - def __init__(self): + Note: currently, all lines shall have the same color. + + """ + _default_color = (1., 1., 1., 1.) + + def __init__(self, color=None): super(LineVisual, self).__init__() self.set_shader('simple') self.set_primitive_type('lines') + self.color = color or self._default_color self.data_range = Range(NDC) self.transforms.add_on_cpu(self.data_range) @@ -354,7 +362,7 @@ def validate(x0=None, y0=None, x1=None, y1=None, - color=None, + # color=None, data_bounds=None, ): @@ -365,17 +373,16 @@ def validate(x0=None, x1 = _validate_line_coord(x1, n_lines, +1) y1 = _validate_line_coord(y1, n_lines, +1) - assert x0.shape == y0.shape == x1.shape == y1.shape == (n_lines,) + assert x0.shape == y0.shape == x1.shape == y1.shape == (n_lines, 1) + # By default, we assume that the coordinates are in NDC. if data_bounds is None: - xmin = min(x0.min(), x1.min()) - ymin = min(y0.min(), y1.min()) - xmax = max(x0.max(), x1.max()) - ymax = max(y0.max(), y1.max()) - data_bounds = np.c_[xmin, ymin, xmax, ymax] + data_bounds = NDC - color = _get_array(color, (4,), LineVisual._default_color) - assert len(color) == 4 + # NOTE: currently, we don't support custom colors. We could do it + # by replacing the uniform by an attribute in the shaders. + # color = _get_array(color, (4,), LineVisual._default_color) + # assert len(color) == 4 data_bounds = _get_data_bounds(data_bounds, length=n_lines) data_bounds = data_bounds.astype(np.float32) @@ -385,7 +392,7 @@ def validate(x0=None, y0=y0, x1=x1, y1=y1, - color=color, + # color=color, data_bounds=data_bounds, ) @@ -414,4 +421,5 @@ def set_data(self, *args, **kwargs): self.program['a_position'] = pos_tr # Color. - self.program['u_color'] = data.color + # self.program['u_color'] = data.color + self.program['u_color'] = self.color From 57c611a1494f9bc9feee5e1ce254548f5cbfe6b0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 14:12:28 +0100 Subject: [PATCH 0717/1059] WIP: add boxes in grid --- phy/plot/interact.py | 5 +++-- phy/plot/plot.py | 13 +++++++++++++ phy/plot/visuals.py | 5 +++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 6b0d948ab..2aad29fbd 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -47,9 +47,10 @@ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): def attach(self, canvas): super(Grid, self).attach(canvas) - m = 1. - .05 # Margin. + m = 1. - .05 + m2 = 1. - .075 canvas.transforms.add_on_gpu([Scale('u_grid_zoom'), - Scale((m, m)), + Scale((m2, m2)), Clip([-m, -m, m, m]), Subplot(self.shape_var, self.box_var), ]) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index e2443aa66..94d411f01 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -187,6 +187,19 @@ def __init__(self, shape=None, **kwargs): self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) self.panzoom.attach(self) + def build(self): + n, m = self.grid.shape + a = .01 # margin + for i in range(n): + for j in range(m): + self[i, j].lines(x0=[-1, +1, +1, -1], + y0=[-1, -1, +1, +1], + x1=[+1, +1, -1, -1], + y1=[-1, +1, +1, -1], + data_bounds=[-1 + a, -1 + a, 1 - a, 1 - a], + ) + super(GridView, self).build() + class BoxedView(BaseView): """Subplots at arbitrary positions""" diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 5bbdfef89..213785c41 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -137,11 +137,12 @@ def _validate_line_coord(x, n, default): x = default if not hasattr(x, '__len__'): x = x * np.ones(n) + x = np.asarray(x, dtype=np.float32) assert isinstance(x, np.ndarray) if x.ndim == 1: x = x[:, None] assert x.shape == (n, 1) - return x.astype(np.float32) + return x def _get_length(*args): @@ -346,7 +347,7 @@ class LineVisual(BaseVisual): Note: currently, all lines shall have the same color. """ - _default_color = (1., 1., 1., 1.) + _default_color = (.3, .3, .3, 1.) def __init__(self, color=None): super(LineVisual, self).__init__() From d957d60f928e770ddc6419ab95d55111def00f9b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 14:47:37 +0100 Subject: [PATCH 0718/1059] Done boxes in grid view --- phy/plot/interact.py | 2 +- phy/plot/plot.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 2aad29fbd..c6ff61df3 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -47,7 +47,7 @@ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): def attach(self, canvas): super(Grid, self).attach(canvas) - m = 1. - .05 + m = 1. - .025 m2 = 1. - .075 canvas.transforms.add_on_gpu([Scale('u_grid_zoom'), Scale((m2, m2)), diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 94d411f01..3745235ca 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -189,7 +189,7 @@ def __init__(self, shape=None, **kwargs): def build(self): n, m = self.grid.shape - a = .01 # margin + a = .045 # margin for i in range(n): for j in range(m): self[i, j].lines(x0=[-1, +1, +1, -1], From 3e4e73e820eee9bcc6e723c4c1dc56b52e94caea Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 15:19:20 +0100 Subject: [PATCH 0719/1059] WIP: increase coverage in manual clustering views --- .coveragerc | 1 + phy/cluster/manual/tests/test_views.py | 31 ++++++++++++++++++++------ phy/cluster/manual/views.py | 19 +++++++--------- phy/plot/plot.py | 2 +- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/.coveragerc b/.coveragerc index c705eacb3..e7222f912 100644 --- a/.coveragerc +++ b/.coveragerc @@ -12,3 +12,4 @@ exclude_lines = raise AssertionError raise NotImplementedError pass + return diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 92ccaf8c3..1ec0fc3a8 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -8,7 +8,7 @@ import numpy as np from numpy.testing import assert_equal as ae -from pytest import raises +from pytest import raises, yield_fixture from phy.io.mock import (artificial_waveforms, artificial_features, @@ -17,9 +17,11 @@ artificial_masks, artificial_traces, ) +from phy.gui import GUI from phy.electrode.mea import staggered_positions from ..views import (WaveformView, FeatureView, CorrelogramView, TraceView, - _extract_wave) + _extract_wave, _selected_clusters_colors, + ) #------------------------------------------------------------------------------ @@ -34,6 +36,15 @@ def _show(qtbot, view, stop=False): view.close() +@yield_fixture +def gui(qtbot): + gui = GUI(position=(200, 100), size=(800, 600)) + # gui.show() + # qtbot.waitForWindowShown(gui) + yield gui + gui.close() + + #------------------------------------------------------------------------------ # Test utils #------------------------------------------------------------------------------ @@ -62,11 +73,18 @@ def test_extract_wave(): [[16, 17, 18], [21, 22, 23], [0, 0, 0], [0, 0, 0]]) +def test_selected_clusters_colors(): + assert _selected_clusters_colors().shape[0] > 10 + assert _selected_clusters_colors(0).shape[0] == 0 + assert _selected_clusters_colors(1).shape[0] == 1 + assert _selected_clusters_colors(100).shape[0] == 100 + + #------------------------------------------------------------------------------ # Test waveform view #------------------------------------------------------------------------------ -def test_waveform_view(qtbot): +def test_waveform_view(qtbot, gui): n_spikes = 20 n_samples = 30 n_channels = 40 @@ -84,13 +102,13 @@ def test_waveform_view(qtbot): channel_positions=channel_positions, ) # Select some spikes. - spike_ids = np.arange(5) + spike_ids = np.arange(10) cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) # Show the view. - v.show() - qtbot.waitForWindowShown(v.native) + v.attach(gui) + gui.show() # Select other spikes. spike_ids = np.arange(2, 10) @@ -98,7 +116,6 @@ def test_waveform_view(qtbot): v.on_select(cluster_ids, spike_ids) # qtbot.stop() - v.close() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index c0a47537e..7eef1c97d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -14,7 +14,6 @@ from six import string_types from phy.io.array import _index_of, _get_padded, get_excerpts -from phy.electrode.mea import linear_positions from phy.gui import Actions from phy.plot import (BoxedView, StackedView, GridView, _get_linear_x) @@ -64,7 +63,7 @@ def _extract_wave(traces, spk, mask, wave_len=None): channels = np.nonzero(mask > .1)[0] # There should be at least one non-masked channel. if not len(channels): - return + return # pragma: no cover i = spk - wave_len // 2 j = spk + wave_len // 2 a, b = max(0, i), min(j, n_samples - 1) @@ -147,7 +146,7 @@ def __init__(self, spike_clusters=None, channel_positions=None, shortcuts=None, - keys='interactive', + keys=None, ): """ @@ -164,8 +163,6 @@ def __init__(self, self._spike_ids = None # Initialize the view. - if channel_positions is None: - channel_positions = linear_positions(self.n_channels) box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(box_bounds, keys=keys) @@ -494,7 +491,7 @@ def __init__(self, masks=None, spike_times=None, spike_clusters=None, - keys='interactive', + keys=None, ): assert features.ndim == 3 @@ -605,7 +602,7 @@ def __init__(self, window_size=None, excerpt_size=None, n_excerpts=None, - keys='interactive', + keys=None, ): assert sample_rate > 0 @@ -648,10 +645,10 @@ def on_select(self, cluster_ids, spike_ids): sc = get_excerpts(sc, excerpt_size=self.excerpt_size, n_excerpts=self.n_excerpts) n_spikes_exerpts = len(st) - logger.debug("Computing correlograms for clusters %s (%d/%d spikes).", - ', '.join(map(str, cluster_ids)), - n_spikes_exerpts, n_spikes_total, - ) + logger.log(5, "Computing correlograms for clusters %s (%d/%d spikes).", + ', '.join(map(str, cluster_ids)), + n_spikes_exerpts, n_spikes_total, + ) # Compute all pairwise correlograms. ccg = correlograms(st, sc, diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 3745235ca..3915f7003 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -87,7 +87,7 @@ class BaseView(BaseCanvas): def __init__(self, **kwargs): if not kwargs.get('keys', None): - kwargs['keys'] = 'interactive' + kwargs['keys'] = None super(BaseView, self).__init__(**kwargs) self.clear() From aeb44a5b3bf01fc8c5e2806dd3f720ec64dd4ad8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 15:25:29 +0100 Subject: [PATCH 0720/1059] Increase coverage --- phy/cluster/manual/tests/test_views.py | 13 ++++++---- phy/cluster/manual/views.py | 33 -------------------------- 2 files changed, 8 insertions(+), 38 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 1ec0fc3a8..e134285ca 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -115,6 +115,8 @@ def test_waveform_view(qtbot, gui): cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) + v.toggle_waveform_overlap() + # qtbot.stop() @@ -139,19 +141,19 @@ def test_trace_view_spikes(qtbot): n_channels = 12 sample_rate = 2000. n_spikes = 20 - n_clusters = 3 + # n_clusters = 3 traces = artificial_traces(n_samples, n_channels) spike_times = artificial_spike_samples(n_spikes) / sample_rate # spike_times = [.1, .2] - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + # spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) masks = artificial_masks(n_spikes, n_channels) # Create the view. v = TraceView(traces=traces, sample_rate=sample_rate, spike_times=spike_times, - spike_clusters=spike_clusters, + # spike_clusters=spike_clusters, masks=masks, n_samples_per_spike=6, ) @@ -201,7 +203,7 @@ def test_feature_view(qtbot): # Test correlogram view #------------------------------------------------------------------------------ -def test_correlogram_view(qtbot): +def test_correlogram_view(qtbot, gui): n_spikes = 50 n_clusters = 5 sample_rate = 20000. @@ -235,5 +237,6 @@ def test_correlogram_view(qtbot): cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) + v.attach(gui) + gui.show() # qtbot.stop() - v.close() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 7eef1c97d..6826029d8 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -11,7 +11,6 @@ import numpy as np from matplotlib.colors import hsv_to_rgb, rgb_to_hsv -from six import string_types from phy.io.array import _index_of, _get_padded, get_excerpts from phy.gui import Actions @@ -19,7 +18,6 @@ _get_linear_x) from phy.plot.utils import _get_boxes from phy.stats import correlograms -from phy.utils._types import _is_integer logger = logging.getLogger(__name__) @@ -385,24 +383,6 @@ def attach(self, gui): # Feature view # ----------------------------------------------------------------------------- -def _check_dimension(dim, n_channels, n_features): - """Check that a dimension is valid.""" - if _is_integer(dim): - dim = (dim, 0) - if isinstance(dim, tuple): - assert len(dim) == 2 - channel, feature = dim - assert _is_integer(channel) - assert _is_integer(feature) - assert 0 <= channel < n_channels - assert 0 <= feature < n_features - elif isinstance(dim, string_types): - assert dim == 'time' - elif dim: - raise ValueError('{0} should be (channel, feature) '.format(dim) + - 'or one of the extra features.') - - def _dimensions_matrix(x_channels, y_channels): """Dimensions matrix.""" # time, depth time, (x, 0) time, (y, 0) time, (z, 0) @@ -453,19 +433,6 @@ def _dimensions_for_clusters(cluster_ids, n_cols=None, return _dimensions_matrix(x_channels, y_channels) -def _smart_dim(dim, n_features=None, prev_dim=None, prev_dim_other=None): - channel, feature = dim - prev_channel, prev_feature = prev_dim - # Scroll the feature if the channel is the same. - if prev_channel == channel: - feature = (prev_feature + 1) % n_features - # Scroll the feature if it is the same than in the other axis. - if (prev_dim_other != 'time' and - prev_dim_other == (channel, feature)): - feature = (feature + 1) % n_features - dim = (channel, feature) - - def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): """Return the mask and depth vectors for a given dimension.""" n_spikes = masks.shape[0] From 8246ea4eefbf575d5526772961ddf8b28e27db9d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 16:11:21 +0100 Subject: [PATCH 0721/1059] WIP: refactor trace view --- phy/cluster/manual/tests/test_views.py | 8 +- phy/cluster/manual/views.py | 103 ++++++++++++++++--------- 2 files changed, 72 insertions(+), 39 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index e134285ca..1caa9f1e3 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -141,21 +141,21 @@ def test_trace_view_spikes(qtbot): n_channels = 12 sample_rate = 2000. n_spikes = 20 - # n_clusters = 3 + n_clusters = 3 traces = artificial_traces(n_samples, n_channels) spike_times = artificial_spike_samples(n_spikes) / sample_rate - # spike_times = [.1, .2] - # spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) masks = artificial_masks(n_spikes, n_channels) # Create the view. v = TraceView(traces=traces, sample_rate=sample_rate, spike_times=spike_times, - # spike_clusters=spike_clusters, + spike_clusters=spike_clusters, masks=masks, n_samples_per_spike=6, + keys='interactive', ) _show(qtbot, v) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 6826029d8..5172a934c 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -259,11 +259,13 @@ def __init__(self, spike_clusters=None, masks=None, n_samples_per_spike=None, + keys=None, ): # Sample rate. assert sample_rate > 0 self.sample_rate = sample_rate + self.dt = 1. / self.sample_rate # Traces. assert traces.ndim == 2 @@ -294,7 +296,7 @@ def __init__(self, self.spike_times = self.spike_clusters = self.masks = None # Initialize the view. - super(TraceView, self).__init__(self.n_channels) + super(TraceView, self).__init__(self.n_channels, keys=keys) # TODO: choose the interval. self.set_interval((0., .25)) @@ -315,59 +317,90 @@ def _load_traces(self, interval): return traces def _load_spikes(self, interval): + """Return spike times, spike clusters, masks.""" assert self.spike_times is not None # Keep the spikes in the interval. a, b = self.spike_times.searchsorted(interval) return self.spike_times[a:b], self.spike_clusters[a:b], self.masks[a:b] - def set_interval(self, interval): + def _plot_traces(self, traces, start=None, data_bounds=None): + t = start + np.arange(traces.shape[0]) * self.dt + gray = .4 + for ch in range(self.n_channels): + self[ch].plot(t, traces[:, ch], + color=(gray, gray, gray, 1), + data_bounds=data_bounds) + + def _plot_spike(self, spike_idx, start=None, cluster_ids=None, + traces=None, spike_times=None, spike_clusters=None, + masks=None, data_bounds=None): + + wave_len = self.n_samples_per_spike + dur_spike = wave_len * self.dt + trace_start = int(self.sample_rate * start) + + # Find the first x of the spike, relative to the start of + # the interval + sample_rel = (int(spike_times[spike_idx] * self.sample_rate) - + trace_start) + + # Determine the color as a function of the spike's cluster. + clu = spike_clusters[spike_idx] + if cluster_ids is None or clu not in cluster_ids: + gray = .9 + color = (gray, gray, gray, 1) + else: + clu_rel = cluster_ids.index(clu) + color = _COLORMAP[clu_rel % len(_COLORMAP)] + + # Extract the waveform from the traces. + w, ch = _extract_wave(traces, sample_rel, + self.masks[spike_idx], wave_len) + + # Generate the x coordinates of the waveform. + t0 = spike_times[spike_idx] - dur_spike / 2. + t = t0 + self.dt * np.arange(wave_len) + t = np.tile(t, (len(ch), 1)) + + # The box index depends on the channel. + box_index = np.repeat(ch[:, np.newaxis], wave_len, axis=0) + self.plot(t, w.T, color=color, box_index=box_index, + data_bounds=data_bounds) + + def set_interval(self, interval, cluster_ids=None): + """Display the traces and spikes in a given interval.""" self.clear() start, end = interval - color = (.5, .5, .5, 1) - - dt = 1. / self.sample_rate # Load traces. traces = self._load_traces(interval) - n_samples = traces.shape[0] assert traces.shape[1] == self.n_channels + # Determine the data bounds. m, M = traces.min(), traces.max() data_bounds = np.array([start, m, end, M]) - # Generate the trace plots. - # TODO OPTIM: avoid the loop and generate all channel traces in + # Plot the traces. + # OPTIM: avoid the loop and generate all channel traces in # one pass with NumPy (but need to set a_box_index manually too). - # t = _get_linear_x(1, traces.shape[0]) - t = start + np.arange(n_samples) * dt - for ch in range(self.n_channels): - self[ch].plot(t, traces[:, ch], color=color, - data_bounds=data_bounds) + self._plot_traces(traces, start=start, data_bounds=data_bounds) # Display the spikes. if self.spike_times is not None: - wave_len = self.n_samples_per_spike + # Load the spikes. spike_times, spike_clusters, masks = self._load_spikes(interval) - n_spikes = len(spike_times) - dt = 1. / float(self.sample_rate) - dur_spike = wave_len * dt - trace_start = int(self.sample_rate * start) - - for i in range(n_spikes): - sample_rel = (int(spike_times[i] * self.sample_rate) - - trace_start) - mask = self.masks[i] - # TODO: color of spike = white or color if selected cluster - # clu = spike_clusters[i] - w, ch = _extract_wave(traces, sample_rel, mask, wave_len) - n_ch = len(ch) - t0 = spike_times[i] - dur_spike / 2. - color = np.array([1, 0, 0, 1]) - box_index = np.repeat(ch[:, np.newaxis], wave_len, axis=0) - t = t0 + dt * np.arange(wave_len) - t = np.tile(t, (n_ch, 1)) - self.plot(t, w.T, color=color, box_index=box_index, - data_bounds=data_bounds) + + # Plot every spike. + for i in range(len(spike_times)): + self._plot_spike(i, + start=start, + cluster_ids=cluster_ids, + traces=traces, + spike_times=spike_times, + spike_clusters=spike_clusters, + masks=masks, + data_bounds=data_bounds, + ) self.build() self.update() @@ -518,7 +551,7 @@ def on_select(self, cluster_ids, spike_ids): best_channels_func=None) # Plot all features. - # TODO: optim: avoid the loop. + # OPTIM: avoid the loop. with self.building(): for i in range(self.n_cols): for j in range(self.n_cols): From 83d805faff7715d0727faa42fe6eea4ea7352cf5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 16:37:16 +0100 Subject: [PATCH 0722/1059] WIP: trace view --- phy/cluster/manual/tests/test_views.py | 19 ++++++++++--- phy/cluster/manual/views.py | 37 +++++++++++++++++++------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 1caa9f1e3..771d214b0 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -136,7 +136,7 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def test_trace_view_spikes(qtbot): +def test_trace_view_spikes(qtbot, gui): n_samples = 1000 n_channels = 12 sample_rate = 2000. @@ -155,9 +155,22 @@ def test_trace_view_spikes(qtbot): spike_clusters=spike_clusters, masks=masks, n_samples_per_spike=6, - keys='interactive', ) - _show(qtbot, v) + + # Select some spikes. + spike_ids = np.arange(10) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # Show the view. + v.attach(gui) + gui.show() + + # Select other spikes. + spike_ids = np.arange(2, 10) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + # qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 5172a934c..981317f40 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -284,8 +284,8 @@ def __init__(self, self.spike_times = spike_times # Spike clusters. - if spike_clusters is None: - spike_clusters = np.zeros(self.n_spikes) + spike_clusters = (np.zeros(self.n_spikes) if spike_clusters is None + else spike_clusters) assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters @@ -344,6 +344,10 @@ def _plot_spike(self, spike_idx, start=None, cluster_ids=None, sample_rel = (int(spike_times[spike_idx] * self.sample_rate) - trace_start) + # Extract the waveform from the traces. + w, ch = _extract_wave(traces, sample_rel, + masks[spike_idx], wave_len) + # Determine the color as a function of the spike's cluster. clu = spike_clusters[spike_idx] if cluster_ids is None or clu not in cluster_ids: @@ -351,11 +355,12 @@ def _plot_spike(self, spike_idx, start=None, cluster_ids=None, color = (gray, gray, gray, 1) else: clu_rel = cluster_ids.index(clu) - color = _COLORMAP[clu_rel % len(_COLORMAP)] - - # Extract the waveform from the traces. - w, ch = _extract_wave(traces, sample_rel, - self.masks[spike_idx], wave_len) + r, g, b = (_COLORMAP[clu_rel % len(_COLORMAP)] / 255.) + color = (r, g, b, 1.) + sc = clu_rel * np.ones(len(ch), dtype=np.int32) + color = _get_color(masks[spike_idx, ch], + spike_clusters_rel=sc, + n_clusters=len(cluster_ids)) # Generate the x coordinates of the waveform. t0 = spike_times[spike_idx] - dur_spike / 2. @@ -371,6 +376,7 @@ def set_interval(self, interval, cluster_ids=None): """Display the traces and spikes in a given interval.""" self.clear() start, end = interval + cluster_ids = list(cluster_ids) if cluster_ids is not None else None # Load traces. traces = self._load_traces(interval) @@ -406,10 +412,23 @@ def set_interval(self, interval, cluster_ids=None): self.update() def on_select(self, cluster_ids, spike_ids): - pass + # TODO: choose the interval. + self.set_interval((0., .25), cluster_ids=cluster_ids) def attach(self, gui): - pass + """Attach the view to the GUI.""" + + # Disable keyboard pan so that we can use arrows as global shortcuts + # in the GUI. + self.panzoom.enable_keyboard_pan = False + + gui.add_view(self) + + gui.connect_(self.on_select) + # gui.connect_(self.on_cluster) + + # self.actions = Actions(gui, default_shortcuts=self.shortcuts) + # self.actions.add(self.toggle_waveform_overlap) # ----------------------------------------------------------------------------- From a5c1e97d16f948d2a70fc96a5079438b0ceb8d8e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 17:34:38 +0100 Subject: [PATCH 0723/1059] WIP: trace view actions --- phy/cluster/manual/tests/test_views.py | 7 +- phy/cluster/manual/views.py | 90 ++++++++++++++++++++------ phy/plot/interact.py | 12 ++-- 3 files changed, 83 insertions(+), 26 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 771d214b0..78d711795 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -140,7 +140,7 @@ def test_trace_view_spikes(qtbot, gui): n_samples = 1000 n_channels = 12 sample_rate = 2000. - n_spikes = 20 + n_spikes = 50 n_clusters = 3 traces = artificial_traces(n_samples, n_channels) @@ -170,6 +170,11 @@ def test_trace_view_spikes(qtbot, gui): spike_ids = np.arange(2, 10) cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) + + v.go_to(.5) + v.go_to(-.5) + v.go_left() + v.go_right() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 981317f40..d6a43c5cd 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -129,7 +129,6 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): # ----------------------------------------------------------------------------- class WaveformView(BoxedView): - # TODO: make this configurable normalization_percentile = .95 normalization_n_spikes = 1000 overlap = True @@ -252,6 +251,13 @@ def toggle_waveform_overlap(self): # ----------------------------------------------------------------------------- class TraceView(StackedView): + interval_duration = .25 # default duration of the interval + shift_amount = .1 + default_shortcuts = { + 'go_left': 'ctrl+left', + 'go_right': 'ctrl+right', + } + def __init__(self, traces=None, sample_rate=None, @@ -259,9 +265,14 @@ def __init__(self, spike_clusters=None, masks=None, n_samples_per_spike=None, + shortcuts=None, keys=None, ): + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + # Sample rate. assert sample_rate > 0 self.sample_rate = sample_rate @@ -271,10 +282,11 @@ def __init__(self, assert traces.ndim == 2 self.n_samples, self.n_channels = traces.shape self.traces = traces + self.duration = self.dt * self.n_samples # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or - int(.002 * sample_rate)) + round(.002 * sample_rate)) # Spike times. if spike_times is not None: @@ -298,15 +310,19 @@ def __init__(self, # Initialize the view. super(TraceView, self).__init__(self.n_channels, keys=keys) - # TODO: choose the interval. - self.set_interval((0., .25)) + # Initial interval. + self.cluster_ids = [] + self.set_interval((0., self.interval_duration)) + + # We use ctrl+left|right to navigate in the view. + self.enable_box_width_shortcuts = False def _load_traces(self, interval): """Load traces in an interval (in seconds).""" start, end = interval - i, j = int(self.sample_rate * start), int(self.sample_rate * end) + i, j = round(self.sample_rate * start), round(self.sample_rate * end) traces = self.traces[i:j, :] # Detrend the traces. @@ -331,17 +347,17 @@ def _plot_traces(self, traces, start=None, data_bounds=None): color=(gray, gray, gray, 1), data_bounds=data_bounds) - def _plot_spike(self, spike_idx, start=None, cluster_ids=None, + def _plot_spike(self, spike_idx, start=None, traces=None, spike_times=None, spike_clusters=None, masks=None, data_bounds=None): wave_len = self.n_samples_per_spike dur_spike = wave_len * self.dt - trace_start = int(self.sample_rate * start) + trace_start = round(self.sample_rate * start) # Find the first x of the spike, relative to the start of # the interval - sample_rel = (int(spike_times[spike_idx] * self.sample_rate) - + sample_rel = (round(spike_times[spike_idx] * self.sample_rate) - trace_start) # Extract the waveform from the traces. @@ -350,17 +366,17 @@ def _plot_spike(self, spike_idx, start=None, cluster_ids=None, # Determine the color as a function of the spike's cluster. clu = spike_clusters[spike_idx] - if cluster_ids is None or clu not in cluster_ids: + if self.cluster_ids is None or clu not in self.cluster_ids: gray = .9 color = (gray, gray, gray, 1) else: - clu_rel = cluster_ids.index(clu) + clu_rel = self.cluster_ids.index(clu) r, g, b = (_COLORMAP[clu_rel % len(_COLORMAP)] / 255.) color = (r, g, b, 1.) sc = clu_rel * np.ones(len(ch), dtype=np.int32) color = _get_color(masks[spike_idx, ch], spike_clusters_rel=sc, - n_clusters=len(cluster_ids)) + n_clusters=len(self.cluster_ids)) # Generate the x coordinates of the waveform. t0 = spike_times[spike_idx] - dur_spike / 2. @@ -372,11 +388,26 @@ def _plot_spike(self, spike_idx, start=None, cluster_ids=None, self.plot(t, w.T, color=color, box_index=box_index, data_bounds=data_bounds) - def set_interval(self, interval, cluster_ids=None): + def _restrict_interval(self, interval): + start, end = interval + # Restrict the interval to the boundaries of the traces. + if start < 0: + end += (-start) + start = 0 + elif end >= self.duration: + start -= (end - self.duration) + end = self.duration + start = np.clip(start, 0, end) + end = np.clip(end, start, self.duration) + assert 0 <= start < end <= self.duration + return start, end + + def set_interval(self, interval): """Display the traces and spikes in a given interval.""" self.clear() + interval = self._restrict_interval(interval) + self.interval = interval start, end = interval - cluster_ids = list(cluster_ids) if cluster_ids is not None else None # Load traces. traces = self._load_traces(interval) @@ -400,7 +431,6 @@ def set_interval(self, interval, cluster_ids=None): for i in range(len(spike_times)): self._plot_spike(i, start=start, - cluster_ids=cluster_ids, traces=traces, spike_times=spike_times, spike_clusters=spike_clusters, @@ -412,8 +442,8 @@ def set_interval(self, interval, cluster_ids=None): self.update() def on_select(self, cluster_ids, spike_ids): - # TODO: choose the interval. - self.set_interval((0., .25), cluster_ids=cluster_ids) + self.cluster_ids = list(cluster_ids) + self.set_interval((0., .25)) def attach(self, gui): """Attach the view to the GUI.""" @@ -427,8 +457,30 @@ def attach(self, gui): gui.connect_(self.on_select) # gui.connect_(self.on_cluster) - # self.actions = Actions(gui, default_shortcuts=self.shortcuts) - # self.actions.add(self.toggle_waveform_overlap) + self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions.add(self.go_to, alias='tg') + self.actions.add(self.shift, alias='ts') + self.actions.add(self.go_right) + self.actions.add(self.go_left) + + def go_to(self, time): + start, end = self.interval + half_dur = (end - start) * .5 + self.set_interval((time - half_dur, time + half_dur)) + + def shift(self, delay): + time = sum(self.interval) * .5 + self.go_to(time + delay) + + def go_right(self): + start, end = self.interval + delay = (end - start) * .2 + self.shift(delay) + + def go_left(self): + start, end = self.interval + delay = (end - start) * .2 + self.shift(-delay) # ----------------------------------------------------------------------------- @@ -501,7 +553,6 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(GridView): - # TODO: make this configurable normalization_percentile = .95 normalization_n_spikes = 1000 @@ -609,7 +660,6 @@ def attach(self, gui): # ----------------------------------------------------------------------------- class CorrelogramView(GridView): - # TODO: make this configurable excerpt_size = 10000 n_excerpts = 100 diff --git a/phy/plot/interact.py b/phy/plot/interact.py index c6ff61df3..1fd778d64 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -153,6 +153,7 @@ def __init__(self, assert self._box_bounds.shape[1] == 4 self.n_boxes = len(self._box_bounds) + self.enable_box_width_shortcuts = True def attach(self, canvas): super(Boxed, self).attach(canvas) @@ -231,11 +232,12 @@ def on_key_press(self, event): if ctrl and key in self._arrows + self._pm: coeff = 1.1 box_size = np.array(self.box_size) - if key == 'Left': - box_size[0] /= coeff - elif key == 'Right': - box_size[0] *= coeff - elif key in ('Down', '-'): + if self.enable_box_width_shortcuts: + if key == 'Left': + box_size[0] /= coeff + elif key == 'Right': + box_size[0] *= coeff + if key in ('Down', '-'): box_size[1] /= coeff elif key in ('Up', '+'): box_size[1] *= coeff From 27cec61f3190673593f2f54b3dad4d09e49ae226 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 18:41:03 +0100 Subject: [PATCH 0724/1059] Bug fixes --- phy/cluster/manual/views.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d6a43c5cd..9f576cf09 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -279,7 +279,7 @@ def __init__(self, self.dt = 1. / self.sample_rate # Traces. - assert traces.ndim == 2 + assert len(traces.shape) == 2 self.n_samples, self.n_channels = traces.shape self.traces = traces self.duration = self.dt * self.n_samples @@ -323,10 +323,11 @@ def _load_traces(self, interval): start, end = interval i, j = round(self.sample_rate * start), round(self.sample_rate * end) + i, j = int(i), int(j) traces = self.traces[i:j, :] # Detrend the traces. - m = np.mean(traces[::10, :], axis=0) + m = np.mean(traces[::10, :], axis=0).astype(traces.dtype) traces -= m # Create the plots. @@ -564,7 +565,7 @@ def __init__(self, keys=None, ): - assert features.ndim == 3 + assert len(features.shape) == 3 self.n_spikes, self.n_channels, self.n_features = features.shape self.n_cols = self.n_features + 1 self.shape = (self.n_cols, self.n_cols) From 825662aa21419dbda892d4dc59dffa7a4fcf21aa Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 18:44:43 +0100 Subject: [PATCH 0725/1059] Minor updates --- phy/cluster/manual/views.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 9f576cf09..d95348d45 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -251,11 +251,11 @@ def toggle_waveform_overlap(self): # ----------------------------------------------------------------------------- class TraceView(StackedView): - interval_duration = .25 # default duration of the interval + interval_duration = .5 # default duration of the interval shift_amount = .1 default_shortcuts = { - 'go_left': 'ctrl+left', - 'go_right': 'ctrl+right', + 'go_left': 'alt+left', + 'go_right': 'alt+right', } def __init__(self, @@ -314,9 +314,6 @@ def __init__(self, self.cluster_ids = [] self.set_interval((0., self.interval_duration)) - # We use ctrl+left|right to navigate in the view. - self.enable_box_width_shortcuts = False - def _load_traces(self, interval): """Load traces in an interval (in seconds).""" @@ -444,7 +441,7 @@ def set_interval(self, interval): def on_select(self, cluster_ids, spike_ids): self.cluster_ids = list(cluster_ids) - self.set_interval((0., .25)) + self.set_interval(self.interval) def attach(self, gui): """Attach the view to the GUI.""" From a80e0cd5b3d0a50c8dc7d5ea1f6dd50ea47b9cfc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 20:22:09 +0100 Subject: [PATCH 0726/1059] Fix aspect ratio issues in boxed and stacked --- phy/plot/interact.py | 30 ++++++++++++++++++------------ phy/plot/utils.py | 23 ++++++++++++++--------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 1fd778d64..84c97594f 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -136,8 +136,11 @@ def __init__(self, box_bounds=None, box_pos=None, box_size=None, - box_var=None): + box_var=None, + keep_aspect_ratio=True, + ): self._key_pressed = None + self.keep_aspect_ratio = keep_aspect_ratio # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' @@ -147,13 +150,13 @@ def __init__(self, assert box_pos is not None # This will find a good box size automatically if it is not # specified. - box_bounds = _get_boxes(box_pos, size=box_size) + box_bounds = _get_boxes(box_pos, size=box_size, + keep_aspect_ratio=self.keep_aspect_ratio) self._box_bounds = np.atleast_2d(box_bounds) assert self._box_bounds.shape[1] == 4 self.n_boxes = len(self._box_bounds) - self.enable_box_width_shortcuts = True def attach(self, canvas): super(Boxed, self).attach(canvas) @@ -201,7 +204,8 @@ def box_pos(self): @box_pos.setter def box_pos(self, val): assert val.shape == (self.n_boxes, 2) - self.box_bounds = _get_boxes(val, size=self.box_size) + self.box_bounds = _get_boxes(val, size=self.box_size, + keep_aspect_ratio=self.keep_aspect_ratio) @property def box_size(self): @@ -211,7 +215,8 @@ def box_size(self): @box_size.setter def box_size(self, val): assert len(val) == 2 - self.box_bounds = _get_boxes(self.box_pos, size=val) + self.box_bounds = _get_boxes(self.box_pos, size=val, + keep_aspect_ratio=self.keep_aspect_ratio) # Interaction event callbacks #-------------------------------------------------------------------------- @@ -232,12 +237,11 @@ def on_key_press(self, event): if ctrl and key in self._arrows + self._pm: coeff = 1.1 box_size = np.array(self.box_size) - if self.enable_box_width_shortcuts: - if key == 'Left': - box_size[0] /= coeff - elif key == 'Right': - box_size[0] *= coeff - if key in ('Down', '-'): + if key == 'Left': + box_size[0] /= coeff + elif key == 'Right': + box_size[0] *= coeff + elif key in ('Down', '-'): box_size[1] /= coeff elif key in ('Up', '+'): box_size[1] *= coeff @@ -294,4 +298,6 @@ def __init__(self, n_boxes, margin=0, box_var=None): b[:, 3] = np.linspace(-1 + 2. / n_boxes - margin, 1., n_boxes) b = b[::-1, :] - super(Stacked, self).__init__(b, box_var=box_var) + super(Stacked, self).__init__(b, box_var=box_var, + keep_aspect_ratio=False, + ) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index d34c0d055..33a5f85f8 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -51,6 +51,7 @@ def _get_box_size(x, y, ar=.5, margin=0): # Deal with degenerate x case. xmin, xmax = x.min(), x.max() if xmin == xmax: + # If all positions are vertical, the width can be maximum. wmax = 1. else: wmax = xmax - xmin @@ -58,18 +59,20 @@ def _get_box_size(x, y, ar=.5, margin=0): def f1(w): """Return true if the configuration with the current box size is non-overlapping.""" + # NOTE: w|h are the *half* width|height. h = w * ar # fixed aspect ratio return not _boxes_overlap(x - w, y - h, x + w, y + h) # Find the largest box size leading to non-overlapping boxes. w = _binary_search(f1, 0, wmax) w = w * (1 - margin) # margin + # Clip the half-width. h = w * ar # aspect ratio return w, h -def _get_boxes(pos, size=None, margin=0): +def _get_boxes(pos, size=None, margin=0, keep_aspect_ratio=True): """Generate non-overlapping boxes in NDC from a set of positions.""" # Get x, y. @@ -85,15 +88,17 @@ def _get_boxes(pos, size=None, margin=0): # Renormalize the whole thing by keeping the aspect ratio. x0min, y0min, x1max, y1max = x0.min(), y0.min(), x1.max(), y1.max() - dx = x1max - x0min - dy = y1max - y0min - if dx > dy: - b = (x0min, (y1max + y0min) / 2. - dx / 2., - x1max, (y1max + y0min) / 2. + dx / 2.) + if not keep_aspect_ratio: + b = (x0min, y0min, x1max, y1max) else: - b = ((x1max + x0min) / 2. - dy / 2., y0min, - (x1max + x0min) / 2. + dy / 2., y1max) - + dx = x1max - x0min + dy = y1max - y0min + if dx > dy: + b = (x0min, (y1max + y0min) / 2. - dx / 2., + x1max, (y1max + y0min) / 2. + dx / 2.) + else: + b = ((x1max + x0min) / 2. - dy / 2., y0min, + (x1max + x0min) / 2. + dy / 2., y1max) r = Range(from_bounds=b, to_bounds=(-1, -1, 1, 1)) return np.c_[r.apply(np.c_[x0, y0]), r.apply(np.c_[x1, y1])] From 936993457e6f02c5acff7bd6c717f40d8b39cb76 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 20:24:47 +0100 Subject: [PATCH 0727/1059] Increase coverage --- phy/plot/tests/test_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 0933deee8..a8138f432 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -113,6 +113,11 @@ def test_get_boxes(): ac(boxes, [[-1, -.25, 0, .25], [+0, -.25, 1, .25]], atol=1e-4) + positions = [[-1, 0], [1, 0]] + boxes = _get_boxes(positions, keep_aspect_ratio=False) + ac(boxes, [[-1, -1, 0, 1], + [0, -1, 1, 1]], atol=1e-4) + positions = linear_positions(4) boxes = _get_boxes(positions) ac(boxes, [[-0.5, -1.0, +0.5, -0.5], From aef864af8f2918b8d9313110556d3c7bf4db9a8b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 21:51:35 +0100 Subject: [PATCH 0728/1059] Attach features --- phy/cluster/manual/views.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d95348d45..21f9077fd 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -650,7 +650,16 @@ def on_select(self, cluster_ids, spike_ids): ) def attach(self, gui): - pass + """Attach the view to the GUI.""" + + # Disable keyboard pan so that we can use arrows as global shortcuts + # in the GUI. + self.panzoom.enable_keyboard_pan = False + + gui.add_view(self) + + gui.connect_(self.on_select) + # gui.connect_(self.on_cluster) # ----------------------------------------------------------------------------- From c13dae98c6f66f002b270698d050bb71d1677500 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 14 Dec 2015 21:54:02 +0100 Subject: [PATCH 0729/1059] Increase coverage --- phy/cluster/manual/tests/test_views.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 78d711795..0945e3f0d 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -182,7 +182,7 @@ def test_trace_view_spikes(qtbot, gui): # Test feature view #------------------------------------------------------------------------------ -def test_feature_view(qtbot): +def test_feature_view(gui, qtbot): n_spikes = 50 n_channels = 5 n_clusters = 2 @@ -204,9 +204,8 @@ def test_feature_view(qtbot): cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) - # Show the view. - v.show() - qtbot.waitForWindowShown(v.native) + v.attach(gui) + gui.show() # Select other spikes. spike_ids = np.arange(2, 10) @@ -214,7 +213,6 @@ def test_feature_view(qtbot): v.on_select(cluster_ids, spike_ids) # qtbot.stop() - v.close() #------------------------------------------------------------------------------ From 026764d7a32400e14dfa3dcd5a84142278f3f1e9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 09:09:35 +0100 Subject: [PATCH 0730/1059] Remove global zoom in grid interact --- phy/plot/interact.py | 34 +-------------------------------- phy/plot/tests/test_interact.py | 20 ------------------- 2 files changed, 1 insertion(+), 53 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 84c97594f..cfb31d1b0 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -38,8 +38,6 @@ class Grid(BaseInteract): """ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): - self._zoom = 1. - # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' self.shape_var = shape_var @@ -49,21 +47,17 @@ def attach(self, canvas): super(Grid, self).attach(canvas) m = 1. - .025 m2 = 1. - .075 - canvas.transforms.add_on_gpu([Scale('u_grid_zoom'), - Scale((m2, m2)), + canvas.transforms.add_on_gpu([Scale((m2, m2)), Clip([-m, -m, m, m]), Subplot(self.shape_var, self.box_var), ]) canvas.inserter.insert_vert(""" attribute vec2 {}; uniform vec2 {}; - uniform float u_grid_zoom; """.format(self.box_var, self.shape_var), 'header') - canvas.connect(self.on_key_press) def update_program(self, program): - program['u_grid_zoom'] = self._zoom program[self.shape_var] = self._shape # Only set the default box index if necessary. try: @@ -71,17 +65,6 @@ def update_program(self, program): except KeyError: program[self.box_var] = (0, 0) - @property - def zoom(self): - """Zoom level.""" - return self._zoom - - @zoom.setter - def zoom(self, value): - """Zoom level.""" - self._zoom = value - self.update() - @property def shape(self): return self._shape @@ -91,21 +74,6 @@ def shape(self, value): self._shape = value self.update() - def on_key_press(self, event): - """Pan and zoom with the keyboard.""" - key = event.key - - # Zoom. - if key in ('-', '+') and event.modifiers == ('Control',): - k = .05 if key == '+' else -.05 - self.zoom *= math.exp(1.5 * k) - self.update() - - # Reset with 'R'. - if key == 'R': - self.zoom = 1. - self.update() - #------------------------------------------------------------------------------ # Boxed interact diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 36014b0c0..fa690a539 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -85,26 +85,6 @@ def test_grid_1(qtbot, canvas): grid = Grid((2, 3)) _create_visual(qtbot, canvas, grid, box_index) - # No effect without modifiers. - c.events.key_press(key=keys.Key('+')) - assert grid.zoom == 1. - - # Zoom with the keyboard. - c.events.key_press(key=keys.Key('+'), modifiers=(keys.CONTROL,)) - assert grid.zoom > 1 - - # Unzoom with the keyboard. - c.events.key_press(key=keys.Key('-'), modifiers=(keys.CONTROL,)) - assert grid.zoom == 1. - - # Set the zoom explicitly. - grid.zoom = 2 - assert grid.zoom == 2. - - # Press 'R'. - c.events.key_press(key=keys.Key('r')) - assert grid.zoom == 1. - # qtbot.stop() From 85a71c4b90aea0e059c86331e0e32369d4722f64 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 14:22:58 +0100 Subject: [PATCH 0731/1059] WIP: attributes in feature view --- phy/cluster/manual/tests/test_views.py | 8 ++-- phy/cluster/manual/views.py | 54 +++++++++++++++++++++----- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 0945e3f0d..4b9234deb 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -207,10 +207,10 @@ def test_feature_view(gui, qtbot): v.attach(gui) gui.show() - # Select other spikes. - spike_ids = np.arange(2, 10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) + # # Select other spikes. + # spike_ids = np.arange(2, 10) + # cluster_ids = np.unique(spike_clusters[spike_ids]) + # v.on_select(cluster_ids, spike_ids) # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 21f9077fd..de9b5e194 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -73,6 +73,8 @@ def _extract_wave(traces, spk, mask, wave_len=None): def _get_data_bounds(arr, n_spikes=None, percentile=None): n = arr.shape[0] + n_spikes = n_spikes or n + percentile = percentile or 100 k = max(1, n // n_spikes) w = np.abs(arr[::k]) n = w.shape[0] @@ -552,7 +554,7 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(GridView): normalization_percentile = .95 - normalization_n_spikes = 1000 + normalization_n_spikes = 10000 def __init__(self, features=None, @@ -585,23 +587,53 @@ def __init__(self, # Spike times. assert spike_times.shape == (self.n_spikes,) - self.spike_times = spike_times + + # Attributes: extra features. This is a dictionary + # {name: (array, data_bounds)} + # where each array is a `(n_spikes,)` array. + self.attributes = {} + + self.add_attribute('time', spike_times) def _get_feature(self, dim, spike_ids=None): f = self.features[spike_ids] assert f.ndim == 3 - if dim == 'time': - t = self.spike_times[spike_ids] - t0, t1 = self.spike_times[0], self.spike_times[-1] - t = -1 + 2 * (t - t0) / float(t1 - t0) - return .9 * t + if dim in self.attributes: + # Extra features like time. + values, _ = self.attributes[dim] + assert values.shape == (self.n_spikes,) + return values else: assert len(dim) == 2 ch, fet = dim - # TODO: normalization of features return f[:, ch, fet] + def _get_dim_bounds_single(self, dim): + """Return the min and max of the bounds for a single dimension.""" + if dim in self.attributes: + # Attribute: the data bounds were computed in add_attribute(). + _, y0, _, y1 = self.attributes[dim][1] + else: + # Features: the data bounds were computed in the constructor. + _, y0, _, y1 = self.data_bounds + return y0, y1 + + def _get_dim_bounds(self, x_dim, y_dim): + """Return the data bounds of a subplot, as a function of the + two x-y dimensions.""" + x0, x1 = self._get_dim_bounds_single(x_dim) + y0, y1 = self._get_dim_bounds_single(y_dim) + return [x0, y0, x1, y1] + + def add_attribute(self, name, values): + assert values.shape == (self.n_spikes,) + bounds = _get_data_bounds(values, + n_spikes=self.normalization_n_spikes, + percentile=self.normalization_percentile, + ) + self.attributes[name] = (values, bounds) + def on_select(self, cluster_ids, spike_ids): n_clusters = len(cluster_ids) n_spikes = len(spike_ids) @@ -619,7 +651,6 @@ def on_select(self, cluster_ids, spike_ids): best_channels_func=None) # Plot all features. - # OPTIM: avoid the loop. with self.building(): for i in range(self.n_cols): for j in range(self.n_cols): @@ -627,6 +658,9 @@ def on_select(self, cluster_ids, spike_ids): x = self._get_feature(x_dim[i, j], spike_ids) y = self._get_feature(y_dim[i, j], spike_ids) + data_bounds = self._get_dim_bounds(x_dim[i, j], + y_dim[i, j]) + mx, dx = _project_mask_depth(x_dim[i, j], masks, spike_clusters_rel=sc, n_clusters=n_clusters) @@ -645,7 +679,7 @@ def on_select(self, cluster_ids, spike_ids): y=y, color=color, depth=d, - data_bounds=self.data_bounds, + data_bounds=data_bounds, size=5 * np.ones(n_spikes), ) From 5e67aa9e9666ef31ba8f166d0ed9fd675ca4aae8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 14:32:30 +0100 Subject: [PATCH 0732/1059] Fixes --- phy/cluster/manual/tests/test_views.py | 8 +-- phy/cluster/manual/views.py | 99 ++++++++++++++++---------- 2 files changed, 66 insertions(+), 41 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 4b9234deb..0945e3f0d 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -207,10 +207,10 @@ def test_feature_view(gui, qtbot): v.attach(gui) gui.show() - # # Select other spikes. - # spike_ids = np.arange(2, 10) - # cluster_ids = np.unique(spike_clusters[spike_ids]) - # v.on_select(cluster_ids, spike_ids) + # Select other spikes. + spike_ids = np.arange(2, 10) + cluster_ids = np.unique(spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index de9b5e194..e7ad98e1b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -595,6 +595,21 @@ def __init__(self, self.add_attribute('time', spike_times) + def add_attribute(self, name, values): + """Add an attribute (aka extra feature). + + The values should be a 1D array with `n_spikes` elements. + + By default, there is the `time` attribute. + + """ + assert values.shape == (self.n_spikes,) + bounds = _get_data_bounds(values, + n_spikes=self.normalization_n_spikes, + percentile=self.normalization_percentile, + ) + self.attributes[name] = (values, bounds) + def _get_feature(self, dim, spike_ids=None): f = self.features[spike_ids] assert f.ndim == 3 @@ -602,7 +617,8 @@ def _get_feature(self, dim, spike_ids=None): if dim in self.attributes: # Extra features like time. values, _ = self.attributes[dim] - assert values.shape == (self.n_spikes,) + values = values[spike_ids] + assert values.shape == (len(spike_ids),) return values else: assert len(dim) == 2 @@ -626,20 +642,51 @@ def _get_dim_bounds(self, x_dim, y_dim): y0, y1 = self._get_dim_bounds_single(y_dim) return [x0, y0, x1, y1] - def add_attribute(self, name, values): - assert values.shape == (self.n_spikes,) - bounds = _get_data_bounds(values, - n_spikes=self.normalization_n_spikes, - percentile=self.normalization_percentile, - ) - self.attributes[name] = (values, bounds) + def _plot_features(self, i, j, x_dim, y_dim, + cluster_ids=None, spike_ids=None, + masks=None, spike_clusters_rel=None): + sc = spike_clusters_rel + n_clusters = len(cluster_ids) + + # Retrieve the x and y values for the subplot. + x = self._get_feature(x_dim[i, j], spike_ids) + y = self._get_feature(y_dim[i, j], spike_ids) + + # Retrieve the data bounds. + data_bounds = self._get_dim_bounds(x_dim[i, j], + y_dim[i, j]) + + # Retrieve the masks and depth. + mx, dx = _project_mask_depth(x_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + my, dy = _project_mask_depth(y_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + + d = np.maximum(dx, dy) + m = np.maximum(mx, my) + + # Get the color of the markers. + color = _get_color(m, + spike_clusters_rel=sc, + n_clusters=n_clusters) + + # Create the scatter plot for the current subplot. + self[i, j].scatter(x=x, + y=y, + color=color, + depth=d, + data_bounds=data_bounds, + size=5 * np.ones(len(spike_ids)), + ) def on_select(self, cluster_ids, spike_ids): - n_clusters = len(cluster_ids) n_spikes = len(spike_ids) if n_spikes == 0: return + # Get the masks for the selected spikes. masks = self.masks[spike_ids] sc = _get_spike_clusters_rel(self.spike_clusters, spike_ids, @@ -654,34 +701,12 @@ def on_select(self, cluster_ids, spike_ids): with self.building(): for i in range(self.n_cols): for j in range(self.n_cols): - - x = self._get_feature(x_dim[i, j], spike_ids) - y = self._get_feature(y_dim[i, j], spike_ids) - - data_bounds = self._get_dim_bounds(x_dim[i, j], - y_dim[i, j]) - - mx, dx = _project_mask_depth(x_dim[i, j], masks, - spike_clusters_rel=sc, - n_clusters=n_clusters) - my, dy = _project_mask_depth(y_dim[i, j], masks, - spike_clusters_rel=sc, - n_clusters=n_clusters) - - d = np.maximum(dx, dy) - m = np.maximum(mx, my) - - color = _get_color(m, - spike_clusters_rel=sc, - n_clusters=n_clusters) - - self[i, j].scatter(x=x, - y=y, - color=color, - depth=d, - data_bounds=data_bounds, - size=5 * np.ones(n_spikes), - ) + self._plot_features(i, j, x_dim, y_dim, + cluster_ids=cluster_ids, + spike_ids=spike_ids, + masks=masks, + spike_clusters_rel=sc, + ) def attach(self, gui): """Attach the view to the GUI.""" From 3b948705bf271e93692fe6282cac39e89b7fedc9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 14:40:23 +0100 Subject: [PATCH 0733/1059] Fix feature normalization --- phy/cluster/manual/views.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e7ad98e1b..f60819b24 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -73,14 +73,15 @@ def _extract_wave(traces, spk, mask, wave_len=None): def _get_data_bounds(arr, n_spikes=None, percentile=None): n = arr.shape[0] - n_spikes = n_spikes or n - percentile = percentile or 100 - k = max(1, n // n_spikes) + k = max(1, n // n_spikes) if n_spikes else 1 w = np.abs(arr[::k]) n = w.shape[0] w = w.reshape((n, -1)) w = w.max(axis=1) - m = np.percentile(w, percentile) + if percentile is not None: + m = np.percentile(w, percentile) + else: + m = w.max() return [-1, -m, +1, +m] @@ -604,10 +605,7 @@ def add_attribute(self, name, values): """ assert values.shape == (self.n_spikes,) - bounds = _get_data_bounds(values, - n_spikes=self.normalization_n_spikes, - percentile=self.normalization_percentile, - ) + bounds = _get_data_bounds(values) self.attributes[name] = (values, bounds) def _get_feature(self, dim, spike_ids=None): From a9e297b5883a5cb90584a610148fd5b1a9932897 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 14:47:13 +0100 Subject: [PATCH 0734/1059] Add column with number of spikes --- phy/cluster/manual/gui_component.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 165e0d7cf..3d420d56a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -198,7 +198,16 @@ def __init__(self, self.clustering = Clustering(spike_clusters) self.cluster_meta = create_cluster_meta(cluster_groups) self._global_history = GlobalHistory(process_ups=_process_ups) + self._register_logging() + # Create the cluster views. + self._create_cluster_views() + self._add_default_columns() + + # Internal methods + # ------------------------------------------------------------------------- + + def _register_logging(self): # Log the actions. @self.clustering.connect def on_cluster(up): @@ -226,10 +235,12 @@ def on_cluster(up): if self.gui: self.gui.emit('cluster', up) - # Create the cluster views. - self._create_cluster_views() - + def _add_default_columns(self): # Default columns. + @self.add_column(name='n_spikes') + def n_spikes(cluster_id): + return self.clustering.spike_counts[cluster_id] + def skip(cluster_id): """Whether to skip that cluster.""" return (self.cluster_meta.get('group', cluster_id) @@ -247,9 +258,6 @@ def similarity(cluster_id): return self.similarity_func(cluster_id, self._best) self.similarity_view.add_column(similarity) - # Internal methods - # ------------------------------------------------------------------------- - def _create_actions(self, gui): self.actions = Actions(gui, default_shortcuts=self.shortcuts) From 451a9153d8466ea7045d1f6d88a1eca629b851ba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 14:47:59 +0100 Subject: [PATCH 0735/1059] Flakify --- phy/plot/interact.py | 2 -- phy/plot/tests/test_interact.py | 1 - 2 files changed, 3 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index cfb31d1b0..e25910c05 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -7,8 +7,6 @@ # Imports #------------------------------------------------------------------------------ -import math - import numpy as np from vispy.gloo import Texture2D diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index fa690a539..c1fda6e1a 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -76,7 +76,6 @@ def _create_visual(qtbot, canvas, interact, box_index): def test_grid_1(qtbot, canvas): - c = canvas n = 1000 box_index = [[i, j] for i, j in product(range(2), range(3))] From 6df047e5cd69ddc6234dcb6c31ae490bf997ce45 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 14:52:54 +0100 Subject: [PATCH 0736/1059] Fix attribute normalization in feature view --- phy/cluster/manual/views.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f60819b24..6bd371032 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -605,8 +605,8 @@ def add_attribute(self, name, values): """ assert values.shape == (self.n_spikes,) - bounds = _get_data_bounds(values) - self.attributes[name] = (values, bounds) + lim = values.min(), values.max() + self.attributes[name] = (values, lim) def _get_feature(self, dim, spike_ids=None): f = self.features[spike_ids] @@ -627,7 +627,7 @@ def _get_dim_bounds_single(self, dim): """Return the min and max of the bounds for a single dimension.""" if dim in self.attributes: # Attribute: the data bounds were computed in add_attribute(). - _, y0, _, y1 = self.attributes[dim][1] + y0, y1 = self.attributes[dim][1] else: # Features: the data bounds were computed in the constructor. _, y0, _, y1 = self.data_bounds From 8157a0b5ec59351ce0bdbfa7dcdf1fe76af8496c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 15:01:56 +0100 Subject: [PATCH 0737/1059] WIP: fix minor graphical issue with markers --- phy/plot/glsl/scatter.frag | 8 ++++---- phy/plot/glsl/utils.glsl | 26 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag index 9a6782b4c..a88110647 100644 --- a/phy/plot/glsl/scatter.frag +++ b/phy/plot/glsl/scatter.frag @@ -1,13 +1,13 @@ -#include "antialias/filled.glsl" #include "markers/%MARKER.glsl" +#include "utils.glsl" varying vec4 v_color; varying float v_size; void main() { - vec2 P = gl_PointCoord.xy - vec2(0.5,0.5); - float point_size = v_size + 2. * (1.0 + 1.5*1.0); - float distance = marker_%MARKER(P*point_size, v_size); + vec2 P = gl_PointCoord.xy - vec2(0.5, 0.5); + float point_size = v_size + 5.; + float distance = marker_%MARKER(P *point_size, v_size); gl_FragColor = filled(distance, 1.0, 1.0, v_color); } diff --git a/phy/plot/glsl/utils.glsl b/phy/plot/glsl/utils.glsl index 5ba113336..f4df55b31 100644 --- a/phy/plot/glsl/utils.glsl +++ b/phy/plot/glsl/utils.glsl @@ -1,4 +1,28 @@ - vec4 fetch_texture(float index, sampler2D texture, float size) { return texture2D(texture, vec2(index / (size - 1.), .5)); } + +vec4 filled(float distance, float linewidth, float antialias, + vec4 bg_color) +{ + vec4 frag_color; + float t = linewidth/2.0 - antialias; + float signed_distance = distance; + float border_distance = abs(signed_distance) - t; + float alpha = border_distance/antialias; + alpha = exp(-alpha*alpha); + + if (border_distance < 0.0) + frag_color = bg_color; + else if (signed_distance < 0.0) + frag_color = bg_color; + else { + if (abs(signed_distance) < (linewidth/2.0 + antialias)) { + frag_color = vec4(bg_color.rgb, alpha * bg_color.a); + } + else { + discard; + } + } + return frag_color; +} From 5440ce68c44253078549b597c9d381af60c0b9e2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 15:11:48 +0100 Subject: [PATCH 0738/1059] Fix --- phy/plot/glsl/scatter.frag | 4 ++-- phy/plot/glsl/utils.glsl | 25 ------------------------- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag index a88110647..16097b612 100644 --- a/phy/plot/glsl/scatter.frag +++ b/phy/plot/glsl/scatter.frag @@ -1,5 +1,5 @@ #include "markers/%MARKER.glsl" -#include "utils.glsl" +#include "antialias/filled.glsl" varying vec4 v_color; varying float v_size; @@ -7,7 +7,7 @@ varying float v_size; void main() { vec2 P = gl_PointCoord.xy - vec2(0.5, 0.5); - float point_size = v_size + 5.; + float point_size = v_size + 5.; float distance = marker_%MARKER(P *point_size, v_size); gl_FragColor = filled(distance, 1.0, 1.0, v_color); } diff --git a/phy/plot/glsl/utils.glsl b/phy/plot/glsl/utils.glsl index f4df55b31..944fa9e86 100644 --- a/phy/plot/glsl/utils.glsl +++ b/phy/plot/glsl/utils.glsl @@ -1,28 +1,3 @@ vec4 fetch_texture(float index, sampler2D texture, float size) { return texture2D(texture, vec2(index / (size - 1.), .5)); } - -vec4 filled(float distance, float linewidth, float antialias, - vec4 bg_color) -{ - vec4 frag_color; - float t = linewidth/2.0 - antialias; - float signed_distance = distance; - float border_distance = abs(signed_distance) - t; - float alpha = border_distance/antialias; - alpha = exp(-alpha*alpha); - - if (border_distance < 0.0) - frag_color = bg_color; - else if (signed_distance < 0.0) - frag_color = bg_color; - else { - if (abs(signed_distance) < (linewidth/2.0 + antialias)) { - frag_color = vec4(bg_color.rgb, alpha * bg_color.a); - } - else { - discard; - } - } - return frag_color; -} From 06cfbc87cc3dc10272c8f312877e9e406ae059f2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 16:32:52 +0100 Subject: [PATCH 0739/1059] WIP --- phy/plot/interact.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index e25910c05..18793fa3d 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -35,6 +35,9 @@ class Grid(BaseInteract): """ + _margin_scale = 1 - .075 + _margin_clip = 1 - .025 + def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): # Name of the variable with the box index. self.box_var = box_var or 'a_box_index' @@ -43,10 +46,13 @@ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): def attach(self, canvas): super(Grid, self).attach(canvas) - m = 1. - .025 - m2 = 1. - .075 - canvas.transforms.add_on_gpu([Scale((m2, m2)), - Clip([-m, -m, m, m]), + canvas.transforms.add_on_gpu([Scale((self._margin_scale, + self._margin_scale)), + Clip([-self._margin_clip, + -self._margin_clip, + +self._margin_clip, + +self._margin_clip, + ]), Subplot(self.shape_var, self.box_var), ]) canvas.inserter.insert_vert(""" From 5ef8d2cf45f9b73d68116af9cc7be83629af1cc5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 16:42:33 +0100 Subject: [PATCH 0740/1059] Minor fixes --- phy/cluster/manual/views.py | 5 +---- phy/plot/plot.py | 12 ++++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 6bd371032..46fad4cfc 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -78,10 +78,7 @@ def _get_data_bounds(arr, n_spikes=None, percentile=None): n = w.shape[0] w = w.reshape((n, -1)) w = w.max(axis=1) - if percentile is not None: - m = np.percentile(w, percentile) - else: - m = w.max() + m = np.percentile(w, percentile) return [-1, -m, +1, +m] diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 3915f7003..f37f7ebaa 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -189,14 +189,14 @@ def __init__(self, shape=None, **kwargs): def build(self): n, m = self.grid.shape - a = .045 # margin + a = 1 + 0.03 # margin for i in range(n): for j in range(m): - self[i, j].lines(x0=[-1, +1, +1, -1], - y0=[-1, -1, +1, +1], - x1=[+1, +1, -1, -1], - y1=[-1, +1, +1, -1], - data_bounds=[-1 + a, -1 + a, 1 - a, 1 - a], + self[i, j].lines(x0=[-a, +a, +a, -a], + y0=[-a, -a, +a, +a], + x1=[+a, +a, -a, -a], + y1=[-a, +a, +a, -a], + data_bounds=NDC, ) super(GridView, self).build() From 9f63a4be30fd10eff47281acd5036e37019ec758 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 19:55:05 +0100 Subject: [PATCH 0741/1059] Best channels selection in feature view --- phy/cluster/manual/views.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 46fad4cfc..e434099cd 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -592,6 +592,7 @@ def __init__(self, self.attributes = {} self.add_attribute('time', spike_times) + self.best_channels_func = None def add_attribute(self, name, values): """Add an attribute (aka extra feature). @@ -676,6 +677,10 @@ def _plot_features(self, i, j, x_dim, y_dim, size=5 * np.ones(len(spike_ids)), ) + def set_best_channels_func(self, func): + """Set a function `cluster_id => list of best channels`.""" + self.best_channels_func = func + def on_select(self, cluster_ids, spike_ids): n_spikes = len(spike_ids) if n_spikes == 0: @@ -687,10 +692,16 @@ def on_select(self, cluster_ids, spike_ids): spike_ids, cluster_ids) + f = self.best_channels_func x_dim, y_dim = _dimensions_for_clusters(cluster_ids, n_cols=self.n_cols, - # TODO - best_channels_func=None) + best_channels_func=f) + + # Set a non-time attribute as y coordinate in the top-left subplot. + attrs = sorted(self.attributes) + attrs.remove('time') + if attrs: + y_dim[0, 0] = attrs[0] # Plot all features. with self.building(): From 6cf78f743fefdf12201d6c0084d8ac1d6a0bc222 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 20:05:34 +0100 Subject: [PATCH 0742/1059] Improve x-scaling of non-overlapping waveforms --- phy/cluster/manual/tests/test_views.py | 1 + phy/cluster/manual/views.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 0945e3f0d..8314f7611 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -115,6 +115,7 @@ def test_waveform_view(qtbot, gui): cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) + v.toggle_waveform_overlap() v.toggle_waveform_overlap() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e434099cd..72969f601 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -131,7 +131,7 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): class WaveformView(BoxedView): normalization_percentile = .95 normalization_n_spikes = 1000 - overlap = True + overlap = False default_shortcuts = { 'toggle_waveform_overlap': 'o', @@ -202,9 +202,11 @@ def on_select(self, cluster_ids, spike_ids): w = self.waveforms[spike_ids] t = _get_linear_x(n_spikes, self.n_samples) # Overlap. - if self.overlap: + if not self.overlap: t = t + 2.5 * (spike_clusters_rel[:, np.newaxis] - (n_clusters - 1) / 2.) + # The total width should not depend on the number of clusters. + t /= n_clusters # Depth as a function of the cluster index and masks. masks = self.masks[spike_ids] From ae9c99cc43acb657abec01d946e19c7bc06840f9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 20:17:08 +0100 Subject: [PATCH 0743/1059] Toggle correlogram normalization --- phy/cluster/manual/tests/test_views.py | 9 ++++++ phy/cluster/manual/views.py | 44 ++++++++++++++++++++------ 2 files changed, 44 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 8314f7611..06c5a7b9f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -200,6 +200,13 @@ def test_feature_view(gui, qtbot): spike_times=spike_times, spike_clusters=spike_clusters, ) + + @v.set_best_channels_func + def best_channels(cluster_id): + return list(range(n_channels)) + + v.add_attribute('sine', np.sin(np.linspace(-10., 10., n_spikes))) + # Select some spikes. spike_ids = np.arange(n_spikes) cluster_ids = np.unique(spike_clusters[spike_ids]) @@ -254,6 +261,8 @@ def test_correlogram_view(qtbot, gui): cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) + v.toggle_normalization() + v.attach(gui) gui.show() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 72969f601..6de01a81b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -540,7 +540,7 @@ def _dimensions_for_clusters(cluster_ids, n_cols=None, def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): """Return the mask and depth vectors for a given dimension.""" n_spikes = masks.shape[0] - if dim != 'time': + if isinstance(dim, tuple): ch, fet = dim m = masks[:, ch] d = _get_depth(m, @@ -736,6 +736,11 @@ def attach(self, gui): class CorrelogramView(GridView): excerpt_size = 10000 n_excerpts = 100 + uniform_normalization = False + default_shortcuts = { + 'go_left': 'alt+left', + 'go_right': 'alt+right', + } def __init__(self, spike_times=None, @@ -745,9 +750,17 @@ def __init__(self, window_size=None, excerpt_size=None, n_excerpts=None, + shortcuts=None, keys=None, ): + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + + self._cluster_ids = None + self._spike_ids = None + assert sample_rate > 0 self.sample_rate = sample_rate @@ -770,11 +783,7 @@ def __init__(self, assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters - def on_select(self, cluster_ids, spike_ids): - n_clusters = len(cluster_ids) - n_spikes = len(spike_ids) - if n_spikes == 0: - return + def _compute_correlograms(self, cluster_ids): # Keep spikes belonging to the selected clusters. ind = np.in1d(self.spike_clusters, cluster_ids) @@ -801,7 +810,19 @@ def on_select(self, cluster_ids, spike_ids): window_size=self.window_size, ) - lim = ccg.max() + return ccg + + def on_select(self, cluster_ids, spike_ids): + self._cluster_ids = cluster_ids + self._spike_ids = spike_ids + + n_clusters = len(cluster_ids) + n_spikes = len(spike_ids) + if n_spikes == 0: + return + + ccg = self._compute_correlograms(cluster_ids) + ylim = [ccg.max()] if not self.uniform_normalization else None colors = _selected_clusters_colors(n_clusters) @@ -814,9 +835,13 @@ def on_select(self, cluster_ids, spike_ids): color = np.hstack((color, [1])) self[i, j].hist(hist, color=color, - ylim=[lim], + ylim=ylim, ) + def toggle_normalization(self): + self.uniform_normalization = not self.uniform_normalization + self.on_select(self._cluster_ids, self._spike_ids) + def attach(self, gui): """Attach the view to the GUI.""" @@ -829,4 +854,5 @@ def attach(self, gui): gui.connect_(self.on_select) # gui.connect_(self.on_cluster) - # self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions.add(self.toggle_normalization, shortcut='n') From 57d4baf1f715f6cb1795b75a4b5967377fdc455f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 20:33:21 +0100 Subject: [PATCH 0744/1059] Menu bar in GUI --- phy/gui/actions.py | 12 ++++++++---- phy/gui/gui.py | 15 ++++++++++++++- phy/gui/qt.py | 2 +- phy/gui/tests/test_actions.py | 11 +++++++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 2fde837f9..431bbf42c 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -116,7 +116,7 @@ def _alias(name): @require_qt def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. - action = QAction(name, gui) + action = QAction(name.title(), gui) def wrapped(checked, *args, **kwargs): # pragma: no cover return callback(*args, **kwargs) @@ -147,12 +147,13 @@ def __init__(self, gui, default_shortcuts=None): gui.actions.append(self) def add(self, callback=None, name=None, shortcut=None, alias=None, - verbose=True): + menu=None, verbose=True): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar if callback is None: # Allow to use either add(func) or @add or @add(...). - return partial(self.add, name=name, shortcut=shortcut, alias=alias) + return partial(self.add, name=name, shortcut=shortcut, + alias=alias, menu=menu) assert callback # Get the name from the callback function if needed. @@ -168,11 +169,14 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, # Create and register the action. action = _create_qaction(self.gui, name, callback, shortcut) action_obj = Bunch(qaction=action, name=name, alias=alias, - shortcut=shortcut, callback=callback) + shortcut=shortcut, callback=callback, menu=menu) if verbose and not name.startswith('_'): logger.log(5, "Add action `%s` (%s).", name, _get_shortcut_string(action.shortcut())) self.gui.addAction(action) + # Add the action to the menu. + if menu: + self.gui.get_menu(menu).addAction(action) self._actions_dict[name] = action_obj # Register the alias -> name mapping. self._aliases[alias] = name diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 9cd2f6590..1e8f53373 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -11,7 +11,7 @@ import logging from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, - Qt, QSize, QMetaObject) + Qt, QSize, QMetaObject, QMenuBar,) from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter from phy.utils import load_master_config @@ -136,6 +136,10 @@ def __init__(self, QMainWindow.AllowNestedDocks | QMainWindow.AnimatedDocks ) + + # Mapping {name: menuBar}. + self._menus = {} + # We can derive from EventEmitter because of a conflict with connect. self._event = EventEmitter() @@ -261,6 +265,15 @@ def view_count(self): counts[_title(view)] += 1 return dict(counts) + # Menu bar + # ------------------------------------------------------------------------- + + def get_menu(self, name): + """Return or create a menu.""" + if name not in self._menus: + self._menus[name] = self.menuBar().addMenu(name) + return self._menus[name] + # Status bar # ------------------------------------------------------------------------- diff --git a/phy/gui/qt.py b/phy/gui/qt.py index f8d4d6a75..1a0a66ca4 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -23,7 +23,7 @@ pyqtSignal, pyqtSlot, QSize, QUrl) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, - QMessageBox, QApplication, + QMessageBox, QApplication, QMenuBar, ) from PyQt4.QtWebKit import QWebView, QWebPage, QWebSettings # noqa diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index c029cb726..a56cd42d9 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -99,6 +99,17 @@ def show_my_shortcuts(): # Test actions and snippet #------------------------------------------------------------------------------ +def test_actions_gui_menu(qtbot, gui, actions): + qtbot.addWidget(gui) + + @actions.add(shortcut='g', menu='&File') + def press(): + pass + + gui.show() + # qtbot.stop() + + def test_actions_gui(qtbot, gui, actions): qtbot.addWidget(gui) gui.show() From a56d50eba6a0c94525498a6b858a1d2b244227b9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 20:39:16 +0100 Subject: [PATCH 0745/1059] Add menu in manual clustering component --- phy/cluster/manual/gui_component.py | 36 ++++++++++++++++------------- phy/gui/actions.py | 6 ++++- phy/gui/gui.py | 2 +- phy/gui/tests/test_actions.py | 2 ++ 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 3d420d56a..a20c29634 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -262,34 +262,38 @@ def _create_actions(self, gui): self.actions = Actions(gui, default_shortcuts=self.shortcuts) # Selection. - self.actions.add(self.select, alias='c') + self.actions.add(self.select, alias='c', menu='&Cluster') + self.actions.separator('&Cluster') # Clustering. - self.actions.add(self.merge, alias='g') - self.actions.add(self.split, alias='k') + self.actions.add(self.merge, alias='g', menu='&Cluster') + self.actions.add(self.split, alias='k', menu='&Cluster') + self.actions.separator('&Cluster') # Move. self.actions.add(self.move) for group in ('noise', 'mua', 'good'): self.actions.add(partial(self.move_best, group), - name='move_best_to_' + group) + name='move_best_to_' + group, menu='&Cluster') self.actions.add(partial(self.move_similar, group), - name='move_similar_to_' + group) + name='move_similar_to_' + group, menu='&Cluster') self.actions.add(partial(self.move_all, group), - name='move_all_to_' + group) - - # Wizard. - self.actions.add(self.reset) - self.actions.add(self.next) - self.actions.add(self.previous) - self.actions.add(self.next_best) - self.actions.add(self.previous_best) + name='move_all_to_' + group, menu='&Cluster') + self.actions.separator('&Cluster') # Others. - self.actions.add(self.undo) - self.actions.add(self.redo) - self.actions.add(self.save) + self.actions.add(self.undo, menu='&Cluster') + self.actions.add(self.redo, menu='&Cluster') + self.actions.add(self.save, menu='&Cluster') + + # Wizard. + self.actions.add(self.reset, menu='&Wizard') + self.actions.add(self.next, menu='&Wizard') + self.actions.add(self.previous, menu='&Wizard') + self.actions.add(self.next_best, menu='&Wizard') + self.actions.add(self.previous_best, menu='&Wizard') + self.actions.separator('&Cluster') def _create_cluster_views(self): # Create the cluster view. diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 431bbf42c..b3a3ff394 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -116,7 +116,7 @@ def _alias(name): @require_qt def _create_qaction(gui, name, callback, shortcut): # Create the QAction instance. - action = QAction(name.title(), gui) + action = QAction(name.capitalize().replace('_', ' '), gui) def wrapped(checked, *args, **kwargs): # pragma: no cover return callback(*args, **kwargs) @@ -185,6 +185,10 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, if callback: setattr(self, name, callback) + def separator(self, menu): + """Add a separator""" + self.gui.get_menu(menu).addSeparator() + def disable(self, name=None): """Disable one or all actions.""" if name is None: diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 1e8f53373..ce8db284a 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -11,7 +11,7 @@ import logging from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, - Qt, QSize, QMetaObject, QMenuBar,) + Qt, QSize, QMetaObject) from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter from phy.utils import load_master_config diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index a56cd42d9..c93af619d 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -106,6 +106,8 @@ def test_actions_gui_menu(qtbot, gui, actions): def press(): pass + actions.separator('File') + gui.show() # qtbot.stop() From e9a834f814d8cde783e2179292a419fd88028fb6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 20:46:49 +0100 Subject: [PATCH 0746/1059] Fix menu order --- phy/gui/gui.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index ce8db284a..1966dd1db 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -152,17 +152,18 @@ def __init__(self, # Default actions. self.default_actions = Actions(self) - @self.default_actions.add(shortcut=('HelpContents', 'h')) + @self.default_actions.add(shortcut='ctrl+q', menu='&File') + def exit(): + self.close() + + @self.default_actions.add(shortcut=('HelpContents', 'h'), + menu='&Help') def show_shortcuts(): shortcuts = self.default_actions.shortcuts for actions in self.actions: shortcuts.update(actions.shortcuts) _show_shortcuts(shortcuts, self.name) - @self.default_actions.add(shortcut='ctrl+q') - def exit(): - self.close() - # Create and attach snippets. self.snippets = Snippets(self) From 5ef852a6425b80c9340e59cbba94bdbdf9d5b5f2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 21:08:05 +0100 Subject: [PATCH 0747/1059] Compress log file --- phy/utils/cli.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index bec6dc9c2..58b43a7c7 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -8,6 +8,7 @@ # Imports #------------------------------------------------------------------------------ +import gzip import logging import os import os.path as op @@ -42,9 +43,11 @@ def exceptionHandler(exception_type, exception, traceback): # pragma: no cover def _add_log_file(filename): - """Create a `phy.log` log file with DEBUG level in the + """Create a `phy.log.gz` log file with DEBUG level in the current directory.""" - handler = logging.FileHandler(filename) + log_file = gzip.open(filename, mode='wt', encoding='utf-8') + handler = logging.StreamHandler(log_file) + handler.setLevel(logging.DEBUG) formatter = _Formatter(fmt=_logger_fmt, datefmt='%Y-%m-%d %H:%M:%S') @@ -65,7 +68,7 @@ def phy(ctx): using `attach_to_cli()` and the `click` library.""" # Create a `phy.log` log file with DEBUG level in the current directory. - _add_log_file(op.join(os.getcwd(), 'phy.log')) + _add_log_file(op.join(os.getcwd(), 'phy.log.gz')) #------------------------------------------------------------------------------ From b30c3ec423f8476934d90afbca505fa6d20cbe90 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 21:27:45 +0100 Subject: [PATCH 0748/1059] Emit add_view signal in GUI --- .gitignore | 2 +- phy/gui/gui.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 075cbd602..2e500629a 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,7 @@ wiki .ipynb_checkpoints .*fuse* *.orig -*.log +*.log* .eggs .profile __pycache__ diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 1966dd1db..fc30fffdc 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -105,6 +105,7 @@ class GUI(QMainWindow): close show + add_view Note ---- @@ -244,6 +245,7 @@ def add_view(self, if floating is not None: dockwidget.setFloating(floating) dockwidget.show() + self.emit('add_view', view) return dockwidget def list_views(self, title='', is_visible=True): From 9937ec8241c438cd3090830f6abb95d17609bcc2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 21:46:12 +0100 Subject: [PATCH 0749/1059] Fix Python 2 encoding error in gzip log file --- phy/utils/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 58b43a7c7..68144ed53 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -45,7 +45,7 @@ def exceptionHandler(exception_type, exception, traceback): # pragma: no cover def _add_log_file(filename): """Create a `phy.log.gz` log file with DEBUG level in the current directory.""" - log_file = gzip.open(filename, mode='wt', encoding='utf-8') + log_file = gzip.open(filename, mode='wt') handler = logging.StreamHandler(log_file) handler.setLevel(logging.DEBUG) From 24c2488dd806661697d45dee4744a6a6d1e133ad Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 22:00:55 +0100 Subject: [PATCH 0750/1059] WIP: inverse transforms --- phy/plot/tests/test_transform.py | 14 +++++++++++-- phy/plot/transform.py | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index 9df0eec94..a1bf29290 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -23,7 +23,7 @@ # Fixtures #------------------------------------------------------------------------------ -def _check(transform, array, expected): +def _check_forward(transform, array, expected): transformed = transform.apply(array) if array is None or not len(array): assert transformed == array @@ -35,7 +35,17 @@ def _check(transform, array, expected): if not len(transformed): assert not len(expected) else: - assert np.allclose(transformed, expected) + assert np.allclose(transformed, expected, atol=1e-7) + + +def _check(transform, array, expected): + _check_forward(transform, array, expected) + # Test the inverse transform if it is implemented. + try: + inv = transform.inverse() + _check_forward(inv, expected, array) + except NotImplementedError: + pass #------------------------------------------------------------------------------ diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 0e8213690..b787fe53c 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -54,6 +54,22 @@ def _glslify(r): return 'vec{}({})'.format(len(r), ', '.join(map(str, r))) +def _minus(value): + if isinstance(value, np.ndarray): + return -value + else: + assert len(value) == 2 + return -value[0], -value[1] + + +def _inverse(value): + if isinstance(value, np.ndarray): + return 1. / value + else: + assert len(value) == 2 + return 1. / value[0], 1. / value[1] + + def subplot_bounds(shape=None, index=None): i, j = index n_rows, n_cols = shape @@ -101,6 +117,9 @@ def apply(self, arr): def glsl(self, var): raise NotImplementedError() + def inverse(self): + raise NotImplementedError() + class Translate(BaseTransform): def apply(self, arr): @@ -112,6 +131,12 @@ def glsl(self, var): return """{var} = {var} + {translate};""".format(var=var, translate=self.value) + def inverse(self): + if isinstance(self.value, string_types): + return Translate('-' + self.value) + else: + return Translate(_minus(self.value)) + class Scale(BaseTransform): def apply(self, arr): @@ -121,6 +146,12 @@ def glsl(self, var): assert var return """{var} = {var} * {scale};""".format(var=var, scale=self.value) + def inverse(self): + if isinstance(self.value, string_types): + return Scale('1.0 / ' + self.value) + else: + return Scale(_inverse(self.value)) + class Range(BaseTransform): def __init__(self, from_bounds=None, to_bounds=None): @@ -149,6 +180,10 @@ def glsl(self, var): "({var} - {f}.xy) / ({f}.zw - {f}.xy);" "").format(var=var, f=from_bounds, t=to_bounds) + def inverse(self): + return Range(from_bounds=self.to_bounds, + to_bounds=self.from_bounds) + class Clip(BaseTransform): def __init__(self, bounds=None): From 9908fb4181736bcbde40bcf856f83cd7ae22391e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 22:36:05 +0100 Subject: [PATCH 0751/1059] WIP: u_window_size uniform --- phy/plot/base.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index c61352178..b10a9c417 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -48,9 +48,13 @@ def __init__(self): self.gl_primitive_type = None self.transforms = TransformChain() self.inserter = GLSLInserter() + self.inserter.insert_vert('uniform vec2 u_window_size;', 'header') # The program will be set by the canvas when the visual is # added to the canvas. self.program = None + # Whether u_window_size is used in the shaders. Allow to avoid + # the warning in VisPy when setting an inactive uniform. + self._use_window_size = False # Visual definition # ------------------------------------------------------------------------- @@ -59,6 +63,12 @@ def set_shader(self, name): self.vertex_shader = _load_shader(name + '.vert') self.fragment_shader = _load_shader(name + '.frag') + # HACK: we check whether u_window_size is used in order to avoid + # the VisPy warning. + s = self.vertex_shader + self.fragment_shader + s = s.replace('u_window_size;', '') + self._use_window_size = ('u_window_size' in s) + def set_primitive_type(self, primitive_type): self.gl_primitive_type = primitive_type @@ -73,6 +83,10 @@ def on_draw(self): logger.debug("Skipping drawing visual `%s` because the program " "has not been built yet.", self) + def on_resize(self, size): + if self._use_window_size: + self.program['u_window_size'] = size + # To override # ------------------------------------------------------------------------- @@ -207,13 +221,15 @@ def add_transform_chain(self, tc): def insert_into_shaders(self, vertex, fragment): """Apply the insertions to shader code.""" to_insert = defaultdict(str) - to_insert.update({key: '\n'.join(self._to_insert[key]) + to_insert.update({key: '\n'.join(self._to_insert[key]) + '\n' for key in self._to_insert}) return _insert_glsl(vertex, fragment, to_insert) def __add__(self, inserter): """Concatenate two inserters.""" - self._to_insert.update(inserter._to_insert) + for key, values in self._to_insert.items(): + values.extend([_ for _ in inserter._to_insert[key] + if _ not in values]) return self @@ -268,6 +284,8 @@ def add_visual(self, visual, transforms=None): visual.program = gloo.Program(vs, fs) logger.log(5, "Vertex shader: %s", vs) logger.log(5, "Fragment shader: %s", fs) + # Initialize the size. + visual.on_resize(self.size) # Register the visual in the list of visuals in the canvas. self.visuals.append(visual) self.events.visual_added(visual=visual) @@ -275,6 +293,9 @@ def add_visual(self, visual, transforms=None): def on_resize(self, event): """Resize the OpenGL context.""" self.context.set_viewport(0, 0, event.size[0], event.size[1]) + for visual in self.visuals: + visual.on_resize(event.size) + self.update() def on_draw(self, e): """Draw all visuals.""" From 084d03542b993806ee2e019c49eae510d6dec678 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 23:43:53 +0100 Subject: [PATCH 0752/1059] Improve Subplot transform --- phy/plot/transform.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index b787fe53c..61d35cf2f 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -86,6 +86,14 @@ def subplot_bounds(shape=None, index=None): return [x, y, x + width, y + height] +def subplot_bounds_glsl(shape=None, index=None): + x0 = '-1.0 + 2.0 * {i}.y / {s}.y'.format(s=shape, i=index) + y0 = '+1.0 - 2.0 * ({i}.x + 1) / {s}.x'.format(s=shape, i=index) + x1 = '-1.0 + 2.0 * ({i}.y + 1) / {s}.y'.format(s=shape, i=index) + y1 = '+1.0 - 2.0 * ({i}.x) / {s}.x'.format(s=shape, i=index) + return 'vec4({x0}, {y0}, {x1}, {y1})'.format(x0=x0, y0=y0, x1=x1, y1=y1) + + def pixels_to_ndc(pos, size=None): """Convert from pixels to normalized device coordinates (in [-1, 1]).""" pos = np.asarray(pos, dtype=np.float32) @@ -221,26 +229,10 @@ def __init__(self, shape, index=None): self.from_bounds = NDC if isinstance(self.shape, tuple) and isinstance(self.index, tuple): self.to_bounds = subplot_bounds(shape=self.shape, index=self.index) - - def glsl(self, var): - assert var - - index = _glslify(self.index) - shape = _glslify(self.shape) - - snippet = """ - float subplot_width = 2. / {shape}.y; - float subplot_height = 2. / {shape}.x; - - float subplot_x = -1.0 + {index}.y * subplot_width; - float subplot_y = +1.0 - ({index}.x + 1) * subplot_height; - - {var} = vec2(subplot_x + subplot_width * ({var}.x + 1) * .5, - subplot_y + subplot_height * ({var}.y + 1) * .5); - """.format(index=index, shape=shape, var=var) - - snippet = snippet.format(index=index, shape=shape, var=var) - return snippet + elif (isinstance(self.shape, string_types) and + isinstance(self.index, string_types)): + self.to_bounds = subplot_bounds_glsl(shape=self.shape, + index=self.index) #------------------------------------------------------------------------------ From ed18b624f288f988197516d90482cc591e6e1ffd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 23:51:57 +0100 Subject: [PATCH 0753/1059] Fix u_window_size --- phy/plot/base.py | 15 +++++---------- phy/plot/tests/test_base.py | 4 ++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index b10a9c417..345341ea3 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -52,9 +52,6 @@ def __init__(self): # The program will be set by the canvas when the visual is # added to the canvas. self.program = None - # Whether u_window_size is used in the shaders. Allow to avoid - # the warning in VisPy when setting an inactive uniform. - self._use_window_size = False # Visual definition # ------------------------------------------------------------------------- @@ -63,12 +60,6 @@ def set_shader(self, name): self.vertex_shader = _load_shader(name + '.vert') self.fragment_shader = _load_shader(name + '.frag') - # HACK: we check whether u_window_size is used in order to avoid - # the VisPy warning. - s = self.vertex_shader + self.fragment_shader - s = s.replace('u_window_size;', '') - self._use_window_size = ('u_window_size' in s) - def set_primitive_type(self, primitive_type): self.gl_primitive_type = primitive_type @@ -84,7 +75,11 @@ def on_draw(self): "has not been built yet.", self) def on_resize(self, size): - if self._use_window_size: + # HACK: we check whether u_window_size is used in order to avoid + # the VisPy warning. We only update it if that uniform is active. + s = '\n'.join(self.program.shaders) + s = s.replace('uniform vec2 u_window_size;', '') + if 'u_window_size' in s: self.program['u_window_size'] = size # To override diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 71066057b..8d5c897cc 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -115,8 +115,8 @@ def __init__(self): self.transforms.add_on_cpu(Range((-1, -1, 1, 1), (-1.5, -1.5, 1.5, 1.5), )) - self.inserter.insert_vert('gl_Position.y += 1;', - 'after_transforms') + s = 'gl_Position.y += (1 + 1e-8 * u_window_size.x);' + self.inserter.insert_vert(s, 'after_transforms') def set_data(self): data = np.random.uniform(0, 20, (1000, 2)).astype(np.float32) From 4a151b6e664d6de2a9ccdf7e7224566a5c6e5a40 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 23:57:53 +0100 Subject: [PATCH 0754/1059] Remove scaling in grid interact --- phy/plot/interact.py | 7 ++----- phy/plot/plot.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 18793fa3d..00a37b785 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -35,8 +35,7 @@ class Grid(BaseInteract): """ - _margin_scale = 1 - .075 - _margin_clip = 1 - .025 + _margin_clip = 1 - .035 def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): # Name of the variable with the box index. @@ -46,9 +45,7 @@ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): def attach(self, canvas): super(Grid, self).attach(canvas) - canvas.transforms.add_on_gpu([Scale((self._margin_scale, - self._margin_scale)), - Clip([-self._margin_clip, + canvas.transforms.add_on_gpu([Clip([-self._margin_clip, -self._margin_clip, +self._margin_clip, +self._margin_clip, diff --git a/phy/plot/plot.py b/phy/plot/plot.py index f37f7ebaa..7ad960a70 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -189,7 +189,7 @@ def __init__(self, shape=None, **kwargs): def build(self): n, m = self.grid.shape - a = 1 + 0.03 # margin + a = self.grid._margin_clip for i in range(n): for j in range(m): self[i, j].lines(x0=[-a, +a, +a, -a], From b11abba71bd2002e1906c317ad12da9de2870f6f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 15 Dec 2015 23:58:14 +0100 Subject: [PATCH 0755/1059] Flakify --- phy/plot/interact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 00a37b785..69d96512f 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -11,7 +11,7 @@ from vispy.gloo import Texture2D from .base import BaseInteract -from .transform import Scale, Range, Subplot, Clip, NDC +from .transform import Range, Subplot, Clip, NDC from .utils import _get_texture, _get_boxes, _get_box_pos_size From e3b434a6d380cd95a9e709240f9e2c9b2a33cb3b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 11:35:56 +0100 Subject: [PATCH 0756/1059] Fix grid boxes --- phy/plot/base.py | 18 ++++++++-------- phy/plot/interact.py | 42 ++++++++++++++++++++++++++++++------- phy/plot/plot.py | 15 +++---------- phy/plot/tests/test_base.py | 2 +- phy/plot/transform.py | 9 ++++++++ phy/plot/visuals.py | 2 ++ 6 files changed, 59 insertions(+), 29 deletions(-) diff --git a/phy/plot/base.py b/phy/plot/base.py index 345341ea3..78e33e2f2 100644 --- a/phy/plot/base.py +++ b/phy/plot/base.py @@ -52,6 +52,7 @@ def __init__(self): # The program will be set by the canvas when the visual is # added to the canvas. self.program = None + self.set_canvas_transforms_filter(lambda t: t) # Visual definition # ------------------------------------------------------------------------- @@ -104,6 +105,10 @@ def set_data(self): """ raise NotImplementedError() + def set_canvas_transforms_filter(self, f): + """Set a function filtering the canvas' transforms.""" + self.canvas_transforms_filter = f + #------------------------------------------------------------------------------ # Build program with interacts @@ -250,7 +255,7 @@ def __init__(self, *args, **kwargs): # Enable transparency. _enable_depth_mask() - def add_visual(self, visual, transforms=None): + def add_visual(self, visual): """Add a visual to the canvas, and build its program by the same occasion. @@ -262,14 +267,9 @@ def add_visual(self, visual, transforms=None): inserter = visual.inserter # Add the visual's transforms. inserter.add_transform_chain(visual.transforms) - # Then, add the canvas' transforms... - if transforms is None: - inserter.add_transform_chain(self.transforms) - # or user-specified transforms. - else: - tc = TransformChain() - tc.add_on_gpu(transforms) - inserter.add_transform_chain(tc) + # Then, add the canvas' transforms. + canvas_transforms = visual.canvas_transforms_filter(self.transforms) + inserter.add_transform_chain(canvas_transforms) # Also, add the canvas' inserter. inserter += self.inserter # Now, we insert the transforms GLSL into the shaders. diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 69d96512f..4e215fd23 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -11,8 +11,9 @@ from vispy.gloo import Texture2D from .base import BaseInteract -from .transform import Range, Subplot, Clip, NDC +from .transform import Scale, Range, Subplot, Clip, NDC from .utils import _get_texture, _get_boxes, _get_box_pos_size +from .visuals import LineVisual #------------------------------------------------------------------------------ @@ -35,7 +36,7 @@ class Grid(BaseInteract): """ - _margin_clip = 1 - .035 + margin = .075 def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): # Name of the variable with the box index. @@ -45,11 +46,10 @@ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): def attach(self, canvas): super(Grid, self).attach(canvas) - canvas.transforms.add_on_gpu([Clip([-self._margin_clip, - -self._margin_clip, - +self._margin_clip, - +self._margin_clip, - ]), + ms = 1 - self.margin + mc = 1 - self.margin + canvas.transforms.add_on_gpu([Scale((ms, ms)), + Clip([-mc, -mc, +mc, +mc]), Subplot(self.shape_var, self.box_var), ]) canvas.inserter.insert_vert(""" @@ -58,6 +58,34 @@ def attach(self, canvas): """.format(self.box_var, self.shape_var), 'header') + def add_boxes(self, canvas): + n, m = self.shape + n_boxes = n * m + a = 1 + .05 + + x0 = np.tile([-a, +a, +a, -a], n_boxes) + y0 = np.tile([-a, -a, +a, +a], n_boxes) + x1 = np.tile([+a, +a, -a, -a], n_boxes) + y1 = np.tile([-a, +a, +a, -a], n_boxes) + + box_index = [] + for i in range(n): + for j in range(m): + box_index.append([i, j]) + box_index = np.vstack(box_index) + box_index = np.repeat(box_index, 8, axis=0) + box_index = box_index.astype(np.float32) + + boxes = LineVisual() + + @boxes.set_canvas_transforms_filter + def _remove_clip(tc): + return tc.remove('Clip') + + canvas.add_visual(boxes) + boxes.set_data(x0=x0, y0=y0, x1=x1, y1=y1) + boxes.program['a_box_index'] = box_index + def update_program(self, program): program[self.shape_var] = self._shape # Only set the default box index if necessary. diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 7ad960a70..fae472e2f 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -187,18 +187,9 @@ def __init__(self, shape=None, **kwargs): self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) self.panzoom.attach(self) - def build(self): - n, m = self.grid.shape - a = self.grid._margin_clip - for i in range(n): - for j in range(m): - self[i, j].lines(x0=[-a, +a, +a, -a], - y0=[-a, -a, +a, +a], - x1=[+a, +a, -a, -a], - y1=[-a, +a, +a, -a], - data_bounds=NDC, - ) - super(GridView, self).build() + # NOTE: we need to add the grid boxes after PanZoom so that + # they work with pan & zoom. + self.grid.add_boxes(self) class BoxedView(BaseView): diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 8d5c897cc..2921e62d9 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -133,7 +133,7 @@ def set_data(self): v.set_data() v = TestVisual() - canvas.add_visual(v, transforms=[Subplot((10, 10), (0, 0))]) + canvas.add_visual(v) v.set_data() canvas.show() diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 61d35cf2f..1771df14b 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -266,6 +266,15 @@ def get(self, class_name): if transform.__class__.__name__ == class_name: return transform + def remove(self, class_name): + """Remove a transform in the chain.""" + cpu_transforms = [t for t in self.cpu_transforms + if t.__class__.__name__ != class_name] + gpu_transforms = [t for t in self.gpu_transforms + if t.__class__.__name__ != class_name] + return (TransformChain().add_on_cpu(cpu_transforms). + add_on_gpu(gpu_transforms)) + def apply(self, arr): """Apply all CPU transforms on an array.""" for t in self.cpu_transforms: diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 213785c41..4a3a04269 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -367,6 +367,8 @@ def validate(x0=None, data_bounds=None, ): + # TODO: single argument pos (n, 4) instead of x0 y0 etc. + # Get the number of lines. n_lines = _get_length(x0, y0, x1, y1) x0 = _validate_line_coord(x0, n_lines, -1) From 9b351e4d6404046ce417425440f54cbe82d764bd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 11:53:45 +0100 Subject: [PATCH 0757/1059] Feature scaling --- phy/cluster/manual/tests/test_views.py | 3 +++ phy/cluster/manual/views.py | 37 ++++++++++++++++++++++++-- phy/plot/interact.py | 2 ++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 06c5a7b9f..9146e77f2 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -220,6 +220,9 @@ def best_channels(cluster_id): cluster_ids = np.unique(spike_clusters[spike_ids]) v.on_select(cluster_ids, spike_ids) + v.increase_feature_scaling() + v.decrease_feature_scaling() + # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 6de01a81b..e4450bf9e 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -554,16 +554,27 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(GridView): normalization_percentile = .95 - normalization_n_spikes = 10000 + normalization_n_spikes = 1000 + _feature_scaling = 1. + + default_shortcuts = { + 'increase_feature_scaling': 'ctrl++', + 'decrease_feature_scaling': 'ctrl+-', + } def __init__(self, features=None, masks=None, spike_times=None, spike_clusters=None, + shortcuts=None, keys=None, ): + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + assert len(features.shape) == 3 self.n_spikes, self.n_channels, self.n_features = features.shape self.n_cols = self.n_features + 1 @@ -621,7 +632,7 @@ def _get_feature(self, dim, spike_ids=None): else: assert len(dim) == 2 ch, fet = dim - return f[:, ch, fet] + return f[:, ch, fet] * self._feature_scaling def _get_dim_bounds_single(self, dim): """Return the min and max of the bounds for a single dimension.""" @@ -688,6 +699,9 @@ def on_select(self, cluster_ids, spike_ids): if n_spikes == 0: return + self._cluster_ids = cluster_ids + self._spike_ids = spike_ids + # Get the masks for the selected spikes. masks = self.masks[spike_ids] sc = _get_spike_clusters_rel(self.spike_clusters, @@ -728,6 +742,25 @@ def attach(self, gui): gui.connect_(self.on_select) # gui.connect_(self.on_cluster) + self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions.add(self.increase_feature_scaling) + self.actions.add(self.decrease_feature_scaling) + + def increase_feature_scaling(self): + self.feature_scaling *= 1.2 + + def decrease_feature_scaling(self): + self.feature_scaling /= 1.2 + + @property + def feature_scaling(self): + return self._feature_scaling + + @feature_scaling.setter + def feature_scaling(self, value): + self._feature_scaling = value + self.on_select(self._cluster_ids, self._spike_ids) + # ----------------------------------------------------------------------------- # Correlogram view diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 4e215fd23..69ae61625 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -59,6 +59,8 @@ def attach(self, canvas): 'header') def add_boxes(self, canvas): + if not isinstance(self.shape, tuple): + return n, m = self.shape n_boxes = n * m a = 1 + .05 From 6157b43e7f2541acd7ba96fe8c09268ace891332 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 12:03:10 +0100 Subject: [PATCH 0758/1059] Fix bug in plot where scatter visuals were not collected --- phy/plot/plot.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index fae472e2f..f8348e42f 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -71,10 +71,18 @@ def _accumulate(data_list, no_concat=()): return out +# NOTE: we ensure that we only create every type *once*, so that +# BaseView._items has only one key for any class. +_SCATTER_CLASSES = {} + + def _make_scatter_class(marker): """Return a temporary ScatterVisual class with a given marker.""" - return type('ScatterVisual' + marker.title(), - (ScatterVisual,), {'_default_marker': marker}) + name = 'ScatterVisual' + marker.title() + if name not in _SCATTER_CLASSES: + cls = type(name, (ScatterVisual,), {'_default_marker': marker}) + _SCATTER_CLASSES[name] = cls + return _SCATTER_CLASSES[name] #------------------------------------------------------------------------------ From f047c1f3cee73ecf0f09d395799563d36a83aff0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 12:13:18 +0100 Subject: [PATCH 0759/1059] Increase coverage --- phy/cluster/manual/views.py | 2 ++ phy/plot/interact.py | 8 ++++---- phy/plot/plot.py | 4 ---- phy/plot/tests/test_interact.py | 2 ++ phy/plot/tests/test_transform.py | 2 ++ phy/plot/transform.py | 11 ++++++----- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e4450bf9e..75c86f3ee 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -729,6 +729,8 @@ def on_select(self, cluster_ids, spike_ids): masks=masks, spike_clusters_rel=sc, ) + # Add the boxes. + self.grid.add_boxes(self, self.shape) def attach(self, gui): """Attach the view to the GUI.""" diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 69ae61625..be795c42f 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -58,10 +58,10 @@ def attach(self, canvas): """.format(self.box_var, self.shape_var), 'header') - def add_boxes(self, canvas): - if not isinstance(self.shape, tuple): - return - n, m = self.shape + def add_boxes(self, canvas, shape=None): + shape = shape or self.shape + assert isinstance(shape, tuple) + n, m = shape n_boxes = n * m a = 1 + .05 diff --git a/phy/plot/plot.py b/phy/plot/plot.py index f8348e42f..4a7ed6a15 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -195,10 +195,6 @@ def __init__(self, shape=None, **kwargs): self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) self.panzoom.attach(self) - # NOTE: we need to add the grid boxes after PanZoom so that - # they work with pan & zoom. - self.grid.add_boxes(self) - class BoxedView(BaseView): """Subplots at arbitrary positions""" diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index c1fda6e1a..414a83472 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -84,6 +84,8 @@ def test_grid_1(qtbot, canvas): grid = Grid((2, 3)) _create_visual(qtbot, canvas, grid, box_index) + grid.add_boxes(canvas) + # qtbot.stop() diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index a1bf29290..ae31ccf0a 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -226,6 +226,8 @@ def test_transform_chain_complete(array): ae(t.apply(array), [[0, .5], [1, 1.5]]) + assert len(t.remove('Scale').cpu_transforms) == len(t.cpu_transforms) - 2 + def test_transform_chain_add(): tc = TransformChain() diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 1771df14b..5e4a055b1 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -266,12 +266,13 @@ def get(self, class_name): if transform.__class__.__name__ == class_name: return transform - def remove(self, class_name): + def _remove_transform(self, transforms, name): + return [t for t in transforms if t.__class__.__name__ != name] + + def remove(self, name): """Remove a transform in the chain.""" - cpu_transforms = [t for t in self.cpu_transforms - if t.__class__.__name__ != class_name] - gpu_transforms = [t for t in self.gpu_transforms - if t.__class__.__name__ != class_name] + cpu_transforms = self._remove_transform(self.cpu_transforms, name) + gpu_transforms = self._remove_transform(self.gpu_transforms, name) return (TransformChain().add_on_cpu(cpu_transforms). add_on_gpu(gpu_transforms)) From 545c8047384d99351dbaaa47dce44d5455718b1b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 14:41:33 +0100 Subject: [PATCH 0760/1059] Refactor view layout interface --- docs/plot.md | 16 +++---- phy/plot/__init__.py | 2 +- phy/plot/plot.py | 83 ++++++++++++------------------------- phy/plot/tests/test_plot.py | 20 ++++----- 4 files changed, 47 insertions(+), 74 deletions(-) diff --git a/docs/plot.md b/docs/plot.md index 71e6de8cc..bd11dfa5b 100644 --- a/docs/plot.md +++ b/docs/plot.md @@ -14,11 +14,11 @@ Let's create a simple view with a scatter plot. ```python >>> import numpy as np ->>> from phy.plot import SimpleView, GridView, BoxedView, StackedView +>>> from phy.plot import View ``` ```python ->>> view = SimpleView() +>>> view = View() ... >>> n = 1000 >>> x, y = np.random.randn(2, n) @@ -40,12 +40,14 @@ Note that you can pan and zoom with the mouse and keyboard. The other plotting commands currently supported are `plot()` and `hist()`. We're planning to add support for text in the near future. +Several layouts are supported for subplots. + ## Grid view -The `GridView` lets you create multiple subplots arranged in a grid. Subplots are all individually clipped, which means that their viewports never overlap across the grid boundaries. Here is an example: +The Grid view lets you create multiple subplots arranged in a grid (like in matplotlib). Subplots are all individually clipped, which means that their viewports never overlap across the grid boundaries. Here is an example: ```python ->>> view = GridView((1, 2)) # the shape is `(n_rows, n_cols)` +>>> view = View(layout='grid', shape=(1, 2)) # the shape is `(n_rows, n_cols)` ... >>> x = np.linspace(-10., 10., 1000) ... @@ -65,7 +67,7 @@ Note that there are no axes at this point, but we'll be working on it. Also, ind The stacked view lets you stack several subplots vertically with no clipping. An example is a trace view showing a multichannel time-dependent signal. ```python ->>> view = StackedView(50) +>>> view = View(layout='stacked', n_plots=50) ... >>> with view.building(): ... for i in range(view.n_plots): @@ -87,7 +89,7 @@ The boxed view lets you put subplots at arbitrary locations. You can dynamically >>> y = np.sin(t) >>> box_pos = np.c_[x, y] ... ->>> view = BoxedView(box_pos=box_pos) +>>> view = View(layout='boxed', box_pos=box_pos) ... >>> with view.building(): ... for i in range(view.n_plots): @@ -105,7 +107,7 @@ You can use `ctrl+arrows` and `shift+arrows` to change the scaling of the positi Data normalization is supported via the `data_bounds` keyword. This is a 4-tuple `(xmin, ymin, xmax, ymax)` with the coordinates of the viewport in the data coordinate system. By default, this is obtained with the min and max of the data. Here is an example: ```python ->>> view = StackedView(2) +>>> view = View(layout='stacked', n_plots=2) ... >>> n = 100 >>> x = np.linspace(0., 1., n) diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 8554d8e0f..e5d3c1f12 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -12,7 +12,7 @@ from vispy import config -from .plot import SimpleView, GridView, BoxedView, StackedView # noqa +from .plot import View # noqa from .transform import Translate, Scale, Range, Subplot, NDC from .panzoom import PanZoom from .utils import _get_linear_x diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 4a7ed6a15..6899265de 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -72,7 +72,7 @@ def _accumulate(data_list, no_concat=()): # NOTE: we ensure that we only create every type *once*, so that -# BaseView._items has only one key for any class. +# View._items has only one key for any class. _SCATTER_CLASSES = {} @@ -89,14 +89,37 @@ def _make_scatter_class(marker): # Plotting interface #------------------------------------------------------------------------------ -class BaseView(BaseCanvas): +class View(BaseCanvas): """High-level plotting canvas.""" _default_box_index = (0,) - def __init__(self, **kwargs): + def __init__(self, layout=None, shape=None, n_plots=None, + box_bounds=None, box_pos=None, box_size=None, **kwargs): if not kwargs.get('keys', None): kwargs['keys'] = None - super(BaseView, self).__init__(**kwargs) + super(View, self).__init__(**kwargs) + + if layout == 'grid': + self._default_box_index = (0, 0) + self.grid = Grid(shape) + self.grid.attach(self) + + elif layout == 'boxed': + self.n_plots = (len(box_bounds) + if box_bounds is not None else len(box_pos)) + self.boxed = Boxed(box_bounds=box_bounds, + box_pos=box_pos, + box_size=box_size) + self.boxed.attach(self) + + elif layout == 'stacked': + self.n_plots = n_plots + self.stacked = Stacked(n_plots, margin=.1) + self.stacked.attach(self) + + self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) + self.panzoom.attach(self) + self.clear() def clear(self): @@ -171,55 +194,3 @@ def building(self): self.clear() yield self.build() - - -class SimpleView(BaseView): - """A simple view.""" - def __init__(self, shape=None, **kwargs): - super(SimpleView, self).__init__(**kwargs) - - self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) - self.panzoom.attach(self) - - -class GridView(BaseView): - """A 2D grid with clipping.""" - _default_box_index = (0, 0) - - def __init__(self, shape=None, **kwargs): - super(GridView, self).__init__(**kwargs) - - self.grid = Grid(shape) - self.grid.attach(self) - - self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) - self.panzoom.attach(self) - - -class BoxedView(BaseView): - """Subplots at arbitrary positions""" - def __init__(self, box_bounds=None, box_pos=None, box_size=None, **kwargs): - super(BoxedView, self).__init__(**kwargs) - self.n_plots = (len(box_bounds) - if box_bounds is not None else len(box_pos)) - - self.boxed = Boxed(box_bounds=box_bounds, - box_pos=box_pos, - box_size=box_size) - self.boxed.attach(self) - - self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) - self.panzoom.attach(self) - - -class StackedView(BaseView): - """Stacked subplots""" - def __init__(self, n_plots, **kwargs): - super(StackedView, self).__init__(**kwargs) - self.n_plots = n_plots - - self.stacked = Stacked(n_plots, margin=.1) - self.stacked.attach(self) - - self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) - self.panzoom.attach(self) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 125ae6fbe..1a92cfe2c 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -10,7 +10,7 @@ import numpy as np from ..panzoom import PanZoom -from ..plot import BaseView, SimpleView, GridView, BoxedView, StackedView +from ..plot import View from ..utils import _get_linear_x @@ -32,7 +32,7 @@ def _show(qtbot, view, stop=False): #------------------------------------------------------------------------------ def test_building(qtbot): - view = BaseView(keys='interactive') + view = View(keys='interactive') n = 1000 x = np.random.randn(n) @@ -47,7 +47,7 @@ def test_building(qtbot): def test_simple_view(qtbot): - view = SimpleView() + view = View() n = 1000 x = np.random.randn(n) @@ -58,7 +58,7 @@ def test_simple_view(qtbot): def test_grid_scatter(qtbot): - view = GridView((2, 3)) + view = View(layout='grid', shape=(2, 3)) n = 100 assert isinstance(view.panzoom, PanZoom) @@ -84,7 +84,7 @@ def test_grid_scatter(qtbot): def test_grid_plot(qtbot): - view = GridView((1, 2)) + view = View(layout='grid', shape=(1, 2)) n_plots, n_samples = 5, 50 x = _get_linear_x(n_plots, n_samples) @@ -97,7 +97,7 @@ def test_grid_plot(qtbot): def test_grid_hist(qtbot): - view = GridView((3, 3)) + view = View(layout='grid', shape=(3, 3)) hist = np.random.rand(3, 3, 20) @@ -110,7 +110,7 @@ def test_grid_hist(qtbot): def test_grid_lines(qtbot): - view = GridView((1, 2)) + view = View(layout='grid', shape=(1, 2)) view[0, 0].lines(y0=-.5, y1=-.5) view[0, 1].lines(y0=+.5, y1=+.5) @@ -119,7 +119,7 @@ def test_grid_lines(qtbot): def test_grid_complete(qtbot): - view = GridView((2, 2)) + view = View(layout='grid', shape=(2, 2)) t = _get_linear_x(1, 1000).ravel() view[0, 0].scatter(*np.random.randn(2, 100)) @@ -132,7 +132,7 @@ def test_grid_complete(qtbot): def test_stacked_complete(qtbot): - view = StackedView(3) + view = View(layout='stacked', n_plots=3) t = _get_linear_x(1, 1000).ravel() view[0].scatter(*np.random.randn(2, 100)) @@ -154,7 +154,7 @@ def test_boxed_complete(qtbot): b = np.zeros((n, 4)) b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 2. / 3., n) b[:, 2] = b[:, 3] = np.linspace(-1. + 2. / 3., 1., n) - view = BoxedView(b) + view = View(layout='boxed', box_bounds=b) t = _get_linear_x(1, 1000).ravel() view[0].scatter(*np.random.randn(2, 100)) From 8c8e8117e04567cd5e1f7ea1c5aeb332cbacad4e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 14:45:10 +0100 Subject: [PATCH 0761/1059] Update views --- phy/cluster/manual/views.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 75c86f3ee..d08f284f8 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -14,8 +14,7 @@ from phy.io.array import _index_of, _get_padded, get_excerpts from phy.gui import Actions -from phy.plot import (BoxedView, StackedView, GridView, - _get_linear_x) +from phy.plot import View, _get_linear_x from phy.plot.utils import _get_boxes from phy.stats import correlograms @@ -128,7 +127,7 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): # Waveform view # ----------------------------------------------------------------------------- -class WaveformView(BoxedView): +class WaveformView(View): normalization_percentile = .95 normalization_n_spikes = 1000 overlap = False @@ -161,7 +160,10 @@ def __init__(self, # Initialize the view. box_bounds = _get_boxes(channel_positions) - super(WaveformView, self).__init__(box_bounds, keys=keys) + super(WaveformView, self).__init__(layout='boxed', + box_bounds=box_bounds, + keys=keys, + ) # Waveforms. assert waveforms.ndim == 3 @@ -252,7 +254,7 @@ def toggle_waveform_overlap(self): # Trace view # ----------------------------------------------------------------------------- -class TraceView(StackedView): +class TraceView(View): interval_duration = .5 # default duration of the interval shift_amount = .1 default_shortcuts = { @@ -310,7 +312,10 @@ def __init__(self, self.spike_times = self.spike_clusters = self.masks = None # Initialize the view. - super(TraceView, self).__init__(self.n_channels, keys=keys) + super(TraceView, self).__init__(layout='stacked', + n_plots=self.n_channels, + keys=keys, + ) # Initial interval. self.cluster_ids = [] @@ -552,7 +557,7 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): return m, d -class FeatureView(GridView): +class FeatureView(View): normalization_percentile = .95 normalization_n_spikes = 1000 _feature_scaling = 1. @@ -582,7 +587,10 @@ def __init__(self, self.features = features # Initialize the view. - super(FeatureView, self).__init__(self.shape, keys=keys) + super(FeatureView, self).__init__(layout='grid', + shape=self.shape, + keys=keys, + ) # Feature normalization. self.data_bounds = _get_data_bounds(features, @@ -768,7 +776,7 @@ def feature_scaling(self, value): # Correlogram view # ----------------------------------------------------------------------------- -class CorrelogramView(GridView): +class CorrelogramView(View): excerpt_size = 10000 n_excerpts = 100 uniform_normalization = False @@ -812,7 +820,10 @@ def __init__(self, self.n_spikes, = self.spike_times.shape # Initialize the view. - super(CorrelogramView, self).__init__(keys=keys) + super(CorrelogramView, self).__init__(layout='grid', + shape=(1, 1), + keys=keys, + ) # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) From dcb0ea6282fad61958e12b46af57331b4f7fb23e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 15:03:37 +0100 Subject: [PATCH 0762/1059] Refactor views --- phy/cluster/manual/views.py | 182 +++++++++++++++--------------------- 1 file changed, 73 insertions(+), 109 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d08f284f8..1cdd1b08b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -123,11 +123,50 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): return color +# ----------------------------------------------------------------------------- +# Manual clustering view +# ----------------------------------------------------------------------------- + +class ManualClusteringView(View): + default_shortcuts = { + } + + def __init__(self, shortcuts=None, **kwargs): + + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + + self.cluster_ids = None + self.spike_ids = None + + super(ManualClusteringView, self).__init__(**kwargs) + + def on_select(self, cluster_ids=None, spike_ids=None): + cluster_ids = (cluster_ids if cluster_ids is not None + else self.cluster_ids) + spike_ids = (spike_ids if spike_ids is not None + else self.spike_ids) + self.cluster_ids = list(cluster_ids) + self.spike_ids = spike_ids + + def attach(self, gui): + """Attach the view to the GUI.""" + + # Disable keyboard pan so that we can use arrows as global shortcuts + # in the GUI. + self.panzoom.enable_keyboard_pan = False + + gui.add_view(self) + gui.connect_(self.on_select) + self.actions = Actions(gui, default_shortcuts=self.shortcuts) + + # ----------------------------------------------------------------------------- # Waveform view # ----------------------------------------------------------------------------- -class WaveformView(View): +class WaveformView(ManualClusteringView): normalization_percentile = .95 normalization_n_spikes = 1000 overlap = False @@ -141,29 +180,18 @@ def __init__(self, masks=None, spike_clusters=None, channel_positions=None, - shortcuts=None, - keys=None, - ): + **kwargs): """ The channel order in waveforms needs to correspond to the one in channel_positions. """ - - # Load default shortcuts, and override with any user shortcuts. - self.shortcuts = self.default_shortcuts.copy() - self.shortcuts.update(shortcuts or {}) - - self._cluster_ids = None - self._spike_ids = None - # Initialize the view. box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(layout='boxed', box_bounds=box_bounds, - keys=keys, - ) + **kwargs) # Waveforms. assert waveforms.ndim == 3 @@ -186,15 +214,15 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - def on_select(self, cluster_ids, spike_ids): + def on_select(self, cluster_ids=None, spike_ids=None): + super(WaveformView, self).on_select(cluster_ids=cluster_ids, + spike_ids=spike_ids) + cluster_ids, spike_ids = self.cluster_ids, self.spike_ids n_clusters = len(cluster_ids) n_spikes = len(spike_ids) if n_spikes == 0: return - self._cluster_ids = cluster_ids - self._spike_ids = spike_ids - # Relative spike clusters. spike_clusters_rel = _get_spike_clusters_rel(self.spike_clusters, spike_ids, @@ -232,29 +260,19 @@ def on_select(self, cluster_ids, spike_ids): def attach(self, gui): """Attach the view to the GUI.""" - - # Disable keyboard pan so that we can use arrows as global shortcuts - # in the GUI. - self.panzoom.enable_keyboard_pan = False - - gui.add_view(self) - - gui.connect_(self.on_select) - # gui.connect_(self.on_cluster) - - self.actions = Actions(gui, default_shortcuts=self.shortcuts) + super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap) def toggle_waveform_overlap(self): self.overlap = not self.overlap - self.on_select(self._cluster_ids, self._spike_ids) + self.on_select() # ----------------------------------------------------------------------------- # Trace view # ----------------------------------------------------------------------------- -class TraceView(View): +class TraceView(ManualClusteringView): interval_duration = .5 # default duration of the interval shift_amount = .1 default_shortcuts = { @@ -269,13 +287,7 @@ def __init__(self, spike_clusters=None, masks=None, n_samples_per_spike=None, - shortcuts=None, - keys=None, - ): - - # Load default shortcuts, and override with any user shortcuts. - self.shortcuts = self.default_shortcuts.copy() - self.shortcuts.update(shortcuts or {}) + **kwargs): # Sample rate. assert sample_rate > 0 @@ -314,11 +326,9 @@ def __init__(self, # Initialize the view. super(TraceView, self).__init__(layout='stacked', n_plots=self.n_channels, - keys=keys, - ) + **kwargs) # Initial interval. - self.cluster_ids = [] self.set_interval((0., self.interval_duration)) def _load_traces(self, interval): @@ -446,23 +456,14 @@ def set_interval(self, interval): self.build() self.update() - def on_select(self, cluster_ids, spike_ids): - self.cluster_ids = list(cluster_ids) + def on_select(self, cluster_ids=None, spike_ids=None): + super(TraceView, self).on_select(cluster_ids=cluster_ids, + spike_ids=spike_ids) self.set_interval(self.interval) def attach(self, gui): """Attach the view to the GUI.""" - - # Disable keyboard pan so that we can use arrows as global shortcuts - # in the GUI. - self.panzoom.enable_keyboard_pan = False - - gui.add_view(self) - - gui.connect_(self.on_select) - # gui.connect_(self.on_cluster) - - self.actions = Actions(gui, default_shortcuts=self.shortcuts) + super(TraceView, self).attach(gui) self.actions.add(self.go_to, alias='tg') self.actions.add(self.shift, alias='ts') self.actions.add(self.go_right) @@ -557,7 +558,7 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): return m, d -class FeatureView(View): +class FeatureView(ManualClusteringView): normalization_percentile = .95 normalization_n_spikes = 1000 _feature_scaling = 1. @@ -572,13 +573,7 @@ def __init__(self, masks=None, spike_times=None, spike_clusters=None, - shortcuts=None, - keys=None, - ): - - # Load default shortcuts, and override with any user shortcuts. - self.shortcuts = self.default_shortcuts.copy() - self.shortcuts.update(shortcuts or {}) + **kwargs): assert len(features.shape) == 3 self.n_spikes, self.n_channels, self.n_features = features.shape @@ -589,8 +584,7 @@ def __init__(self, # Initialize the view. super(FeatureView, self).__init__(layout='grid', shape=self.shape, - keys=keys, - ) + **kwargs) # Feature normalization. self.data_bounds = _get_data_bounds(features, @@ -702,14 +696,14 @@ def set_best_channels_func(self, func): """Set a function `cluster_id => list of best channels`.""" self.best_channels_func = func - def on_select(self, cluster_ids, spike_ids): + def on_select(self, cluster_ids=None, spike_ids=None): + super(FeatureView, self).on_select(cluster_ids=cluster_ids, + spike_ids=spike_ids) + cluster_ids, spike_ids = self.cluster_ids, self.spike_ids n_spikes = len(spike_ids) if n_spikes == 0: return - self._cluster_ids = cluster_ids - self._spike_ids = spike_ids - # Get the masks for the selected spikes. masks = self.masks[spike_ids] sc = _get_spike_clusters_rel(self.spike_clusters, @@ -742,17 +736,7 @@ def on_select(self, cluster_ids, spike_ids): def attach(self, gui): """Attach the view to the GUI.""" - - # Disable keyboard pan so that we can use arrows as global shortcuts - # in the GUI. - self.panzoom.enable_keyboard_pan = False - - gui.add_view(self) - - gui.connect_(self.on_select) - # gui.connect_(self.on_cluster) - - self.actions = Actions(gui, default_shortcuts=self.shortcuts) + super(FeatureView, self).attach(gui) self.actions.add(self.increase_feature_scaling) self.actions.add(self.decrease_feature_scaling) @@ -769,14 +753,14 @@ def feature_scaling(self): @feature_scaling.setter def feature_scaling(self, value): self._feature_scaling = value - self.on_select(self._cluster_ids, self._spike_ids) + self.on_select() # ----------------------------------------------------------------------------- # Correlogram view # ----------------------------------------------------------------------------- -class CorrelogramView(View): +class CorrelogramView(ManualClusteringView): excerpt_size = 10000 n_excerpts = 100 uniform_normalization = False @@ -793,16 +777,7 @@ def __init__(self, window_size=None, excerpt_size=None, n_excerpts=None, - shortcuts=None, - keys=None, - ): - - # Load default shortcuts, and override with any user shortcuts. - self.shortcuts = self.default_shortcuts.copy() - self.shortcuts.update(shortcuts or {}) - - self._cluster_ids = None - self._spike_ids = None + **kwargs): assert sample_rate > 0 self.sample_rate = sample_rate @@ -822,8 +797,7 @@ def __init__(self, # Initialize the view. super(CorrelogramView, self).__init__(layout='grid', shape=(1, 1), - keys=keys, - ) + **kwargs) # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) @@ -858,10 +832,10 @@ def _compute_correlograms(self, cluster_ids): return ccg - def on_select(self, cluster_ids, spike_ids): - self._cluster_ids = cluster_ids - self._spike_ids = spike_ids - + def on_select(self, cluster_ids=None, spike_ids=None): + super(CorrelogramView, self).on_select(cluster_ids=cluster_ids, + spike_ids=spike_ids) + cluster_ids, spike_ids = self.cluster_ids, self.spike_ids n_clusters = len(cluster_ids) n_spikes = len(spike_ids) if n_spikes == 0: @@ -886,19 +860,9 @@ def on_select(self, cluster_ids, spike_ids): def toggle_normalization(self): self.uniform_normalization = not self.uniform_normalization - self.on_select(self._cluster_ids, self._spike_ids) + self.on_select() def attach(self, gui): """Attach the view to the GUI.""" - - # Disable keyboard pan so that we can use arrows as global shortcuts - # in the GUI. - self.panzoom.enable_keyboard_pan = False - - gui.add_view(self) - - gui.connect_(self.on_select) - # gui.connect_(self.on_cluster) - - self.actions = Actions(gui, default_shortcuts=self.shortcuts) + super(CorrelogramView, self).attach(gui) self.actions.add(self.toggle_normalization, shortcut='n') From e7443c05f330c6bda058b4e15474dcd5305c58f3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 15:18:57 +0100 Subject: [PATCH 0763/1059] Minor fixes --- phy/gui/gui.py | 2 +- phy/utils/cli.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index fc30fffdc..172474e27 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -73,7 +73,7 @@ def load_gui_plugins(gui, plugins=None, session=None): # Attach the plugins to the GUI. for plugin in plugins: - logger.info("Attach plugin `%s` to %s.", plugin, name) + logger.debug("Attach plugin `%s` to %s.", plugin, name) get_plugin(plugin)().attach_to_gui(gui, session) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 68144ed53..7891df690 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -8,7 +8,6 @@ # Imports #------------------------------------------------------------------------------ -import gzip import logging import os import os.path as op @@ -43,10 +42,9 @@ def exceptionHandler(exception_type, exception, traceback): # pragma: no cover def _add_log_file(filename): - """Create a `phy.log.gz` log file with DEBUG level in the + """Create a `phy.log` log file with DEBUG level in the current directory.""" - log_file = gzip.open(filename, mode='wt') - handler = logging.StreamHandler(log_file) + handler = logging.FileHandler(filename) handler.setLevel(logging.DEBUG) formatter = _Formatter(fmt=_logger_fmt, @@ -68,7 +66,7 @@ def phy(ctx): using `attach_to_cli()` and the `click` library.""" # Create a `phy.log` log file with DEBUG level in the current directory. - _add_log_file(op.join(os.getcwd(), 'phy.log.gz')) + _add_log_file(op.join(os.getcwd(), 'phy.log')) #------------------------------------------------------------------------------ @@ -87,7 +85,7 @@ def load_cli_plugins(cli): for plugin in plugins: if not hasattr(plugin, 'attach_to_cli'): # pragma: no cover continue - logger.info("Attach plugin `%s` to CLI.", plugin.__name__) + logger.debug("Attach plugin `%s` to CLI.", plugin.__name__) # NOTE: plugin is a class, so we need to instantiate it. plugin().attach_to_cli(cli) From 7a97d5c8402ab5d755f4b372c4b61985ae93fe9c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 15:30:10 +0100 Subject: [PATCH 0764/1059] Add grouped_mean() function --- phy/io/array.py | 23 +++++++++++++++++++++++ phy/io/tests/test_array.py | 7 +++++++ 2 files changed, 30 insertions(+) diff --git a/phy/io/array.py b/phy/io/array.py index 42c2628af..b2d1d1700 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -358,6 +358,29 @@ def _flatten_per_cluster(per_cluster): return np.sort(np.concatenate(list(per_cluster.values()))).astype(np.int64) +def grouped_mean(arr, spike_clusters): + """Compute the mean of a spike-dependent quantity for every cluster. + + The two arguments should be 1D array with `n_spikes` elements. + + The output is a 1D array with `n_clusters` elements. The clusters are + sorted in increasing order. + + """ + arr = np.asarray(arr) + spike_clusters = np.asarray(spike_clusters) + assert arr.ndim == 1 + assert arr.shape[0] == len(spike_clusters) + cluster_ids = _unique(spike_clusters) + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + spike_counts = np.bincount(spike_clusters_rel) + assert len(spike_counts) == len(cluster_ids) + t = np.zeros(len(cluster_ids)) + # Compute the sum with possible repetitions. + np.add.at(t, spike_clusters_rel, arr) + return t / spike_counts + + def regular_subset(spikes, n_spikes_max=None, offset=0): """Prune the current selection to get at most n_spikes_max spikes.""" assert spikes is not None diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 634bf3c2d..b4d2fa9de 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -23,6 +23,7 @@ regular_subset, excerpts, data_chunk, + grouped_mean, get_excerpts, _range_from_slice, _pad, @@ -344,6 +345,12 @@ def test_flatten_per_cluster(): ae(arr, [2, 3, 5, 7, 11]) +def test_grouped_mean(): + spike_clusters = np.array([2, 3, 2, 2, 5]) + arr = spike_clusters * 10 + ae(grouped_mean(arr, spike_clusters), [20, 30, 50]) + + def test_select_spikes(): with raises(AssertionError): select_spikes() From 841ab95cda11d41c5e7bfdc17b86db54ad74ebdd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 17:13:29 +0100 Subject: [PATCH 0765/1059] WIP: GUIState and create_gui() function --- phy/gui/__init__.py | 2 +- phy/gui/gui.py | 73 ++++++++++++++++++++++++--------------- phy/gui/tests/test_gui.py | 29 +++++++++++----- 3 files changed, 68 insertions(+), 36 deletions(-) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index 1e7576a2e..afbbbc979 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -4,6 +4,6 @@ """GUI routines.""" from .qt import require_qt, create_app, run_app -from .gui import GUI, load_gui_plugins +from .gui import GUI, GUIState, create_gui from .actions import Actions from .widgets import HTMLWidget, Table diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 172474e27..7f29a81a0 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -14,7 +14,7 @@ Qt, QSize, QMetaObject) from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter -from phy.utils import load_master_config +from phy.utils import load_master_config, Bunch, _load_json, _save_json from phy.utils.plugin import get_plugin logger = logging.getLogger(__name__) @@ -51,32 +51,6 @@ def _try_get_matplotlib_canvas(view): return view -def load_gui_plugins(gui, plugins=None, session=None): - """Attach a list of plugins to a GUI. - - By default, the list of plugins is taken from the `c.TheGUI.plugins` - parameter, where `TheGUI` is the name of the GUI class. - - """ - session = session or {} - plugins = plugins or [] - - # GUI name. - name = gui.name - - # If no plugins are specified, load the master config and - # get the list of user plugins to attach to the GUI. - config = load_master_config() - plugins_conf = config[name].plugins - plugins_conf = plugins_conf if isinstance(plugins_conf, list) else [] - plugins.extend(plugins_conf) - - # Attach the plugins to the GUI. - for plugin in plugins: - logger.debug("Attach plugin `%s` to %s.", plugin, name) - get_plugin(plugin)().attach_to_gui(gui, session) - - class DockWidget(QDockWidget): """A QDockWidget that can emit events.""" def __init__(self, *args, **kwargs): @@ -322,3 +296,48 @@ def restore_geometry_state(self, gs): self.restoreGeometry((gs['geometry'])) if gs.get('state', None): self.restoreState((gs['state'])) + + +# ----------------------------------------------------------------------------- +# GUI state, creator, plugins +# ----------------------------------------------------------------------------- + +class GUIState(Bunch): + def __init__(self, geometry_state=None, plugins=None, **kwargs): + super(GUIState, self).__init__(geomety_state=geometry_state, + plugins=plugins or [], + **kwargs) + + def to_json(self, filename): + _save_json(filename, self) + + def from_json(self, filename): + self.update(_load_json(filename)) + + +def create_gui(model=None, state=None): + """Create a GUI with a model and a GUI state. + + By default, the list of plugins is taken from the `c.TheGUI.plugins` + parameter, where `TheGUI` is the name of the GUI class. + + """ + gui = GUI() + state = state or GUIState() + plugins = state.plugins + # GUI name. + name = gui.name + + # If no plugins are specified, load the master config and + # get the list of user plugins to attach to the GUI. + config = load_master_config() + plugins_conf = config[name].plugins + plugins_conf = plugins_conf if isinstance(plugins_conf, list) else [] + plugins.extend(plugins_conf) + + # Attach the plugins to the GUI. + for plugin in plugins: + logger.debug("Attach plugin `%s` to %s.", plugin, name) + get_plugin(plugin)().attach_to_gui(gui, state=state, model=model) + + return gui diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 839d62cc9..9fa585d95 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,10 +6,13 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op + from pytest import raises from ..qt import Qt, QApplication, QWidget -from ..gui import (GUI, load_gui_plugins, +from ..gui import (GUI, GUIState, + create_gui, _try_get_matplotlib_canvas, _try_get_vispy_canvas, ) @@ -93,19 +96,29 @@ def on_close_widget(): gui.default_actions.exit() -def test_load_gui_plugins(gui, tempdir): +def test_gui_state(tempdir): + path = op.join(tempdir, 'state.json') + + state = GUIState(hello='world') + state.to_json(path) + + state = GUIState() + state.from_json(path) + assert state.hello == 'world' + - load_gui_plugins(gui) +def test_create_gui_1(qapp, tempdir): _tmp = [] class MyPlugin(IPlugin): - def attach_to_gui(self, gui, session): - _tmp.append(session) + def attach_to_gui(self, gui, model=None, state=None): + _tmp.append(state.hello) - load_gui_plugins(gui, plugins=['MyPlugin'], session='hello') + gui = create_gui(state=GUIState(plugins=['MyPlugin'], hello='world')) + assert gui - assert _tmp == ['hello'] + assert _tmp == ['world'] def test_gui_component(gui): @@ -130,7 +143,7 @@ def test_gui_status_message(gui): assert gui.status_message == ':hello world!' -def test_gui_state(qtbot): +def test_gui_geometry_state(qtbot): _gs = [] gui = GUI(size=(100, 100)) qtbot.addWidget(gui) From bb4a5efca392592f67f549d7b1c6de806dbf4bb3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 17:41:00 +0100 Subject: [PATCH 0766/1059] Add ContextPlugin --- phy/io/context.py | 9 ++++++++- phy/io/tests/test_context.py | 10 +++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index b4a70aba9..e70d21186 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -23,7 +23,8 @@ "Install it with `conda install dask`.") from .array import read_array, write_array -from phy.utils import Bunch, _save_json, _load_json, _ensure_dir_exists +from phy.utils import (Bunch, _save_json, _load_json, _ensure_dir_exists, + IPlugin,) logger = logging.getLogger(__name__) @@ -278,6 +279,12 @@ def __setstate__(self, state): self._set_memory(state['cache_dir']) +class ContextPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + # Create the computing context. + gui.context = Context(op.join(op.dirname(model.path), '.phy/')) + + #------------------------------------------------------------------------------ # Task #------------------------------------------------------------------------------ diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 5578a3156..207b86e83 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -14,9 +14,10 @@ from pytest import yield_fixture, mark, raises from six.moves import cPickle -from ..context import (Context, Task, +from ..context import (Context, ContextPlugin, Task, _iter_chunks_dask, write_array, read_array, ) +from phy.utils import Bunch #------------------------------------------------------------------------------ @@ -122,6 +123,13 @@ def test_pickle_cache(tempdir, parallel_context): assert ctx.cache_dir == parallel_context.cache_dir +def test_context_plugin(tempdir): + gui = Bunch() + path = op.join(tempdir, 'model.ext') + ContextPlugin().attach_to_gui(gui, model=Bunch(path=path), state=Bunch()) + assert gui.context.cache_dir == op.dirname(path) + '/.phy' + + #------------------------------------------------------------------------------ # Test map #------------------------------------------------------------------------------ From cdc343fde7a31095784d850f768d08f0d073c2cc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 17:49:53 +0100 Subject: [PATCH 0767/1059] Add ManualClusteringPlugin --- phy/cluster/manual/gui_component.py | 35 +++++++++++++++++++ .../manual/tests/test_gui_component.py | 18 +++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index a20c29634..62b2bdf9a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -22,6 +22,7 @@ from phy.gui.actions import Actions from phy.gui.widgets import Table from phy.io.array import select_spikes +from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -522,3 +523,37 @@ def save(self): groups = {c: self.cluster_meta.get('group', c) or 'unsorted' for c in self.clustering.cluster_ids} self.gui.emit('request_save', spike_clusters, groups) + + +class ManualClusteringPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + + # Attach the manual clustering logic (wizard, merge, split, + # undo stack) to the GUI. + n = state.n_spikes_max_per_cluster + mc = ManualClustering(model.spike_clusters, + cluster_groups=model.cluster_groups, + n_spikes_max_per_cluster=n, + ) + mc.attach(gui) + + spc = mc.clustering.spikes_per_cluster + nfc = model.n_features_per_channel + + q, s = default_wizard_functions(waveforms=model.waveforms, + features=model.features, + masks=model.masks, + n_features_per_channel=nfc, + spikes_per_cluster=spc, + ) + + ctx = getattr(gui, 'context', None) + if ctx: # pragma: no cover + q, s = ctx.cache(q), ctx.cache(s) + else: + logger.warn("Context not available, unable to cache " + "the wizard functions.") + + mc.add_column(q, name='quality') + mc.set_default_sort('quality') + mc.set_similarity_func(s) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index c6a696ab0..e661c46e4 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -12,7 +12,10 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from ..gui_component import ManualClustering, default_wizard_functions +from ..gui_component import (ManualClustering, + ManualClusteringPlugin, + default_wizard_functions, + ) from phy.gui import GUI from phy.io.array import _spikes_per_cluster from phy.io.mock import (artificial_waveforms, @@ -20,6 +23,7 @@ artificial_features, artificial_spike_clusters, ) +from phy.utils import Bunch #------------------------------------------------------------------------------ @@ -61,6 +65,18 @@ def gui(qtbot): # Test GUI component #------------------------------------------------------------------------------ +def test_manual_clustering_plugin(qtbot, gui): + model = Bunch(spike_clusters=[0, 1, 2], + cluster_groups=None, + n_features_per_channel=2, + waveforms=np.zeros((3, 4, 1)), + features=np.zeros((3, 1, 2)), + masks=np.zeros((3, 1)), + ) + state = Bunch(n_spikes_max_per_cluster=10) + ManualClusteringPlugin().attach_to_gui(gui, model=model, state=state) + + def test_manual_clustering_edge_cases(manual_clustering): mc = manual_clustering From 292c87eae5863b4033588009d3c7d30221f0075f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 18:51:44 +0100 Subject: [PATCH 0768/1059] View plugins --- phy/cluster/manual/tests/test_views.py | 238 +++++++++---------------- phy/cluster/manual/views.py | 63 ++++++- phy/gui/gui.py | 3 + 3 files changed, 150 insertions(+), 154 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 9146e77f2..7a22fb234 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -6,10 +6,13 @@ # Imports #------------------------------------------------------------------------------ +from contextlib import contextmanager + import numpy as np from numpy.testing import assert_equal as ae from pytest import raises, yield_fixture +from phy.io.array import _spikes_per_cluster from phy.io.mock import (artificial_waveforms, artificial_features, artificial_spike_clusters, @@ -17,11 +20,10 @@ artificial_masks, artificial_traces, ) -from phy.gui import GUI +from phy.gui import create_gui from phy.electrode.mea import staggered_positions -from ..views import (WaveformView, FeatureView, CorrelogramView, TraceView, - _extract_wave, _selected_clusters_colors, - ) +from phy.utils import Bunch +from ..views import TraceView, _extract_wave, _selected_clusters_colors #------------------------------------------------------------------------------ @@ -36,12 +38,65 @@ def _show(qtbot, view, stop=False): view.close() +@yield_fixture(scope='session') +def model(): + model = Bunch() + + n_spikes = 51 + n_samples_w = 31 + n_samples_t = 20000 + n_channels = 11 + n_clusters = 3 + n_features = 4 + + model.n_channels = n_channels + model.n_spikes = n_spikes + model.sample_rate = 20000. + model.spike_times = artificial_spike_samples(n_spikes) * 1. + model.spike_times /= model.spike_times[-1] + model.spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + model.channel_positions = staggered_positions(n_channels) + model.waveforms = artificial_waveforms(n_spikes, n_samples_w, n_channels) + model.masks = artificial_masks(n_spikes, n_channels) + model.traces = artificial_traces(n_samples_t, n_channels) + model.features = artificial_features(n_spikes, n_channels, n_features) + model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) + + yield model + + @yield_fixture -def gui(qtbot): - gui = GUI(position=(200, 100), size=(800, 600)) - # gui.show() - # qtbot.waitForWindowShown(gui) - yield gui +def state(): + state = Bunch() + state.CorrelogramView1 = Bunch(bin_size=1e-3, + window_size=50e-3, + excerpt_size=8, + n_excerpts=5, + ) + state.n_samples_per_spike = 6 + yield state + + +@contextmanager +def _test_view(view_name, model=None, state=None): + state.plugins = [view_name + 'Plugin'] + gui = create_gui(model=model, state=state) + gui.show() + + v = gui.list_views(view_name)[0].view + + # Select some spikes. + spike_ids = np.arange(10) + cluster_ids = np.unique(model.spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + # Select other spikes. + spike_ids = np.arange(2, 10) + cluster_ids = np.unique(model.spike_clusters[spike_ids]) + v.on_select(cluster_ids, spike_ids) + + yield v + gui.close() @@ -84,39 +139,10 @@ def test_selected_clusters_colors(): # Test waveform view #------------------------------------------------------------------------------ -def test_waveform_view(qtbot, gui): - n_spikes = 20 - n_samples = 30 - n_channels = 40 - n_clusters = 3 - - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - channel_positions = staggered_positions(n_channels) - - # Create the view. - v = WaveformView(waveforms=waveforms, - masks=masks, - spike_clusters=spike_clusters, - channel_positions=channel_positions, - ) - # Select some spikes. - spike_ids = np.arange(10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - # Show the view. - v.attach(gui) - gui.show() - - # Select other spikes. - spike_ids = np.arange(2, 10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - v.toggle_waveform_overlap() - v.toggle_waveform_overlap() +def test_waveform_view(qtbot, model, state): + with _test_view('WaveformView', model=model, state=state) as v: + v.toggle_waveform_overlap() + v.toggle_waveform_overlap() # qtbot.stop() @@ -137,45 +163,13 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def test_trace_view_spikes(qtbot, gui): - n_samples = 1000 - n_channels = 12 - sample_rate = 2000. - n_spikes = 50 - n_clusters = 3 - - traces = artificial_traces(n_samples, n_channels) - spike_times = artificial_spike_samples(n_spikes) / sample_rate - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - masks = artificial_masks(n_spikes, n_channels) - - # Create the view. - v = TraceView(traces=traces, - sample_rate=sample_rate, - spike_times=spike_times, - spike_clusters=spike_clusters, - masks=masks, - n_samples_per_spike=6, - ) - - # Select some spikes. - spike_ids = np.arange(10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - # Show the view. - v.attach(gui) - gui.show() - - # Select other spikes. - spike_ids = np.arange(2, 10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) +def test_trace_view_spikes(qtbot, model, state): + with _test_view('TraceView', model=model, state=state) as v: + v.go_to(.5) + v.go_to(-.5) + v.go_left() + v.go_right() - v.go_to(.5) - v.go_to(-.5) - v.go_left() - v.go_right() # qtbot.stop() @@ -183,45 +177,17 @@ def test_trace_view_spikes(qtbot, gui): # Test feature view #------------------------------------------------------------------------------ -def test_feature_view(gui, qtbot): - n_spikes = 50 - n_channels = 5 - n_clusters = 2 - n_features = 4 - - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_times = artificial_spike_samples(n_spikes) / 20000. +def test_feature_view(qtbot, model, state): + with _test_view('FeatureView', model=model, state=state) as v: - # Create the view. - v = FeatureView(features=features, - masks=masks, - spike_times=spike_times, - spike_clusters=spike_clusters, - ) + @v.set_best_channels_func + def best_channels(cluster_id): + return list(range(model.n_channels)) - @v.set_best_channels_func - def best_channels(cluster_id): - return list(range(n_channels)) + v.add_attribute('sine', np.sin(np.linspace(-10., 10., model.n_spikes))) - v.add_attribute('sine', np.sin(np.linspace(-10., 10., n_spikes))) - - # Select some spikes. - spike_ids = np.arange(n_spikes) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - v.attach(gui) - gui.show() - - # Select other spikes. - spike_ids = np.arange(2, 10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - v.increase_feature_scaling() - v.decrease_feature_scaling() + v.increase_feature_scaling() + v.decrease_feature_scaling() # qtbot.stop() @@ -230,42 +196,8 @@ def best_channels(cluster_id): # Test correlogram view #------------------------------------------------------------------------------ -def test_correlogram_view(qtbot, gui): - n_spikes = 50 - n_clusters = 5 - sample_rate = 20000. - bin_size = 1e-3 - window_size = 50e-3 - - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_times = artificial_spike_samples(n_spikes) / sample_rate +def test_correlogram_view(qtbot, model, state): + with _test_view('CorrelogramView', model=model, state=state) as v: + v.toggle_normalization() - # Create the view. - v = CorrelogramView(spike_times=spike_times, - spike_clusters=spike_clusters, - sample_rate=sample_rate, - bin_size=bin_size, - window_size=window_size, - excerpt_size=8, - n_excerpts=5, - ) - - # Select some spikes. - spike_ids = np.arange(n_spikes) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - # Show the view. - v.show() - qtbot.waitForWindowShown(v.native) - - # Select other spikes. - spike_ids = np.arange(2, 10) - cluster_ids = np.unique(spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) - - v.toggle_normalization() - - v.attach(gui) - gui.show() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 1cdd1b08b..55aba5977 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -17,6 +17,8 @@ from phy.plot import View, _get_linear_x from phy.plot.utils import _get_boxes from phy.stats import correlograms +from phy.stats.clusters import mean, unmasked_channels, sorted_main_channels +from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -161,6 +163,8 @@ def attach(self, gui): gui.connect_(self.on_select) self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.show() + # ----------------------------------------------------------------------------- # Waveform view @@ -268,6 +272,17 @@ def toggle_waveform_overlap(self): self.on_select() +class WaveformViewPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + w = WaveformView(waveforms=model.waveforms, + masks=model.masks, + spike_clusters=model.spike_clusters, + channel_positions=model.channel_positions, + ) + w.attach(gui) + # TODO: scaling factors + + # ----------------------------------------------------------------------------- # Trace view # ----------------------------------------------------------------------------- @@ -489,6 +504,17 @@ def go_left(self): self.shift(-delay) +class TraceViewPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + t = TraceView(traces=model.traces, + sample_rate=model.sample_rate, + spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + masks=model.masks, + ) + t.attach(gui) + + # ----------------------------------------------------------------------------- # Feature view # ----------------------------------------------------------------------------- @@ -539,7 +565,7 @@ def _dimensions_for_clusters(cluster_ids, n_cols=None, # Now, select the right number of channels in the x axis. x_channels = x_channels[:n_cols - 1] if len(x_channels) < n_cols - 1: - x_channels = y_channels + x_channels = y_channels # pragma: no cover return _dimensions_matrix(x_channels, y_channels) @@ -756,6 +782,28 @@ def feature_scaling(self, value): self.on_select() +class FeatureViewPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + + f = FeatureView(features=model.features, + masks=model.masks, + spike_clusters=model.spike_clusters, + spike_times=model.spike_times, + ) + + @f.set_best_channels_func + def best_channels(cluster_id): + """Select the best channels for a given cluster.""" + # TODO: better perf with cluster stats and cache + spike_ids = model.spikes_per_cluster[cluster_id] + m = model.masks[spike_ids] + mean_masks = mean(m) + uch = unmasked_channels(mean_masks) + return sorted_main_channels(mean_masks, uch) + + f.attach(gui) + + # ----------------------------------------------------------------------------- # Correlogram view # ----------------------------------------------------------------------------- @@ -866,3 +914,16 @@ def attach(self, gui): """Attach the view to the GUI.""" super(CorrelogramView, self).attach(gui) self.actions.add(self.toggle_normalization, shortcut='n') + + +class CorrelogramViewPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + ccg = CorrelogramView(spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + sample_rate=model.sample_rate, + bin_size=state.CorrelogramView1.bin_size, + window_size=state.CorrelogramView1.window_size, + excerpt_size=state.CorrelogramView1.excerpt_size, + n_excerpts=state.CorrelogramView1.n_excerpts, + ) + ccg.attach(gui) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 7f29a81a0..00205697a 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -185,6 +185,7 @@ def add_view(self, **kwargs): """Add a widget to the main window.""" + original_view = view title = title or view.__class__.__name__ view = _try_get_vispy_canvas(view) view = _try_get_matplotlib_canvas(view) @@ -194,6 +195,7 @@ def add_view(self, dockwidget.setObjectName(title) dockwidget.setWindowTitle(title) dockwidget.setWidget(view) + dockwidget.view = original_view # Set gui widget options. options = QDockWidget.DockWidgetMovable @@ -220,6 +222,7 @@ def add_view(self, dockwidget.setFloating(floating) dockwidget.show() self.emit('add_view', view) + logger.debug("Add %s to %s.", title, self) return dockwidget def list_views(self, title='', is_visible=True): From 49b2e8d4ccdd0fbc22b24dc12558bfdfe3a846c2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 19:09:19 +0100 Subject: [PATCH 0769/1059] Update log level --- phy/gui/gui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 00205697a..9d626feef 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -222,7 +222,7 @@ def add_view(self, dockwidget.setFloating(floating) dockwidget.show() self.emit('add_view', view) - logger.debug("Add %s to %s.", title, self) + logger.log(5, "Add %s to GUI.", title) return dockwidget def list_views(self, title='', is_visible=True): From a55de4c5b588bd1aa51c217a90bfdc498cae8556 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 19:25:59 +0100 Subject: [PATCH 0770/1059] WIP: refactor GUI --- phy/gui/gui.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 9d626feef..16926d172 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -90,7 +90,7 @@ class GUI(QMainWindow): def __init__(self, position=None, size=None, - title=None, + name=None, ): # HACK to ensure that closeEvent is called only twice (seems like a # Qt bug). @@ -98,26 +98,22 @@ def __init__(self, if not QApplication.instance(): # pragma: no cover raise RuntimeError("A Qt application must be created.") super(GUI, self).__init__() - if title is None: - title = self.__class__.__name__ - self.setWindowTitle(title) - if position is not None: - self.move(position[0], position[1]) - if size is not None: - self.resize(QSize(size[0], size[1])) - self.setObjectName(title) QMetaObject.connectSlotsByName(self) self.setDockOptions(QMainWindow.AllowTabbedDocks | QMainWindow.AllowNestedDocks | QMainWindow.AnimatedDocks ) + self._set_name(name) + self._set_pos_size(position, size) + # Mapping {name: menuBar}. self._menus = {} # We can derive from EventEmitter because of a conflict with connect. self._event = EventEmitter() + # Status bar. self._status_bar = QStatusBar() self.setStatusBar(self._status_bar) @@ -125,6 +121,24 @@ def __init__(self, self.actions = [] # Default actions. + self._set_default_actions() + + # Create and attach snippets. + self.snippets = Snippets(self) + + def _set_name(self, name): + if name is None: + name = self.__class__.__name__ + self.setWindowTitle(name) + self.setObjectName(name) + + def _set_pos_size(self, position, size): + if position is not None: + self.move(position[0], position[1]) + if size is not None: + self.resize(QSize(size[0], size[1])) + + def _set_default_actions(self): self.default_actions = Actions(self) @self.default_actions.add(shortcut='ctrl+q', menu='&File') @@ -139,9 +153,6 @@ def show_shortcuts(): shortcuts.update(actions.shortcuts) _show_shortcuts(shortcuts, self.name) - # Create and attach snippets. - self.snippets = Snippets(self) - # Events # ------------------------------------------------------------------------- From c53191c50325d09065a24b11891027959081c27d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 19:38:35 +0100 Subject: [PATCH 0771/1059] WIP: refactor GUI --- phy/gui/gui.py | 101 +++++++++++++++++++++----------------- phy/gui/tests/test_gui.py | 7 ++- 2 files changed, 63 insertions(+), 45 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 16926d172..65ccbd65d 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -69,6 +69,38 @@ def closeEvent(self, e): super(DockWidget, self).closeEvent(e) +def _create_dock_widget(widget, name, closable=True, floatable=True): + # Create the gui widget. + dock_widget = DockWidget() + dock_widget.setObjectName(name) + dock_widget.setWindowTitle(name) + dock_widget.setWidget(widget) + + # Set gui widget options. + options = QDockWidget.DockWidgetMovable + if closable: + options = options | QDockWidget.DockWidgetClosable + if floatable: + options = options | QDockWidget.DockWidgetFloatable + + dock_widget.setFeatures(options) + dock_widget.setAllowedAreas(Qt.LeftDockWidgetArea | + Qt.RightDockWidgetArea | + Qt.TopDockWidgetArea | + Qt.BottomDockWidgetArea + ) + + return dock_widget + + +def _get_dock_position(position): + return {'left': Qt.LeftDockWidgetArea, + 'right': Qt.RightDockWidgetArea, + 'top': Qt.TopDockWidgetArea, + 'bottom': Qt.BottomDockWidgetArea, + }[position or 'right'] + + class GUI(QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. @@ -80,6 +112,7 @@ class GUI(QMainWindow): close show add_view + close_view Note ---- @@ -188,61 +221,41 @@ def show(self): def add_view(self, view, - title=None, + name=None, position=None, closable=True, floatable=True, - floating=None, - **kwargs): + floating=None): """Add a widget to the main window.""" - original_view = view - title = title or view.__class__.__name__ - view = _try_get_vispy_canvas(view) - view = _try_get_matplotlib_canvas(view) - - # Create the gui widget. - dockwidget = DockWidget(self) - dockwidget.setObjectName(title) - dockwidget.setWindowTitle(title) - dockwidget.setWidget(view) - dockwidget.view = original_view - - # Set gui widget options. - options = QDockWidget.DockWidgetMovable - if closable: - options = options | QDockWidget.DockWidgetClosable - if floatable: - options = options | QDockWidget.DockWidgetFloatable - - dockwidget.setFeatures(options) - dockwidget.setAllowedAreas(Qt.LeftDockWidgetArea | - Qt.RightDockWidgetArea | - Qt.TopDockWidgetArea | - Qt.BottomDockWidgetArea - ) - - q_position = { - 'left': Qt.LeftDockWidgetArea, - 'right': Qt.RightDockWidgetArea, - 'top': Qt.TopDockWidgetArea, - 'bottom': Qt.BottomDockWidgetArea, - }[position or 'right'] - self.addDockWidget(q_position, dockwidget) + name = name or view.__class__.__name__ + widget = _try_get_vispy_canvas(view) + widget = _try_get_matplotlib_canvas(widget) + + dock_widget = _create_dock_widget(widget, name, + closable=closable, + floatable=floatable, + ) + self.addDockWidget(_get_dock_position(position), dock_widget) if floating is not None: - dockwidget.setFloating(floating) - dockwidget.show() + dock_widget.setFloating(floating) + + @dock_widget.connect_ + def on_close_widget(): + self.emit('close_view', view) + + dock_widget.show() self.emit('add_view', view) - logger.log(5, "Add %s to GUI.", title) - return dockwidget + logger.log(5, "Add %s to GUI.", name) + return dock_widget - def list_views(self, title='', is_visible=True): - """List all views which title start with a given string.""" - title = title.lower() + def list_views(self, name='', is_visible=True): + """List all views which name start with a given string.""" + name = name.lower() children = self.findChildren(QWidget) return [child for child in children if isinstance(child, QDockWidget) and - _title(child).startswith(title) and + _title(child).startswith(name) and (child.isVisible() if is_visible else True) and child.width() >= 10 and child.height() >= 10 diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 9fa585d95..22ad46a10 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -90,8 +90,13 @@ def on_show(): @view.connect_ def on_close_widget(): _close.append(0) + + @gui.connect_ + def on_close_view(view): + _close.append(1) + view.close() - assert _close == [0] + assert _close == [1, 0] gui.default_actions.exit() From c6a0d4c8b44d2cffd97e87e3fa3b8c58da6ee6d2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 21:19:34 +0100 Subject: [PATCH 0772/1059] Update tests --- phy/gui/gui.py | 29 ++++++++++++++++++----------- phy/gui/tests/test_gui.py | 6 +++--- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 65ccbd65d..3fe66976a 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -24,10 +24,6 @@ # GUI main window # ----------------------------------------------------------------------------- -def _title(widget): - return str(widget.windowTitle()).lower() - - def _try_get_vispy_canvas(view): # Get the Qt widget from a VisPy canvas. try: @@ -164,6 +160,8 @@ def _set_name(self, name): name = self.__class__.__name__ self.setWindowTitle(name) self.setObjectName(name) + # Set the name in the GUI. + self.__name__ = name def _set_pos_size(self, position, size): if position is not None: @@ -184,7 +182,7 @@ def show_shortcuts(): shortcuts = self.default_actions.shortcuts for actions in self.actions: shortcuts.update(actions.shortcuts) - _show_shortcuts(shortcuts, self.name) + _show_shortcuts(shortcuts, self.__name__) # Events # ------------------------------------------------------------------------- @@ -219,6 +217,13 @@ def show(self): # Views # ------------------------------------------------------------------------- + def _get_view_name(self, view): + """The view name is the class name followed by 1, 2, or n.""" + name = view.__class__.__name__ + views = self.list_views(name) + n = len(views) + 1 + return '{:s}{:d}'.format(name, n) + def add_view(self, view, name=None, @@ -228,7 +233,9 @@ def add_view(self, floating=None): """Add a widget to the main window.""" - name = name or view.__class__.__name__ + name = name or self._get_view_name(view) + # Set the name in the view. + view.__name__ = name widget = _try_get_vispy_canvas(view) widget = _try_get_matplotlib_canvas(widget) @@ -239,6 +246,7 @@ def add_view(self, self.addDockWidget(_get_dock_position(position), dock_widget) if floating is not None: dock_widget.setFloating(floating) + dock_widget.view = view @dock_widget.connect_ def on_close_widget(): @@ -251,11 +259,10 @@ def on_close_widget(): def list_views(self, name='', is_visible=True): """List all views which name start with a given string.""" - name = name.lower() children = self.findChildren(QWidget) - return [child for child in children + return [child.view for child in children if isinstance(child, QDockWidget) and - _title(child).startswith(name) and + child.view.__name__.startswith(name) and (child.isVisible() if is_visible else True) and child.width() >= 10 and child.height() >= 10 @@ -266,7 +273,7 @@ def view_count(self): views = self.list_views() counts = defaultdict(lambda: 0) for view in views: - counts[_title(view)] += 1 + counts[view.__name__] += 1 return dict(counts) # Menu bar @@ -353,7 +360,7 @@ def create_gui(model=None, state=None): state = state or GUIState() plugins = state.plugins # GUI name. - name = gui.name + name = gui.__name__ # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 22ad46a10..e411e27af 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -76,12 +76,12 @@ def on_show(): qtbot.keyPress(gui, Qt.Key_Control) qtbot.keyRelease(gui, Qt.Key_Control) - view = gui.add_view(_create_canvas(), 'view1', floating=True) - gui.add_view(_create_canvas(), 'view2') + view = gui.add_view(_create_canvas(), floating=True) + gui.add_view(_create_canvas()) view.setFloating(False) gui.show() - assert len(gui.list_views('view')) == 2 + assert len(gui.list_views('Canvas')) == 2 # Check that the close_widget event is fired when the gui widget is # closed. From 72605fc56dc3b46e7ff245429d15bb26b408788c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 21:22:09 +0100 Subject: [PATCH 0773/1059] Fixes --- phy/cluster/manual/gui_component.py | 4 ++-- phy/cluster/manual/tests/test_views.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 62b2bdf9a..73ade1ac0 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -427,8 +427,8 @@ def attach(self, gui): self._create_actions(gui) # Add the cluster views. - gui.add_view(self.cluster_view, title='ClusterView') - gui.add_view(self.similarity_view, title='SimilarityView') + gui.add_view(self.cluster_view, name='ClusterView') + gui.add_view(self.similarity_view, name='SimilarityView') # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 7a22fb234..df8081e51 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -83,7 +83,7 @@ def _test_view(view_name, model=None, state=None): gui = create_gui(model=model, state=state) gui.show() - v = gui.list_views(view_name)[0].view + v = gui.list_views(view_name)[0] # Select some spikes. spike_ids = np.arange(10) From 511da79ec78fb07bdd9d9a812fbb092575a05487 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 16 Dec 2015 21:45:28 +0100 Subject: [PATCH 0774/1059] GUI name --- phy/gui/gui.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 3fe66976a..ff9bca61d 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -313,7 +313,6 @@ def save_geometry_state(self): return { 'geometry': self.saveGeometry(), 'state': self.saveState(), - 'view_count': self.view_count(), } def restore_geometry_state(self, gs): @@ -349,14 +348,14 @@ def from_json(self, filename): self.update(_load_json(filename)) -def create_gui(model=None, state=None): +def create_gui(name=None, model=None, state=None): """Create a GUI with a model and a GUI state. By default, the list of plugins is taken from the `c.TheGUI.plugins` parameter, where `TheGUI` is the name of the GUI class. """ - gui = GUI() + gui = GUI(name=name) state = state or GUIState() plugins = state.plugins # GUI name. From 7690674d476a660579ebf89f1986508c0f66f985 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 10:19:47 +0100 Subject: [PATCH 0775/1059] Smaller marker size in feature view --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index df8081e51..c69a506dc 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -189,7 +189,7 @@ def best_channels(cluster_id): v.increase_feature_scaling() v.decrease_feature_scaling() - # qtbot.stop() + # qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 55aba5977..eca4ecb13 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -587,6 +587,7 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(ManualClusteringView): normalization_percentile = .95 normalization_n_spikes = 1000 + _default_marker_size = 3. _feature_scaling = 1. default_shortcuts = { @@ -710,12 +711,13 @@ def _plot_features(self, i, j, x_dim, y_dim, n_clusters=n_clusters) # Create the scatter plot for the current subplot. + ms = self._default_marker_size self[i, j].scatter(x=x, y=y, color=color, depth=d, data_bounds=data_bounds, - size=5 * np.ones(len(spike_ids)), + size=ms * np.ones(len(spike_ids)), ) def set_best_channels_func(self, func): From 78e436ff02b560f9201626ee8e11a28b817a5042 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 10:53:07 +0100 Subject: [PATCH 0776/1059] WIP --- phy/cluster/manual/gui_component.py | 3 ++- phy/cluster/manual/views.py | 14 ++++++++++++-- phy/gui/gui.py | 17 ++++++++++++++++- phy/gui/tests/test_gui.py | 16 ++++++++++++++++ 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 73ade1ac0..eab2269a3 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -554,6 +554,7 @@ def attach_to_gui(self, gui, model=None, state=None): logger.warn("Context not available, unable to cache " "the wizard functions.") - mc.add_column(q, name='quality') + # Add the quality column in the cluster view. + mc.cluster_view.add_column(q, name='quality') mc.set_default_sort('quality') mc.set_similarity_func(s) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index eca4ecb13..fbb8f237f 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -184,6 +184,7 @@ def __init__(self, masks=None, spike_clusters=None, channel_positions=None, + box_bounds=None, **kwargs): """ @@ -192,7 +193,8 @@ def __init__(self, """ # Initialize the view. - box_bounds = _get_boxes(channel_positions) + box_bounds = (_get_boxes(channel_positions) if box_bounds is None + else box_bounds) super(WaveformView, self).__init__(layout='boxed', box_bounds=box_bounds, **kwargs) @@ -274,13 +276,21 @@ def toggle_waveform_overlap(self): class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): + # NOTE: we assume that the state contains fields for every view. + # Load the box_bounds from the state. + box_bounds = state.WaveformView1.box_bounds w = WaveformView(waveforms=model.waveforms, masks=model.masks, spike_clusters=model.spike_clusters, channel_positions=model.channel_positions, + box_bounds=box_bounds, ) w.attach(gui) - # TODO: scaling factors + + @gui.connect_ + def on_close(): + # Save the box bounds. + state[w.__name__].box_bounds = w.stacked.box_bounds # ----------------------------------------------------------------------------- diff --git a/phy/gui/gui.py b/phy/gui/gui.py index ff9bca61d..4931c19c4 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -228,7 +228,7 @@ def add_view(self, view, name=None, position=None, - closable=True, + closable=False, floatable=True, floating=None): """Add a widget to the main window.""" @@ -335,6 +335,21 @@ def restore_geometry_state(self, gs): # GUI state, creator, plugins # ----------------------------------------------------------------------------- +class DefaultBunch(defaultdict): + def __init__(self, *args, **kwargs): + super(DefaultBunch, self).__init__(*args, **kwargs) + self.__dict__ = self + + def __missing__(self, item) + pass + + +class DefaultDictBunch(defaultdict): + def __init__(self, **kwargs): + super(DefaultDictBunch, self).__init__(DefaultBunch, **kwargs) + self.__dict__ = self + + class GUIState(Bunch): def __init__(self, geometry_state=None, plugins=None, **kwargs): super(GUIState, self).__init__(geomety_state=geometry_state, diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index e411e27af..b5df04d65 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -15,6 +15,7 @@ create_gui, _try_get_matplotlib_canvas, _try_get_vispy_canvas, + DefaultBunch, DefaultDictBunch, ) from phy.utils import IPlugin from phy.utils._color import _random_color @@ -55,6 +56,21 @@ def test_matplotlib_view(): # Test GUI #------------------------------------------------------------------------------ +def test_default_bunch(): + b = DefaultBunch() + assert b.hello is None + b.hello = 'world' + assert b.hello == 'world' + + b = DefaultDictBunch() + assert len(b.unknown) == 0 + assert b.hello.world is None + assert len(b.hello) == 1 + assert b.hello == {'world': None} + b.hello.dolly = '!' + assert b.hello.dolly == '!' + + def test_gui_noapp(): if not QApplication.instance(): with raises(RuntimeError): # pragma: no cover From 1f28cfb7d990e22b1f9c2114d4bca85b3c496b83 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 15:02:06 +0100 Subject: [PATCH 0777/1059] Remove default bunch --- phy/gui/gui.py | 15 --------------- phy/gui/tests/test_gui.py | 16 ---------------- 2 files changed, 31 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 4931c19c4..2e9dd11cf 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -335,21 +335,6 @@ def restore_geometry_state(self, gs): # GUI state, creator, plugins # ----------------------------------------------------------------------------- -class DefaultBunch(defaultdict): - def __init__(self, *args, **kwargs): - super(DefaultBunch, self).__init__(*args, **kwargs) - self.__dict__ = self - - def __missing__(self, item) - pass - - -class DefaultDictBunch(defaultdict): - def __init__(self, **kwargs): - super(DefaultDictBunch, self).__init__(DefaultBunch, **kwargs) - self.__dict__ = self - - class GUIState(Bunch): def __init__(self, geometry_state=None, plugins=None, **kwargs): super(GUIState, self).__init__(geomety_state=geometry_state, diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index b5df04d65..e411e27af 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -15,7 +15,6 @@ create_gui, _try_get_matplotlib_canvas, _try_get_vispy_canvas, - DefaultBunch, DefaultDictBunch, ) from phy.utils import IPlugin from phy.utils._color import _random_color @@ -56,21 +55,6 @@ def test_matplotlib_view(): # Test GUI #------------------------------------------------------------------------------ -def test_default_bunch(): - b = DefaultBunch() - assert b.hello is None - b.hello = 'world' - assert b.hello == 'world' - - b = DefaultDictBunch() - assert len(b.unknown) == 0 - assert b.hello.world is None - assert len(b.hello) == 1 - assert b.hello == {'world': None} - b.hello.dolly = '!' - assert b.hello.dolly == '!' - - def test_gui_noapp(): if not QApplication.instance(): with raises(RuntimeError): # pragma: no cover From 87c839de7716f3411fb0a0d5156cb5862e29d763 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 15:13:18 +0100 Subject: [PATCH 0778/1059] WIP: GUI state view param --- phy/gui/gui.py | 9 +++++++++ phy/gui/tests/test_gui.py | 13 +++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 2e9dd11cf..09c0e8001 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -347,6 +347,15 @@ def to_json(self, filename): def from_json(self, filename): self.update(_load_json(filename)) + def get_view_param(self, view_name, name): + return self.get(view_name + '1', Bunch()).get(name, None) + + def set_view_params(self, view, **kwargs): + view_name = view.__name__ + if view_name not in self: + self[view_name] = Bunch() + self[view_name].update(kwargs) + def create_gui(name=None, model=None, state=None): """Create a GUI with a model and a GUI state. diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index e411e27af..446175db4 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -16,7 +16,7 @@ _try_get_matplotlib_canvas, _try_get_vispy_canvas, ) -from phy.utils import IPlugin +from phy.utils import IPlugin, Bunch from phy.utils._color import _random_color @@ -101,7 +101,7 @@ def on_close_view(view): gui.default_actions.exit() -def test_gui_state(tempdir): +def test_gui_state_json(tempdir): path = op.join(tempdir, 'state.json') state = GUIState(hello='world') @@ -112,6 +112,15 @@ def test_gui_state(tempdir): assert state.hello == 'world' +def test_gui_state_view(): + view = Bunch(__name__='myview1') + state = GUIState() + state.set_view_params(view, hello='world') + state.get_view_param('unknown', 'hello') is None + state.get_view_param('myview', 'unknown') is None + state.get_view_param('myview', 'hello') == 'world' + + def test_create_gui_1(qapp, tempdir): _tmp = [] From d435e5ffe2ff0cebaa8a455f06e8d677433ea7d2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 15:34:04 +0100 Subject: [PATCH 0779/1059] Add global save in context --- phy/io/context.py | 16 ++++++++++++---- phy/io/tests/test_context.py | 17 +++++++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index e70d21186..66881df3f 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -25,6 +25,7 @@ from .array import read_array, write_array from phy.utils import (Bunch, _save_json, _load_json, _ensure_dir_exists, IPlugin,) +from phy.utils.config import phy_user_dir logger = logging.getLogger(__name__) @@ -251,15 +252,22 @@ def map(self, f, *args): else: return self._map_serial(f, *args) - def save(self, name, data): + def _get_path(self, name, location): + if location == 'local': + return op.join(self.cache_dir, name + '.json') + elif location == 'global': + return op.join(phy_user_dir(), name + '.json') + + def save(self, name, data, location='local'): """Save a dictionary in a JSON file within the cache directory.""" - path = op.join(self.cache_dir, name + '.json') + path = self._get_path(name, location) _ensure_dir_exists(op.dirname(path)) + logger.debug("Save data to `%s`.", path) _save_json(path, data) - def load(self, name): + def load(self, name, location='local'): """Load saved data from the cache directory.""" - path = op.join(self.cache_dir, name + '.json') + path = self._get_path(name, location) if not op.exists(path): logger.debug("The file `%s` doesn't exist.", path) return diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 207b86e83..53dcbbd3d 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -59,6 +59,16 @@ def parallel_context(tempdir, ipy_client, request): yield ctx +@yield_fixture +def temp_phy_user_dir(tempdir): + """Use a temporary phy user directory.""" + import phy.io.context + f = phy.io.context.phy_user_dir + phy.io.context.phy_user_dir = lambda: tempdir + yield + phy.io.context.phy_user_dir = f + + #------------------------------------------------------------------------------ # ipyparallel tests #------------------------------------------------------------------------------ @@ -81,12 +91,15 @@ def test_read_write(tempdir): ae(read_array(op.join(tempdir, 'test.npy')), x) -def test_context_load_save(context): +def test_context_load_save(tempdir, context, temp_phy_user_dir): assert context.load('unexisting') is None context.save('a/hello', {'text': 'world'}) assert context.load('a/hello')['text'] == 'world' + context.save('a/hello', {'text': 'world!'}, location='global') + assert context.load('a/hello', location='global')['text'] == 'world!' + def test_context_cache(context): @@ -127,7 +140,7 @@ def test_context_plugin(tempdir): gui = Bunch() path = op.join(tempdir, 'model.ext') ContextPlugin().attach_to_gui(gui, model=Bunch(path=path), state=Bunch()) - assert gui.context.cache_dir == op.dirname(path) + '/.phy' + assert op.dirname(path) + '/.phy' in gui.context.cache_dir #------------------------------------------------------------------------------ From 755dc0b7a8f107acd0024abcdc7dd7b93240c4fe Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 15:54:17 +0100 Subject: [PATCH 0780/1059] WIP: add GUI state save plugins --- phy/gui/gui.py | 28 +++++++++- phy/gui/tests/test_gui.py | 115 +++++++++++++++++++++----------------- phy/io/__init__.py | 1 + 3 files changed, 91 insertions(+), 53 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 09c0e8001..b2f28cfbf 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -15,7 +15,7 @@ from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter from phy.utils import load_master_config, Bunch, _load_json, _save_json -from phy.utils.plugin import get_plugin +from phy.utils.plugin import get_plugin, IPlugin logger = logging.getLogger(__name__) @@ -357,6 +357,32 @@ def set_view_params(self, view, **kwargs): self[view_name].update(kwargs) +class SaveGeometryStatePlugin(IPlugin): + def attach_to_gui(self, gui, state=None, model=None): + + @gui.connect_ + def on_close(): + gs = gui.save_geometry_state() + state['geometry_state'] = gs + + @gui.connect_ + def on_show(): + gs = state['geometry_state'] + gui.restore_geometry_state(gs) + + +class SaveGUIStatePlugin(IPlugin): + def attach_to_gui(self, gui, state=None, model=None): + + state_name = '{}/state'.format(gui.name) + location = state.get('state_save_location', 'global') + + @gui.connect_ + def on_close(): + logger.debug("Save GUI state to %s.", state_name) + gui.context.save(state_name, state, location=location) + + def create_gui(name=None, model=None, state=None): """Create a GUI with a model and a GUI state. diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 446175db4..1b88ed70c 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -15,8 +15,11 @@ create_gui, _try_get_matplotlib_canvas, _try_get_vispy_canvas, + SaveGUIStatePlugin, + SaveGeometryStatePlugin, ) -from phy.utils import IPlugin, Bunch +from phy.io import Context +from phy.utils import IPlugin, Bunch, _load_json from phy.utils._color import _random_color @@ -76,7 +79,7 @@ def on_show(): qtbot.keyPress(gui, Qt.Key_Control) qtbot.keyRelease(gui, Qt.Key_Control) - view = gui.add_view(_create_canvas(), floating=True) + view = gui.add_view(_create_canvas(), floating=True, closable=True) gui.add_view(_create_canvas()) view.setFloating(False) gui.show() @@ -101,56 +104,6 @@ def on_close_view(view): gui.default_actions.exit() -def test_gui_state_json(tempdir): - path = op.join(tempdir, 'state.json') - - state = GUIState(hello='world') - state.to_json(path) - - state = GUIState() - state.from_json(path) - assert state.hello == 'world' - - -def test_gui_state_view(): - view = Bunch(__name__='myview1') - state = GUIState() - state.set_view_params(view, hello='world') - state.get_view_param('unknown', 'hello') is None - state.get_view_param('myview', 'unknown') is None - state.get_view_param('myview', 'hello') == 'world' - - -def test_create_gui_1(qapp, tempdir): - - _tmp = [] - - class MyPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): - _tmp.append(state.hello) - - gui = create_gui(state=GUIState(plugins=['MyPlugin'], hello='world')) - assert gui - - assert _tmp == ['world'] - - -def test_gui_component(gui): - - class TestComponent(object): - def __init__(self, arg): - self._arg = arg - - def attach(self, gui): - gui._attached = self._arg - return 'attached' - - tc = TestComponent(3) - - assert tc.attach(gui) == 'attached' - assert gui._attached == 3 - - def test_gui_status_message(gui): assert gui.status_message == '' gui.status_message = ':hello world!' @@ -204,3 +157,61 @@ def on_show(): } gui.close() + + +#------------------------------------------------------------------------------ +# Test GUI plugin +#------------------------------------------------------------------------------ + +def test_gui_state_json(tempdir): + path = op.join(tempdir, 'state.json') + + state = GUIState(hello='world') + state.to_json(path) + + state = GUIState() + state.from_json(path) + assert state.hello == 'world' + + +def test_gui_state_view(): + view = Bunch(__name__='myview1') + state = GUIState() + state.set_view_params(view, hello='world') + state.get_view_param('unknown', 'hello') is None + state.get_view_param('myview', 'unknown') is None + state.get_view_param('myview', 'hello') == 'world' + + +def test_create_gui_1(qapp, tempdir): + + _tmp = [] + + class MyPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + _tmp.append(state.hello) + + gui = create_gui(state=GUIState(plugins=['MyPlugin'], hello='world')) + assert gui + + assert _tmp == ['world'] + + +def test_save_geometry_state(gui): + state = Bunch() + SaveGeometryStatePlugin().attach_to_gui(gui, state=state) + gui.close() + + assert state.geometry_state['geometry'] + assert state.geometry_state['state'] + + gui.show() + + +def test_save_gui_state(gui, tempdir): + gui.context = Context(tempdir) + state = Bunch(hello='world', state_save_location='local') + SaveGUIStatePlugin().attach_to_gui(gui, state=state) + gui.close() + json = _load_json(op.join(tempdir, 'GUI/state.json')) + assert json['hello'] == 'world' diff --git a/phy/io/__init__.py b/phy/io/__init__.py index d068cad04..f4fd14812 100644 --- a/phy/io/__init__.py +++ b/phy/io/__init__.py @@ -3,4 +3,5 @@ """Input/output.""" +from .context import Context from .traces import read_dat, read_kwd From 6cdc244ef619da201ee68d8bb5f98058a294d812 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 16:26:10 +0100 Subject: [PATCH 0781/1059] Update GUI save state --- phy/gui/gui.py | 50 +++++++++++++++++++++++++++------------ phy/gui/tests/test_gui.py | 12 ++-------- phy/utils/__init__.py | 2 +- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index b2f28cfbf..3543060fd 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -9,12 +9,16 @@ from collections import defaultdict import logging +import os.path as op + +from six import string_types from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter -from phy.utils import load_master_config, Bunch, _load_json, _save_json +from phy.utils import (load_master_config, Bunch, _load_json, _save_json, + _ensure_dir_exists, phy_user_dir,) from phy.utils.plugin import get_plugin, IPlugin logger = logging.getLogger(__name__) @@ -335,6 +339,25 @@ def restore_geometry_state(self, gs): # GUI state, creator, plugins # ----------------------------------------------------------------------------- +def _get_path(name): + return op.join(phy_user_dir(), name + '.json') + + +def _save(name, data): + path = _get_path(name) + _ensure_dir_exists(op.dirname(path)) + logger.debug("Save data to `%s`.", path) + _save_json(path, data) + + +def _load(name): + path = _get_path(name) + if not op.exists(path): + logger.debug("The file `%s` doesn't exist.", path) + return + return _load_json(path) + + class GUIState(Bunch): def __init__(self, geometry_state=None, plugins=None, **kwargs): super(GUIState, self).__init__(geomety_state=geometry_state, @@ -351,7 +374,7 @@ def get_view_param(self, view_name, name): return self.get(view_name + '1', Bunch()).get(name, None) def set_view_params(self, view, **kwargs): - view_name = view.__name__ + view_name = view if isinstance(view, string_types) else view.__name__ if view_name not in self: self[view_name] = Bunch() self[view_name].update(kwargs) @@ -367,22 +390,10 @@ def on_close(): @gui.connect_ def on_show(): - gs = state['geometry_state'] + gs = state.get('geometry_state', None) gui.restore_geometry_state(gs) -class SaveGUIStatePlugin(IPlugin): - def attach_to_gui(self, gui, state=None, model=None): - - state_name = '{}/state'.format(gui.name) - location = state.get('state_save_location', 'global') - - @gui.connect_ - def on_close(): - logger.debug("Save GUI state to %s.", state_name) - gui.context.save(state_name, state, location=location) - - def create_gui(name=None, model=None, state=None): """Create a GUI with a model and a GUI state. @@ -396,6 +407,10 @@ def create_gui(name=None, model=None, state=None): # GUI name. name = gui.__name__ + # Load the state from disk. + state_name = '{}/state'.format(gui.name) + state.update(_load(state_name) or {}) + # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. config = load_master_config() @@ -408,4 +423,9 @@ def create_gui(name=None, model=None, state=None): logger.debug("Attach plugin `%s` to %s.", plugin, name) get_plugin(plugin)().attach_to_gui(gui, state=state, model=model) + # Save the state to disk. + @gui.connect_ + def on_close(): + _save(state_name, state) + return gui diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 1b88ed70c..561f51cf2 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -15,7 +15,6 @@ create_gui, _try_get_matplotlib_canvas, _try_get_vispy_canvas, - SaveGUIStatePlugin, SaveGeometryStatePlugin, ) from phy.io import Context @@ -196,6 +195,8 @@ def attach_to_gui(self, gui, model=None, state=None): assert _tmp == ['world'] + gui.close() + def test_save_geometry_state(gui): state = Bunch() @@ -206,12 +207,3 @@ def test_save_geometry_state(gui): assert state.geometry_state['state'] gui.show() - - -def test_save_gui_state(gui, tempdir): - gui.context = Context(tempdir) - state = Bunch(hello='world', state_save_location='local') - SaveGUIStatePlugin().attach_to_gui(gui, state=state) - gui.close() - json = _load_json(op.join(tempdir, 'GUI/state.json')) - assert json['hello'] == 'world' diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 187ae593a..f298945f0 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -8,4 +8,4 @@ Bunch, _is_list) from .event import EventEmitter, ProgressReporter from .plugin import IPlugin, get_plugin, get_all_plugins -from .config import _ensure_dir_exists, load_master_config +from .config import _ensure_dir_exists, load_master_config, phy_user_dir From 35f668bc4db79ed2b5334e866593e38ea488c12d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 16:44:56 +0100 Subject: [PATCH 0782/1059] Fixes --- phy/cluster/manual/tests/test_views.py | 19 ++++++++++--------- phy/cluster/manual/views.py | 4 ++-- phy/gui/gui.py | 18 +++++++++++++++--- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index c69a506dc..d3cb88e5e 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -20,7 +20,7 @@ artificial_masks, artificial_traces, ) -from phy.gui import create_gui +from phy.gui import create_gui, GUIState from phy.electrode.mea import staggered_positions from phy.utils import Bunch from ..views import TraceView, _extract_wave, _selected_clusters_colors @@ -65,14 +65,15 @@ def model(): yield model -@yield_fixture +@yield_fixture(scope='function') def state(): - state = Bunch() - state.CorrelogramView1 = Bunch(bin_size=1e-3, - window_size=50e-3, - excerpt_size=8, - n_excerpts=5, - ) + state = GUIState() + state.set_view_params('CorrelogramView1', + bin_size=1e-3, + window_size=50e-3, + excerpt_size=8, + n_excerpts=5, + ) state.n_samples_per_spike = 6 yield state @@ -144,7 +145,7 @@ def test_waveform_view(qtbot, model, state): v.toggle_waveform_overlap() v.toggle_waveform_overlap() - # qtbot.stop() + # qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index fbb8f237f..f80811719 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -278,7 +278,7 @@ class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): # NOTE: we assume that the state contains fields for every view. # Load the box_bounds from the state. - box_bounds = state.WaveformView1.box_bounds + box_bounds = state.get_view_param('WaveformView', 'box_bounds') w = WaveformView(waveforms=model.waveforms, masks=model.masks, spike_clusters=model.spike_clusters, @@ -290,7 +290,7 @@ def attach_to_gui(self, gui, model=None, state=None): @gui.connect_ def on_close(): # Save the box bounds. - state[w.__name__].box_bounds = w.stacked.box_bounds + state.set_view_params(w, box_bounds=w.boxed.box_bounds) # ----------------------------------------------------------------------------- diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 3543060fd..e445b79f1 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -350,17 +350,27 @@ def _save(name, data): _save_json(path, data) +def _bunchify(b): + """Ensure all dict elements are Bunch.""" + assert isinstance(b, dict) + b = Bunch(b) + for k in b: + if isinstance(b[k], dict): + b[k] = Bunch(b[k]) + return b + + def _load(name): path = _get_path(name) if not op.exists(path): logger.debug("The file `%s` doesn't exist.", path) return - return _load_json(path) + return _bunchify(_load_json(path)) class GUIState(Bunch): def __init__(self, geometry_state=None, plugins=None, **kwargs): - super(GUIState, self).__init__(geomety_state=geometry_state, + super(GUIState, self).__init__(geometry_state=geometry_state, plugins=plugins or [], **kwargs) @@ -368,6 +378,7 @@ def to_json(self, filename): _save_json(filename, self) def from_json(self, filename): + # TODO: remove? self.update(_load_json(filename)) def get_view_param(self, view_name, name): @@ -403,13 +414,14 @@ def create_gui(name=None, model=None, state=None): """ gui = GUI(name=name) state = state or GUIState() + assert isinstance(state, GUIState) plugins = state.plugins # GUI name. name = gui.__name__ # Load the state from disk. state_name = '{}/state'.format(gui.name) - state.update(_load(state_name) or {}) + state.update(_load(state_name) or Bunch()) # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. From f280a3716a0cc5943f36b4e0ca8d79a76c56ad0c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 17:06:54 +0100 Subject: [PATCH 0783/1059] Flakify --- phy/gui/tests/test_gui.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 561f51cf2..89baaf11c 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -17,8 +17,7 @@ _try_get_vispy_canvas, SaveGeometryStatePlugin, ) -from phy.io import Context -from phy.utils import IPlugin, Bunch, _load_json +from phy.utils import IPlugin, Bunch from phy.utils._color import _random_color From d7146c2d34fdd88f1eb7242636cecaaee5169c1a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 19:32:48 +0100 Subject: [PATCH 0784/1059] WIP: update GUI state --- phy/gui/gui.py | 98 ++++++++++++++++++--------------------- phy/gui/tests/test_gui.py | 30 ++++++------ 2 files changed, 57 insertions(+), 71 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index e445b79f1..9cd9bac06 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -17,7 +17,8 @@ Qt, QSize, QMetaObject) from .actions import Actions, _show_shortcuts, Snippets from phy.utils.event import EventEmitter -from phy.utils import (load_master_config, Bunch, _load_json, _save_json, +from phy.utils import (load_master_config, Bunch, _bunchify, + _load_json, _save_json, _ensure_dir_exists, phy_user_dir,) from phy.utils.plugin import get_plugin, IPlugin @@ -339,50 +340,24 @@ def restore_geometry_state(self, gs): # GUI state, creator, plugins # ----------------------------------------------------------------------------- -def _get_path(name): - return op.join(phy_user_dir(), name + '.json') - - -def _save(name, data): - path = _get_path(name) - _ensure_dir_exists(op.dirname(path)) - logger.debug("Save data to `%s`.", path) - _save_json(path, data) - - -def _bunchify(b): - """Ensure all dict elements are Bunch.""" - assert isinstance(b, dict) - b = Bunch(b) - for k in b: - if isinstance(b[k], dict): - b[k] = Bunch(b[k]) - return b - - -def _load(name): - path = _get_path(name) - if not op.exists(path): - logger.debug("The file `%s` doesn't exist.", path) - return - return _bunchify(_load_json(path)) - - class GUIState(Bunch): - def __init__(self, geometry_state=None, plugins=None, **kwargs): - super(GUIState, self).__init__(geometry_state=geometry_state, - plugins=plugins or [], - **kwargs) - - def to_json(self, filename): - _save_json(filename, self) + """Represent the state of the GUI: positions of the views and + all parameters associated to the GUI and views. - def from_json(self, filename): - # TODO: remove? - self.update(_load_json(filename)) + This is automatically loaded from the configuration directory. - def get_view_param(self, view_name, name): - return self.get(view_name + '1', Bunch()).get(name, None) + """ + def __init__(self, name='GUI', config_dir=None, **kwargs): + super(GUIState, self).__init__(**kwargs) + self.name = name + self.config_dir = config_dir or phy_user_dir() + _ensure_dir_exists(op.join(self.config_dir, self.name)) + self.load() + + def get_view_params(self, view_name, *names): + # TODO: how to choose view index + return [self.get(view_name + '1', Bunch()).get(name, None) + for name in names] def set_view_params(self, view, **kwargs): view_name = view if isinstance(view, string_types) else view.__name__ @@ -390,6 +365,25 @@ def set_view_params(self, view, **kwargs): self[view_name] = Bunch() self[view_name].update(kwargs) + @property + def path(self): + return op.join(self.config_dir, self.name, 'state.json') + + def load(self): + """Load the state from the JSON file in the config dir.""" + if not op.exists(self.path): + logger.debug("The GUI state file `%s` doesn't exist.", self.path) + # TODO: create the default state. + return + assert op.exists(self.path) + logger.debug("Load the GUI state from `%s`.", self.path) + self.update(_bunchify(_load_json(self.path))) + + def save(self): + """Save the state to the JSON file in the config dir.""" + logger.debug("Save the GUI state to `%s`.", self.path) + _save_json(self.path, self) + class SaveGeometryStatePlugin(IPlugin): def attach_to_gui(self, gui, state=None, model=None): @@ -405,28 +399,24 @@ def on_show(): gui.restore_geometry_state(gs) -def create_gui(name=None, model=None, state=None): - """Create a GUI with a model and a GUI state. +def create_gui(name=None, model=None, plugins=None, config_dir=None): + """Create a GUI with a model and a list of plugins. By default, the list of plugins is taken from the `c.TheGUI.plugins` parameter, where `TheGUI` is the name of the GUI class. """ gui = GUI(name=name) - state = state or GUIState() - assert isinstance(state, GUIState) - plugins = state.plugins - # GUI name. name = gui.__name__ + plugins = plugins or [] - # Load the state from disk. - state_name = '{}/state'.format(gui.name) - state.update(_load(state_name) or Bunch()) + # Load the state. + state = GUIState(gui.name, config_dir=config_dir) + gui.state = state # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. - config = load_master_config() - plugins_conf = config[name].plugins + plugins_conf = load_master_config()[name].plugins plugins_conf = plugins_conf if isinstance(plugins_conf, list) else [] plugins.extend(plugins_conf) @@ -438,6 +428,6 @@ def create_gui(name=None, model=None, state=None): # Save the state to disk. @gui.connect_ def on_close(): - _save(state_name, state) + state.save() return gui diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 89baaf11c..e72a3d78b 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -17,7 +17,7 @@ _try_get_vispy_canvas, SaveGeometryStatePlugin, ) -from phy.utils import IPlugin, Bunch +from phy.utils import IPlugin, Bunch, _save_json, _ensure_dir_exists from phy.utils._color import _random_color @@ -161,38 +161,34 @@ def on_show(): # Test GUI plugin #------------------------------------------------------------------------------ -def test_gui_state_json(tempdir): - path = op.join(tempdir, 'state.json') - - state = GUIState(hello='world') - state.to_json(path) - - state = GUIState() - state.from_json(path) - assert state.hello == 'world' - - def test_gui_state_view(): view = Bunch(__name__='myview1') state = GUIState() state.set_view_params(view, hello='world') - state.get_view_param('unknown', 'hello') is None - state.get_view_param('myview', 'unknown') is None - state.get_view_param('myview', 'hello') == 'world' + state.get_view_params('unknown', 'hello') == [None] + state.get_view_params('myview', 'unknown') == [None] + state.get_view_params('myview', 'hello') == ['world'] def test_create_gui_1(qapp, tempdir): + _ensure_dir_exists(op.join(tempdir, 'GUI/')) + path = op.join(tempdir, 'GUI/state.json') + _save_json(path, {'hello': 'world'}) + _tmp = [] class MyPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): _tmp.append(state.hello) - gui = create_gui(state=GUIState(plugins=['MyPlugin'], hello='world')) + gui = create_gui(plugins=['MyPlugin'], config_dir=tempdir) assert gui - assert _tmp == ['world'] + gui.state.hello = 'dolly' + gui.state.save() + + assert GUIState(config_dir=tempdir).hello == 'dolly' gui.close() From ab425e80b523f4e5bb298e646390ad1c7b0e5467 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 19:33:15 +0100 Subject: [PATCH 0785/1059] Add _bunchify() function --- phy/utils/__init__.py | 2 +- phy/utils/_types.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index f298945f0..484fc229a 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -5,7 +5,7 @@ from ._misc import _load_json, _save_json from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, - Bunch, _is_list) + Bunch, _is_list, _bunchify) from .event import EventEmitter, ProgressReporter from .plugin import IPlugin, get_plugin, get_all_plugins from .config import _ensure_dir_exists, load_master_config, phy_user_dir diff --git a/phy/utils/_types.py b/phy/utils/_types.py index f68432b39..021b25e01 100644 --- a/phy/utils/_types.py +++ b/phy/utils/_types.py @@ -31,6 +31,16 @@ def copy(self): return Bunch(super(Bunch, self).copy()) +def _bunchify(b): + """Ensure all dict elements are Bunch.""" + assert isinstance(b, dict) + b = Bunch(b) + for k in b: + if isinstance(b[k], dict): + b[k] = Bunch(b[k]) + return b + + def _is_list(obj): return isinstance(obj, list) From f3c8bec409cba2ef0381cb2dabc47002c4097290 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 22:10:21 +0100 Subject: [PATCH 0786/1059] Update views --- phy/cluster/manual/tests/test_views.py | 48 ++++++------- phy/cluster/manual/views.py | 99 ++++++++++++++++---------- 2 files changed, 86 insertions(+), 61 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index d3cb88e5e..eb8451102 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -10,6 +10,7 @@ import numpy as np from numpy.testing import assert_equal as ae +from numpy.testing import assert_allclose as ac from pytest import raises, yield_fixture from phy.io.array import _spikes_per_cluster @@ -65,23 +66,18 @@ def model(): yield model -@yield_fixture(scope='function') -def state(): - state = GUIState() - state.set_view_params('CorrelogramView1', - bin_size=1e-3, - window_size=50e-3, - excerpt_size=8, - n_excerpts=5, - ) - state.n_samples_per_spike = 6 - yield state +@contextmanager +def _test_view(view_name, model=None, tempdir=None): + # Save a test GUI state JSON file in the tempdir. + state = GUIState(config_dir=tempdir) + state.set_view_params('TraceView1', box_size=(1., .01)) + state.set_view_params('FeatureView1', feature_scaling=.5) + state.save() -@contextmanager -def _test_view(view_name, model=None, state=None): - state.plugins = [view_name + 'Plugin'] - gui = create_gui(model=model, state=state) + # Create the GUI. + plugins = [view_name + 'Plugin'] + gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) gui.show() v = gui.list_views(view_name)[0] @@ -140,8 +136,8 @@ def test_selected_clusters_colors(): # Test waveform view #------------------------------------------------------------------------------ -def test_waveform_view(qtbot, model, state): - with _test_view('WaveformView', model=model, state=state) as v: +def test_waveform_view(qtbot, model, tempdir): + with _test_view('WaveformView', model=model, tempdir=tempdir) as v: v.toggle_waveform_overlap() v.toggle_waveform_overlap() @@ -164,22 +160,26 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def test_trace_view_spikes(qtbot, model, state): - with _test_view('TraceView', model=model, state=state) as v: +def test_trace_view_spikes(qtbot, model, tempdir): + with _test_view('TraceView', model=model, tempdir=tempdir) as v: + ac(v.stacked.box_size, (1., .01), atol=1e-2) + v.go_to(.5) v.go_to(-.5) v.go_left() v.go_right() - # qtbot.stop() + # qtbot.stop() #------------------------------------------------------------------------------ # Test feature view #------------------------------------------------------------------------------ -def test_feature_view(qtbot, model, state): - with _test_view('FeatureView', model=model, state=state) as v: +def test_feature_view(qtbot, model, tempdir): + with _test_view('FeatureView', model=model, tempdir=tempdir) as v: + + assert v.feature_scaling == .5 @v.set_best_channels_func def best_channels(cluster_id): @@ -197,8 +197,8 @@ def best_channels(cluster_id): # Test correlogram view #------------------------------------------------------------------------------ -def test_correlogram_view(qtbot, model, state): - with _test_view('CorrelogramView', model=model, state=state) as v: +def test_correlogram_view(qtbot, model, tempdir): + with _test_view('CorrelogramView', model=model, tempdir=tempdir) as v: v.toggle_normalization() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f80811719..b5b98407b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -278,19 +278,19 @@ class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): # NOTE: we assume that the state contains fields for every view. # Load the box_bounds from the state. - box_bounds = state.get_view_param('WaveformView', 'box_bounds') - w = WaveformView(waveforms=model.waveforms, - masks=model.masks, - spike_clusters=model.spike_clusters, - channel_positions=model.channel_positions, - box_bounds=box_bounds, - ) - w.attach(gui) + box_bounds, = state.get_view_params('WaveformView', 'box_bounds') + view = WaveformView(waveforms=model.waveforms, + masks=model.masks, + spike_clusters=model.spike_clusters, + channel_positions=model.channel_positions, + box_bounds=box_bounds, + ) + view.attach(gui) @gui.connect_ def on_close(): # Save the box bounds. - state.set_view_params(w, box_bounds=w.boxed.box_bounds) + state.set_view_params(view, box_bounds=view.boxed.box_bounds) # ----------------------------------------------------------------------------- @@ -516,13 +516,21 @@ def go_left(self): class TraceViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - t = TraceView(traces=model.traces, - sample_rate=model.sample_rate, - spike_times=model.spike_times, - spike_clusters=model.spike_clusters, - masks=model.masks, - ) - t.attach(gui) + view = TraceView(traces=model.traces, + sample_rate=model.sample_rate, + spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + masks=model.masks, + ) + b, = state.get_view_params('TraceView', 'box_size') + if b: + view.stacked.box_size = b + view.attach(gui) + + @gui.connect_ + def on_close(): + # Save the box bounds. + state.set_view_params(view, box_size=view.stacked.box_size) # ----------------------------------------------------------------------------- @@ -780,9 +788,11 @@ def attach(self, gui): def increase_feature_scaling(self): self.feature_scaling *= 1.2 + self.on_select() def decrease_feature_scaling(self): self.feature_scaling /= 1.2 + self.on_select() @property def feature_scaling(self): @@ -791,19 +801,21 @@ def feature_scaling(self): @feature_scaling.setter def feature_scaling(self, value): self._feature_scaling = value - self.on_select() class FeatureViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - f = FeatureView(features=model.features, - masks=model.masks, - spike_clusters=model.spike_clusters, - spike_times=model.spike_times, - ) + view = FeatureView(features=model.features, + masks=model.masks, + spike_clusters=model.spike_clusters, + spike_times=model.spike_times, + ) + fs, = state.get_view_params('FeatureView', 'feature_scaling') + if fs: + view.feature_scaling = fs - @f.set_best_channels_func + @view.set_best_channels_func def best_channels(cluster_id): """Select the best channels for a given cluster.""" # TODO: better perf with cluster stats and cache @@ -813,7 +825,12 @@ def best_channels(cluster_id): uch = unmasked_channels(mean_masks) return sorted_main_channels(mean_masks, uch) - f.attach(gui) + view.attach(gui) + + @gui.connect_ + def on_close(): + # Save the box bounds. + state.set_view_params(view, feature_scaling=view.feature_scaling) # ----------------------------------------------------------------------------- @@ -823,6 +840,8 @@ def best_channels(cluster_id): class CorrelogramView(ManualClusteringView): excerpt_size = 10000 n_excerpts = 100 + bin_size = 1e-3 + window_size = 50e-3 uniform_normalization = False default_shortcuts = { 'go_left': 'alt+left', @@ -842,11 +861,11 @@ def __init__(self, assert sample_rate > 0 self.sample_rate = sample_rate - assert bin_size > 0 - self.bin_size = bin_size + self.bin_size = bin_size or self.bin_size + assert self.bin_size > 0 - assert window_size > 0 - self.window_size = window_size + self.window_size = window_size or self.window_size + assert self.window_size > 0 self.excerpt_size = excerpt_size or self.excerpt_size self.n_excerpts = n_excerpts or self.n_excerpts @@ -930,12 +949,18 @@ def attach(self, gui): class CorrelogramViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - ccg = CorrelogramView(spike_times=model.spike_times, - spike_clusters=model.spike_clusters, - sample_rate=model.sample_rate, - bin_size=state.CorrelogramView1.bin_size, - window_size=state.CorrelogramView1.window_size, - excerpt_size=state.CorrelogramView1.excerpt_size, - n_excerpts=state.CorrelogramView1.n_excerpts, - ) - ccg.attach(gui) + bs, ws, es, ne = state.get_view_params('CorrelogramView', + 'bin_size', + 'window_size', + 'excerpt_size', + 'n_excerpts', + ) + view = CorrelogramView(spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + sample_rate=model.sample_rate, + bin_size=bs, + window_size=ws, + excerpt_size=es, + n_excerpts=ne, + ) + view.attach(gui) From de7c06e3b0b3ffc2a8e151aefafb96d21ccfaeae Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 22:24:13 +0100 Subject: [PATCH 0787/1059] Minor updates --- phy/cluster/manual/gui_component.py | 2 -- phy/cluster/manual/tests/test_views.py | 2 ++ phy/cluster/manual/views.py | 5 ++++- phy/utils/_misc.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index eab2269a3..579b15cdf 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -530,10 +530,8 @@ def attach_to_gui(self, gui, model=None, state=None): # Attach the manual clustering logic (wizard, merge, split, # undo stack) to the GUI. - n = state.n_spikes_max_per_cluster mc = ManualClustering(model.spike_clusters, cluster_groups=model.cluster_groups, - n_spikes_max_per_cluster=n, ) mc.attach(gui) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index eb8451102..95dfe907d 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -71,6 +71,7 @@ def _test_view(view_name, model=None, tempdir=None): # Save a test GUI state JSON file in the tempdir. state = GUIState(config_dir=tempdir) + state.set_view_params('WaveformView1', box_size=(.1, .1)) state.set_view_params('TraceView1', box_size=(1., .01)) state.set_view_params('FeatureView1', feature_scaling=.5) state.save() @@ -138,6 +139,7 @@ def test_selected_clusters_colors(): def test_waveform_view(qtbot, model, tempdir): with _test_view('WaveformView', model=model, tempdir=tempdir) as v: + ac(v.boxed.box_size, (.1, .1), atol=1e-2) v.toggle_waveform_overlap() v.toggle_waveform_overlap() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index b5b98407b..7261317a6 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -283,10 +283,13 @@ def attach_to_gui(self, gui, model=None, state=None): masks=model.masks, spike_clusters=model.spike_clusters, channel_positions=model.channel_positions, - box_bounds=box_bounds, ) view.attach(gui) + b, = state.get_view_params('WaveformView', 'box_size') + if b: + view.boxed.box_size = b + @gui.connect_ def on_close(): # Save the box bounds. diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index ffd427e32..cb5c42ab9 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -98,7 +98,7 @@ def _save_json(path, data): data = _stringify_keys(data) path = op.realpath(op.expanduser(path)) with open(path, 'w') as f: - json.dump(data, f, cls=_CustomEncoder, indent=2) + json.dump(data, f, cls=_CustomEncoder, indent=2, sort_keys=True) #------------------------------------------------------------------------------ From ab93bf2697320cf62132bcd0ca739e902a9f8f1c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 17 Dec 2015 22:26:03 +0100 Subject: [PATCH 0788/1059] Fix --- phy/cluster/manual/views.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 7261317a6..37770b2a3 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -276,9 +276,6 @@ def toggle_waveform_overlap(self): class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - # NOTE: we assume that the state contains fields for every view. - # Load the box_bounds from the state. - box_bounds, = state.get_view_params('WaveformView', 'box_bounds') view = WaveformView(waveforms=model.waveforms, masks=model.masks, spike_clusters=model.spike_clusters, @@ -293,7 +290,7 @@ def attach_to_gui(self, gui, model=None, state=None): @gui.connect_ def on_close(): # Save the box bounds. - state.set_view_params(view, box_bounds=view.boxed.box_bounds) + state.set_view_params(view, box_size=view.boxed.box_size) # ----------------------------------------------------------------------------- From 15f47b18fe6f68cca197cb3687b77ffc237e8799 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 09:10:55 +0100 Subject: [PATCH 0789/1059] Remove some fields in GUI state JSON --- phy/gui/gui.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 9cd9bac06..6b5e0f28d 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -382,7 +382,8 @@ def load(self): def save(self): """Save the state to the JSON file in the config dir.""" logger.debug("Save the GUI state to `%s`.", self.path) - _save_json(self.path, self) + _save_json(self.path, {k: v for k, v in self.items() + if k not in ('config_dir', 'name')}) class SaveGeometryStatePlugin(IPlugin): From a9abb3f4da18715be020e57a91df1150d83d039e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 11:54:14 +0100 Subject: [PATCH 0790/1059] WIP: improve scaling in waveform view --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 99 +++++++++++++++++++++++--- phy/plot/interact.py | 48 ------------- phy/plot/tests/test_interact.py | 34 --------- 4 files changed, 92 insertions(+), 91 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 95dfe907d..840757b85 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -139,7 +139,7 @@ def test_selected_clusters_colors(): def test_waveform_view(qtbot, model, tempdir): with _test_view('WaveformView', model=model, tempdir=tempdir) as v: - ac(v.boxed.box_size, (.1, .1), atol=1e-2) + ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) v.toggle_waveform_overlap() v.toggle_waveform_overlap() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 37770b2a3..8ada49f9f 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -174,9 +174,22 @@ class WaveformView(ManualClusteringView): normalization_percentile = .95 normalization_n_spikes = 1000 overlap = False + box_coeff = 1.1 default_shortcuts = { 'toggle_waveform_overlap': 'o', + + # Box scaling. + 'widen': 'ctrl+right', + 'narrow': 'ctrl+left', + 'increase': 'ctrl+up', + 'decrease': 'ctrl+down', + + # Probe scaling. + 'extend_horizontally': 'shift+right', + 'shrink_horizontally': 'shift+left', + 'extend_vertically': 'shift+up', + 'shrink_vertically': 'shift+down', } def __init__(self, @@ -184,7 +197,8 @@ def __init__(self, masks=None, spike_clusters=None, channel_positions=None, - box_bounds=None, + box_scaling=None, + probe_scaling=None, **kwargs): """ @@ -193,12 +207,23 @@ def __init__(self, """ # Initialize the view. - box_bounds = (_get_boxes(channel_positions) if box_bounds is None - else box_bounds) + box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(layout='boxed', box_bounds=box_bounds, **kwargs) + # Box and probe scaling. + self.box_scaling = np.array(box_scaling if box_scaling is not None + else (1., 1.)) + self.probe_scaling = np.array(probe_scaling + if probe_scaling is not None + else (1., 1.)) + + # Make a copy of the initial box pos and size. We'll apply the scaling + # to these quantities. + self.box_pos = np.array(self.boxed.box_pos) + self.box_size = np.array(self.boxed.box_size) + # Waveforms. assert waveforms.ndim == 3 self.n_spikes, self.n_samples, self.n_channels = waveforms.shape @@ -269,28 +294,86 @@ def attach(self, gui): super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap) + # Box scaling. + self.actions.add(self.widen) + self.actions.add(self.narrow) + self.actions.add(self.increase) + self.actions.add(self.decrease) + + # Probe scaling. + self.actions.add(self.extend_horizontally) + self.actions.add(self.shrink_horizontally) + self.actions.add(self.extend_vertically) + self.actions.add(self.shrink_vertically) + def toggle_waveform_overlap(self): self.overlap = not self.overlap self.on_select() + # Box scaling + # ------------------------------------------------------------------------- + + def _update_box_size(self): + self.boxed.box_size = self.box_size * self.box_scaling + + def widen(self): + self.box_scaling[0] *= self.box_coeff + self._update_box_size() + + def narrow(self): + self.box_scaling[0] /= self.box_coeff + self._update_box_size() + + def increase(self): + self.box_scaling[1] *= self.box_coeff + self._update_box_size() + + def decrease(self): + self.box_scaling[1] /= self.box_coeff + self._update_box_size() + + # Probe scaling + # ------------------------------------------------------------------------- + + def _update_box_pos(self): + self.boxed.box_pos = self.box_pos * self.probe_scaling + + def extend_horizontally(self): + self.probe_scaling[0] *= self.box_coeff + self._update_box_pos() + + def shrink_horizontally(self): + self.probe_scaling[0] /= self.box_coeff + self._update_box_pos() + + def extend_vertically(self): + self.probe_scaling[1] *= self.box_coeff + self._update_box_pos() + + def shrink_vertically(self): + self.probe_scaling[1] /= self.box_coeff + self._update_box_pos() + class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): + bs, ps = state.get_view_params('WaveformView', 'box_scaling', + 'probe_scaling') view = WaveformView(waveforms=model.waveforms, masks=model.masks, spike_clusters=model.spike_clusters, channel_positions=model.channel_positions, + box_scaling=bs, + probe_scaling=ps, ) view.attach(gui) - b, = state.get_view_params('WaveformView', 'box_size') - if b: - view.boxed.box_size = b - @gui.connect_ def on_close(): # Save the box bounds. - state.set_view_params(view, box_size=view.boxed.box_size) + state.set_view_params(view, + box_scaling=view.box_scaling, + probe_scaling=view.probe_scaling) # ----------------------------------------------------------------------------- diff --git a/phy/plot/interact.py b/phy/plot/interact.py index be795c42f..0d270542c 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -172,8 +172,6 @@ def attach(self, canvas): n_boxes); box_bounds = (2 * box_bounds - 1); // See hack in Python. """.format(self.box_var), 'before_transforms') - canvas.connect(self.on_key_press) - canvas.connect(self.on_key_release) def update_program(self, program): # Signal bounds (positions). @@ -217,52 +215,6 @@ def box_size(self, val): self.box_bounds = _get_boxes(self.box_pos, size=val, keep_aspect_ratio=self.keep_aspect_ratio) - # Interaction event callbacks - #-------------------------------------------------------------------------- - - _arrows = ('Left', 'Right', 'Up', 'Down') - _pm = ('+', '-') - - def on_key_press(self, event): - """Handle key press events.""" - key = event.key - - self._key_pressed = key - - ctrl = 'Control' in event.modifiers - shift = 'Shift' in event.modifiers - - # Box scale. - if ctrl and key in self._arrows + self._pm: - coeff = 1.1 - box_size = np.array(self.box_size) - if key == 'Left': - box_size[0] /= coeff - elif key == 'Right': - box_size[0] *= coeff - elif key in ('Down', '-'): - box_size[1] /= coeff - elif key in ('Up', '+'): - box_size[1] *= coeff - self.box_size = box_size - - # Probe scale. - if shift and key in self._arrows: - coeff = 1.1 - box_pos = self.box_pos - if key == 'Left': - box_pos[:, 0] /= coeff - elif key == 'Right': - box_pos[:, 0] *= coeff - elif key == 'Down': - box_pos[:, 1] /= coeff - elif key == 'Up': - box_pos[:, 1] *= coeff - self.box_pos = box_pos - - def on_key_release(self, event): - self._key_pressed = None # pragma: no cover - class Stacked(Boxed): """Stacked interact. diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 414a83472..04cc9915a 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -11,8 +11,6 @@ import numpy as np from numpy.testing import assert_equal as ae -from numpy.testing import assert_allclose as ac -from vispy.util import keys from ..base import BaseVisual from ..interact import Grid, Boxed, Stacked @@ -121,38 +119,6 @@ def test_boxed_1(qtbot, canvas): ae(boxed.box_bounds, b) boxed.box_bounds = b - # Change box vertical size. - bs = boxed.box_size - for k in (('+', '-'), ('Up', 'Down')): - canvas.events.key_press(key=keys.Key(k[0]), modifiers=(keys.CONTROL,)) - assert boxed.box_size[1] > bs[1] - canvas.events.key_press(key=keys.Key(k[1]), modifiers=(keys.CONTROL,)) - ac(boxed.box_size[1], bs[1], atol=1e-3) - - # Change box horizontal size. - bs = boxed.box_size - canvas.events.key_press(key=keys.Key('Left'), modifiers=(keys.CONTROL,)) - assert boxed.box_size[0] < bs[0] - canvas.events.key_press(key=keys.Key('Right'), modifiers=(keys.CONTROL,)) - ac(boxed.box_size[0], bs[0], atol=1e-3) - - # Change box vertical positions. - bp = boxed.box_pos - canvas.events.key_press(key=keys.Key('Up'), modifiers=(keys.SHIFT,)) - assert np.all(np.abs(boxed.box_pos[:, 1]) > np.abs(bp[:, 1])) - canvas.events.key_press(key=keys.Key('Down'), modifiers=(keys.SHIFT,)) - ac(boxed.box_pos, bp, atol=1e-3) - - # Change box horizontal positions. - bp = boxed.box_pos - canvas.events.key_press(key=keys.Key('Left'), modifiers=(keys.SHIFT,)) - assert np.all(np.abs(boxed.box_pos[:, 0]) < np.abs(bp[:, 0])) - canvas.events.key_press(key=keys.Key('Right'), modifiers=(keys.SHIFT,)) - ac(boxed.box_pos, bp, atol=1e-3) - - # Release a key. - canvas.events.key_release(key=keys.Key('Right'), modifiers=(keys.SHIFT,)) - # qtbot.stop() From 84f6a9a7b37fb6ab1c0a136d789f62cc09cf2079 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 12:07:42 +0100 Subject: [PATCH 0791/1059] Fix waveform scaling --- phy/cluster/manual/tests/test_views.py | 22 ++++++++++++++++++++++ phy/cluster/manual/views.py | 25 ++++++++++++------------- phy/plot/interact.py | 10 ++++++++++ phy/plot/tests/test_interact.py | 4 ++++ 4 files changed, 48 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 840757b85..a40d14bd7 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -143,6 +143,28 @@ def test_waveform_view(qtbot, model, tempdir): v.toggle_waveform_overlap() v.toggle_waveform_overlap() + # Box scaling. + bs = v.boxed.box_size + v.increase() + v.decrease() + ac(v.boxed.box_size, bs) + + bs = v.boxed.box_size + v.widen() + v.narrow() + ac(v.boxed.box_size, bs) + + # Probe scaling. + bp = v.boxed.box_pos + v.extend_horizontally() + v.shrink_horizontally() + ac(v.boxed.box_pos, bp) + + bp = v.boxed.box_pos + v.extend_vertically() + v.shrink_vertically() + ac(v.boxed.box_pos, bp) + # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 8ada49f9f..0d2907952 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -223,6 +223,7 @@ def __init__(self, # to these quantities. self.box_pos = np.array(self.boxed.box_pos) self.box_size = np.array(self.boxed.box_size) + self._update_boxes() # Waveforms. assert waveforms.ndim == 3 @@ -313,46 +314,44 @@ def toggle_waveform_overlap(self): # Box scaling # ------------------------------------------------------------------------- - def _update_box_size(self): - self.boxed.box_size = self.box_size * self.box_scaling + def _update_boxes(self): + self.boxed.update_boxes(self.box_pos * self.probe_scaling, + self.box_size * self.box_scaling) def widen(self): self.box_scaling[0] *= self.box_coeff - self._update_box_size() + self._update_boxes() def narrow(self): self.box_scaling[0] /= self.box_coeff - self._update_box_size() + self._update_boxes() def increase(self): self.box_scaling[1] *= self.box_coeff - self._update_box_size() + self._update_boxes() def decrease(self): self.box_scaling[1] /= self.box_coeff - self._update_box_size() + self._update_boxes() # Probe scaling # ------------------------------------------------------------------------- - def _update_box_pos(self): - self.boxed.box_pos = self.box_pos * self.probe_scaling - def extend_horizontally(self): self.probe_scaling[0] *= self.box_coeff - self._update_box_pos() + self._update_boxes() def shrink_horizontally(self): self.probe_scaling[0] /= self.box_coeff - self._update_box_pos() + self._update_boxes() def extend_vertically(self): self.probe_scaling[1] *= self.box_coeff - self._update_box_pos() + self._update_boxes() def shrink_vertically(self): self.probe_scaling[1] /= self.box_coeff - self._update_box_pos() + self._update_boxes() class WaveformViewPlugin(IPlugin): diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 0d270542c..e3117ad8c 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -176,6 +176,7 @@ def attach(self, canvas): def update_program(self, program): # Signal bounds (positions). box_bounds = _get_texture(self._box_bounds, NDC, self.n_boxes, [-1, 1]) + # TODO OPTIM: set the texture at initialization and update the data program['u_box_bounds'] = Texture2D(box_bounds, internalformat='rgba32f') program['n_boxes'] = self.n_boxes @@ -215,6 +216,15 @@ def box_size(self, val): self.box_bounds = _get_boxes(self.box_pos, size=val, keep_aspect_ratio=self.keep_aspect_ratio) + def update_boxes(self, box_pos, box_size): + """Set the box bounds from specified box positions and sizes.""" + assert box_pos.shape == (self.n_boxes, 2) + assert len(box_size) == 2 + self.box_bounds = _get_boxes(box_pos, + size=box_size, + keep_aspect_ratio=self.keep_aspect_ratio, + ) + class Stacked(Boxed): """Stacked interact. diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 04cc9915a..5cd1c5dc7 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -11,6 +11,7 @@ import numpy as np from numpy.testing import assert_equal as ae +from numpy.testing import assert_allclose as ac from ..base import BaseVisual from ..interact import Grid, Boxed, Stacked @@ -119,6 +120,9 @@ def test_boxed_1(qtbot, canvas): ae(boxed.box_bounds, b) boxed.box_bounds = b + boxed.update_boxes(boxed.box_pos, boxed.box_size) + ac(boxed.box_bounds, b) + # qtbot.stop() From b05eb772dad071aa8b69758f47c3dd10a7970f07 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 12:10:38 +0100 Subject: [PATCH 0792/1059] Minor update --- phy/cluster/manual/views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 0d2907952..f76778804 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -371,8 +371,8 @@ def attach_to_gui(self, gui, model=None, state=None): def on_close(): # Save the box bounds. state.set_view_params(view, - box_scaling=view.box_scaling, - probe_scaling=view.probe_scaling) + box_scaling=tuple(view.box_scaling), + probe_scaling=tuple(view.probe_scaling)) # ----------------------------------------------------------------------------- From 628f3baf881a4500136f087f673a05e21532c122 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 12:20:03 +0100 Subject: [PATCH 0793/1059] WIP: scaling in trace view --- phy/cluster/manual/tests/test_views.py | 6 +-- phy/cluster/manual/views.py | 72 ++++++++++++++++++++------ 2 files changed, 58 insertions(+), 20 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index a40d14bd7..6195915cc 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -186,7 +186,7 @@ def test_trace_view_no_spikes(qtbot): def test_trace_view_spikes(qtbot, model, tempdir): with _test_view('TraceView', model=model, tempdir=tempdir) as v: - ac(v.stacked.box_size, (1., .01), atol=1e-2) + ac(v.stacked.box_size, (1., .08181), atol=1e-3) v.go_to(.5) v.go_to(-.5) @@ -211,8 +211,8 @@ def best_channels(cluster_id): v.add_attribute('sine', np.sin(np.linspace(-10., 10., model.n_spikes))) - v.increase_feature_scaling() - v.decrease_feature_scaling() + v.increase() + v.decrease() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f76778804..16b2f1f3b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -174,7 +174,7 @@ class WaveformView(ManualClusteringView): normalization_percentile = .95 normalization_n_spikes = 1000 overlap = False - box_coeff = 1.1 + scaling_coeff = 1.1 default_shortcuts = { 'toggle_waveform_overlap': 'o', @@ -319,38 +319,38 @@ def _update_boxes(self): self.box_size * self.box_scaling) def widen(self): - self.box_scaling[0] *= self.box_coeff + self.box_scaling[0] *= self.scaling_coeff self._update_boxes() def narrow(self): - self.box_scaling[0] /= self.box_coeff + self.box_scaling[0] /= self.scaling_coeff self._update_boxes() def increase(self): - self.box_scaling[1] *= self.box_coeff + self.box_scaling[1] *= self.scaling_coeff self._update_boxes() def decrease(self): - self.box_scaling[1] /= self.box_coeff + self.box_scaling[1] /= self.scaling_coeff self._update_boxes() # Probe scaling # ------------------------------------------------------------------------- def extend_horizontally(self): - self.probe_scaling[0] *= self.box_coeff + self.probe_scaling[0] *= self.scaling_coeff self._update_boxes() def shrink_horizontally(self): - self.probe_scaling[0] /= self.box_coeff + self.probe_scaling[0] /= self.scaling_coeff self._update_boxes() def extend_vertically(self): - self.probe_scaling[1] *= self.box_coeff + self.probe_scaling[1] *= self.scaling_coeff self._update_boxes() def shrink_vertically(self): - self.probe_scaling[1] /= self.box_coeff + self.probe_scaling[1] /= self.scaling_coeff self._update_boxes() @@ -382,9 +382,12 @@ def on_close(): class TraceView(ManualClusteringView): interval_duration = .5 # default duration of the interval shift_amount = .1 + scaling_coeff = 1.1 default_shortcuts = { 'go_left': 'alt+left', 'go_right': 'alt+right', + 'increase': 'alt+up', + 'decrease': 'alt+down', } def __init__(self, @@ -394,6 +397,7 @@ def __init__(self, spike_clusters=None, masks=None, n_samples_per_spike=None, + scaling=None, **kwargs): # Sample rate. @@ -434,10 +438,20 @@ def __init__(self, super(TraceView, self).__init__(layout='stacked', n_plots=self.n_channels, **kwargs) + # Box and probe scaling. + self.scaling = scaling or 1. + + # Make a copy of the initial box pos and size. We'll apply the scaling + # to these quantities. + self.box_size = np.array(self.stacked.box_size) + self._update_boxes() # Initial interval. self.set_interval((0., self.interval_duration)) + # Internal methods + # ------------------------------------------------------------------------- + def _load_traces(self, interval): """Load traces in an interval (in seconds).""" @@ -524,6 +538,9 @@ def _restrict_interval(self, interval): assert 0 <= start < end <= self.duration return start, end + # Public methods + # ------------------------------------------------------------------------- + def set_interval(self, interval): """Display the traces and spikes in a given interval.""" self.clear() @@ -575,6 +592,11 @@ def attach(self, gui): self.actions.add(self.shift, alias='ts') self.actions.add(self.go_right) self.actions.add(self.go_left) + self.actions.add(self.increase) + self.actions.add(self.decrease) + + # Navigation + # ------------------------------------------------------------------------- def go_to(self, time): start, end = self.interval @@ -595,24 +617,40 @@ def go_left(self): delay = (end - start) * .2 self.shift(-delay) + # Channel scaling + # ------------------------------------------------------------------------- + + # TODO: ctrl+alt+left/right to increase duration + # TODO: current interval, current central time + + def _update_boxes(self): + self.stacked.box_size = self.box_size * self.scaling + + def increase(self): + self.scaling *= self.scaling_coeff + self._update_boxes() + + def decrease(self): + self.scaling /= self.scaling_coeff + self._update_boxes() + class TraceViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): + s, = state.get_view_params('TraceView', 'scaling') view = TraceView(traces=model.traces, sample_rate=model.sample_rate, spike_times=model.spike_times, spike_clusters=model.spike_clusters, masks=model.masks, + scaling=s, ) - b, = state.get_view_params('TraceView', 'box_size') - if b: - view.stacked.box_size = b view.attach(gui) @gui.connect_ def on_close(): # Save the box bounds. - state.set_view_params(view, box_size=view.stacked.box_size) + state.set_view_params(view, scaling=view.scaling) # ----------------------------------------------------------------------------- @@ -865,14 +903,14 @@ def on_select(self, cluster_ids=None, spike_ids=None): def attach(self, gui): """Attach the view to the GUI.""" super(FeatureView, self).attach(gui) - self.actions.add(self.increase_feature_scaling) - self.actions.add(self.decrease_feature_scaling) + self.actions.add(self.increase) + self.actions.add(self.decrease) - def increase_feature_scaling(self): + def increase(self): self.feature_scaling *= 1.2 self.on_select() - def decrease_feature_scaling(self): + def decrease(self): self.feature_scaling /= 1.2 self.on_select() From 88ff48d05342335d266155611832a6a0e9fdc26c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 13:09:56 +0100 Subject: [PATCH 0794/1059] Add actions in trace view --- phy/cluster/manual/tests/test_views.py | 27 +++++++++++++++++ phy/cluster/manual/views.py | 41 +++++++++++++++++++++----- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 6195915cc..1d7cfe162 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -53,6 +53,7 @@ def model(): model.n_channels = n_channels model.n_spikes = n_spikes model.sample_rate = 20000. + model.duration = n_samples_t / float(model.sample_rate) model.spike_times = artificial_spike_samples(n_spikes) * 1. model.spike_times /= model.spike_times[-1] model.spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) @@ -187,11 +188,37 @@ def test_trace_view_no_spikes(qtbot): def test_trace_view_spikes(qtbot, model, tempdir): with _test_view('TraceView', model=model, tempdir=tempdir) as v: ac(v.stacked.box_size, (1., .08181), atol=1e-3) + assert v.time == .25 v.go_to(.5) + assert v.time == .5 + v.go_to(-.5) + assert v.time == .25 + v.go_left() + assert v.time == .25 + v.go_right() + assert v.time == .35 + + # Change interval size. + v.set_interval((.25, .75)) + ac(v.interval, (.25, .75)) + v.widen() + ac(v.interval, (.225, .775)) + v.narrow() + ac(v.interval, (.25, .75)) + + # Widen the max interval. + v.set_interval((0, model.duration)) + v.widen() + + # Change channel scaling. + bs = v.stacked.box_size + v.increase() + v.decrease() + ac(v.stacked.box_size, bs, atol=1e-3) # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 16b2f1f3b..91f4d8d56 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -386,8 +386,10 @@ class TraceView(ManualClusteringView): default_shortcuts = { 'go_left': 'alt+left', 'go_right': 'alt+right', - 'increase': 'alt+up', 'decrease': 'alt+down', + 'increase': 'alt+up', + 'widen': 'ctrl+alt+left', + 'narrow': 'ctrl+alt+right', } def __init__(self, @@ -594,35 +596,60 @@ def attach(self, gui): self.actions.add(self.go_left) self.actions.add(self.increase) self.actions.add(self.decrease) + self.actions.add(self.widen) + self.actions.add(self.narrow) # Navigation # ------------------------------------------------------------------------- + @property + def time(self): + """Time at the center of the window.""" + return sum(self.interval) * .5 + + @property + def half_duration(self): + """Half of the duration of the current interval.""" + a, b = self.interval + return (b - a) * .5 + def go_to(self, time): + """Go to a specific time (in seconds).""" start, end = self.interval - half_dur = (end - start) * .5 + half_dur = self.half_duration self.set_interval((time - half_dur, time + half_dur)) def shift(self, delay): - time = sum(self.interval) * .5 - self.go_to(time + delay) + """Shift the interval by a given delay (in seconds).""" + self.go_to(self.time + delay) def go_right(self): + """Go to right.""" start, end = self.interval delay = (end - start) * .2 self.shift(delay) def go_left(self): + """Go to left.""" start, end = self.interval delay = (end - start) * .2 self.shift(-delay) + def widen(self): + """Increase the interval size.""" + t, h = self.time, self.half_duration + h *= self.scaling_coeff + self.set_interval((t - h, t + h)) + + def narrow(self): + """Decrease the interval size.""" + t, h = self.time, self.half_duration + h /= self.scaling_coeff + self.set_interval((t - h, t + h)) + # Channel scaling # ------------------------------------------------------------------------- - # TODO: ctrl+alt+left/right to increase duration - # TODO: current interval, current central time - def _update_boxes(self): self.stacked.box_size = self.box_size * self.scaling From 2aa18715a8a154579e8dd5741364e7a264daf964 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 13:12:42 +0100 Subject: [PATCH 0795/1059] Fix --- phy/cluster/manual/views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 91f4d8d56..68eafab07 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -756,8 +756,8 @@ class FeatureView(ManualClusteringView): _feature_scaling = 1. default_shortcuts = { - 'increase_feature_scaling': 'ctrl++', - 'decrease_feature_scaling': 'ctrl+-', + 'increase': 'ctrl++', + 'decrease': 'ctrl+-', } def __init__(self, From f359fb4a418cfbf652dc89dbef8327d53afe616d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 13:44:31 +0100 Subject: [PATCH 0796/1059] Add Actions names --- phy/cluster/manual/gui_component.py | 3 ++- phy/cluster/manual/views.py | 4 +++- phy/gui/actions.py | 14 +++++++++----- phy/gui/gui.py | 12 +++++------- phy/gui/tests/test_actions.py | 11 +++++++++++ 5 files changed, 30 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 579b15cdf..02ab541bb 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -260,7 +260,8 @@ def similarity(cluster_id): self.similarity_view.add_column(similarity) def _create_actions(self, gui): - self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions = Actions(gui, name='Manual clustering', + default_shortcuts=self.shortcuts) # Selection. self.actions.add(self.select, alias='c', menu='&Cluster') diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 68eafab07..10fab0c78 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -161,7 +161,9 @@ def attach(self, gui): gui.add_view(self) gui.connect_(self.on_select) - self.actions = Actions(gui, default_shortcuts=self.shortcuts) + self.actions = Actions(gui, + name=self.__name__, + default_shortcuts=self.shortcuts) self.show() diff --git a/phy/gui/actions.py b/phy/gui/actions.py index b3a3ff394..23feb07cc 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -92,7 +92,7 @@ def _get_qkeysequence(shortcut): def _show_shortcuts(shortcuts, name=None): """Display shortcuts.""" name = name or '' - print() + print('') if name: name = ' for ' + name print('Keyboard shortcuts' + name) @@ -100,7 +100,6 @@ def _show_shortcuts(shortcuts, name=None): shortcut = _get_shortcut_string(shortcuts[name]) if not name.startswith('_'): print('{0:<40}: {1:s}'.format(name, shortcut)) - print() # ----------------------------------------------------------------------------- @@ -139,10 +138,11 @@ class Actions(object): * Display all shortcuts """ - def __init__(self, gui, default_shortcuts=None): + def __init__(self, gui, name=None, default_shortcuts=None): self._actions_dict = {} self._aliases = {} self._default_shortcuts = default_shortcuts or {} + self.name = name self.gui = gui gui.actions.append(self) @@ -242,7 +242,11 @@ def shortcuts(self): def show_shortcuts(self): """Print all shortcuts.""" - _show_shortcuts(self.shortcuts, self.gui.windowTitle()) + gui_name = self.gui.name + actions_name = self.name + name = ('{} - {}'.format(gui_name, actions_name) + if actions_name else gui_name) + _show_shortcuts(self.shortcuts, name) def __contains__(self, name): return name in self._actions_dict @@ -292,7 +296,7 @@ def __init__(self, gui): self.gui = gui self._status_message = gui.status_message - self.actions = Actions(gui) + self.actions = Actions(gui, name='Snippets') # Register snippet mode shortcut. @self.actions.add(shortcut=':') diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 6b5e0f28d..66ec1236e 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -15,7 +15,7 @@ from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) -from .actions import Actions, _show_shortcuts, Snippets +from .actions import Actions, Snippets from phy.utils.event import EventEmitter from phy.utils import (load_master_config, Bunch, _bunchify, _load_json, _save_json, @@ -175,19 +175,17 @@ def _set_pos_size(self, position, size): self.resize(QSize(size[0], size[1])) def _set_default_actions(self): - self.default_actions = Actions(self) + self.default_actions = Actions(self, name='Default') @self.default_actions.add(shortcut='ctrl+q', menu='&File') def exit(): self.close() @self.default_actions.add(shortcut=('HelpContents', 'h'), - menu='&Help') - def show_shortcuts(): - shortcuts = self.default_actions.shortcuts + menu='&File') + def show_all_shortcuts(): for actions in self.actions: - shortcuts.update(actions.shortcuts) - _show_shortcuts(shortcuts, self.__name__) + actions.show_shortcuts() # Events # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py index c93af619d..4077e6595 100644 --- a/phy/gui/tests/test_actions.py +++ b/phy/gui/tests/test_actions.py @@ -126,8 +126,19 @@ def press(): actions.press() assert _press == [0] + # Show action shortcuts. + with captured_output() as (stdout, stderr): + actions.show_shortcuts() + assert 'g\n' in stdout.getvalue() + + # Show default action shortcuts. with captured_output() as (stdout, stderr): gui.default_actions.show_shortcuts() + assert 'q\n' in stdout.getvalue() + + # Show all action shortcuts. + with captured_output() as (stdout, stderr): + gui.default_actions.show_all_shortcuts() assert 'g\n' in stdout.getvalue() From 0374f26ea3d1369f47d22c15c41c8b52a60ba5c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 13:51:55 +0100 Subject: [PATCH 0797/1059] Default menu in actions --- phy/cluster/manual/gui_component.py | 30 +++++++++++++++-------------- phy/cluster/manual/views.py | 3 ++- phy/gui/actions.py | 8 +++++--- phy/gui/gui.py | 13 ++++++------- 4 files changed, 29 insertions(+), 25 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 02ab541bb..4219ff49c 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -260,34 +260,36 @@ def similarity(cluster_id): self.similarity_view.add_column(similarity) def _create_actions(self, gui): - self.actions = Actions(gui, name='Manual clustering', + self.actions = Actions(gui, + name='Clustering', + menu='&Clustering', default_shortcuts=self.shortcuts) # Selection. - self.actions.add(self.select, alias='c', menu='&Cluster') - self.actions.separator('&Cluster') + self.actions.add(self.select, alias='c') + self.actions.separator() # Clustering. - self.actions.add(self.merge, alias='g', menu='&Cluster') - self.actions.add(self.split, alias='k', menu='&Cluster') - self.actions.separator('&Cluster') + self.actions.add(self.merge, alias='g') + self.actions.add(self.split, alias='k') + self.actions.separator() # Move. self.actions.add(self.move) for group in ('noise', 'mua', 'good'): self.actions.add(partial(self.move_best, group), - name='move_best_to_' + group, menu='&Cluster') + name='move_best_to_' + group) self.actions.add(partial(self.move_similar, group), - name='move_similar_to_' + group, menu='&Cluster') + name='move_similar_to_' + group) self.actions.add(partial(self.move_all, group), - name='move_all_to_' + group, menu='&Cluster') - self.actions.separator('&Cluster') + name='move_all_to_' + group) + self.actions.separator() # Others. - self.actions.add(self.undo, menu='&Cluster') - self.actions.add(self.redo, menu='&Cluster') - self.actions.add(self.save, menu='&Cluster') + self.actions.add(self.undo) + self.actions.add(self.redo) + self.actions.add(self.save) # Wizard. self.actions.add(self.reset, menu='&Wizard') @@ -295,7 +297,7 @@ def _create_actions(self, gui): self.actions.add(self.previous, menu='&Wizard') self.actions.add(self.next_best, menu='&Wizard') self.actions.add(self.previous_best, menu='&Wizard') - self.actions.separator('&Cluster') + self.actions.separator() def _create_cluster_views(self): # Create the cluster view. diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 10fab0c78..e9866265a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -162,7 +162,8 @@ def attach(self, gui): gui.add_view(self) gui.connect_(self.on_select) self.actions = Actions(gui, - name=self.__name__, + name=self.__class__.__name__, + menu=self.__class__.__name__, default_shortcuts=self.shortcuts) self.show() diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 23feb07cc..ffcd657b2 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -138,11 +138,12 @@ class Actions(object): * Display all shortcuts """ - def __init__(self, gui, name=None, default_shortcuts=None): + def __init__(self, gui, name=None, menu=None, default_shortcuts=None): self._actions_dict = {} self._aliases = {} self._default_shortcuts = default_shortcuts or {} self.name = name + self.menu = menu self.gui = gui gui.actions.append(self) @@ -175,6 +176,7 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, _get_shortcut_string(action.shortcut())) self.gui.addAction(action) # Add the action to the menu. + menu = menu or self.menu if menu: self.gui.get_menu(menu).addAction(action) self._actions_dict[name] = action_obj @@ -185,9 +187,9 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, if callback: setattr(self, name, callback) - def separator(self, menu): + def separator(self, menu=None): """Add a separator""" - self.gui.get_menu(menu).addSeparator() + self.gui.get_menu(menu or self.menu).addSeparator() def disable(self, name=None): """Disable one or all actions.""" diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 66ec1236e..edf9dc87d 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -175,18 +175,17 @@ def _set_pos_size(self, position, size): self.resize(QSize(size[0], size[1])) def _set_default_actions(self): - self.default_actions = Actions(self, name='Default') + self.default_actions = Actions(self, name='Default', menu='&File') - @self.default_actions.add(shortcut='ctrl+q', menu='&File') - def exit(): - self.close() - - @self.default_actions.add(shortcut=('HelpContents', 'h'), - menu='&File') + @self.default_actions.add(shortcut=('HelpContents', 'h')) def show_all_shortcuts(): for actions in self.actions: actions.show_shortcuts() + @self.default_actions.add(shortcut='ctrl+q') + def exit(): + self.close() + # Events # ------------------------------------------------------------------------- From 93714066999510c6639c5ec2636571fae90f187b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 14:18:16 +0100 Subject: [PATCH 0798/1059] Add Selector class --- phy/io/array.py | 31 +++++++++++++++++++++++++++++++ phy/io/tests/test_array.py | 10 ++++++++++ 2 files changed, 41 insertions(+) diff --git a/phy/io/array.py b/phy/io/array.py index b2d1d1700..fdf3a3411 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -418,3 +418,34 @@ def select_spikes(cluster_ids=None, spikes = spikes_per_cluster[cluster] selection[cluster] = regular_subset(spikes, n_spikes_max=n) return _flatten_per_cluster(selection) + + +class Selector(object): + """This object is passed with the `select` event when clusters are + selected. It allows to make selections of spikes.""" + def __init__(self, + spike_clusters=None, + spikes_per_cluster=None, + spike_ids=None, + ): + self.spike_clusters = spike_clusters + self.spikes_per_cluster = spikes_per_cluster + self.n_spikes = len(spike_clusters) + self.spike_ids = (np.asarray(spike_ids) if spike_ids is not None + else np.arange(self.n_spikes)) + + def select_spikes(self, cluster_ids=None, + max_n_spikes_per_cluster=None): + if cluster_ids is None or not len(cluster_ids): + return None + ns = max_n_spikes_per_cluster + assert len(cluster_ids) >= 1 + # Select all spikes belonging to the cluster. + if ns is None: + spikes_rel = _spikes_in_clusters(self.spike_clusters, cluster_ids) + return (self.spike_ids[spikes_rel] + if self.spike_ids is not None else spikes_rel) + # Select a subset of the spikes. + return select_spikes(cluster_ids, + spikes_per_cluster=self.spikes_per_cluster, + max_n_spikes_per_cluster=ns) diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index b4d2fa9de..dab2837d5 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -19,6 +19,7 @@ _spikes_per_cluster, _flatten_per_cluster, select_spikes, + Selector, chunk_bounds, regular_subset, excerpts, @@ -355,6 +356,7 @@ def test_select_spikes(): with raises(AssertionError): select_spikes() spikes = [2, 3, 5, 7, 11] + sc = [2, 3, 3, 2, 2] spc = {2: [2, 7, 11], 3: [3, 5], 5: []} ae(select_spikes([], spikes_per_cluster=spc), []) ae(select_spikes([2, 3, 5], spikes_per_cluster=spc), spikes) @@ -364,3 +366,11 @@ def test_select_spikes(): ae(select_spikes([2, 3, 5], None, spikes_per_cluster=spc), spikes) ae(select_spikes([2, 3, 5], 1, spikes_per_cluster=spc), [2, 3]) ae(select_spikes([2, 5], 2, spikes_per_cluster=spc), [2]) + + sel = Selector(spike_clusters=sc, + spikes_per_cluster=spc, + spike_ids=spikes, + ) + assert sel.select_spikes() is None + ae(sel.select_spikes([2, 5]), spc[2]) + ae(sel.select_spikes([2, 5], 2), [2]) From c3c82b34bf467045a93ae19ea6a6690235aff6ea Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 14:22:10 +0100 Subject: [PATCH 0799/1059] WIP: use Selector in select event --- phy/cluster/manual/gui_component.py | 31 ++++++++++--------- .../manual/tests/test_gui_component.py | 2 +- phy/io/array.py | 2 +- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4219ff49c..18e671670 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -21,7 +21,7 @@ ) from phy.gui.actions import Actions from phy.gui.widgets import Table -from phy.io.array import select_spikes +from phy.io.array import select_spikes, Selector from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -52,13 +52,15 @@ def default_wizard_functions(waveforms=None, masks=None, n_features_per_channel=None, spikes_per_cluster=None, + max_n_spikes_per_cluster=1000, ): spc = spikes_per_cluster nfc = n_features_per_channel + maxn = max_n_spikes_per_cluster def max_waveform_amplitude_quality(cluster): spike_ids = select_spikes(cluster_ids=[cluster], - max_n_spikes_per_cluster=100, + max_n_spikes_per_cluster=maxn, spikes_per_cluster=spc, ) m = np.atleast_2d(masks[spike_ids]) @@ -72,11 +74,11 @@ def max_waveform_amplitude_quality(cluster): def mean_masked_features_similarity(c0, c1): s0 = select_spikes(cluster_ids=[c0], - max_n_spikes_per_cluster=100, + max_n_spikes_per_cluster=maxn, spikes_per_cluster=spc, ) s1 = select_spikes(cluster_ids=[c1], - max_n_spikes_per_cluster=100, + max_n_spikes_per_cluster=maxn, spikes_per_cluster=spc, ) @@ -131,7 +133,6 @@ class ManualClustering(object): spike_clusters : ndarray cluster_groups : dictionary - n_spikes_max_per_cluster : int shortcuts : dict GUI events @@ -140,7 +141,7 @@ class ManualClustering(object): When this component is attached to a GUI, the GUI emits the following events: - select(cluster_ids, spike_ids) + select(cluster_ids, selector) when clusters are selected cluster(up) when a merge or split happens @@ -184,12 +185,10 @@ class ManualClustering(object): def __init__(self, spike_clusters, cluster_groups=None, - n_spikes_max_per_cluster=100, shortcuts=None, ): self.gui = None - self.n_spikes_max_per_cluster = n_spikes_max_per_cluster # Load default shortcuts, and override with any user shortcuts. self.shortcuts = self.default_shortcuts.copy() @@ -201,6 +200,13 @@ def __init__(self, self._global_history = GlobalHistory(process_ups=_process_ups) self._register_logging() + # Create the spike selector. + sc = self.clustering.spike_clusters + spc = self.clustering.spikes_per_cluster + self.selector = Selector(spike_clusters=sc, + spikes_per_cluster=spc, + ) + # Create the cluster views. self._create_cluster_views() self._add_default_columns() @@ -351,14 +357,9 @@ def _update_similarity_view(self): def _emit_select(self, cluster_ids): """Choose spikes from the specified clusters and emit the `select` event on the GUI.""" - # Choose a spike subset. - spike_ids = select_spikes(np.array(cluster_ids), - self.n_spikes_max_per_cluster, - self.clustering.spikes_per_cluster) - logger.debug("Select clusters: %s (%d spikes).", - ', '.join(map(str, cluster_ids)), len(spike_ids)) + logger.debug("Select clusters: %s.", ', '.join(map(str, cluster_ids))) if self.gui: - self.gui.emit('select', cluster_ids, spike_ids) + self.gui.emit('select', cluster_ids, self.selector) # Public methods # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index e661c46e4..41189b09a 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -73,7 +73,7 @@ def test_manual_clustering_plugin(qtbot, gui): features=np.zeros((3, 1, 2)), masks=np.zeros((3, 1)), ) - state = Bunch(n_spikes_max_per_cluster=10) + state = Bunch() ManualClusteringPlugin().attach_to_gui(gui, model=model, state=state) diff --git a/phy/io/array.py b/phy/io/array.py index fdf3a3411..57f0a95f3 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -432,7 +432,7 @@ def __init__(self, self.spikes_per_cluster = spikes_per_cluster self.n_spikes = len(spike_clusters) self.spike_ids = (np.asarray(spike_ids) if spike_ids is not None - else np.arange(self.n_spikes)) + else None) def select_spikes(self, cluster_ids=None, max_n_spikes_per_cluster=None): From 659c20e9358aea6284015ce3861ee2fb1bc0eb1b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 14:31:13 +0100 Subject: [PATCH 0800/1059] Update views with selector --- phy/cluster/manual/tests/test_views.py | 12 ++++---- phy/cluster/manual/views.py | 38 ++++++++++++++++---------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 1d7cfe162..9d49caec3 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -13,7 +13,7 @@ from numpy.testing import assert_allclose as ac from pytest import raises, yield_fixture -from phy.io.array import _spikes_per_cluster +from phy.io.array import _spikes_per_cluster, Selector from phy.io.mock import (artificial_waveforms, artificial_features, artificial_spike_clusters, @@ -87,12 +87,14 @@ def _test_view(view_name, model=None, tempdir=None): # Select some spikes. spike_ids = np.arange(10) cluster_ids = np.unique(model.spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) + v.on_select(cluster_ids, spike_ids=spike_ids) # Select other spikes. - spike_ids = np.arange(2, 10) - cluster_ids = np.unique(model.spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids) + cluster_ids = [0, 2] + sel = Selector(spike_clusters=model.spike_clusters, + spikes_per_cluster=model.spikes_per_cluster, + ) + v.on_select(cluster_ids, selector=sel) yield v diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e9866265a..dc30ba2b6 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -130,6 +130,7 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): # ----------------------------------------------------------------------------- class ManualClusteringView(View): + max_n_spikes_per_cluster = None default_shortcuts = { } @@ -144,13 +145,18 @@ def __init__(self, shortcuts=None, **kwargs): super(ManualClusteringView, self).__init__(**kwargs) - def on_select(self, cluster_ids=None, spike_ids=None): - cluster_ids = (cluster_ids if cluster_ids is not None - else self.cluster_ids) - spike_ids = (spike_ids if spike_ids is not None - else self.spike_ids) - self.cluster_ids = list(cluster_ids) - self.spike_ids = spike_ids + def on_select(self, cluster_ids=None, selector=None, spike_ids=None): + cluster_ids = list(cluster_ids if cluster_ids is not None + else self.cluster_ids) + if spike_ids is None: + # Use the selector to select some or all of the spikes. + if selector: + ns = self.max_n_spikes_per_cluster + spike_ids = selector.select_spikes(cluster_ids, ns) + else: + spike_ids = self.spike_ids + self.cluster_ids = cluster_ids + self.spike_ids = np.asarray(spike_ids) def attach(self, gui): """Attach the view to the GUI.""" @@ -174,6 +180,7 @@ def attach(self, gui): # ----------------------------------------------------------------------------- class WaveformView(ManualClusteringView): + max_n_spikes_per_cluster = 100 normalization_percentile = .95 normalization_n_spikes = 1000 overlap = False @@ -249,9 +256,9 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - def on_select(self, cluster_ids=None, spike_ids=None): + def on_select(self, cluster_ids=None, **kwargs): super(WaveformView, self).on_select(cluster_ids=cluster_ids, - spike_ids=spike_ids) + **kwargs) cluster_ids, spike_ids = self.cluster_ids, self.spike_ids n_clusters = len(cluster_ids) n_spikes = len(spike_ids) @@ -585,9 +592,9 @@ def set_interval(self, interval): self.build() self.update() - def on_select(self, cluster_ids=None, spike_ids=None): + def on_select(self, cluster_ids=None, **kwargs): super(TraceView, self).on_select(cluster_ids=cluster_ids, - spike_ids=spike_ids) + **kwargs) self.set_interval(self.interval) def attach(self, gui): @@ -753,6 +760,7 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(ManualClusteringView): + max_n_spikes_per_cluster = 100000 normalization_percentile = .95 normalization_n_spikes = 1000 _default_marker_size = 3. @@ -892,9 +900,9 @@ def set_best_channels_func(self, func): """Set a function `cluster_id => list of best channels`.""" self.best_channels_func = func - def on_select(self, cluster_ids=None, spike_ids=None): + def on_select(self, cluster_ids=None, **kwargs): super(FeatureView, self).on_select(cluster_ids=cluster_ids, - spike_ids=spike_ids) + **kwargs) cluster_ids, spike_ids = self.cluster_ids, self.spike_ids n_spikes = len(spike_ids) if n_spikes == 0: @@ -1061,9 +1069,9 @@ def _compute_correlograms(self, cluster_ids): return ccg - def on_select(self, cluster_ids=None, spike_ids=None): + def on_select(self, cluster_ids=None, **kwargs): super(CorrelogramView, self).on_select(cluster_ids=cluster_ids, - spike_ids=spike_ids) + **kwargs) cluster_ids, spike_ids = self.cluster_ids, self.spike_ids n_clusters = len(cluster_ids) n_spikes = len(spike_ids) From 98c732a4b7f922e1d7d5a9121424a7b4b8604a9f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 14:48:49 +0100 Subject: [PATCH 0801/1059] Minor fixes --- phy/cluster/manual/gui_component.py | 6 +++++- phy/io/context.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 18e671670..2678eed4f 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -68,6 +68,7 @@ def max_waveform_amplitude_quality(cluster): mean_masks = mean(m) mean_waveforms = mean(w) q = max_waveform_amplitude(mean_masks, mean_waveforms) + q = np.asscalar(q) logger.debug("Computed cluster quality for %d: %.3f.", cluster, q) return q @@ -359,7 +360,10 @@ def _emit_select(self, cluster_ids): `select` event on the GUI.""" logger.debug("Select clusters: %s.", ', '.join(map(str, cluster_ids))) if self.gui: - self.gui.emit('select', cluster_ids, self.selector) + self.gui.emit('select', + cluster_ids=cluster_ids, + selector=self.selector, + ) # Public methods # ------------------------------------------------------------------------- diff --git a/phy/io/context.py b/phy/io/context.py index 66881df3f..dbbbef9da 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -157,7 +157,10 @@ def _set_memory(self, cache_dir): # Try importing joblib. try: from joblib import Memory - self._memory = Memory(cachedir=self.cache_dir, verbose=0) + self._memory = Memory(cachedir=self.cache_dir, + mmap_mode=None, + verbose=0, + ) logger.debug("Initialize joblib cache dir at `%s`.", self.cache_dir) except ImportError: # pragma: no cover From 2b7cec20b1cf54ae53c4ecfc81566e488c8cc973 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 14:59:48 +0100 Subject: [PATCH 0802/1059] GUI subtitle --- phy/gui/gui.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index edf9dc87d..c927e0fd1 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -125,6 +125,7 @@ def __init__(self, position=None, size=None, name=None, + subtitle=None, ): # HACK to ensure that closeEvent is called only twice (seems like a # Qt bug). @@ -138,7 +139,7 @@ def __init__(self, QMainWindow.AnimatedDocks ) - self._set_name(name) + self._set_name(name, subtitle) self._set_pos_size(position, size) # Mapping {name: menuBar}. @@ -160,13 +161,14 @@ def __init__(self, # Create and attach snippets. self.snippets = Snippets(self) - def _set_name(self, name): + def _set_name(self, name, subtitle): if name is None: name = self.__class__.__name__ - self.setWindowTitle(name) + title = name if not subtitle else name + ' - ' + subtitle + self.setWindowTitle(title) self.setObjectName(name) # Set the name in the GUI. - self.__name__ = name + self.name = name def _set_pos_size(self, position, size): if position is not None: @@ -290,10 +292,6 @@ def get_menu(self, name): # Status bar # ------------------------------------------------------------------------- - @property - def name(self): - return str(self.windowTitle()) - @property def status_message(self): """The message in the status bar.""" @@ -397,15 +395,16 @@ def on_show(): gui.restore_geometry_state(gs) -def create_gui(name=None, model=None, plugins=None, config_dir=None): +def create_gui(name=None, subtitle=None, model=None, + plugins=None, config_dir=None): """Create a GUI with a model and a list of plugins. By default, the list of plugins is taken from the `c.TheGUI.plugins` parameter, where `TheGUI` is the name of the GUI class. """ - gui = GUI(name=name) - name = gui.__name__ + gui = GUI(name=name, subtitle=subtitle) + name = gui.name plugins = plugins or [] # Load the state. From 4f4d5fad3d7055a0ccf3d3461e96ba802c79bfa2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 16:16:28 +0100 Subject: [PATCH 0803/1059] Lock status message --- phy/gui/actions.py | 4 ++++ phy/gui/gui.py | 20 ++++++++++++++++++++ phy/gui/tests/test_gui.py | 7 +++++++ 3 files changed, 31 insertions(+) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index ffcd657b2..ae622825f 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -323,7 +323,9 @@ def command(self): @command.setter def command(self, value): value += self.cursor + self.gui.unlock_status() self.gui.status_message = value + self.gui.lock_status() def _backspace(self): """Erase the last character in the snippet command.""" @@ -404,6 +406,7 @@ def mode_on(self): logger.info("Snippet mode enabled, press `escape` to leave this mode.") # Save the current status message. self._status_message = self.gui.status_message + self.gui.lock_status() # Silent all actions except the Snippets actions. for actions in self.gui.actions: @@ -414,6 +417,7 @@ def mode_on(self): self.command = ':' def mode_off(self): + self.gui.unlock_status() # Reset the GUI status message that was set before the mode was # activated. self.gui.status_message = self._status_message diff --git a/phy/gui/gui.py b/phy/gui/gui.py index c927e0fd1..b76d81d59 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -69,6 +69,10 @@ def closeEvent(self, e): self.emit('close_widget') super(DockWidget, self).closeEvent(e) + def enterEvent(self, e): # pragma: no cover + self.emit('enter_widget') + super(DockWidget, self).enterEvent(e) + def _create_dock_widget(widget, name, closable=True, floatable=True): # Create the gui widget. @@ -149,6 +153,7 @@ def __init__(self, self._event = EventEmitter() # Status bar. + self._lock_status = False self._status_bar = QStatusBar() self.setStatusBar(self._status_bar) @@ -252,10 +257,17 @@ def add_view(self, dock_widget.setFloating(floating) dock_widget.view = view + # Emit the close_view event when the dock widget is closed. @dock_widget.connect_ def on_close_widget(): self.emit('close_view', view) + # Change the status bar when the mouse enters in a widget. + @dock_widget.connect_ + def on_enter_widget(): # pragma: no cover + if getattr(view, 'status', None): + self.status_message = view.status + dock_widget.show() self.emit('add_view', view) logger.log(5, "Add %s to GUI.", name) @@ -299,8 +311,16 @@ def status_message(self): @status_message.setter def status_message(self, value): + if self._lock_status: + return self._status_bar.showMessage(str(value)) + def lock_status(self): + self._lock_status = True + + def unlock_status(self): + self._lock_status = False + # State # ------------------------------------------------------------------------- diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index e72a3d78b..d74828a97 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -107,6 +107,13 @@ def test_gui_status_message(gui): gui.status_message = ':hello world!' assert gui.status_message == ':hello world!' + gui.lock_status() + gui.status_message = '' + assert gui.status_message == ':hello world!' + gui.unlock_status() + gui.status_message = '' + assert gui.status_message == '' + def test_gui_geometry_state(qtbot): _gs = [] From c10222a4d08b4f7d0433b9d425b1ea7cff60dda6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 16:20:36 +0100 Subject: [PATCH 0804/1059] WIP: views set the status message --- phy/cluster/manual/views.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index dc30ba2b6..dc8f74272 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -140,6 +140,10 @@ def __init__(self, shortcuts=None, **kwargs): self.shortcuts = self.default_shortcuts.copy() self.shortcuts.update(shortcuts or {}) + # Message to show in the status bar. + self.status = None + + # Keep track of the selected clusters and spikes. self.cluster_ids = None self.spike_ids = None @@ -564,6 +568,9 @@ def set_interval(self, interval): traces = self._load_traces(interval) assert traces.shape[1] == self.n_channels + # Set the status message. + self.status = 'Interval: {:.3f}s - {:.3f}s'.format(start, end) + # Determine the data bounds. m, M = traces.min(), traces.max() data_bounds = np.array([start, m, end, M]) @@ -919,6 +926,13 @@ def on_select(self, cluster_ids=None, **kwargs): n_cols=self.n_cols, best_channels_func=f) + # Set the status message. + n = self.n_cols + self.status = 'Channels: ' + ', '.join(map(str, (y_dim[0, i] + for i in range(1, n)))) + self.status += ' - ' + self.status += ', '.join(map(str, (y_dim[i, 0] for i in range(1, n)))) + # Set a non-time attribute as y coordinate in the top-left subplot. attrs = sorted(self.attributes) attrs.remove('time') @@ -1083,6 +1097,10 @@ def on_select(self, cluster_ids=None, **kwargs): colors = _selected_clusters_colors(n_clusters) + # Set the status message. + b, w = self.bin_size * 1000, self.window_size * 1000 + self.status = 'Bin: {:.1f}. Window: {:.1f}.'.format(b, w) + self.grid.shape = (n_clusters, n_clusters) with self.building(): for i in range(n_clusters): From abc643ad6408795207d9591b5a2ec1cd7e8ee612 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 17:49:34 +0100 Subject: [PATCH 0805/1059] WIP: view status --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 29 ++++++++++++++++++++------ phy/gui/gui.py | 10 --------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 9d49caec3..fd4bc2cdf 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -254,4 +254,4 @@ def test_correlogram_view(qtbot, model, tempdir): with _test_view('CorrelogramView', model=model, tempdir=tempdir) as v: v.toggle_normalization() - # qtbot.stop() + # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index dc8f74272..146636dfb 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -11,6 +11,7 @@ import numpy as np from matplotlib.colors import hsv_to_rgb, rgb_to_hsv +from vispy.util.event import Event from phy.io.array import _index_of, _get_padded, get_excerpts from phy.gui import Actions @@ -129,6 +130,12 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): # Manual clustering view # ----------------------------------------------------------------------------- +class StatusEvent(Event): + def __init__(self, type, message=None): + super(StatusEvent, self).__init__(type) + self.message = message + + class ManualClusteringView(View): max_n_spikes_per_cluster = None default_shortcuts = { @@ -148,6 +155,7 @@ def __init__(self, shortcuts=None, **kwargs): self.spike_ids = None super(ManualClusteringView, self).__init__(**kwargs) + self.events.add(status=StatusEvent) def on_select(self, cluster_ids=None, selector=None, spike_ids=None): cluster_ids = list(cluster_ids if cluster_ids is not None @@ -176,8 +184,18 @@ def attach(self, gui): menu=self.__class__.__name__, default_shortcuts=self.shortcuts) + # Update the GUI status message when the `self.set_status()` method + # is called, i.e. when the `status` event is raised by the VisPy + # view. + @self.connect + def on_status(e): + gui.status_message = e.message + self.show() + def set_status(self, message): + self.events.status(message=message) + # ----------------------------------------------------------------------------- # Waveform view @@ -569,7 +587,7 @@ def set_interval(self, interval): assert traces.shape[1] == self.n_channels # Set the status message. - self.status = 'Interval: {:.3f}s - {:.3f}s'.format(start, end) + self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) # Determine the data bounds. m, M = traces.min(), traces.max() @@ -928,10 +946,9 @@ def on_select(self, cluster_ids=None, **kwargs): # Set the status message. n = self.n_cols - self.status = 'Channels: ' + ', '.join(map(str, (y_dim[0, i] - for i in range(1, n)))) - self.status += ' - ' - self.status += ', '.join(map(str, (y_dim[i, 0] for i in range(1, n)))) + ch_i = ', '.join(map(str, (y_dim[0, i] for i in range(1, n)))) + ch_j = ', '.join(map(str, (y_dim[i, 0] for i in range(1, n)))) + self.set_status('Channels: {} - {}'.format(ch_i, ch_j)) # Set a non-time attribute as y coordinate in the top-left subplot. attrs = sorted(self.attributes) @@ -1099,7 +1116,7 @@ def on_select(self, cluster_ids=None, **kwargs): # Set the status message. b, w = self.bin_size * 1000, self.window_size * 1000 - self.status = 'Bin: {:.1f}. Window: {:.1f}.'.format(b, w) + self.set_status('Bin: {:.1f} ms. Window: {:.1f} ms.'.format(b, w)) self.grid.shape = (n_clusters, n_clusters) with self.building(): diff --git a/phy/gui/gui.py b/phy/gui/gui.py index b76d81d59..0653137f8 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -69,10 +69,6 @@ def closeEvent(self, e): self.emit('close_widget') super(DockWidget, self).closeEvent(e) - def enterEvent(self, e): # pragma: no cover - self.emit('enter_widget') - super(DockWidget, self).enterEvent(e) - def _create_dock_widget(widget, name, closable=True, floatable=True): # Create the gui widget. @@ -262,12 +258,6 @@ def add_view(self, def on_close_widget(): self.emit('close_view', view) - # Change the status bar when the mouse enters in a widget. - @dock_widget.connect_ - def on_enter_widget(): # pragma: no cover - if getattr(view, 'status', None): - self.status_message = view.status - dock_widget.show() self.emit('add_view', view) logger.log(5, "Add %s to GUI.", name) From aab2dd8be0433527cad1c47a13a229db677958f8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 18:03:27 +0100 Subject: [PATCH 0806/1059] Minor updates in view status --- phy/cluster/manual/views.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 146636dfb..5262ff215 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -193,9 +193,16 @@ def on_status(e): self.show() - def set_status(self, message): + def set_status(self, message=None): + message = message or self.status + if not message: + return + self.status = message self.events.status(message=message) + def on_mouse_move(self, e): + self.set_status() + # ----------------------------------------------------------------------------- # Waveform view @@ -575,7 +582,7 @@ def _restrict_interval(self, interval): # Public methods # ------------------------------------------------------------------------- - def set_interval(self, interval): + def set_interval(self, interval, change_status=True): """Display the traces and spikes in a given interval.""" self.clear() interval = self._restrict_interval(interval) @@ -587,7 +594,8 @@ def set_interval(self, interval): assert traces.shape[1] == self.n_channels # Set the status message. - self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) + if change_status: + self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) # Determine the data bounds. m, M = traces.min(), traces.max() @@ -620,7 +628,7 @@ def set_interval(self, interval): def on_select(self, cluster_ids=None, **kwargs): super(TraceView, self).on_select(cluster_ids=cluster_ids, **kwargs) - self.set_interval(self.interval) + self.set_interval(self.interval, change_status=False) def attach(self, gui): """Attach the view to the GUI.""" @@ -1071,6 +1079,10 @@ def __init__(self, assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters + # Set the status message. + b, w = self.bin_size * 1000, self.window_size * 1000 + self.set_status('Bin: {:.1f} ms. Window: {:.1f} ms.'.format(b, w)) + def _compute_correlograms(self, cluster_ids): # Keep spikes belonging to the selected clusters. @@ -1114,10 +1126,6 @@ def on_select(self, cluster_ids=None, **kwargs): colors = _selected_clusters_colors(n_clusters) - # Set the status message. - b, w = self.bin_size * 1000, self.window_size * 1000 - self.set_status('Bin: {:.1f} ms. Window: {:.1f} ms.'.format(b, w)) - self.grid.shape = (n_clusters, n_clusters) with self.building(): for i in range(n_clusters): From f017e4c0daa9f751de7f8dda82cd8e1530aad03c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 18:07:09 +0100 Subject: [PATCH 0807/1059] Save CCG normalization in view --- phy/cluster/manual/tests/test_views.py | 1 + phy/cluster/manual/views.py | 25 ++++++++++++++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index fd4bc2cdf..3135042e9 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -75,6 +75,7 @@ def _test_view(view_name, model=None, tempdir=None): state.set_view_params('WaveformView1', box_size=(.1, .1)) state.set_view_params('TraceView1', box_size=(1., .01)) state.set_view_params('FeatureView1', feature_scaling=.5) + state.set_view_params('CorrelogramView1', uniform_normalization=True) state.save() # Create the GUI. diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 5262ff215..f11209eb4 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -200,7 +200,7 @@ def set_status(self, message=None): self.status = message self.events.status(message=message) - def on_mouse_move(self, e): + def on_mouse_move(self, e): # pragma: no cover self.set_status() @@ -1150,12 +1150,14 @@ def attach(self, gui): class CorrelogramViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - bs, ws, es, ne = state.get_view_params('CorrelogramView', - 'bin_size', - 'window_size', - 'excerpt_size', - 'n_excerpts', - ) + bs, ws, es, ne, un = state.get_view_params('CorrelogramView', + 'bin_size', + 'window_size', + 'excerpt_size', + 'n_excerpts', + 'uniform_normalization', + ) + view = CorrelogramView(spike_times=model.spike_times, spike_clusters=model.spike_clusters, sample_rate=model.sample_rate, @@ -1164,4 +1166,13 @@ def attach_to_gui(self, gui, model=None, state=None): excerpt_size=es, n_excerpts=ne, ) + if un is not None: + view.uniform_normalization = un view.attach(gui) + + @gui.connect_ + def on_close(): + # Save the normalization. + un = view.uniform_normalization + state.set_view_params(view, + uniform_normalization=un) From d257e758ace62ad6a731bbb5ceabe042d8446fe2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 18:26:11 +0100 Subject: [PATCH 0808/1059] Change correlogram bin window actions --- phy/cluster/manual/tests/test_views.py | 4 +- phy/cluster/manual/views.py | 56 +++++++++++++++++++------- phy/gui/actions.py | 5 ++- 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 3135042e9..1c3498c22 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -72,7 +72,7 @@ def _test_view(view_name, model=None, tempdir=None): # Save a test GUI state JSON file in the tempdir. state = GUIState(config_dir=tempdir) - state.set_view_params('WaveformView1', box_size=(.1, .1)) + state.set_view_params('WaveformView1', overlap=False, box_size=(.1, .1)) state.set_view_params('TraceView1', box_size=(1., .01)) state.set_view_params('FeatureView1', feature_scaling=.5) state.set_view_params('CorrelogramView1', uniform_normalization=True) @@ -255,4 +255,6 @@ def test_correlogram_view(qtbot, model, tempdir): with _test_view('CorrelogramView', model=model, tempdir=tempdir) as v: v.toggle_normalization() + v.set_bin(1) + v.set_window(100) # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f11209eb4..a6ef9be46 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -158,8 +158,8 @@ def __init__(self, shortcuts=None, **kwargs): self.events.add(status=StatusEvent) def on_select(self, cluster_ids=None, selector=None, spike_ids=None): - cluster_ids = list(cluster_ids if cluster_ids is not None - else self.cluster_ids) + cluster_ids = (cluster_ids if cluster_ids is not None + else self.cluster_ids) if spike_ids is None: # Use the selector to select some or all of the spikes. if selector: @@ -167,8 +167,8 @@ def on_select(self, cluster_ids=None, selector=None, spike_ids=None): spike_ids = selector.select_spikes(cluster_ids, ns) else: spike_ids = self.spike_ids - self.cluster_ids = cluster_ids - self.spike_ids = np.asarray(spike_ids) + self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] + self.spike_ids = np.asarray(spike_ids if spike_ids is not None else []) def attach(self, gui): """Attach the view to the GUI.""" @@ -395,8 +395,11 @@ def shrink_vertically(self): class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - bs, ps = state.get_view_params('WaveformView', 'box_scaling', - 'probe_scaling') + bs, ps, ov = state.get_view_params('WaveformView', + 'box_scaling', + 'probe_scaling', + 'overlap', + ) view = WaveformView(waveforms=model.waveforms, masks=model.masks, spike_clusters=model.spike_clusters, @@ -406,12 +409,17 @@ def attach_to_gui(self, gui, model=None, state=None): ) view.attach(gui) + if ov is not None: + view.overlap = ov + @gui.connect_ def on_close(): # Save the box bounds. state.set_view_params(view, box_scaling=tuple(view.box_scaling), - probe_scaling=tuple(view.probe_scaling)) + probe_scaling=tuple(view.probe_scaling), + overlap=view.overlap, + ) # ----------------------------------------------------------------------------- @@ -1058,12 +1066,6 @@ def __init__(self, assert sample_rate > 0 self.sample_rate = sample_rate - self.bin_size = bin_size or self.bin_size - assert self.bin_size > 0 - - self.window_size = window_size or self.window_size - assert self.window_size > 0 - self.excerpt_size = excerpt_size or self.excerpt_size self.n_excerpts = n_excerpts or self.n_excerpts @@ -1079,6 +1081,17 @@ def __init__(self, assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters + # Set the default bin and window size. + self.set_bin_window(bin_size=bin_size, window_size=window_size) + + def set_bin_window(self, bin_size=None, window_size=None): + bin_size = bin_size or self.bin_size + window_size = window_size or self.window_size + assert 1e-6 < bin_size < 1e3 + assert 1e-6 < window_size < 1e3 + assert bin_size < window_size + self.bin_size = bin_size + self.window_size = window_size # Set the status message. b, w = self.bin_size * 1000, self.window_size * 1000 self.set_status('Bin: {:.1f} ms. Window: {:.1f} ms.'.format(b, w)) @@ -1146,6 +1159,18 @@ def attach(self, gui): """Attach the view to the GUI.""" super(CorrelogramView, self).attach(gui) self.actions.add(self.toggle_normalization, shortcut='n') + self.actions.add(self.set_bin, alias='cb') + self.actions.add(self.set_window, alias='cw') + + def set_bin(self, bin_size): + """Set the correlogram bin size in milliseconds.""" + self.set_bin_window(bin_size=bin_size * 1e-3) + self.on_select() + + def set_window(self, window_size): + """Set the correlogram window size in milliseconds.""" + self.set_bin_window(window_size=window_size * 1e-3) + self.on_select() class CorrelogramViewPlugin(IPlugin): @@ -1175,4 +1200,7 @@ def on_close(): # Save the normalization. un = view.uniform_normalization state.set_view_params(view, - uniform_normalization=un) + uniform_normalization=un, + bin_size=view.bin_size, + window_size=view.window_size, + ) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index ae622825f..febe302a9 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -177,7 +177,8 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, self.gui.addAction(action) # Add the action to the menu. menu = menu or self.menu - if menu: + # Do not show private actions in the menu. + if menu and not name.startswith('_'): self.gui.get_menu(menu).addAction(action) self._actions_dict[name] = action_obj # Register the alias -> name mapping. @@ -298,7 +299,7 @@ def __init__(self, gui): self.gui = gui self._status_message = gui.status_message - self.actions = Actions(gui, name='Snippets') + self.actions = Actions(gui, name='Snippets', menu='Snippets') # Register snippet mode shortcut. @self.actions.add(shortcut=':') From 622b113876208e6fbe6f4955a090c09d904c477f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 18:30:23 +0100 Subject: [PATCH 0809/1059] Better detrending in trace view --- phy/cluster/manual/views.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a6ef9be46..4bb799a74 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -427,6 +427,7 @@ def on_close(): # ----------------------------------------------------------------------------- class TraceView(ManualClusteringView): + n_samples_for_mean = 1000 interval_duration = .5 # default duration of the interval shift_amount = .1 scaling_coeff = 1.1 @@ -460,6 +461,10 @@ def __init__(self, self.traces = traces self.duration = self.dt * self.n_samples + # Compute the mean traces in order to detrend the traces. + k = max(1, self.n_samples // self.n_samples_for_mean) + self.mean_traces = np.mean(traces[::k, :], axis=0).astype(traces.dtype) + # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or round(.002 * sample_rate)) @@ -511,8 +516,7 @@ def _load_traces(self, interval): traces = self.traces[i:j, :] # Detrend the traces. - m = np.mean(traces[::10, :], axis=0).astype(traces.dtype) - traces -= m + traces -= self.mean_traces # Create the plots. return traces From 5c9804536dd2e900a8f07c2a7a305df20415f0dd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 18:50:01 +0100 Subject: [PATCH 0810/1059] Add docstrings to actions --- phy/cluster/manual/gui_component.py | 27 +++++++++++++++++++++++---- phy/cluster/manual/views.py | 19 +++++++++++++++++-- phy/gui/actions.py | 17 ++++++++++++++--- phy/gui/gui.py | 2 ++ 4 files changed, 56 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 2678eed4f..2d6822b86 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -286,11 +286,16 @@ def _create_actions(self, gui): for group in ('noise', 'mua', 'good'): self.actions.add(partial(self.move_best, group), - name='move_best_to_' + group) + name='move_best_to_' + group, + docstring='Move the best clusters to %s.' % group) self.actions.add(partial(self.move_similar, group), - name='move_similar_to_' + group) + name='move_similar_to_' + group, + docstring='Move the similar clusters to %s.' % + group) self.actions.add(partial(self.move_all, group), - name='move_all_to_' + group) + name='move_all_to_' + group, + docstring='Move all selected clusters to %s.' % + group) self.actions.separator() # Others. @@ -446,7 +451,7 @@ def attach(self, gui): # ------------------------------------------------------------------------- def select(self, *cluster_ids): - """Select action: select clusters in the cluster view.""" + """Select a list of clusters.""" # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])` # This makes it more convenient to select multiple clusters with # the snippet: `:c 1 2 3` instead of `:c 1,2,3`. @@ -463,6 +468,7 @@ def selected(self): # ------------------------------------------------------------------------- def merge(self, cluster_ids=None): + """Merge the selected clusters.""" if cluster_ids is None: cluster_ids = self.selected if len(cluster_ids or []) <= 1: @@ -471,6 +477,7 @@ def merge(self, cluster_ids=None): self._global_history.action(self.clustering) def split(self, spike_ids): + """Split the selected spikes (NOT IMPLEMENTED YET).""" if len(spike_ids) == 0: return # TODO: connect to request_split emitted by view @@ -481,52 +488,64 @@ def split(self, spike_ids): # ------------------------------------------------------------------------- def move(self, cluster_ids, group): + """Move clusters to a group.""" if len(cluster_ids) == 0: return self.cluster_meta.set('group', cluster_ids, group) self._global_history.action(self.cluster_meta) def move_best(self, group): + """Move all selected best clusters to a group.""" self.move(self.cluster_view.selected, group) def move_similar(self, group): + """Move all selected similar clusters to a group.""" self.move(self.similarity_view.selected, group) def move_all(self, group): + """Move all selected clusters to a group.""" self.move(self.selected, group) # Wizard actions # ------------------------------------------------------------------------- def reset(self): + """Reset the wizard.""" self._update_cluster_view() self.cluster_view.next() def next_best(self): + """Select the next best cluster.""" self.cluster_view.next() def previous_best(self): + """Select the previous best cluster.""" self.cluster_view.previous() def next(self): + """Select the next cluster.""" if not self.selected: self.cluster_view.next() else: self.similarity_view.next() def previous(self): + """Select the previous cluster.""" self.similarity_view.previous() # Other actions # ------------------------------------------------------------------------- def undo(self): + """Undo the last action.""" self._global_history.undo() def redo(self): + """Undo the last undone action.""" self._global_history.redo() def save(self): + """Save the manual clustering back to disk.""" spike_clusters = self.clustering.spike_clusters groups = {c: self.cluster_meta.get('group', c) or 'unsorted' for c in self.clustering.cluster_ids} diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 4bb799a74..8bb54424b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -347,6 +347,7 @@ def attach(self, gui): self.actions.add(self.shrink_vertically) def toggle_waveform_overlap(self): + """Toggle the overlap of the waveforms.""" self.overlap = not self.overlap self.on_select() @@ -358,18 +359,22 @@ def _update_boxes(self): self.box_size * self.box_scaling) def widen(self): + """Increase the horizontal scaling of the waveforms.""" self.box_scaling[0] *= self.scaling_coeff self._update_boxes() def narrow(self): + """Decrease the horizontal scaling of the waveforms.""" self.box_scaling[0] /= self.scaling_coeff self._update_boxes() def increase(self): + """Increase the vertical scaling of the waveforms.""" self.box_scaling[1] *= self.scaling_coeff self._update_boxes() def decrease(self): + """Decrease the vertical scaling of the waveforms.""" self.box_scaling[1] /= self.scaling_coeff self._update_boxes() @@ -377,18 +382,22 @@ def decrease(self): # ------------------------------------------------------------------------- def extend_horizontally(self): + """Increase the horizontal scaling of the probe.""" self.probe_scaling[0] *= self.scaling_coeff self._update_boxes() def shrink_horizontally(self): + """Decrease the horizontal scaling of the waveforms.""" self.probe_scaling[0] /= self.scaling_coeff self._update_boxes() def extend_vertically(self): + """Increase the vertical scaling of the waveforms.""" self.probe_scaling[1] *= self.scaling_coeff self._update_boxes() def shrink_vertically(self): + """Decrease the vertical scaling of the waveforms.""" self.probe_scaling[1] /= self.scaling_coeff self._update_boxes() @@ -709,10 +718,12 @@ def _update_boxes(self): self.stacked.box_size = self.box_size * self.scaling def increase(self): + """Increase the scaling of the traces.""" self.scaling *= self.scaling_coeff self._update_boxes() def decrease(self): + """Decrease the scaling of the traces.""" self.scaling /= self.scaling_coeff self._update_boxes() @@ -996,10 +1007,12 @@ def attach(self, gui): self.actions.add(self.decrease) def increase(self): + """Increase the scaling of the features.""" self.feature_scaling *= 1.2 self.on_select() def decrease(self): + """Decrease the scaling of the features.""" self.feature_scaling /= 1.2 self.on_select() @@ -1089,6 +1102,7 @@ def __init__(self, self.set_bin_window(bin_size=bin_size, window_size=window_size) def set_bin_window(self, bin_size=None, window_size=None): + """Set the bin and window sizes.""" bin_size = bin_size or self.bin_size window_size = window_size or self.window_size assert 1e-6 < bin_size < 1e3 @@ -1156,6 +1170,7 @@ def on_select(self, cluster_ids=None, **kwargs): ) def toggle_normalization(self): + """Change the normalization of the correlograms.""" self.uniform_normalization = not self.uniform_normalization self.on_select() @@ -1167,12 +1182,12 @@ def attach(self, gui): self.actions.add(self.set_window, alias='cw') def set_bin(self, bin_size): - """Set the correlogram bin size in milliseconds.""" + """Set the correlogram bin size (in milliseconds).""" self.set_bin_window(bin_size=bin_size * 1e-3) self.on_select() def set_window(self, window_size): - """Set the correlogram window size in milliseconds.""" + """Set the correlogram window size (in milliseconds).""" self.set_bin_window(window_size=window_size * 1e-3) self.on_select() diff --git a/phy/gui/actions.py b/phy/gui/actions.py index febe302a9..35a5ec7d3 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -9,6 +9,7 @@ from functools import partial import logging +import re import sys import traceback @@ -113,7 +114,7 @@ def _alias(name): @require_qt -def _create_qaction(gui, name, callback, shortcut): +def _create_qaction(gui, name, callback, shortcut, docstring=None): # Create the QAction instance. action = QAction(name.capitalize().replace('_', ' '), gui) @@ -125,6 +126,9 @@ def wrapped(checked, *args, **kwargs): # pragma: no cover if not isinstance(sequence, (tuple, list)): sequence = [sequence] action.setShortcuts(sequence) + assert docstring + action.setStatusTip(docstring) + action.setWhatsThis(docstring) return action @@ -148,7 +152,7 @@ def __init__(self, gui, name=None, menu=None, default_shortcuts=None): gui.actions.append(self) def add(self, callback=None, name=None, shortcut=None, alias=None, - menu=None, verbose=True): + docstring=None, menu=None, verbose=True): """Add an action with a keyboard shortcut.""" # TODO: add menu_name option and create menu bar if callback is None: @@ -167,8 +171,13 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, if name in self._actions_dict: return + # Set the status tip from the function's docstring. + docstring = docstring or callback.__doc__ or name + docstring = re.sub(r'[\s]{2,}', ' ', docstring) + # Create and register the action. - action = _create_qaction(self.gui, name, callback, shortcut) + action = _create_qaction(self.gui, name, callback, shortcut, + docstring=docstring) action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback, menu=menu) if verbose and not name.startswith('_'): @@ -304,6 +313,8 @@ def __init__(self, gui): # Register snippet mode shortcut. @self.actions.add(shortcut=':') def enable_snippet_mode(): + """Enable the snippet mode (type action alias in the status + bar).""" self.mode_on() self._create_snippet_actions() diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 0653137f8..1718ae48b 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -182,11 +182,13 @@ def _set_default_actions(self): @self.default_actions.add(shortcut=('HelpContents', 'h')) def show_all_shortcuts(): + """Show the shortcuts of all actions.""" for actions in self.actions: actions.show_shortcuts() @self.default_actions.add(shortcut='ctrl+q') def exit(): + """Close the GUI.""" self.close() # Events From e139c2e44f1f03d9f07634027e876e332cd4493f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 19:21:59 +0100 Subject: [PATCH 0811/1059] Add background spikes in feature view --- phy/cluster/manual/views.py | 78 +++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 8bb54424b..cd1d16019 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -100,8 +100,12 @@ def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): mask and cluster index.""" n_spikes = len(masks) assert masks.shape == (n_spikes,) - depth = (-0.1 - (spike_clusters_rel + masks) / - float(n_clusters + 10.)) + # Fixed depth for background spikes. + if spike_clusters_rel is None: + depth = .5 * np.ones(n_spikes) + else: + depth = (-0.1 - (spike_clusters_rel + masks) / + float(n_clusters + 10.)) depth[masks <= 0.25] = 0 assert depth.shape == (n_spikes,) return depth @@ -110,19 +114,26 @@ def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): def _get_color(masks, spike_clusters_rel=None, n_clusters=None): """Return the color of vertices as a function of the mask and cluster index.""" - n_spikes = len(masks) + n_spikes = masks.shape[0] + # The transparency depends on whether the spike clusters are specified. + # For background spikes, we use a smaller alpha. + alpha = .5 if spike_clusters_rel is not None else .25 assert masks.shape == (n_spikes,) - assert spike_clusters_rel.shape == (n_spikes,) # Generate the colors. colors = _selected_clusters_colors(n_clusters) # Color as a function of the mask. - color = colors[spike_clusters_rel] + if spike_clusters_rel is not None: + assert spike_clusters_rel.shape == (n_spikes,) + color = colors[spike_clusters_rel] + else: + # Fixed color when the spike clusters are not specified. + color = .5 * np.ones((n_spikes, 3)) hsv = rgb_to_hsv(color[:, :3]) # Change the saturation and value as a function of the mask. hsv[:, 1] *= masks hsv[:, 2] *= .5 * (1. + masks) color = hsv_to_rgb(hsv) - color = np.c_[color, .5 * np.ones((n_spikes, 1))] + color = np.c_[color, alpha * np.ones((n_spikes, 1))] return color @@ -819,6 +830,7 @@ class FeatureView(ManualClusteringView): max_n_spikes_per_cluster = 100000 normalization_percentile = .95 normalization_n_spikes = 1000 + n_spikes_bg = 10000 _default_marker_size = 3. _feature_scaling = 1. @@ -853,6 +865,11 @@ def __init__(self, # Masks. self.masks = masks + # Background spikes. + k = max(1, self.n_spikes // self.n_spikes_bg) + self.spike_ids_bg = slice(None, None, k) + self.masks_bg = self.masks[self.spike_ids_bg] + # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters @@ -888,7 +905,7 @@ def _get_feature(self, dim, spike_ids=None): # Extra features like time. values, _ = self.attributes[dim] values = values[spike_ids] - assert values.shape == (len(spike_ids),) + assert values.shape == (f.shape[0],) return values else: assert len(dim) == 2 @@ -912,15 +929,16 @@ def _get_dim_bounds(self, x_dim, y_dim): y0, y1 = self._get_dim_bounds_single(y_dim) return [x0, y0, x1, y1] - def _plot_features(self, i, j, x_dim, y_dim, - cluster_ids=None, spike_ids=None, + def _plot_features(self, i, j, x_dim, y_dim, x, y, masks=None, spike_clusters_rel=None): - sc = spike_clusters_rel - n_clusters = len(cluster_ids) + """Plot the features in a subplot.""" + assert x.shape == y.shape + n_spikes = x.shape[0] - # Retrieve the x and y values for the subplot. - x = self._get_feature(x_dim[i, j], spike_ids) - y = self._get_feature(y_dim[i, j], spike_ids) + sc = spike_clusters_rel + if sc is not None: + assert sc.shape == (n_spikes,) + n_clusters = len(self.cluster_ids) # Retrieve the data bounds. data_bounds = self._get_dim_bounds(x_dim[i, j], @@ -933,23 +951,25 @@ def _plot_features(self, i, j, x_dim, y_dim, my, dy = _project_mask_depth(y_dim[i, j], masks, spike_clusters_rel=sc, n_clusters=n_clusters) + assert mx.shape == my.shape == dx.shape == dy.shape == (n_spikes,) d = np.maximum(dx, dy) m = np.maximum(mx, my) # Get the color of the markers. - color = _get_color(m, - spike_clusters_rel=sc, - n_clusters=n_clusters) + color = _get_color(m, spike_clusters_rel=sc, n_clusters=n_clusters) + assert color.shape == (n_spikes, 4) # Create the scatter plot for the current subplot. - ms = self._default_marker_size + # The marker size is smaller for background spikes. + ms = (self._default_marker_size + if spike_clusters_rel is not None else 1.) self[i, j].scatter(x=x, y=y, color=color, depth=d, data_bounds=data_bounds, - size=ms * np.ones(len(spike_ids)), + size=ms * np.ones(n_spikes), ) def set_best_channels_func(self, func): @@ -991,12 +1011,22 @@ def on_select(self, cluster_ids=None, **kwargs): with self.building(): for i in range(self.n_cols): for j in range(self.n_cols): - self._plot_features(i, j, x_dim, y_dim, - cluster_ids=cluster_ids, - spike_ids=spike_ids, + + # Retrieve the x and y values for the subplot. + x = self._get_feature(x_dim[i, j], self.spike_ids) + y = self._get_feature(y_dim[i, j], self.spike_ids) + + # Retrieve the x and y values for the background spikes. + x_bg = self._get_feature(x_dim[i, j], self.spike_ids_bg) + y_bg = self._get_feature(y_dim[i, j], self.spike_ids_bg) + + # Background features. + self._plot_features(i, j, x_dim, y_dim, x_bg, y_bg, + masks=self.masks_bg) + # Cluster features. + self._plot_features(i, j, x_dim, y_dim, x, y, masks=masks, - spike_clusters_rel=sc, - ) + spike_clusters_rel=sc) # Add the boxes. self.grid.add_boxes(self, self.shape) From c2e3ab8a51e7d81263997841c28983a85259eb6d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 21:12:21 +0100 Subject: [PATCH 0812/1059] WIP: refactor LineVisual --- phy/plot/glsl/line.frag | 5 ++++ phy/plot/glsl/line.vert | 8 +++++ phy/plot/tests/test_visuals.py | 8 +++-- phy/plot/visuals.py | 54 ++++++++++------------------------ 4 files changed, 35 insertions(+), 40 deletions(-) create mode 100644 phy/plot/glsl/line.frag create mode 100644 phy/plot/glsl/line.vert diff --git a/phy/plot/glsl/line.frag b/phy/plot/glsl/line.frag new file mode 100644 index 000000000..7c31f83aa --- /dev/null +++ b/phy/plot/glsl/line.frag @@ -0,0 +1,5 @@ +varying vec4 v_color; + +void main() { + gl_FragColor = v_color; +} diff --git a/phy/plot/glsl/line.vert b/phy/plot/glsl/line.vert new file mode 100644 index 000000000..a5c0c5474 --- /dev/null +++ b/phy/plot/glsl/line.vert @@ -0,0 +1,8 @@ +attribute vec2 a_position; +attribute vec4 a_color; +varying vec4 v_color; + +void main() { + gl_Position = transform(a_position); + v_color = a_color; +} diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 6448b82a0..e677a1529 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -164,10 +164,14 @@ def test_histogram_2(qtbot, canvas_pz): #------------------------------------------------------------------------------ def test_line_empty(qtbot, canvas): - _test_visual(qtbot, canvas, LineVisual()) + pos = np.zeros((0, 4)) + _test_visual(qtbot, canvas, LineVisual(), pos=pos) def test_line_0(qtbot, canvas_pz): + n = 10 y = np.linspace(-.5, .5, 10) + pos = np.c_[-np.ones(n), y, np.ones(n), y] + color = np.random.uniform(.5, .9, (n, 4)) _test_visual(qtbot, canvas_pz, LineVisual(), - y0=y, y1=y, data_bounds=[-1, -1, 1, 1]) + pos=pos, color=color, data_bounds=[-1, -1, 1, 1]) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 4a3a04269..e703e192c 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -351,62 +351,40 @@ class LineVisual(BaseVisual): def __init__(self, color=None): super(LineVisual, self).__init__() - self.set_shader('simple') + self.set_shader('line') self.set_primitive_type('lines') - self.color = color or self._default_color - self.data_range = Range(NDC) self.transforms.add_on_cpu(self.data_range) @staticmethod - def validate(x0=None, - y0=None, - x1=None, - y1=None, - # color=None, - data_bounds=None, - ): - - # TODO: single argument pos (n, 4) instead of x0 y0 etc. - - # Get the number of lines. - n_lines = _get_length(x0, y0, x1, y1) - x0 = _validate_line_coord(x0, n_lines, -1) - y0 = _validate_line_coord(y0, n_lines, -1) - x1 = _validate_line_coord(x1, n_lines, +1) - y1 = _validate_line_coord(y1, n_lines, +1) + def validate(pos=None, color=None, data_bounds=None): + assert pos is not None + pos = np.asarray(pos) + assert pos.ndim == 2 + n_lines = pos.shape[0] + assert pos.shape[1] == 4 + pos = pos.astype(np.float32) - assert x0.shape == y0.shape == x1.shape == y1.shape == (n_lines, 1) + # Color. + color = _get_array(color, (n_lines, 4), LineVisual._default_color) # By default, we assume that the coordinates are in NDC. if data_bounds is None: data_bounds = NDC - - # NOTE: currently, we don't support custom colors. We could do it - # by replacing the uniform by an attribute in the shaders. - # color = _get_array(color, (4,), LineVisual._default_color) - # assert len(color) == 4 - data_bounds = _get_data_bounds(data_bounds, length=n_lines) data_bounds = data_bounds.astype(np.float32) assert data_bounds.shape == (n_lines, 4) - return Bunch(x0=x0, - y0=y0, - x1=x1, - y1=y1, - # color=color, - data_bounds=data_bounds, - ) + return Bunch(pos=pos, color=color, data_bounds=data_bounds) @staticmethod - def vertex_count(x0=None, y0=None, x1=None, y1=None, **kwargs): + def vertex_count(pos=None, **kwargs): """Take the output of validate() as input.""" - return 2 * _get_length(x0, y0, x1, y1) + return pos.shape[0] * 2 def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) - pos = np.c_[data.x0, data.y0, data.x1, data.y1].astype(np.float32) + pos = data.pos assert pos.ndim == 2 assert pos.shape[1] == 4 assert pos.dtype == np.float32 @@ -424,5 +402,5 @@ def set_data(self, *args, **kwargs): self.program['a_position'] = pos_tr # Color. - # self.program['u_color'] = data.color - self.program['u_color'] = self.color + color = np.repeat(data.color, 2, axis=0).astype(np.float32) + self.program['a_color'] = color From bb844d701a2eea60968c97df7cc0c6c4e14028e7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 21:18:40 +0100 Subject: [PATCH 0813/1059] WIP: refactor LineVisual --- phy/plot/interact.py | 12 +++++++----- phy/plot/tests/test_plot.py | 4 ++-- phy/plot/visuals.py | 23 +---------------------- 3 files changed, 10 insertions(+), 29 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index e3117ad8c..7f553f25d 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -65,10 +65,12 @@ def add_boxes(self, canvas, shape=None): n_boxes = n * m a = 1 + .05 - x0 = np.tile([-a, +a, +a, -a], n_boxes) - y0 = np.tile([-a, -a, +a, +a], n_boxes) - x1 = np.tile([+a, +a, -a, -a], n_boxes) - y1 = np.tile([-a, +a, +a, -a], n_boxes) + pos = np.array([[-a, -a, +a, -a], + [+a, -a, +a, +a], + [+a, +a, -a, +a], + [-a, +a, -a, -a], + ]) + pos = np.tile(pos, (n_boxes, 1)) box_index = [] for i in range(n): @@ -85,7 +87,7 @@ def _remove_clip(tc): return tc.remove('Clip') canvas.add_visual(boxes) - boxes.set_data(x0=x0, y0=y0, x1=x1, y1=y1) + boxes.set_data(pos=pos) boxes.program['a_box_index'] = box_index def update_program(self, program): diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 1a92cfe2c..62f5cbef3 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -112,8 +112,8 @@ def test_grid_hist(qtbot): def test_grid_lines(qtbot): view = View(layout='grid', shape=(1, 2)) - view[0, 0].lines(y0=-.5, y1=-.5) - view[0, 1].lines(y0=+.5, y1=+.5) + view[0, 0].lines(pos=[-1, -.5, +1, -.5]) + view[0, 1].lines(pos=[-1, +.5, +1, +.5]) _show(qtbot, view) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index e703e192c..c73fa81d9 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -131,27 +131,6 @@ def _max(arr): return arr.max() if len(arr) else 1 -def _validate_line_coord(x, n, default): - assert n >= 0 - if x is None: - x = default - if not hasattr(x, '__len__'): - x = x * np.ones(n) - x = np.asarray(x, dtype=np.float32) - assert isinstance(x, np.ndarray) - if x.ndim == 1: - x = x[:, None] - assert x.shape == (n, 1) - return x - - -def _get_length(*args): - for arg in args: - if hasattr(arg, '__len__'): - return len(arg) - return 1 - - class PlotVisual(BaseVisual): _default_color = DEFAULT_COLOR allow_list = ('x', 'y') @@ -359,7 +338,7 @@ def __init__(self, color=None): @staticmethod def validate(pos=None, color=None, data_bounds=None): assert pos is not None - pos = np.asarray(pos) + pos = np.atleast_2d(pos) assert pos.ndim == 2 n_lines = pos.shape[0] assert pos.shape[1] == 4 From 19862604d77152649ea0301d3d521c135bfc2d53 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 21:22:21 +0100 Subject: [PATCH 0814/1059] Add axes in feature view --- phy/cluster/manual/views.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index cd1d16019..55a858425 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1027,6 +1027,12 @@ def on_select(self, cluster_ids=None, **kwargs): self._plot_features(i, j, x_dim, y_dim, x, y, masks=masks, spike_clusters_rel=sc) + + # Add axes. + self[i, j].lines(pos=[[-1, 0, +1, 0], + [0, -1, 0, +1]], + color=(.25, .25, .25, .5)) + # Add the boxes. self.grid.add_boxes(self, self.shape) From ed2c46cef46c9203a8b2b76f8ce7347ac044b1a8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 22:00:46 +0100 Subject: [PATCH 0815/1059] Add PanZoom.set_range() --- phy/plot/panzoom.py | 20 +++++++++++++++++++- phy/plot/tests/test_panzoom.py | 18 ++++++++++++++++++ phy/plot/tests/test_transform.py | 8 ++++++++ phy/plot/transform.py | 11 ++++++++++- 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 9d7cf282f..3f5f3faaf 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -12,7 +12,7 @@ import numpy as np from .base import BaseInteract -from .transform import Translate, Scale, pixels_to_ndc +from .transform import Translate, Scale, TransformChain, pixels_to_ndc from phy.utils._types import _as_array @@ -283,6 +283,24 @@ def zoom_delta(self, d, p=(0., 0.), c=1.): self.update() + def set_range(self, bounds): + """Zoom to fit a box.""" + # a * (-1 + t) = v0 + # a * (+1 + t) = v1 + bounds = np.asarray(bounds, dtype=np.float64) + v0 = bounds[:2] + v1 = bounds[2:] + self.zoom = (v1 - v0) / 2. + self.pan = v1 / self.zoom - 1 + + def get_range(self): + """Return the bounds currently visible.""" + v0 = np.array([-1., -1.]) + v1 = np.array([+1., +1.]) + x0, y0 = self.zoom * (v0 + self.pan) + x1, y1 = self.zoom * (v1 + self.pan) + return (x0, y0, x1, y1) + # Event callbacks # ------------------------------------------------------------------------- diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 902fb5a6e..8511cf101 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -8,6 +8,7 @@ #------------------------------------------------------------------------------ import numpy as np +from numpy.testing import assert_allclose as ac from pytest import yield_fixture from vispy.app import MouseEvent from vispy.util import keys @@ -156,6 +157,23 @@ def test_panzoom_constraints_z(): assert pz.zoom == [2, 2] +def test_panzoom_set_range(): + pz = PanZoom() + + def _test_range(*bounds): + pz.set_range(bounds) + ac(pz.get_range(), bounds) + + _test_range(-1, -1, 1, 1) + _test_range(-.5, -.5, .5, .5) + _test_range(0, 0, 1, 1) + _test_range(-.5, 0, 0, 1) + + +#------------------------------------------------------------------------------ +# Test panzoom on canvas +#------------------------------------------------------------------------------ + def test_panzoom_pan_mouse(qtbot, canvas_pz, panzoom): c = canvas_pz pz = panzoom diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index ae31ccf0a..c7f57e01d 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -237,3 +237,11 @@ def test_transform_chain_add(): tc_2.add_on_cpu([Scale(2)]) ae((tc + tc_2).apply([3]), [[3]]) + + +def test_transform_chain_inverse(): + tc = TransformChain() + tc.add_on_cpu([Scale(.5), Translate((1, 0)), Scale(2)]) + tci = tc.inverse() + ae(tc.apply([[1, 0]]), [[3, 0]]) + ae(tci.apply([[3, 0]]), [[1, 0]]) diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 5e4a055b1..252bcd9bb 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -65,9 +65,11 @@ def _minus(value): def _inverse(value): if isinstance(value, np.ndarray): return 1. / value - else: + elif hasattr(value, '__len__'): assert len(value) == 2 return 1. / value[0], 1. / value[1] + else: + return 1. / value def subplot_bounds(shape=None, index=None): @@ -282,6 +284,13 @@ def apply(self, arr): arr = t.apply(arr) return arr + def inverse(self): + """Return the inverse chain of transforms.""" + transforms = self.cpu_transforms + self.gpu_transforms + inv_transforms = [transform.inverse() + for transform in transforms[::-1]] + return TransformChain().add_on_cpu(inv_transforms) + def __add__(self, tc): assert isinstance(tc, TransformChain) assert tc.transformed_var_name == self.transformed_var_name From c369adfda42bb9b91a22afa7c008fa85f3cd29cb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 22:04:38 +0100 Subject: [PATCH 0816/1059] Flakify --- phy/plot/panzoom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 3f5f3faaf..b33531ded 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -12,7 +12,7 @@ import numpy as np from .base import BaseInteract -from .transform import Translate, Scale, TransformChain, pixels_to_ndc +from .transform import Translate, Scale, pixels_to_ndc from phy.utils._types import _as_array From 9299f2c95d79994a9d27f347c24812e6625d6900 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 22:25:35 +0100 Subject: [PATCH 0817/1059] Fix PanZoom.set_range() --- phy/plot/panzoom.py | 26 ++++++++++++++++++-------- phy/plot/tests/test_panzoom.py | 9 ++++++++- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index b33531ded..42459c1da 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -283,22 +283,32 @@ def zoom_delta(self, d, p=(0., 0.), c=1.): self.update() + def set_pan_zoom(self, pan, zoom): + self._pan = pan + self._zoom = np.clip(zoom, self._zmin, self._zmax) + + # Constrain bounding box. + self._constrain_pan() + self._constrain_zoom() + + self.update() + def set_range(self, bounds): """Zoom to fit a box.""" - # a * (-1 + t) = v0 - # a * (+1 + t) = v1 + # a * (v0 + t) = -1 + # a * (v1 + t) = +1 + # => + # a * (v1 - v0) = 2 bounds = np.asarray(bounds, dtype=np.float64) v0 = bounds[:2] v1 = bounds[2:] - self.zoom = (v1 - v0) / 2. - self.pan = v1 / self.zoom - 1 + self.set_pan_zoom(-.5 * (v0 + v1), 2. / (v1 - v0)) def get_range(self): """Return the bounds currently visible.""" - v0 = np.array([-1., -1.]) - v1 = np.array([+1., +1.]) - x0, y0 = self.zoom * (v0 + self.pan) - x1, y1 = self.zoom * (v1 + self.pan) + p, z = np.asarray(self.pan), np.asarray(self.zoom) + x0, y0 = -1. / z - p + x1, y1 = +1. / z - p return (x0, y0, x1, y1) # Event callbacks diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 8511cf101..87469e11c 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -165,9 +165,16 @@ def _test_range(*bounds): ac(pz.get_range(), bounds) _test_range(-1, -1, 1, 1) + ac(pz.zoom, (1, 1)) + _test_range(-.5, -.5, .5, .5) + ac(pz.zoom, (2, 2)) + _test_range(0, 0, 1, 1) - _test_range(-.5, 0, 0, 1) + ac(pz.zoom, (2, 2)) + + _test_range(-1, 0, 1, 1) + ac(pz.zoom, (1, 2)) #------------------------------------------------------------------------------ From 819cadace978f41023757cd351ddd98ab8e18c52 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 22:42:58 +0100 Subject: [PATCH 0818/1059] Keep aspect in PanZoom.set_range() --- phy/plot/panzoom.py | 10 +++++++--- phy/plot/tests/test_panzoom.py | 3 +++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 42459c1da..da88afaf2 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -283,7 +283,7 @@ def zoom_delta(self, d, p=(0., 0.), c=1.): self.update() - def set_pan_zoom(self, pan, zoom): + def set_pan_zoom(self, pan=None, zoom=None): self._pan = pan self._zoom = np.clip(zoom, self._zmin, self._zmax) @@ -293,7 +293,7 @@ def set_pan_zoom(self, pan, zoom): self.update() - def set_range(self, bounds): + def set_range(self, bounds, keep_aspect=False): """Zoom to fit a box.""" # a * (v0 + t) = -1 # a * (v1 + t) = +1 @@ -302,7 +302,11 @@ def set_range(self, bounds): bounds = np.asarray(bounds, dtype=np.float64) v0 = bounds[:2] v1 = bounds[2:] - self.set_pan_zoom(-.5 * (v0 + v1), 2. / (v1 - v0)) + pan = -.5 * (v0 + v1) + zoom = 2. / (v1 - v0) + if keep_aspect: + zoom = zoom.min() * np.ones(2) + self.set_pan_zoom(pan=pan, zoom=zoom) def get_range(self): """Return the bounds currently visible.""" diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 87469e11c..a0bb5c838 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -176,6 +176,9 @@ def _test_range(*bounds): _test_range(-1, 0, 1, 1) ac(pz.zoom, (1, 2)) + pz.set_range((-1, 0, 1, 1), keep_aspect=True) + ac(pz.zoom, (1, 1)) + #------------------------------------------------------------------------------ # Test panzoom on canvas From 1e385edce64fcc489b0c7dc630f2a016a950ff61 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 18 Dec 2015 22:43:58 +0100 Subject: [PATCH 0819/1059] Add zoom_on_channels() in WaveformView --- phy/cluster/manual/tests/test_views.py | 2 ++ phy/cluster/manual/views.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 1c3498c22..80680b019 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -169,6 +169,8 @@ def test_waveform_view(qtbot, model, tempdir): v.shrink_vertically() ac(v.boxed.box_pos, bp) + v.zoom_on_channels([0, 2, 4]) + # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 55a858425..42b2f6083 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -412,6 +412,19 @@ def shrink_vertically(self): self.probe_scaling[1] /= self.scaling_coeff self._update_boxes() + # Navigation + # ------------------------------------------------------------------------- + + def zoom_on_channels(self, channels_rel): + """Zoom on some channels.""" + channels_rel = np.asarray(channels_rel, dtype=np.int32) + assert 0 <= channels_rel.min() <= channels_rel.max() < self.n_channels + # Bounds of the channels. + b = self.boxed.box_bounds[channels_rel] + x0, y0 = b[:, :2].min(axis=0) + x1, y1 = b[:, 2:].max(axis=0) + self.panzoom.set_range((x0, y0, x1, y1), keep_aspect=True) + class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): From be4e48f17a93d12f1cabc8a4723d18ede8762a3b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 10:20:25 +0100 Subject: [PATCH 0820/1059] Implement memcache in Context --- phy/io/context.py | 43 +++++++++++++++++++++++++++++++---- phy/io/tests/test_context.py | 44 ++++++++++++++++++++++++++++++++---- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index dbbbef9da..070b2807e 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +from functools import wraps import logging import os import os.path as op @@ -140,10 +141,15 @@ def _ensure_cache_dirs_exist(cache_dir, name): os.makedirs(dirpath) +def _fullname(o): + """Return the fully-qualified name of an object.""" + return o.__module__ + "." + o.__class__.__name__ + + class Context(object): """Handle function cacheing and parallel map with ipyparallel.""" - def __init__(self, cache_dir, ipy_view=None): - + def __init__(self, cache_dir, ipy_view=None, verbose=0): + self.verbose = verbose # Make sure the cache directory exists. self.cache_dir = op.realpath(op.expanduser(cache_dir)) if not op.exists(self.cache_dir): @@ -152,6 +158,7 @@ def __init__(self, cache_dir, ipy_view=None): self._set_memory(self.cache_dir) self.ipy_view = ipy_view if ipy_view else None + self._memcache = {} def _set_memory(self, cache_dir): # Try importing joblib. @@ -159,7 +166,7 @@ def _set_memory(self, cache_dir): from joblib import Memory self._memory = Memory(cachedir=self.cache_dir, mmap_mode=None, - verbose=0, + verbose=self.verbose, ) logger.debug("Initialize joblib cache dir at `%s`.", self.cache_dir) @@ -180,12 +187,38 @@ def ipy_view(self, value): # Dill is necessary because we need to serialize closures. value.use_dill() - def cache(self, f): + def cache(self, f=None, memcache=False): """Cache a function using the context's cache directory.""" + if f is None: + return lambda _: self.cache(_, memcache=memcache) if self._memory is None: # pragma: no cover logger.debug("Joblib is not installed: skipping cacheing.") return - return self._memory.cache(f) + disk_cached = self._memory.cache(f) + name = _fullname(f) + if memcache: + from joblib import hash + # Create the cache dictionary for the function. + if name not in self._memcache: + self._memcache[name] = {} + + c = self._memcache[name] + + @wraps(f) + def mem_cached(*args, **kwargs): + """Cache the function in memory.""" + h = hash((args, kwargs)) + if h in c: + logger.debug("Retrieve `%s()` from the cache.", name) + return c[h] + else: + logger.debug("Compute `%s()`.", name) + out = disk_cached(*args, **kwargs) + c[h] = out + return out + return mem_cached + else: + return disk_cached def map_dask_array(self, func, da, *args, **kwargs): """Map a function on the chunks of a dask array, and return a diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 53dcbbd3d..66a90a47a 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -8,6 +8,7 @@ import os import os.path as op +import shutil import numpy as np from numpy.testing import assert_array_equal as ae @@ -16,6 +17,7 @@ from ..context import (Context, ContextPlugin, Task, _iter_chunks_dask, write_array, read_array, + _fullname, ) from phy.utils import Bunch @@ -24,7 +26,7 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture(scope='module') +@yield_fixture() def ipy_client(): def iptest_stdstreams_fileno(): @@ -46,7 +48,7 @@ def iptest_stdstreams_fileno(): @yield_fixture(scope='function') def context(tempdir): - ctx = Context('{}/cache/'.format(tempdir)) + ctx = Context('{}/cache/'.format(tempdir), verbose=1) yield ctx @@ -75,9 +77,6 @@ def temp_phy_user_dir(tempdir): def test_client_1(ipy_client): assert ipy_client.ids == [0, 1] - - -def test_client_2(ipy_client): assert ipy_client[:].map_sync(lambda x: x * x, [1, 2, 3]) == [1, 4, 9] @@ -126,6 +125,41 @@ def f(x): assert len(_res) == 2 +def test_context_memmap(tempdir, context): + + _res = [] + + @context.cache(memcache=True) + def f(x): + _res.append(x) + return x ** 2 + + # Compute the function a first time. + x = np.arange(10) + ae(f(x), x ** 2) + assert len(_res) == 1 + + # The second time, the memory cache is used. + ae(f(x), x ** 2) + assert len(_res) == 1 + + # We artificially clear the memory cache. + context._memcache[_fullname(f)].clear() + + # This time, the result is loaded from disk. + ae(f(x), x ** 2) + assert len(_res) == 1 + + # Remove the cache directory. + assert context.cache_dir.startswith(tempdir) + shutil.rmtree(context.cache_dir) + context._memcache[_fullname(f)].clear() + + # Now, the result is re-computed. + ae(f(x), x ** 2) + assert len(_res) == 2 + + def test_pickle_cache(tempdir, parallel_context): """Make sure the Context is picklable.""" with open(op.join(tempdir, 'test.pkl'), 'wb') as f: From b0cc164668a5850f7c4238315bdb8bb2a0208664 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 10:55:00 +0100 Subject: [PATCH 0821/1059] Add ClusterStats --- phy/stats/clusters.py | 39 ++++++++++++++++++++++-------- phy/stats/tests/test_clusters.py | 41 +++++++++++++++++++++++--------- 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index 9df52acea..22eb260bc 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -17,16 +17,16 @@ def mean(x): return x.mean(axis=0) -def unmasked_channels(mean_masks, min_mask=.1): +def get_unmasked_channels(mean_masks, min_mask=.1): return np.nonzero(mean_masks > min_mask)[0] -def mean_probe_position(mean_masks, site_positions): +def get_mean_probe_position(mean_masks, site_positions): return (np.sum(site_positions * mean_masks[:, np.newaxis], axis=0) / max(1, np.sum(mean_masks))) -def sorted_main_channels(mean_masks, unmasked_channels): +def get_sorted_main_channels(mean_masks, unmasked_channels): # Weighted mean of the channels, weighted by the mean masks. main_channels = np.argsort(mean_masks)[::-1] main_channels = np.array([c for c in main_channels @@ -38,7 +38,7 @@ def sorted_main_channels(mean_masks, unmasked_channels): # Wizard measures #------------------------------------------------------------------------------ -def max_waveform_amplitude(mean_masks, mean_waveforms): +def get_max_waveform_amplitude(mean_masks, mean_waveforms): assert mean_waveforms.ndim == 2 n_samples, n_channels = mean_waveforms.shape @@ -53,12 +53,12 @@ def max_waveform_amplitude(mean_masks, mean_waveforms): return np.max(M - m) -def mean_masked_features_distance(mean_features_0, - mean_features_1, - mean_masks_0, - mean_masks_1, - n_features_per_channel=None, - ): +def get_mean_masked_features_distance(mean_features_0, + mean_features_1, + mean_masks_0, + mean_masks_1, + n_features_per_channel=None, + ): """Compute the distance between the mean masked features.""" assert n_features_per_channel > 0 @@ -76,3 +76,22 @@ def mean_masked_features_distance(mean_features_0, d_1 = mu_1 * omeg_1 return np.linalg.norm(d_0 - d_1) + + +#------------------------------------------------------------------------------ +# Cluster stats object +#------------------------------------------------------------------------------ + +class ClusterStats(object): + def __init__(self, context=None): + self.context = context + self._stats = {} + + def add(self, f, name=None): + if f is None: + return lambda _: self.add(_, name=name) + name = name or f.__name__ + if self.context: + f = self.context.cache(f, memcache=True) + self._stats[name] = f + setattr(self, name, f) diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py index f9de58370..452fbe1c9 100644 --- a/phy/stats/tests/test_clusters.py +++ b/phy/stats/tests/test_clusters.py @@ -12,17 +12,19 @@ from pytest import yield_fixture from ..clusters import (mean, - unmasked_channels, - mean_probe_position, - sorted_main_channels, - mean_masked_features_distance, - max_waveform_amplitude, + get_unmasked_channels, + get_mean_probe_position, + get_sorted_main_channels, + get_mean_masked_features_distance, + get_max_waveform_amplitude, + ClusterStats, ) from phy.electrode.mea import staggered_positions from phy.io.mock import (artificial_features, artificial_masks, artificial_waveforms, ) +from phy.io.context import Context #------------------------------------------------------------------------------ @@ -86,7 +88,7 @@ def test_unmasked_channels(masks, n_channels): # Compute the mean masks. mean_masks = mean(masks) # Find the unmasked channels. - channels = unmasked_channels(mean_masks, threshold) + channels = get_unmasked_channels(mean_masks, threshold) # These are 0, 2, 4, etc. ae(channels, np.arange(0, n_channels, 2)) @@ -94,7 +96,7 @@ def test_unmasked_channels(masks, n_channels): def test_mean_probe_position(masks, site_positions): masks[:, ::2] *= .05 mean_masks = mean(masks) - mean_pos = mean_probe_position(mean_masks, site_positions) + mean_pos = get_mean_probe_position(mean_masks, site_positions) assert mean_pos.shape == (2,) assert mean_pos[0] < 0 assert mean_pos[1] > 0 @@ -104,7 +106,8 @@ def test_sorted_main_channels(masks): masks *= .05 masks[:, [5, 7]] *= 20 mean_masks = mean(masks) - channels = sorted_main_channels(mean_masks, unmasked_channels(mean_masks)) + channels = get_sorted_main_channels(mean_masks, + get_unmasked_channels(mean_masks)) assert np.all(np.in1d(channels, [5, 7])) @@ -118,7 +121,7 @@ def test_max_waveform_amplitude(masks, waveforms): mean_waveforms = mean(waveforms) mean_masks = mean(masks) - amplitude = max_waveform_amplitude(mean_masks, mean_waveforms) + amplitude = get_max_waveform_amplitude(mean_masks, mean_waveforms) assert amplitude > 0 @@ -138,6 +141,22 @@ def test_mean_masked_features_distance(features, # Check the distance. d_expected = np.sqrt(n_features_per_channel) * shift - d_computed = mean_masked_features_distance(f0, f1, m0, m1, - n_features_per_channel) + d_computed = get_mean_masked_features_distance(f0, f1, m0, m1, + n_features_per_channel) ac(d_expected, d_computed) + + +#------------------------------------------------------------------------------ +# Test ClusterStats +#------------------------------------------------------------------------------ + +def test_cluster_stats(tempdir): + context = Context(tempdir) + cs = ClusterStats(context=context) + + @cs.add + def f(x): + return x * x + + assert cs.f(3) == 9 + assert cs.f(3) == 9 From 471c286798dda06e64cd43faa47ae8deb416c152 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 11:14:20 +0100 Subject: [PATCH 0822/1059] WIP: integrate cluster stats in manual clustering component --- phy/cluster/manual/__init__.py | 2 +- phy/cluster/manual/gui_component.py | 146 ++++++++---------- phy/cluster/manual/tests/conftest.py | 41 +++++ .../manual/tests/test_gui_component.py | 69 +++------ phy/cluster/manual/tests/test_views.py | 44 +----- phy/cluster/manual/views.py | 20 +-- 6 files changed, 143 insertions(+), 179 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index b9a3246ba..6618d2eba 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,5 +5,5 @@ from ._utils import ClusterMeta from .clustering import Clustering -from .gui_component import ManualClustering, default_wizard_functions +from .gui_component import ManualClustering, create_cluster_stats from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 2d6822b86..824dcc6ab 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -16,12 +16,15 @@ from ._utils import create_cluster_meta from .clustering import Clustering from phy.stats.clusters import (mean, - max_waveform_amplitude, - mean_masked_features_distance, + get_max_waveform_amplitude, + get_mean_masked_features_distance, + get_unmasked_channels, + get_sorted_main_channels, + ClusterStats, ) from phy.gui.actions import Actions from phy.gui.widgets import Table -from phy.io.array import select_spikes, Selector +from phy.io.array import Selector from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -47,64 +50,60 @@ def _process_ups(ups): # pragma: no cover raise NotImplementedError() -def default_wizard_functions(waveforms=None, - features=None, - masks=None, - n_features_per_channel=None, - spikes_per_cluster=None, - max_n_spikes_per_cluster=1000, - ): - spc = spikes_per_cluster - nfc = n_features_per_channel - maxn = max_n_spikes_per_cluster - - def max_waveform_amplitude_quality(cluster): - spike_ids = select_spikes(cluster_ids=[cluster], - max_n_spikes_per_cluster=maxn, - spikes_per_cluster=spc, - ) - m = np.atleast_2d(masks[spike_ids]) - w = np.atleast_3d(waveforms[spike_ids]) - mean_masks = mean(m) - mean_waveforms = mean(w) - q = max_waveform_amplitude(mean_masks, mean_waveforms) - q = np.asscalar(q) - logger.debug("Computed cluster quality for %d: %.3f.", - cluster, q) - return q - - def mean_masked_features_similarity(c0, c1): - s0 = select_spikes(cluster_ids=[c0], - max_n_spikes_per_cluster=maxn, - spikes_per_cluster=spc, - ) - s1 = select_spikes(cluster_ids=[c1], - max_n_spikes_per_cluster=maxn, - spikes_per_cluster=spc, - ) - - f0 = features[s0] - m0 = np.atleast_2d(masks[s0]) - - f1 = features[s1] - m1 = np.atleast_2d(masks[s1]) - - mf0 = mean(f0) - mm0 = mean(m0) - - mf1 = mean(f1) - mm1 = mean(m1) - - d = mean_masked_features_distance(mf0, mf1, mm0, mm1, - n_features_per_channel=nfc, - ) - d = 1. / max(1e-10, d) # From distance to similarity. - logger.log(5, "Computed cluster similarity for (%d, %d): %.3f.", - c0, c1, d) - return d - - return (max_waveform_amplitude_quality, - mean_masked_features_similarity) +# ----------------------------------------------------------------------------- +# Cluster statistics +# ----------------------------------------------------------------------------- + +def create_cluster_stats(model, selector=None, context=None, + max_n_spikes_per_cluster=1000): + cs = ClusterStats(context=context) + ns = max_n_spikes_per_cluster + + def select(cluster_id): + assert cluster_id >= 0 + return selector.select_spikes([cluster_id], + max_n_spikes_per_cluster=ns) + + @cs.add + def mean_masks(cluster_id): + spike_ids = select(cluster_id) + return (mean(model.masks[spike_ids])) + + @cs.add + def mean_features(cluster_id): + spike_ids = select(cluster_id) + return (mean(model.features[spike_ids])) + + @cs.add + def mean_waveforms(cluster_id): + spike_ids = select(cluster_id) + return (mean(model.waveforms[spike_ids])) + + @cs.add + def best_channels(cluster_id): + mm = cs.mean_masks(cluster_id) + uch = get_unmasked_channels(mm) + return get_sorted_main_channels(mm, uch) + + @cs.add + def max_waveform_amplitude(cluster_id): + mm = cs.mean_masks(cluster_id) + mw = cs.mean_waveforms(cluster_id) + return np.asscalar(get_max_waveform_amplitude(mm, mw)) + + @cs.add + def mean_masked_features_score(cluster_0, cluster_1): + mf0 = cs.mean_features(cluster_0) + mf1 = cs.mean_features(cluster_1) + mm0 = cs.mean_masks(cluster_0) + mm1 = cs.mean_masks(cluster_1) + nfpc = model.n_features_per_channel + d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, + n_features_per_channel=nfpc) + s = 1. / max(1e-10, d) + return s + + return cs # ----------------------------------------------------------------------------- @@ -562,24 +561,13 @@ def attach_to_gui(self, gui, model=None, state=None): ) mc.attach(gui) - spc = mc.clustering.spikes_per_cluster - nfc = model.n_features_per_channel - - q, s = default_wizard_functions(waveforms=model.waveforms, - features=model.features, - masks=model.masks, - n_features_per_channel=nfc, - spikes_per_cluster=spc, - ) - - ctx = getattr(gui, 'context', None) - if ctx: # pragma: no cover - q, s = ctx.cache(q), ctx.cache(s) - else: - logger.warn("Context not available, unable to cache " - "the wizard functions.") + # Create the cluster stats. + cs = create_cluster_stats(model, + selector=mc.selector, + context=getattr(gui, 'context', None)) + mc.cluster_stats = cs # Add the quality column in the cluster view. - mc.cluster_view.add_column(q, name='quality') + mc.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') mc.set_default_sort('quality') - mc.set_similarity_func(s) + mc.set_similarity_func(cs.mean_masked_features_score) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 94a618c97..13f90d460 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,6 +8,17 @@ from pytest import yield_fixture +from phy.electrode.mea import staggered_positions +from phy.io.array import _spikes_per_cluster +from phy.io.mock import (artificial_waveforms, + artificial_features, + artificial_spike_clusters, + artificial_spike_samples, + artificial_masks, + artificial_traces, + ) +from phy.utils import Bunch + #------------------------------------------------------------------------------ # Fixtures @@ -32,3 +43,33 @@ def quality(): @yield_fixture def similarity(): yield lambda c, d: c * 1.01 + d + + +@yield_fixture(scope='session') +def model(): + model = Bunch() + + n_spikes = 51 + n_samples_w = 31 + n_samples_t = 20000 + n_channels = 11 + n_clusters = 3 + n_features = 4 + + model.n_channels = n_channels + model.n_spikes = n_spikes + model.sample_rate = 20000. + model.duration = n_samples_t / float(model.sample_rate) + model.spike_times = artificial_spike_samples(n_spikes) * 1. + model.spike_times /= model.spike_times[-1] + model.spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + model.channel_positions = staggered_positions(n_channels) + model.waveforms = artificial_waveforms(n_spikes, n_samples_w, n_channels) + model.masks = artificial_masks(n_spikes, n_channels) + model.traces = artificial_traces(n_samples_t, n_channels) + model.features = artificial_features(n_spikes, n_channels, n_features) + model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) + model.n_features_per_channel = n_features + model.n_samples_waveforms = n_samples_w + + yield model diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 41189b09a..e82bae5b6 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -14,10 +14,10 @@ from ..gui_component import (ManualClustering, ManualClusteringPlugin, - default_wizard_functions, + create_cluster_stats, ) from phy.gui import GUI -from phy.io.array import _spikes_per_cluster +from phy.io.array import _spikes_per_cluster, Selector from phy.io.mock import (artificial_waveforms, artificial_masks, artificial_features, @@ -61,6 +61,24 @@ def gui(qtbot): qtbot.wait(5) +#------------------------------------------------------------------------------ +# Test cluster stats +#------------------------------------------------------------------------------ + +def test_create_cluster_stats(model): + selector = Selector(spike_clusters=model.spike_clusters, + spikes_per_cluster=model.spikes_per_cluster) + cs = create_cluster_stats(model, selector=selector) + assert cs.mean_masks(1).shape == (model.n_channels,) + assert cs.mean_features(1).shape == (model.n_channels, + model.n_features_per_channel) + assert cs.mean_waveforms(1).shape == (model.n_samples_waveforms, + model.n_channels) + assert 1 <= cs.best_channels(1).shape[0] <= model.n_channels + assert 0 < cs.max_waveform_amplitude(1) < 1 + assert cs.mean_masked_features_score(1, 2) > 0 + + #------------------------------------------------------------------------------ # Test GUI component #------------------------------------------------------------------------------ @@ -109,53 +127,6 @@ def test_manual_clustering_edge_cases(manual_clustering): mc.save() -def test_manual_clustering_default_metrics(qtbot, gui): - - n_spikes = 10 - n_samples = 4 - n_channels = 7 - n_clusters = 3 - npc = 2 - - sc = artificial_spike_clusters(n_spikes, n_clusters) - spc = _spikes_per_cluster(sc) - - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - features = artificial_features(n_spikes, n_channels, npc) - masks = artificial_masks(n_spikes, n_channels) - - mc = ManualClustering(sc) - - q, s = default_wizard_functions(waveforms=waveforms, - features=features, - masks=masks, - n_features_per_channel=npc, - spikes_per_cluster=spc, - ) - - @mc.add_column() - def quality(cluster): - return q(cluster) - - mc.set_default_sort('quality', 'desc') - mc.set_similarity_func(s) - - best = sorted([(c, q(c)) for c in spc], key=itemgetter(1))[-1][0] - - similarity = [(d, s(best, d)) for d in spc if d != best] - match = sorted(similarity, key=itemgetter(1))[-1][0] - - mc.attach(gui) - - mc.cluster_view.next() - assert mc.cluster_view.selected == [best] - - mc.similarity_view.next() - - assert mc.similarity_view.selected == [match] - assert mc.selected == [best, match] - - def test_manual_clustering_skip(qtbot, gui, manual_clustering): mc = manual_clustering diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 80680b019..fc9324d85 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -11,19 +11,11 @@ import numpy as np from numpy.testing import assert_equal as ae from numpy.testing import assert_allclose as ac -from pytest import raises, yield_fixture - -from phy.io.array import _spikes_per_cluster, Selector -from phy.io.mock import (artificial_waveforms, - artificial_features, - artificial_spike_clusters, - artificial_spike_samples, - artificial_masks, - artificial_traces, - ) +from pytest import raises + from phy.gui import create_gui, GUIState -from phy.electrode.mea import staggered_positions -from phy.utils import Bunch +from phy.io.array import Selector +from phy.io.mock import artificial_traces from ..views import TraceView, _extract_wave, _selected_clusters_colors @@ -39,34 +31,6 @@ def _show(qtbot, view, stop=False): view.close() -@yield_fixture(scope='session') -def model(): - model = Bunch() - - n_spikes = 51 - n_samples_w = 31 - n_samples_t = 20000 - n_channels = 11 - n_clusters = 3 - n_features = 4 - - model.n_channels = n_channels - model.n_spikes = n_spikes - model.sample_rate = 20000. - model.duration = n_samples_t / float(model.sample_rate) - model.spike_times = artificial_spike_samples(n_spikes) * 1. - model.spike_times /= model.spike_times[-1] - model.spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - model.channel_positions = staggered_positions(n_channels) - model.waveforms = artificial_waveforms(n_spikes, n_samples_w, n_channels) - model.masks = artificial_masks(n_spikes, n_channels) - model.traces = artificial_traces(n_samples_t, n_channels) - model.features = artificial_features(n_spikes, n_channels, n_features) - model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) - - yield model - - @contextmanager def _test_view(view_name, model=None, tempdir=None): diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 42b2f6083..9dcbc4fd0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -18,7 +18,6 @@ from phy.plot import View, _get_linear_x from phy.plot.utils import _get_boxes from phy.stats import correlograms -from phy.stats.clusters import mean, unmasked_channels, sorted_main_channels from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -1086,15 +1085,16 @@ def attach_to_gui(self, gui, model=None, state=None): if fs: view.feature_scaling = fs - @view.set_best_channels_func - def best_channels(cluster_id): - """Select the best channels for a given cluster.""" - # TODO: better perf with cluster stats and cache - spike_ids = model.spikes_per_cluster[cluster_id] - m = model.masks[spike_ids] - mean_masks = mean(m) - uch = unmasked_channels(mean_masks) - return sorted_main_channels(mean_masks, uch) + # TODO + # @view.set_best_channels_func + # def best_channels(cluster_id): + # """Select the best channels for a given cluster.""" + # # TODO: better perf with cluster stats and cache + # spike_ids = model.spikes_per_cluster[cluster_id] + # m = model.masks[spike_ids] + # mean_masks = mean(m) + # uch = unmasked_channels(mean_masks) + # return sorted_main_channels(mean_masks, uch) view.attach(gui) From ea3867255635456a10b0b58d954e2f07f37f4fd5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 11:23:33 +0100 Subject: [PATCH 0823/1059] Attach best_channels() stats to the feature view in the feature plugin --- phy/cluster/manual/gui_component.py | 3 ++- phy/cluster/manual/tests/conftest.py | 1 + phy/cluster/manual/tests/test_views.py | 12 +++++++----- phy/cluster/manual/views.py | 14 ++++---------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 824dcc6ab..22d957af2 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -560,12 +560,13 @@ def attach_to_gui(self, gui, model=None, state=None): cluster_groups=model.cluster_groups, ) mc.attach(gui) + gui.manual_clustering = mc # Create the cluster stats. cs = create_cluster_stats(model, selector=mc.selector, context=getattr(gui, 'context', None)) - mc.cluster_stats = cs + gui.cluster_stats = cs # Add the quality column in the cluster view. mc.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 13f90d460..a3295bd06 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -71,5 +71,6 @@ def model(): model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) model.n_features_per_channel = n_features model.n_samples_waveforms = n_samples_w + model.cluster_groups = {c: None for c in range(n_clusters)} yield model diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index fc9324d85..8e1da420b 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -42,8 +42,13 @@ def _test_view(view_name, model=None, tempdir=None): state.set_view_params('CorrelogramView1', uniform_normalization=True) state.save() + # Create the selector. + selector = Selector(spike_clusters=model.spike_clusters, + spikes_per_cluster=model.spikes_per_cluster, + ) + # Create the GUI. - plugins = [view_name + 'Plugin'] + plugins = ['ManualClusteringPlugin', view_name + 'Plugin'] gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) gui.show() @@ -56,10 +61,7 @@ def _test_view(view_name, model=None, tempdir=None): # Select other spikes. cluster_ids = [0, 2] - sel = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=model.spikes_per_cluster, - ) - v.on_select(cluster_ids, selector=sel) + v.on_select(cluster_ids, selector=selector) yield v diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 9dcbc4fd0..f93645b58 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1085,16 +1085,10 @@ def attach_to_gui(self, gui, model=None, state=None): if fs: view.feature_scaling = fs - # TODO - # @view.set_best_channels_func - # def best_channels(cluster_id): - # """Select the best channels for a given cluster.""" - # # TODO: better perf with cluster stats and cache - # spike_ids = model.spikes_per_cluster[cluster_id] - # m = model.masks[spike_ids] - # mean_masks = mean(m) - # uch = unmasked_channels(mean_masks) - # return sorted_main_channels(mean_masks, uch) + # Attach the best_channels() function from the cluster stats. + cs = getattr(gui, 'cluster_stats', None) + if cs: + view.set_best_channels_func(cs.best_channels) view.attach(gui) From 0a55ca072c4c369a2cbccf49448adcb8f9ec2bf5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 11:24:10 +0100 Subject: [PATCH 0824/1059] Flakify --- phy/cluster/manual/tests/test_gui_component.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index e82bae5b6..7ea2ecf47 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -6,8 +6,6 @@ # Imports #------------------------------------------------------------------------------ -from operator import itemgetter - from pytest import yield_fixture import numpy as np from numpy.testing import assert_array_equal as ae @@ -17,12 +15,7 @@ create_cluster_stats, ) from phy.gui import GUI -from phy.io.array import _spikes_per_cluster, Selector -from phy.io.mock import (artificial_waveforms, - artificial_masks, - artificial_features, - artificial_spike_clusters, - ) +from phy.io.array import Selector from phy.utils import Bunch From 20bb25c861b3a438a0e5a88edbf93a8f9751e1a4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 11:35:56 +0100 Subject: [PATCH 0825/1059] Fix bug with _fullname(function) --- phy/io/context.py | 4 ++-- phy/io/tests/test_context.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index 070b2807e..09fb6f48d 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -142,8 +142,8 @@ def _ensure_cache_dirs_exist(cache_dir, name): def _fullname(o): - """Return the fully-qualified name of an object.""" - return o.__module__ + "." + o.__class__.__name__ + """Return the fully-qualified name of a function.""" + return o.__module__ + "." + o.__name__ class Context(object): diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 66a90a47a..ad899580c 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -84,6 +84,13 @@ def test_client_1(ipy_client): # Test utils and cache #------------------------------------------------------------------------------ +def test_fullname(): + def myfunction(x): + return x + + assert _fullname(myfunction) == 'phy.io.tests.test_context.myfunction' + + def test_read_write(tempdir): x = np.arange(10) write_array(op.join(tempdir, 'test.npy'), x) From 0183a6cefc7fc84b70bcd89f196c7a58f3787a4e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 11:40:08 +0100 Subject: [PATCH 0826/1059] Add some asserts --- phy/cluster/manual/gui_component.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 22d957af2..b43341b6f 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -67,17 +67,24 @@ def select(cluster_id): @cs.add def mean_masks(cluster_id): spike_ids = select(cluster_id) - return (mean(model.masks[spike_ids])) + masks = np.atleast_2d(model.masks[spike_ids]) + assert masks.ndim == 2 + return mean(masks) @cs.add def mean_features(cluster_id): spike_ids = select(cluster_id) - return (mean(model.features[spike_ids])) + features = np.atleast_2d(model.features[spike_ids]) + assert features.ndim == 3 + return mean(features) @cs.add def mean_waveforms(cluster_id): spike_ids = select(cluster_id) - return (mean(model.waveforms[spike_ids])) + waveforms = np.atleast_2d(model.waveforms[spike_ids]) + assert waveforms.ndim == 3 + mw = mean(waveforms) + return mw @cs.add def best_channels(cluster_id): @@ -89,6 +96,7 @@ def best_channels(cluster_id): def max_waveform_amplitude(cluster_id): mm = cs.mean_masks(cluster_id) mw = cs.mean_waveforms(cluster_id) + assert mw.ndim == 2 return np.asscalar(get_max_waveform_amplitude(mm, mw)) @cs.add From 76ee5fa07244fa51db30304d2d29ac39a56dae1f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 11:47:59 +0100 Subject: [PATCH 0827/1059] Fix --- phy/io/tests/test_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index ad899580c..b16d9b9b8 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -26,7 +26,7 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture() +@yield_fixture(scope='module') def ipy_client(): def iptest_stdstreams_fileno(): From fbe8e1ffd471e55b83b2c726aa5d5bb52550d2bd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 13:42:06 +0100 Subject: [PATCH 0828/1059] Close the file in discover_plugins() --- phy/utils/plugin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index f931f2660..76d5612b1 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -108,6 +108,8 @@ def discover_plugins(dirs): mod = imp.load_module(modname, file, path, descr) # noqa except Exception as e: # pragma: no cover logger.exception(e) + finally: + file.close() return IPluginRegistry.plugins From f4618c2fa8a06deddfa9d421f04de43aa1b85b0a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 13:59:43 +0100 Subject: [PATCH 0829/1059] WIP: zoom on best channels in waveform view --- phy/cluster/manual/gui_component.py | 4 +++- phy/cluster/manual/tests/test_views.py | 23 ++++++----------------- phy/cluster/manual/views.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index b43341b6f..074a0c6fe 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -361,7 +361,9 @@ def _update_similarity_view(self): clusters.""" assert self.similarity_func selection = self.cluster_view.selected - cluster_id = self.cluster_view.selected[0] + if not len(selection): + return + cluster_id = selection[0] self._best = cluster_id self.similarity_view.set_rows([c for c in self.clustering.cluster_ids if c not in selection]) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 8e1da420b..f0f5cbf87 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -42,28 +42,17 @@ def _test_view(view_name, model=None, tempdir=None): state.set_view_params('CorrelogramView1', uniform_normalization=True) state.save() - # Create the selector. - selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=model.spikes_per_cluster, - ) - # Create the GUI. - plugins = ['ManualClusteringPlugin', view_name + 'Plugin'] + plugins = ['ManualClusteringPlugin', + view_name + 'Plugin'] gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) gui.show() - v = gui.list_views(view_name)[0] - - # Select some spikes. - spike_ids = np.arange(10) - cluster_ids = np.unique(model.spike_clusters[spike_ids]) - v.on_select(cluster_ids, spike_ids=spike_ids) - - # Select other spikes. - cluster_ids = [0, 2] - v.on_select(cluster_ids, selector=selector) + gui.manual_clustering.select([]) + gui.manual_clustering.select([0]) + gui.manual_clustering.select([0, 2]) - yield v + yield gui.list_views(view_name)[0] gui.close() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f93645b58..3beec6887 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -342,6 +342,19 @@ def on_select(self, cluster_ids=None, **kwargs): def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) + + # Zoom on the best channels when selecting clusters. + cs = getattr(gui, 'cluster_stats', None) + if cs: + @gui.connect_ + def on_select(cluster_ids): + best_channels = set() + for cluster_id in cluster_ids: + channels = cs.best_channels(cluster_id) + for channel in channels: + best_channels.add(channel) + self.zoom_on_channels(list(best_channels)) + self.actions.add(self.toggle_waveform_overlap) # Box scaling. @@ -416,6 +429,8 @@ def shrink_vertically(self): def zoom_on_channels(self, channels_rel): """Zoom on some channels.""" + if channels_rel is None or not len(channels_rel): + return channels_rel = np.asarray(channels_rel, dtype=np.int32) assert 0 <= channels_rel.min() <= channels_rel.max() < self.n_channels # Bounds of the channels. From 275452612dc4d7ec7e1325698aaf786606d0e1a1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 14:22:32 +0100 Subject: [PATCH 0830/1059] Fixes --- phy/cluster/manual/gui_component.py | 9 +++++++++ phy/cluster/manual/tests/test_views.py | 7 ------- phy/cluster/manual/views.py | 19 ++++++++----------- phy/io/context.py | 4 ++-- phy/stats/clusters.py | 2 +- 5 files changed, 20 insertions(+), 21 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 074a0c6fe..89f53c48b 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -92,6 +92,15 @@ def best_channels(cluster_id): uch = get_unmasked_channels(mm) return get_sorted_main_channels(mm, uch) + @cs.add + def best_channels_multiple(cluster_ids): + best_channels = [] + for cluster in cluster_ids: + channels = cs.best_channels(cluster) + best_channels.extend([ch for ch in channels + if ch not in best_channels]) + return best_channels + @cs.add def max_waveform_amplitude(cluster_id): mm = cs.mean_masks(cluster_id) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index f0f5cbf87..c90f82c2a 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,7 +14,6 @@ from pytest import raises from phy.gui import create_gui, GUIState -from phy.io.array import Selector from phy.io.mock import artificial_traces from ..views import TraceView, _extract_wave, _selected_clusters_colors @@ -189,13 +188,7 @@ def test_trace_view_spikes(qtbot, model, tempdir): def test_feature_view(qtbot, model, tempdir): with _test_view('FeatureView', model=model, tempdir=tempdir) as v: - assert v.feature_scaling == .5 - - @v.set_best_channels_func - def best_channels(cluster_id): - return list(range(model.n_channels)) - v.add_attribute('sine', np.sin(np.linspace(-10., 10., model.n_spikes))) v.increase() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 3beec6887..491f053d7 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -348,12 +348,8 @@ def attach(self, gui): if cs: @gui.connect_ def on_select(cluster_ids): - best_channels = set() - for cluster_id in cluster_ids: - channels = cs.best_channels(cluster_id) - for channel in channels: - best_channels.add(channel) - self.zoom_on_channels(list(best_channels)) + best_channels = cs.best_channels_multiple(cluster_ids) + self.zoom_on_channels(best_channels) self.actions.add(self.toggle_waveform_overlap) @@ -825,14 +821,15 @@ def _dimensions_for_clusters(cluster_ids, n_cols=None, if not n: return {}, {} best_channels_func = best_channels_func or (lambda _: range(n_cols)) - x_channels = best_channels_func(cluster_ids[min(1, n - 1)]) - y_channels = best_channels_func(cluster_ids[0]) + x_channels = best_channels_func([cluster_ids[min(1, n - 1)]]) + y_channels = best_channels_func([cluster_ids[0]]) y_channels = y_channels[:n_cols - 1] # For the x axis, remove the channels that already are in # the y axis. x_channels = [c for c in x_channels if c not in y_channels] # Now, select the right number of channels in the x axis. x_channels = x_channels[:n_cols - 1] + # TODO: improve the choice of the channels here. if len(x_channels) < n_cols - 1: x_channels = y_channels # pragma: no cover return _dimensions_matrix(x_channels, y_channels) @@ -1096,6 +1093,8 @@ def attach_to_gui(self, gui, model=None, state=None): spike_clusters=model.spike_clusters, spike_times=model.spike_times, ) + view.attach(gui) + fs, = state.get_view_params('FeatureView', 'feature_scaling') if fs: view.feature_scaling = fs @@ -1103,9 +1102,7 @@ def attach_to_gui(self, gui, model=None, state=None): # Attach the best_channels() function from the cluster stats. cs = getattr(gui, 'cluster_stats', None) if cs: - view.set_best_channels_func(cs.best_channels) - - view.attach(gui) + view.set_best_channels_func(cs.best_channels_multiple) @gui.connect_ def on_close(): diff --git a/phy/io/context.py b/phy/io/context.py index 09fb6f48d..748c5b386 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -209,10 +209,10 @@ def mem_cached(*args, **kwargs): """Cache the function in memory.""" h = hash((args, kwargs)) if h in c: - logger.debug("Retrieve `%s()` from the cache.", name) + # Retrieve the value from the memcache. return c[h] else: - logger.debug("Compute `%s()`.", name) + # Call and cache the function. out = disk_cached(*args, **kwargs) c[h] = out return out diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index 22eb260bc..60f2a12fd 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -17,7 +17,7 @@ def mean(x): return x.mean(axis=0) -def get_unmasked_channels(mean_masks, min_mask=.1): +def get_unmasked_channels(mean_masks, min_mask=.25): return np.nonzero(mean_masks > min_mask)[0] From 83b1c0416f9c29a10a47dca4578c55611c703183 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 14:34:50 +0100 Subject: [PATCH 0831/1059] Fixes --- phy/cluster/manual/views.py | 2 +- phy/utils/event.py | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 491f053d7..2c352bbfc 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -347,7 +347,7 @@ def attach(self, gui): cs = getattr(gui, 'cluster_stats', None) if cs: @gui.connect_ - def on_select(cluster_ids): + def on_select(cluster_ids=None, selector=None, spike_ids=None): best_channels = cs.best_channels_multiple(cluster_ids) self.zoom_on_channels(best_channels) diff --git a/phy/utils/event.py b/phy/utils/event.py index 05f492b10..b04ca2a6a 100644 --- a/phy/utils/event.py +++ b/phy/utils/event.py @@ -11,7 +11,6 @@ import re from collections import defaultdict from functools import partial -from inspect import getargspec #------------------------------------------------------------------------------ @@ -112,19 +111,13 @@ def emit(self, event, *args, **kwargs): """Call all callback functions registered with an event. Any positional and keyword arguments can be passed here, and they will - be fowarded to the callback functions. + be forwarded to the callback functions. Return the list of callback return results. """ res = [] for callback in self._callbacks.get(event, []): - argspec = getargspec(callback) - if not argspec.keywords: - # Only keep the kwargs that are part of the callback's - # arg spec, unless the callback accepts `**kwargs`. - kwargs = {n: v for n, v in kwargs.items() - if n in argspec.args} res.append(callback(*args, **kwargs)) return res From d30c0bd3122eda77c40b5ab1abdb52b3ba881389 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 14:46:20 +0100 Subject: [PATCH 0832/1059] Improve cluster stats cache --- phy/cluster/manual/gui_component.py | 6 +++--- phy/stats/clusters.py | 19 +++++++++++++++---- phy/stats/tests/test_clusters.py | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 89f53c48b..fe9977e1e 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -86,7 +86,7 @@ def mean_waveforms(cluster_id): mw = mean(waveforms) return mw - @cs.add + @cs.add(cache='memory') def best_channels(cluster_id): mm = cs.mean_masks(cluster_id) uch = get_unmasked_channels(mm) @@ -101,14 +101,14 @@ def best_channels_multiple(cluster_ids): if ch not in best_channels]) return best_channels - @cs.add + @cs.add(cache='memory') def max_waveform_amplitude(cluster_id): mm = cs.mean_masks(cluster_id) mw = cs.mean_waveforms(cluster_id) assert mw.ndim == 2 return np.asscalar(get_max_waveform_amplitude(mm, mw)) - @cs.add + @cs.add(cache='memory') def mean_masked_features_score(cluster_0, cluster_1): mf0 = cs.mean_features(cluster_0) mf1 = cs.mean_features(cluster_1) diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index 60f2a12fd..10de46aaa 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -87,11 +87,22 @@ def __init__(self, context=None): self.context = context self._stats = {} - def add(self, f, name=None): + def add(self, f=None, name=None, cache=None): + """Add a cluster statistic. + + Parameters + ---------- + f : function + name : str + cache : str + Can be `None` (no cache), `disk`, or `memory`. In the latter case + the function will also be cached on disk. + + """ if f is None: - return lambda _: self.add(_, name=name) + return lambda _: self.add(_, name=name, cache=cache) name = name or f.__name__ - if self.context: - f = self.context.cache(f, memcache=True) + if cache and self.context: + f = self.context.cache(f, memcache=(cache == 'memory')) self._stats[name] = f setattr(self, name, f) diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py index 452fbe1c9..e3a2220a8 100644 --- a/phy/stats/tests/test_clusters.py +++ b/phy/stats/tests/test_clusters.py @@ -154,7 +154,7 @@ def test_cluster_stats(tempdir): context = Context(tempdir) cs = ClusterStats(context=context) - @cs.add + @cs.add(cache='memory') def f(x): return x * x From e7c3cd0681447304ca020b3e8d9f7593cd84f73c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 15:15:52 +0100 Subject: [PATCH 0833/1059] WIP: refactor best channels in feature view --- phy/cluster/manual/gui_component.py | 1 + phy/cluster/manual/views.py | 52 ++++++++--------------------- 2 files changed, 14 insertions(+), 39 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index fe9977e1e..8954c5a5e 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -106,6 +106,7 @@ def max_waveform_amplitude(cluster_id): mm = cs.mean_masks(cluster_id) mw = cs.mean_waveforms(cluster_id) assert mw.ndim == 2 + logger.debug("Computing the quality of cluster %d.", cluster_id) return np.asscalar(get_max_waveform_amplitude(mm, mw)) @cs.add(cache='memory') diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 2c352bbfc..e9780ea8f 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -784,9 +784,9 @@ def on_close(): # Feature view # ----------------------------------------------------------------------------- -def _dimensions_matrix(x_channels, y_channels): +def _dimensions_matrix(x_channels, y_channels, top_left_attribute=None): """Dimensions matrix.""" - # time, depth time, (x, 0) time, (y, 0) time, (z, 0) + # time, attr time, (x, 0) time, (y, 0) time, (z, 0) # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) # time, (z', 0) (z', 0), (x, 2) (z', 1), (y, 2) (z', 2), (z, 2) @@ -795,9 +795,8 @@ def _dimensions_matrix(x_channels, y_channels): assert len(y_channels) == n y_dim = {} x_dim = {} - # TODO: extra feature like probe depth x_dim[0, 0] = 'time' - y_dim[0, 0] = 'time' + y_dim[0, 0] = top_left_attribute or 'time' # Time in first column and first row. for i in range(1, n + 1): @@ -814,27 +813,6 @@ def _dimensions_matrix(x_channels, y_channels): return x_dim, y_dim -def _dimensions_for_clusters(cluster_ids, n_cols=None, - best_channels_func=None): - """Return the dimension matrix for the selected clusters.""" - n = len(cluster_ids) - if not n: - return {}, {} - best_channels_func = best_channels_func or (lambda _: range(n_cols)) - x_channels = best_channels_func([cluster_ids[min(1, n - 1)]]) - y_channels = best_channels_func([cluster_ids[0]]) - y_channels = y_channels[:n_cols - 1] - # For the x axis, remove the channels that already are in - # the y axis. - x_channels = [c for c in x_channels if c not in y_channels] - # Now, select the right number of channels in the x axis. - x_channels = x_channels[:n_cols - 1] - # TODO: improve the choice of the channels here. - if len(x_channels) < n_cols - 1: - x_channels = y_channels # pragma: no cover - return _dimensions_matrix(x_channels, y_channels) - - def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): """Return the mask and depth vectors for a given dimension.""" n_spikes = masks.shape[0] @@ -907,9 +885,8 @@ def __init__(self, self.attributes = {} self.add_attribute('time', spike_times) - self.best_channels_func = None - def add_attribute(self, name, values): + def add_attribute(self, name, values, top_left=True): """Add an attribute (aka extra feature). The values should be a 1D array with `n_spikes` elements. @@ -920,6 +897,9 @@ def add_attribute(self, name, values): assert values.shape == (self.n_spikes,) lim = values.min(), values.max() self.attributes[name] = (values, lim) + # Register the attribute to use in the top-left subplot. + if top_left: + self.top_left_attribute = name def _get_feature(self, dim, spike_ids=None): f = self.features[spike_ids] @@ -996,9 +976,9 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, size=ms * np.ones(n_spikes), ) - def set_best_channels_func(self, func): - """Set a function `cluster_id => list of best channels`.""" - self.best_channels_func = func + def _best_channels(self, cluster_ids): + channels = np.arange(min(self.n_channels - 1, self.n_cols - 1)) + return channels, channels def on_select(self, cluster_ids=None, **kwargs): super(FeatureView, self).on_select(cluster_ids=cluster_ids, @@ -1014,10 +994,9 @@ def on_select(self, cluster_ids=None, **kwargs): spike_ids, cluster_ids) - f = self.best_channels_func - x_dim, y_dim = _dimensions_for_clusters(cluster_ids, - n_cols=self.n_cols, - best_channels_func=f) + x_channels, y_channels = self._best_channels(cluster_ids) + x_dim, y_dim = _dimensions_matrix(x_channels, y_channels, + self.top_left_attribute) # Set the status message. n = self.n_cols @@ -1099,11 +1078,6 @@ def attach_to_gui(self, gui, model=None, state=None): if fs: view.feature_scaling = fs - # Attach the best_channels() function from the cluster stats. - cs = getattr(gui, 'cluster_stats', None) - if cs: - view.set_best_channels_func(cs.best_channels_multiple) - @gui.connect_ def on_close(): # Save the box bounds. From d763a19f33a29d807ad15860a5ee2fb43a42a306 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 15:44:43 +0100 Subject: [PATCH 0834/1059] WIP: add register/request mechanism in GUI --- phy/gui/gui.py | 18 ++++++++++++++++++ phy/gui/tests/test_gui.py | 9 +++++++++ 2 files changed, 27 insertions(+) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 1718ae48b..f649e4dc5 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -142,6 +142,9 @@ def __init__(self, self._set_name(name, subtitle) self._set_pos_size(position, size) + # Registered functions. + self._registered = {} + # Mapping {name: menuBar}. self._menus = {} @@ -203,6 +206,21 @@ def connect_(self, *args, **kwargs): def unconnect_(self, *args, **kwargs): self._event.unconnect(*args, **kwargs) + def register(self, func=None, name=None): + """Register a function for a given name.""" + if func is None: + return lambda _: self.register(func=_, name=name) + name = name or func.__name__ + self._registered[name] = func + + def request(self, name, *args, **kwargs): + """Request the result of a possibly registered function.""" + if name in self._registered: + return self._registered[name](*args, **kwargs) + else: + logger.debug("No registered function for `%s`.", name) + return None + def closeEvent(self, e): """Qt slot when the window is closed.""" if self._closed: diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index d74828a97..1fde6b571 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -102,6 +102,15 @@ def on_close_view(view): gui.default_actions.exit() +def test_gui_register(gui): + @gui.register + def hello(msg): + return 'hello ' + msg + + assert gui.request('hello', 'world') == 'hello world' + assert gui.request('unknown') is None + + def test_gui_status_message(gui): assert gui.status_message == '' gui.status_message = ':hello world!' From de1720bc2e679977cf46a4d98ca23eb58291b1d4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 16:03:42 +0100 Subject: [PATCH 0835/1059] WIP: better channel selection in feature view --- phy/cluster/manual/gui_component.py | 4 + .../manual/tests/test_gui_component.py | 4 +- phy/cluster/manual/tests/test_views.py | 12 ++- phy/cluster/manual/views.py | 82 +++++++++++++------ 4 files changed, 76 insertions(+), 26 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 8954c5a5e..b7f8b58e5 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -588,6 +588,10 @@ def attach_to_gui(self, gui, model=None, state=None): context=getattr(gui, 'context', None)) gui.cluster_stats = cs + @gui.register + def best_channels(cluster_ids): + return cs.best_channels_multiple(cluster_ids) + # Add the quality column in the cluster view. mc.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') mc.set_default_sort('quality') diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 7ea2ecf47..698e8c841 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -82,11 +82,13 @@ def test_manual_clustering_plugin(qtbot, gui): n_features_per_channel=2, waveforms=np.zeros((3, 4, 1)), features=np.zeros((3, 1, 2)), - masks=np.zeros((3, 1)), + masks=.75 * np.ones((3, 1)), ) state = Bunch() ManualClusteringPlugin().attach_to_gui(gui, model=model, state=state) + assert gui.request('best_channels', [0, 1]) == [0] + def test_manual_clustering_edge_cases(manual_clustering): mc = manual_clustering diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index c90f82c2a..96fd7aa1f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -15,7 +15,8 @@ from phy.gui import create_gui, GUIState from phy.io.mock import artificial_traces -from ..views import TraceView, _extract_wave, _selected_clusters_colors +from ..views import (TraceView, _extract_wave, _selected_clusters_colors, + _extend) #------------------------------------------------------------------------------ @@ -60,6 +61,15 @@ def _test_view(view_name, model=None, tempdir=None): # Test utils #------------------------------------------------------------------------------ +def test_extend(): + l = list(range(5)) + assert _extend(l) == l + assert _extend(l, 0) == [] + assert _extend(l, 4) == list(range(4)) + assert _extend(l, 5) == l + assert _extend(l, 6) == (l + [4]) + + def test_extract_wave(): traces = np.arange(30).reshape((6, 5)) mask = np.array([0, 1, 1, .5, 0]) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e9780ea8f..957c18174 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -136,6 +136,17 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None): return color +def _extend(channels, n=None): + channels = list(channels) + if n is None: + return channels + if len(channels) < n: + channels.extend([channels[-1]] * (n - len(channels))) + channels = channels[:n] + assert len(channels) == n + return channels + + # ----------------------------------------------------------------------------- # Manual clustering view # ----------------------------------------------------------------------------- @@ -160,6 +171,9 @@ def __init__(self, shortcuts=None, **kwargs): # Message to show in the status bar. self.status = None + # Attached GUI. + self.gui = None + # Keep track of the selected clusters and spikes. self.cluster_ids = None self.spike_ids = None @@ -180,6 +194,21 @@ def on_select(self, cluster_ids=None, selector=None, spike_ids=None): self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] self.spike_ids = np.asarray(spike_ids if spike_ids is not None else []) + def _best_channels(self, cluster_ids, n_channels_requested=None): + """Request best channels for a set of clusters.""" + # Number of channels to find on each axis. + n = n_channels_requested or self.n_channels + # Request the best channels to the GUI. + channels = (self.gui.request('best_channels', cluster_ids) + if self.gui else None) + # By default, select the first channels. + if channels is None or not len(channels): + return + assert len(channels) + # Repeat some channels if there aren't enough. + channels = _extend(channels, n) + return channels + def attach(self, gui): """Attach the view to the GUI.""" @@ -188,6 +217,7 @@ def attach(self, gui): self.panzoom.enable_keyboard_pan = False gui.add_view(self) + self.gui = gui gui.connect_(self.on_select) self.actions = Actions(gui, name=self.__class__.__name__, @@ -339,18 +369,14 @@ def on_select(self, cluster_ids=None, **kwargs): data_bounds=self.data_bounds, ) + # Zoom on the best channels when selecting clusters. + channels = self._best_channels(cluster_ids) + if channels is not None: + self.zoom_on_channels(channels) + def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) - - # Zoom on the best channels when selecting clusters. - cs = getattr(gui, 'cluster_stats', None) - if cs: - @gui.connect_ - def on_select(cluster_ids=None, selector=None, spike_ids=None): - best_channels = cs.best_channels_multiple(cluster_ids) - self.zoom_on_channels(best_channels) - self.actions.add(self.toggle_waveform_overlap) # Box scaling. @@ -784,29 +810,32 @@ def on_close(): # Feature view # ----------------------------------------------------------------------------- -def _dimensions_matrix(x_channels, y_channels, top_left_attribute=None): +def _dimensions_matrix(x_channels, y_channels, n_cols=None, + top_left_attribute=None): """Dimensions matrix.""" # time, attr time, (x, 0) time, (y, 0) time, (z, 0) # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) # time, (z', 0) (z', 0), (x, 2) (z', 1), (y, 2) (z', 2), (z, 2) - n = len(x_channels) - assert len(y_channels) == n + assert n_cols > 0 + assert len(x_channels) >= n_cols - 1 + assert len(y_channels) >= n_cols - 1 + y_dim = {} x_dim = {} x_dim[0, 0] = 'time' y_dim[0, 0] = top_left_attribute or 'time' # Time in first column and first row. - for i in range(1, n + 1): + for i in range(1, n_cols): x_dim[0, i] = 'time' y_dim[0, i] = (x_channels[i - 1], 0) x_dim[i, 0] = 'time' y_dim[i, 0] = (y_channels[i - 1], 0) - for i in range(1, n + 1): - for j in range(1, n + 1): + for i in range(1, n_cols): + for j in range(1, n_cols): x_dim[i, j] = (x_channels[i - 1], j - 1) y_dim[i, j] = (y_channels[j - 1], i - 1) @@ -976,10 +1005,6 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, size=ms * np.ones(n_spikes), ) - def _best_channels(self, cluster_ids): - channels = np.arange(min(self.n_channels - 1, self.n_cols - 1)) - return channels, channels - def on_select(self, cluster_ids=None, **kwargs): super(FeatureView, self).on_select(cluster_ids=cluster_ids, **kwargs) @@ -994,14 +1019,23 @@ def on_select(self, cluster_ids=None, **kwargs): spike_ids, cluster_ids) - x_channels, y_channels = self._best_channels(cluster_ids) + # Select the channels to show. + n = self.n_cols - 1 + channels = self._best_channels(cluster_ids, 2 * n) + channels = (channels if channels is not None + else list(range(self.n_channels))) + channels = _extend(channels, 2 * n) + assert len(channels) == 2 * n + x_channels, y_channels = channels[:n], channels[n:] + # Select the dimensions. + tla = self.top_left_attribute x_dim, y_dim = _dimensions_matrix(x_channels, y_channels, - self.top_left_attribute) + n_cols=self.n_cols, + top_left_attribute=tla) # Set the status message. - n = self.n_cols - ch_i = ', '.join(map(str, (y_dim[0, i] for i in range(1, n)))) - ch_j = ', '.join(map(str, (y_dim[i, 0] for i in range(1, n)))) + ch_i = ', '.join(map(str, x_channels)) + ch_j = ', '.join(map(str, y_channels)) self.set_status('Channels: {} - {}'.format(ch_i, ch_j)) # Set a non-time attribute as y coordinate in the top-left subplot. From 20762a5406af6c13845743bd756335ae7f582dc0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 17:02:11 +0100 Subject: [PATCH 0836/1059] Minor fixes --- phy/cluster/manual/views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 957c18174..5e82e67eb 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -532,7 +532,7 @@ def __init__(self, # Compute the mean traces in order to detrend the traces. k = max(1, self.n_samples // self.n_samples_for_mean) - self.mean_traces = np.mean(traces[::k, :], axis=0).astype(traces.dtype) + self.mean_traces = np.mean(traces[::k], axis=0).astype(traces.dtype) # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or @@ -582,7 +582,7 @@ def _load_traces(self, interval): i, j = round(self.sample_rate * start), round(self.sample_rate * end) i, j = int(i), int(j) - traces = self.traces[i:j, :] + traces = self.traces[i:j] # Detrend the traces. traces -= self.mean_traces From 70173ac0219d34ae7b352b03c46c467b246bba57 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 17:20:59 +0100 Subject: [PATCH 0837/1059] WIP: picking in boxed view --- phy/plot/interact.py | 7 +++++++ phy/plot/panzoom.py | 10 ++++++++++ phy/plot/tests/test_interact.py | 3 +++ phy/plot/tests/test_panzoom.py | 7 +++++++ 4 files changed, 27 insertions(+) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 7f553f25d..e7e0a43cd 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -218,6 +218,13 @@ def box_size(self, val): self.box_bounds = _get_boxes(self.box_pos, size=val, keep_aspect_ratio=self.keep_aspect_ratio) + def get_closest_box(self, pos): + """Get the box closest to some position.""" + pos = np.atleast_2d(pos) + d = np.sum((np.array(self.box_pos) - pos) ** 2, axis=1) + idx = np.argmin(d) + return idx + def update_boxes(self, box_pos, box_size): """Set the box bounds from specified box positions and sizes.""" assert box_pos.shape == (self.n_boxes, 2) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index da88afaf2..a63fca7f5 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -212,6 +212,14 @@ def _constrain_zoom(self): self._zoom[1] = max(self._zoom[1], 1. / (self.ymax - self._pan[1])) + def get_mouse_pos(self, pos): + """Return the mouse coordinates in NDC, taking panzoom into account.""" + position = np.asarray(self._normalize(pos)) + zoom = np.asarray(self._zoom_aspect()) + pan = np.asarray(self.pan) + mouse_pos = ((position / zoom) - pan) + return mouse_pos + # Pan and zoom # ------------------------------------------------------------------------- @@ -412,6 +420,8 @@ def on_key_press(self, event): def size(self): if self.canvas: return self.canvas.size + else: + return (1, 1) def attach(self, canvas): """Attach this interact to a canvas.""" diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 5cd1c5dc7..b2d60ccd9 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -139,6 +139,9 @@ def test_boxed_2(qtbot, canvas): boxed.box_pos *= .25 boxed.box_size = [1, .1] + idx = boxed.get_closest_box((.5, .25)) + assert idx == 4 + # qtbot.stop() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index a0bb5c838..8d2189c7b 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -180,6 +180,13 @@ def _test_range(*bounds): ac(pz.zoom, (1, 1)) +def test_panzoom_mouse_pos(): + pz = PanZoom() + pz.zoom_delta((10, 10), (.5, .25)) + pos = pz.get_mouse_pos((.01, -.01)) + ac(pos, (.5, .25), atol=1e-3) + + #------------------------------------------------------------------------------ # Test panzoom on canvas #------------------------------------------------------------------------------ From fdbae0eef124a46ffe88755ca5dacb668bc30c42 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 17:55:05 +0100 Subject: [PATCH 0838/1059] Implement channel selection in the feature view from the waveform view --- phy/cluster/manual/tests/test_views.py | 21 +++++- phy/cluster/manual/views.py | 99 ++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 96fd7aa1f..d5b05a508 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -12,6 +12,7 @@ from numpy.testing import assert_equal as ae from numpy.testing import assert_allclose as ac from pytest import raises +from vispy.util import keys from phy.gui import create_gui, GUIState from phy.io.mock import artificial_traces @@ -52,7 +53,9 @@ def _test_view(view_name, model=None, tempdir=None): gui.manual_clustering.select([0]) gui.manual_clustering.select([0, 2]) - yield gui.list_views(view_name)[0] + view = gui.list_views(view_name)[0] + view.gui = gui + yield view gui.close() @@ -135,6 +138,19 @@ def test_waveform_view(qtbot, model, tempdir): v.zoom_on_channels([0, 2, 4]) + # Simulate channel selection. + _clicked = [] + + @v.gui.connect_ + def on_channel_click(channel_idx=None, button=None, key=None): + _clicked.append((channel_idx, button, key)) + + v.events.key_press(key=keys.Key('2')) + v.events.mouse_press(pos=(0., 0.), button=1) + v.events.key_release(key=keys.Key('2')) + + assert _clicked == [(0, 1, 2)] + # qtbot.stop() @@ -204,6 +220,9 @@ def test_feature_view(qtbot, model, tempdir): v.increase() v.decrease() + v.on_channel_click(channel_idx=3, button=1, key=2) + v.clear_channels() + # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 5e82e67eb..aebd7df02 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -248,6 +248,14 @@ def on_mouse_move(self, e): # pragma: no cover # Waveform view # ----------------------------------------------------------------------------- +class ChannelClick(Event): + def __init__(self, type, channel_idx=None, key=None, button=None): + super(ChannelClick, self).__init__(type) + self.channel_idx = channel_idx + self.key = key + self.button = button + + class WaveformView(ManualClusteringView): max_n_spikes_per_cluster = 100 normalization_percentile = .95 @@ -285,12 +293,16 @@ def __init__(self, in channel_positions. """ + self._key_pressed = None + # Initialize the view. box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(layout='boxed', box_bounds=box_bounds, **kwargs) + self.events.add(channel_click=ChannelClick) + # Box and probe scaling. self.box_scaling = np.array(box_scaling if box_scaling is not None else (1., 1.)) @@ -391,6 +403,15 @@ def attach(self, gui): self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) + # We forward the event from VisPy to the phy GUI. + @self.connect + def on_channel_click(e): + gui.emit('channel_click', + channel_idx=e.channel_idx, + key=e.key, + button=e.button, + ) + def toggle_waveform_overlap(self): """Toggle the overlap of the waveforms.""" self.overlap = not self.overlap @@ -461,6 +482,25 @@ def zoom_on_channels(self, channels_rel): x1, y1 = b[:, 2:].max(axis=0) self.panzoom.set_range((x0, y0, x1, y1), keep_aspect=True) + def on_key_press(self, event): + """Handle key press events.""" + key = event.key + self._key_pressed = key + + def on_mouse_press(self, e): + key = self._key_pressed + if 'Control' in e.modifiers or key in map(str, range(10)): + key = int(key.name) if key in map(str, range(10)) else None + # Get mouse position in NDC. + mouse_pos = self.panzoom.get_mouse_pos(e.pos) + channel_idx = self.boxed.get_closest_box(mouse_pos) + self.events.channel_click(channel_idx=channel_idx, + key=key, + button=e.button) + + def on_key_release(self, event): + self._key_pressed = None + class WaveformViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): @@ -908,11 +948,14 @@ def __init__(self, # Spike times. assert spike_times.shape == (self.n_spikes,) + # Channels to show. + self.x_channels = None + self.y_channels = None + # Attributes: extra features. This is a dictionary # {name: (array, data_bounds)} # where each array is a `(n_spikes,)` array. self.attributes = {} - self.add_attribute('time', spike_times) def add_attribute(self, name, values, top_left=True): @@ -1005,6 +1048,21 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, size=ms * np.ones(n_spikes), ) + def _get_channel_dims(self, cluster_ids): + """Select the channels to show by default.""" + n = self.n_cols - 1 + channels = self._best_channels(cluster_ids, 2 * n) + channels = (channels if channels is not None + else list(range(self.n_channels))) + channels = _extend(channels, 2 * n) + assert len(channels) == 2 * n + return channels[:n], channels[n:] + + def clear_channels(self): + """Reset the dimensions.""" + self.x_channels = self.y_channels = None + self.on_select() + def on_select(self, cluster_ids=None, **kwargs): super(FeatureView, self).on_select(cluster_ids=cluster_ids, **kwargs) @@ -1019,23 +1077,21 @@ def on_select(self, cluster_ids=None, **kwargs): spike_ids, cluster_ids) - # Select the channels to show. - n = self.n_cols - 1 - channels = self._best_channels(cluster_ids, 2 * n) - channels = (channels if channels is not None - else list(range(self.n_channels))) - channels = _extend(channels, 2 * n) - assert len(channels) == 2 * n - x_channels, y_channels = channels[:n], channels[n:] # Select the dimensions. + # TODO: toggle automatic selection of the channels + x_ch, y_ch = self._get_channel_dims(cluster_ids) + if self.x_channels is None: + self.x_channels = x_ch + if self.y_channels is None: + self.y_channels = y_ch tla = self.top_left_attribute - x_dim, y_dim = _dimensions_matrix(x_channels, y_channels, + x_dim, y_dim = _dimensions_matrix(self.x_channels, self.y_channels, n_cols=self.n_cols, top_left_attribute=tla) # Set the status message. - ch_i = ', '.join(map(str, x_channels)) - ch_j = ', '.join(map(str, y_channels)) + ch_i = ', '.join(map(str, self.x_channels)) + ch_j = ', '.join(map(str, self.y_channels)) self.set_status('Channels: {} - {}'.format(ch_i, ch_j)) # Set a non-time attribute as y coordinate in the top-left subplot. @@ -1078,6 +1134,25 @@ def attach(self, gui): super(FeatureView, self).attach(gui) self.actions.add(self.increase) self.actions.add(self.decrease) + self.actions.add(self.clear_channels) + + gui.connect_(self.on_channel_click) + + def on_channel_click(self, channel_idx=None, key=None, button=None): + """Respond to the click on a channel.""" + if key is None or not (1 <= key <= (self.n_cols - 1)): + return + # Get the axis from the pressed button (1, 2, etc.) + axis = 'x' if button == 1 else 'y' + # Get the existing channels. + channels = self.x_channels if axis == 'x' else self.y_channels + if channels is None: + return + assert len(channels) == self.n_cols - 1 + assert 0 <= channel_idx < self.n_channels + # Update the channel. + channels[key - 1] = channel_idx + self.on_select() def increase(self): """Increase the scaling of the features.""" From a4b80a85ceaebf532320cce1805bbb45d2d911bb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 22:14:27 +0100 Subject: [PATCH 0839/1059] WIP: multitouch support on OS X in PanZoom --- phy/plot/panzoom.py | 40 +++++++++++++++++++++++++++++++++++++- phy/plot/tests/conftest.py | 2 +- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index a63fca7f5..050429b44 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -8,6 +8,7 @@ #------------------------------------------------------------------------------ import math +import sys import numpy as np @@ -57,6 +58,7 @@ def __init__(self, constrain_bounds=None, pan_var_name='u_pan', zoom_var_name='u_zoom', + enable_mouse_wheel=None, ): if constrain_bounds: assert xmin is None @@ -84,6 +86,14 @@ def __init__(self, self._wheel_coeff = self._default_wheel_coeff self.enable_keyboard_pan = True + + # Touch-related variables. + self._pinch = None + self._last_pinch_scale = None + if enable_mouse_wheel is None: + enable_mouse_wheel = sys.platform != 'darwin' + self.enable_mouse_wheel = enable_mouse_wheel + self._zoom_to_pointer = True self._canvas_aspect = np.ones(2) @@ -385,8 +395,33 @@ def on_mouse_move(self, event): c = np.sqrt(self.size[0]) * .03 self.zoom_delta((dx, dy), (x0, y0), c=c) + def on_touch(self, event): + if event.type == 'end': + self._pinch = None + elif event.type == 'pinch': + if event.scale in (1., self._last_pinch_scale): + self._pinch = None + return + self._last_pinch_scale = event.scale + x0, y0 = self._normalize(event.pos) + s = math.log(event.scale / event.last_scale) + c = np.sqrt(self.size[0]) * .05 + self.zoom_delta((s, s), + (x0, y0), + c=c) + self._pinch = True + elif event.type == 'touch': + if self._pinch: + return + x0, y0 = self._normalize(np.array(event.pos).mean(axis=0)) + x1, y1 = self._normalize(np.array(event.last_pos).mean(axis=0)) + dx, dy = x0 - x1, y0 - y1 + c = 5 + self.pan_delta((c * dx, c * dy)) + def on_mouse_wheel(self, event): """Zoom with the mouse wheel.""" + # NOTE: not called on OS X because of touchpad if event.modifiers: return dx = np.sign(event.delta[1]) * self._wheel_coeff @@ -437,9 +472,12 @@ def attach(self, canvas): canvas.connect(self.on_resize) canvas.connect(self.on_mouse_move) - canvas.connect(self.on_mouse_wheel) + canvas.connect(self.on_touch) canvas.connect(self.on_key_press) + if self.enable_mouse_wheel: + canvas.connect(self.on_mouse_wheel) + self._set_canvas_aspect() def update_program(self, program): diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py index a330ac4a2..522a356ef 100644 --- a/phy/plot/tests/conftest.py +++ b/phy/plot/tests/conftest.py @@ -27,5 +27,5 @@ def canvas(qapp): @yield_fixture def canvas_pz(canvas): - PanZoom().attach(canvas) + PanZoom(enable_mouse_wheel=True).attach(canvas) yield canvas From 7a3fe2f65a086d75470dfc62a9f01399088b2b37 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 19 Dec 2015 23:52:51 +0100 Subject: [PATCH 0840/1059] Increase coverage in panzoom with touch --- phy/plot/tests/test_panzoom.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 8d2189c7b..62479aa1d 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -213,6 +213,20 @@ def test_panzoom_pan_mouse(qtbot, canvas_pz, panzoom): # qtbot.stop() +def test_panzoom_touch(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Pan with mouse. + c.events.touch(type='pinch', pos=(0, 0), scale=1, last_scale=1) + c.events.touch(type='pinch', pos=(0, 0), scale=2, last_scale=1) + assert pz.zoom[0] >= 2 + c.events.touch(type='end') + + c.events.touch(type='touch', pos=(0.1, 0), last_pos=(0, 0)) + assert pz.pan[0] >= 1 + + def test_panzoom_pan_keyboard(qtbot, canvas_pz, panzoom): c = canvas_pz pz = panzoom From af0847cccbaa9a3f0a8ec3f46ec7a28592c64b53 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 09:30:55 +0100 Subject: [PATCH 0841/1059] WIP: remove read trace functions --- phy/io/__init__.py | 1 - phy/io/tests/test_traces.py | 62 ------------------------ phy/io/traces.py | 94 ------------------------------------- 3 files changed, 157 deletions(-) delete mode 100644 phy/io/tests/test_traces.py delete mode 100644 phy/io/traces.py diff --git a/phy/io/__init__.py b/phy/io/__init__.py index f4fd14812..d29518b80 100644 --- a/phy/io/__init__.py +++ b/phy/io/__init__.py @@ -4,4 +4,3 @@ """Input/output.""" from .context import Context -from .traces import read_dat, read_kwd diff --git a/phy/io/tests/test_traces.py b/phy/io/tests/test_traces.py deleted file mode 100644 index 5ad397534..000000000 --- a/phy/io/tests/test_traces.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of read traces functions.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ..traces import read_dat, _dat_n_samples, read_kwd, read_ns5 -from ..mock import artificial_traces - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_read_dat(tempdir): - n_samples = 100 - n_channels = 10 - - arr = artificial_traces(n_samples, n_channels) - - path = op.join(tempdir, 'test') - arr.tofile(path) - assert _dat_n_samples(path, dtype=np.float64, - n_channels=n_channels) == n_samples - data = read_dat(path, dtype=arr.dtype, shape=arr.shape) - ae(arr, data) - data = read_dat(path, dtype=arr.dtype, n_channels=n_channels) - ae(arr, data) - - -def test_read_kwd(tempdir): - from h5py import File - - n_samples = 100 - n_channels = 10 - arr = artificial_traces(n_samples, n_channels) - path = op.join(tempdir, 'test.kwd') - - with File(path, 'w') as f: - g0 = f.create_group('/recordings/0') - g1 = f.create_group('/recordings/1') - - arr0 = arr[:n_samples // 2, ...] - arr1 = arr[n_samples // 2:, ...] - - g0.create_dataset('data', data=arr0) - g1.create_dataset('data', data=arr1) - - ae(read_kwd(path)[...], arr) - - -def test_read_ns5(): - with raises(NotImplementedError): - read_ns5('') diff --git a/phy/io/traces.py b/phy/io/traces.py deleted file mode 100644 index 875e5f083..000000000 --- a/phy/io/traces.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Raw data readers.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np - - -#------------------------------------------------------------------------------ -# Raw data readers -#------------------------------------------------------------------------------ - -def _read_recording(filename, rec_name): - """Open a file and return a recording dataset. - - WARNING: the file is not closed when the function returns, so that the - memory-mapped array can still be accessed from disk. - - """ - from h5py import File - f = File(filename, mode='r') - return f['/recordings/{}/data'.format(rec_name)] - - -def read_kwd(filename): - """Read all traces in a `.kwd` file.""" - from h5py import File - from dask.array import Array - - with File(filename, mode='r') as f: - rec_names = sorted([name for name in f['/recordings']]) - shapes = [f['/recordings/{}/data'.format(name)].shape - for name in rec_names] - - # Create the dask graph for all recordings from the .kwdd file. - dask = {('data', idx, 0): (_read_recording, filename, rec_name) - for (idx, rec_name) in enumerate(rec_names)} - - # Make sure all recordings have the same number of channels. - n_cols = shapes[0][1] - assert all(shape[1] == n_cols for shape in shapes) - - # Create the dask Array. - chunks = (tuple(shape[0] for shape in shapes), (n_cols,)) - return Array(dask, 'data', chunks) - - -def _dat_n_samples(filename, dtype=None, n_channels=None): - assert dtype is not None - item_size = np.dtype(dtype).itemsize - n_samples = op.getsize(filename) // (item_size * n_channels) - assert n_samples >= 0 - return n_samples - - -def read_dat(filename, dtype=None, shape=None, offset=0, n_channels=None): - """Read traces from a flat binary `.dat` file. - - The output is a memory-mapped file. - - Parameters - ---------- - - filename : str - The path to the `.dat` file. - dtype : dtype - The NumPy dtype. - offset : 0 - The header size. - n_channels : int - The number of channels in the data. - shape : tuple (optional) - The array shape. Typically `(n_samples, n_channels)`. The shape is - automatically computed from the file size if the number of channels - and dtype are specified. - - """ - if shape is None: - assert n_channels > 0 - n_samples = _dat_n_samples(filename, dtype=dtype, - n_channels=n_channels) - shape = (n_samples, n_channels) - return np.memmap(filename, dtype=dtype, shape=shape, - mode='r', offset=offset) - - -def read_ns5(filename): - # TODO - raise NotImplementedError() From c85231de710d7d00aa2f44bbd647cc892f40dcbc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 10:48:07 +0100 Subject: [PATCH 0842/1059] Bug fixes --- phy/cluster/manual/views.py | 4 +++- phy/io/context.py | 2 +- phy/stats/clusters.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index aebd7df02..e0d019df2 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -572,7 +572,9 @@ def __init__(self, # Compute the mean traces in order to detrend the traces. k = max(1, self.n_samples // self.n_samples_for_mean) - self.mean_traces = np.mean(traces[::k], axis=0).astype(traces.dtype) + # NOTE: only use the first 100000 samples for perf reasons. + self.mean_traces = traces[0:100000:k].mean(axis=0) + self.mean_traces = self.mean_traces.astype(traces.dtype) # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or diff --git a/phy/io/context.py b/phy/io/context.py index 748c5b386..943180008 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -193,7 +193,7 @@ def cache(self, f=None, memcache=False): return lambda _: self.cache(_, memcache=memcache) if self._memory is None: # pragma: no cover logger.debug("Joblib is not installed: skipping cacheing.") - return + return f disk_cached = self._memory.cache(f) name = _fullname(f) if memcache: diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index 10de46aaa..d7229272c 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -104,5 +104,7 @@ def add(self, f=None, name=None, cache=None): name = name or f.__name__ if cache and self.context: f = self.context.cache(f, memcache=(cache == 'memory')) + assert f self._stats[name] = f setattr(self, name, f) + return f From 1f1910b052f3ae361fb3dfff2442d9b9ceb242db Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 11:42:15 +0100 Subject: [PATCH 0843/1059] WIP: support channel order in trace view --- phy/cluster/manual/gui_component.py | 8 ++++---- phy/cluster/manual/tests/conftest.py | 2 ++ phy/cluster/manual/views.py | 27 ++++++++++++++++++++++----- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index b7f8b58e5..ad2e3c11c 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -59,10 +59,10 @@ def create_cluster_stats(model, selector=None, context=None, cs = ClusterStats(context=context) ns = max_n_spikes_per_cluster - def select(cluster_id): + def select(cluster_id, n=None): assert cluster_id >= 0 - return selector.select_spikes([cluster_id], - max_n_spikes_per_cluster=ns) + n = n or ns + return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) @cs.add def mean_masks(cluster_id): @@ -80,7 +80,7 @@ def mean_features(cluster_id): @cs.add def mean_waveforms(cluster_id): - spike_ids = select(cluster_id) + spike_ids = select(cluster_id, ns // 10) waveforms = np.atleast_2d(model.waveforms[spike_ids]) assert waveforms.ndim == 3 mw = mean(waveforms) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index a3295bd06..6ba47ef0c 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -57,6 +57,8 @@ def model(): n_features = 4 model.n_channels = n_channels + # TODO: test with permutation and dead channels + model.channel_order = None model.n_spikes = n_spikes model.sample_rate = 20000. model.duration = n_samples_t / float(model.sample_rate) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e0d019df2..1b1ef911d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -536,7 +536,7 @@ def on_close(): # ----------------------------------------------------------------------------- class TraceView(ManualClusteringView): - n_samples_for_mean = 1000 + n_samples_for_mean = 10000 interval_duration = .5 # default duration of the interval shift_amount = .1 scaling_coeff = 1.1 @@ -555,6 +555,7 @@ def __init__(self, spike_times=None, spike_clusters=None, masks=None, + channel_order=None, n_samples_per_spike=None, scaling=None, **kwargs): @@ -566,14 +567,23 @@ def __init__(self, # Traces. assert len(traces.shape) == 2 - self.n_samples, self.n_channels = traces.shape + self.n_samples, self.n_channels_traces = traces.shape self.traces = traces self.duration = self.dt * self.n_samples + # Channel ordering and dead channels. + # We do traces[..., channel_order] whenever we load traces + # so that the channels match those in masks. + self.n_channels = (self.n_channels_traces if channel_order is None + else len(channel_order)) + self.channel_order = (channel_order if channel_order is not None + else slice(None, None, None)) + # Compute the mean traces in order to detrend the traces. k = max(1, self.n_samples // self.n_samples_for_mean) - # NOTE: only use the first 100000 samples for perf reasons. - self.mean_traces = traces[0:100000:k].mean(axis=0) + # NOTE: the virtual memory mapped traces only works on contiguous + # data so we cannot load ::k here. + self.mean_traces = self.traces[:k, self.channel_order].mean(axis=0) self.mean_traces = self.mean_traces.astype(traces.dtype) # Number of samples per spike. @@ -624,7 +634,11 @@ def _load_traces(self, interval): i, j = round(self.sample_rate * start), round(self.sample_rate * end) i, j = int(i), int(j) - traces = self.traces[i:j] + + # We load the traces and select the requested channels. + assert self.traces.shape[1] == self.n_channels_traces + traces = self.traces[i:j, self.channel_order] + assert traces.shape[1] == self.n_channels # Detrend the traces. traces -= self.mean_traces @@ -714,6 +728,8 @@ def set_interval(self, interval, change_status=True): # Load traces. traces = self._load_traces(interval) + # NOTE: once loaded, the traces do not contain the dead channels + # so there are `n_channels_order` channels here. assert traces.shape[1] == self.n_channels # Set the status message. @@ -838,6 +854,7 @@ def attach_to_gui(self, gui, model=None, state=None): spike_times=model.spike_times, spike_clusters=model.spike_clusters, masks=model.masks, + channel_order=model.channel_order, scaling=s, ) view.attach(gui) From 705346ac749a12544b99a2261e38f1fb93f9367a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:23:00 +0100 Subject: [PATCH 0844/1059] WIP: cluster store --- phy/cluster/manual/store.py | 137 +++++++++++++++++++++++++ phy/cluster/manual/tests/test_store.py | 40 ++++++++ phy/stats/clusters.py | 32 ------ phy/stats/tests/test_clusters.py | 17 --- 4 files changed, 177 insertions(+), 49 deletions(-) create mode 100644 phy/cluster/manual/store.py create mode 100644 phy/cluster/manual/tests/test_store.py diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py new file mode 100644 index 000000000..fbfc38d7b --- /dev/null +++ b/phy/cluster/manual/store.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- + +"""Manual clustering GUI component.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +import logging + +import numpy as np + +from phy.stats.clusters import (mean, + get_max_waveform_amplitude, + get_mean_masked_features_distance, + get_unmasked_channels, + get_sorted_main_channels, + ) +from phy.utils import IPlugin + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Cluster statistics +# ----------------------------------------------------------------------------- + +class ClusterStats(object): + def __init__(self, context=None): + self.context = context + self._stats = {} + + def add(self, f=None, name=None, cache=None): + """Add a cluster statistic. + + Parameters + ---------- + f : function + name : str + cache : str + Can be `None` (no cache), `disk`, or `memory`. In the latter case + the function will also be cached on disk. + + """ + if f is None: + return lambda _: self.add(_, name=name, cache=cache) + name = name or f.__name__ + if cache and self.context: + f = self.context.cache(f, memcache=(cache == 'memory')) + assert f + self._stats[name] = f + setattr(self, name, f) + return f + + def attach(self, gui): + gui.register(self, name='cluster_stats') + + +class ClusterStatsPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + mc = gui.request('manual_clustering') + if not mc: + return + ctx = gui.request('context') + cs = create_cluster_stats(model, selector=mc.selector, context=ctx) + cs.attach(gui) + + +def create_cluster_stats(model, selector=None, context=None, + max_n_spikes_per_cluster=1000): + cs = ClusterStats(context=context) + ns = max_n_spikes_per_cluster + + def select(cluster_id, n=None): + assert cluster_id >= 0 + n = n or ns + return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) + + @cs.add + def mean_masks(cluster_id): + spike_ids = select(cluster_id) + masks = np.atleast_2d(model.masks[spike_ids]) + assert masks.ndim == 2 + return mean(masks) + + @cs.add + def mean_features(cluster_id): + spike_ids = select(cluster_id) + features = np.atleast_2d(model.features[spike_ids]) + assert features.ndim == 3 + return mean(features) + + @cs.add + def mean_waveforms(cluster_id): + spike_ids = select(cluster_id, ns // 10) + waveforms = np.atleast_2d(model.waveforms[spike_ids]) + assert waveforms.ndim == 3 + mw = mean(waveforms) + return mw + + @cs.add(cache='memory') + def best_channels(cluster_id): + mm = cs.mean_masks(cluster_id) + uch = get_unmasked_channels(mm) + return get_sorted_main_channels(mm, uch) + + @cs.add + def best_channels_multiple(cluster_ids): + best_channels = [] + for cluster in cluster_ids: + channels = cs.best_channels(cluster) + best_channels.extend([ch for ch in channels + if ch not in best_channels]) + return best_channels + + @cs.add(cache='memory') + def max_waveform_amplitude(cluster_id): + mm = cs.mean_masks(cluster_id) + mw = cs.mean_waveforms(cluster_id) + assert mw.ndim == 2 + logger.debug("Computing the quality of cluster %d.", cluster_id) + return np.asscalar(get_max_waveform_amplitude(mm, mw)) + + @cs.add(cache='memory') + def mean_masked_features_score(cluster_0, cluster_1): + mf0 = cs.mean_features(cluster_0) + mf1 = cs.mean_features(cluster_1) + mm0 = cs.mean_masks(cluster_0) + mm1 = cs.mean_masks(cluster_1) + nfpc = model.n_features_per_channel + d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, + n_features_per_channel=nfpc) + s = 1. / max(1e-10, d) + return s + + return cs diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py new file mode 100644 index 000000000..a3fc31e67 --- /dev/null +++ b/phy/cluster/manual/tests/test_store.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- + +"""Test GUI component.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from ..store import create_cluster_stats, ClusterStats +from phy.io import Context, Selector + + +#------------------------------------------------------------------------------ +# Test cluster stats +#------------------------------------------------------------------------------ + +def test_create_cluster_stats(model): + selector = Selector(spike_clusters=model.spike_clusters, + spikes_per_cluster=model.spikes_per_cluster) + cs = create_cluster_stats(model, selector=selector) + assert cs.mean_masks(1).shape == (model.n_channels,) + assert cs.mean_features(1).shape == (model.n_channels, + model.n_features_per_channel) + assert cs.mean_waveforms(1).shape == (model.n_samples_waveforms, + model.n_channels) + assert 1 <= cs.best_channels(1).shape[0] <= model.n_channels + assert 0 < cs.max_waveform_amplitude(1) < 1 + assert cs.mean_masked_features_score(1, 2) > 0 + + +def test_cluster_stats(tempdir): + context = Context(tempdir) + cs = ClusterStats(context=context) + + @cs.add(cache='memory') + def f(x): + return x * x + + assert cs.f(3) == 9 + assert cs.f(3) == 9 diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index d7229272c..bff81e5ad 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -76,35 +76,3 @@ def get_mean_masked_features_distance(mean_features_0, d_1 = mu_1 * omeg_1 return np.linalg.norm(d_0 - d_1) - - -#------------------------------------------------------------------------------ -# Cluster stats object -#------------------------------------------------------------------------------ - -class ClusterStats(object): - def __init__(self, context=None): - self.context = context - self._stats = {} - - def add(self, f=None, name=None, cache=None): - """Add a cluster statistic. - - Parameters - ---------- - f : function - name : str - cache : str - Can be `None` (no cache), `disk`, or `memory`. In the latter case - the function will also be cached on disk. - - """ - if f is None: - return lambda _: self.add(_, name=name, cache=cache) - name = name or f.__name__ - if cache and self.context: - f = self.context.cache(f, memcache=(cache == 'memory')) - assert f - self._stats[name] = f - setattr(self, name, f) - return f diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py index e3a2220a8..cb83caea3 100644 --- a/phy/stats/tests/test_clusters.py +++ b/phy/stats/tests/test_clusters.py @@ -17,7 +17,6 @@ get_sorted_main_channels, get_mean_masked_features_distance, get_max_waveform_amplitude, - ClusterStats, ) from phy.electrode.mea import staggered_positions from phy.io.mock import (artificial_features, @@ -144,19 +143,3 @@ def test_mean_masked_features_distance(features, d_computed = get_mean_masked_features_distance(f0, f1, m0, m1, n_features_per_channel) ac(d_expected, d_computed) - - -#------------------------------------------------------------------------------ -# Test ClusterStats -#------------------------------------------------------------------------------ - -def test_cluster_stats(tempdir): - context = Context(tempdir) - cs = ClusterStats(context=context) - - @cs.add(cache='memory') - def f(x): - return x * x - - assert cs.f(3) == 9 - assert cs.f(3) == 9 From a29fc6e837903a2b7599cf2b74d8f119cb95e3f5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:23:40 +0100 Subject: [PATCH 0845/1059] Minor updates in phy.io --- phy/io/__init__.py | 1 + phy/io/tests/test_context.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/io/__init__.py b/phy/io/__init__.py index d29518b80..ce3214b80 100644 --- a/phy/io/__init__.py +++ b/phy/io/__init__.py @@ -4,3 +4,4 @@ """Input/output.""" from .context import Context +from .array import Selector, select_spikes diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index b16d9b9b8..dfdfe8646 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -158,7 +158,7 @@ def f(x): assert len(_res) == 1 # Remove the cache directory. - assert context.cache_dir.startswith(tempdir) + assert context.cache_dir.replace('/private', '').startswith(tempdir) shutil.rmtree(context.cache_dir) context._memcache[_fullname(f)].clear() From da8487161820525769208185ed2a9158bcd9a7e2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:24:43 +0100 Subject: [PATCH 0846/1059] Accept non functions in GUI.register() --- phy/gui/gui.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index f649e4dc5..59bf032b6 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -216,7 +216,11 @@ def register(self, func=None, name=None): def request(self, name, *args, **kwargs): """Request the result of a possibly registered function.""" if name in self._registered: - return self._registered[name](*args, **kwargs) + obj = self._registered[name] + if hasattr(obj, '__call__'): + return obj(*args, **kwargs) + else: + return obj else: logger.debug("No registered function for `%s`.", name) return None From e317ff2ab054ae7f5cf61d9e739624f41e5a280e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:25:12 +0100 Subject: [PATCH 0847/1059] Flakify --- phy/stats/tests/test_clusters.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py index cb83caea3..2fe7dac97 100644 --- a/phy/stats/tests/test_clusters.py +++ b/phy/stats/tests/test_clusters.py @@ -23,7 +23,6 @@ artificial_masks, artificial_waveforms, ) -from phy.io.context import Context #------------------------------------------------------------------------------ From 46b7c870f82f578d457a521070b044713e551117 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:37:38 +0100 Subject: [PATCH 0848/1059] WIP: cluster store --- phy/cluster/manual/__init__.py | 3 +- phy/cluster/manual/gui_component.py | 101 ++---------------- .../manual/tests/test_gui_component.py | 22 ---- phy/cluster/manual/tests/test_views.py | 3 +- phy/cluster/manual/views.py | 2 +- phy/gui/gui.py | 6 +- 6 files changed, 15 insertions(+), 122 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 6618d2eba..eeeadc5c8 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -5,5 +5,6 @@ from ._utils import ClusterMeta from .clustering import Clustering -from .gui_component import ManualClustering, create_cluster_stats +from .gui_component import ManualClustering +from .store import ClusterStats, create_cluster_stats from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index ad2e3c11c..1f0788386 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -10,18 +10,9 @@ from functools import partial import logging -import numpy as np - from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering -from phy.stats.clusters import (mean, - get_max_waveform_amplitude, - get_mean_masked_features_distance, - get_unmasked_channels, - get_sorted_main_channels, - ClusterStats, - ) from phy.gui.actions import Actions from phy.gui.widgets import Table from phy.io.array import Selector @@ -50,80 +41,6 @@ def _process_ups(ups): # pragma: no cover raise NotImplementedError() -# ----------------------------------------------------------------------------- -# Cluster statistics -# ----------------------------------------------------------------------------- - -def create_cluster_stats(model, selector=None, context=None, - max_n_spikes_per_cluster=1000): - cs = ClusterStats(context=context) - ns = max_n_spikes_per_cluster - - def select(cluster_id, n=None): - assert cluster_id >= 0 - n = n or ns - return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) - - @cs.add - def mean_masks(cluster_id): - spike_ids = select(cluster_id) - masks = np.atleast_2d(model.masks[spike_ids]) - assert masks.ndim == 2 - return mean(masks) - - @cs.add - def mean_features(cluster_id): - spike_ids = select(cluster_id) - features = np.atleast_2d(model.features[spike_ids]) - assert features.ndim == 3 - return mean(features) - - @cs.add - def mean_waveforms(cluster_id): - spike_ids = select(cluster_id, ns // 10) - waveforms = np.atleast_2d(model.waveforms[spike_ids]) - assert waveforms.ndim == 3 - mw = mean(waveforms) - return mw - - @cs.add(cache='memory') - def best_channels(cluster_id): - mm = cs.mean_masks(cluster_id) - uch = get_unmasked_channels(mm) - return get_sorted_main_channels(mm, uch) - - @cs.add - def best_channels_multiple(cluster_ids): - best_channels = [] - for cluster in cluster_ids: - channels = cs.best_channels(cluster) - best_channels.extend([ch for ch in channels - if ch not in best_channels]) - return best_channels - - @cs.add(cache='memory') - def max_waveform_amplitude(cluster_id): - mm = cs.mean_masks(cluster_id) - mw = cs.mean_waveforms(cluster_id) - assert mw.ndim == 2 - logger.debug("Computing the quality of cluster %d.", cluster_id) - return np.asscalar(get_max_waveform_amplitude(mm, mw)) - - @cs.add(cache='memory') - def mean_masked_features_score(cluster_0, cluster_1): - mf0 = cs.mean_features(cluster_0) - mf1 = cs.mean_features(cluster_1) - mm0 = cs.mean_masks(cluster_0) - mm1 = cs.mean_masks(cluster_1) - nfpc = model.n_features_per_channel - d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, - n_features_per_channel=nfpc) - s = 1. / max(1e-10, d) - return s - - return cs - - # ----------------------------------------------------------------------------- # Clustering GUI component # ----------------------------------------------------------------------------- @@ -229,6 +146,8 @@ def __init__(self, self._create_cluster_views() self._add_default_columns() + self.similarity_func = None + # Internal methods # ------------------------------------------------------------------------- @@ -369,7 +288,8 @@ def _update_cluster_view(self): def _update_similarity_view(self): """Update the similarity view with matches for the specified clusters.""" - assert self.similarity_func + if not self.similarity_func: + return selection = self.cluster_view.selected if not len(selection): return @@ -582,17 +502,10 @@ def attach_to_gui(self, gui, model=None, state=None): mc.attach(gui) gui.manual_clustering = mc - # Create the cluster stats. - cs = create_cluster_stats(model, - selector=mc.selector, - context=getattr(gui, 'context', None)) - gui.cluster_stats = cs - - @gui.register - def best_channels(cluster_ids): - return cs.best_channels_multiple(cluster_ids) - # Add the quality column in the cluster view. + cs = gui.request('cluster_stats') + if not cs: + return mc.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') mc.set_default_sort('quality') mc.set_similarity_func(cs.mean_masked_features_score) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 698e8c841..bc4206778 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -12,10 +12,8 @@ from ..gui_component import (ManualClustering, ManualClusteringPlugin, - create_cluster_stats, ) from phy.gui import GUI -from phy.io.array import Selector from phy.utils import Bunch @@ -54,24 +52,6 @@ def gui(qtbot): qtbot.wait(5) -#------------------------------------------------------------------------------ -# Test cluster stats -#------------------------------------------------------------------------------ - -def test_create_cluster_stats(model): - selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=model.spikes_per_cluster) - cs = create_cluster_stats(model, selector=selector) - assert cs.mean_masks(1).shape == (model.n_channels,) - assert cs.mean_features(1).shape == (model.n_channels, - model.n_features_per_channel) - assert cs.mean_waveforms(1).shape == (model.n_samples_waveforms, - model.n_channels) - assert 1 <= cs.best_channels(1).shape[0] <= model.n_channels - assert 0 < cs.max_waveform_amplitude(1) < 1 - assert cs.mean_masked_features_score(1, 2) > 0 - - #------------------------------------------------------------------------------ # Test GUI component #------------------------------------------------------------------------------ @@ -87,8 +67,6 @@ def test_manual_clustering_plugin(qtbot, gui): state = Bunch() ManualClusteringPlugin().attach_to_gui(gui, model=model, state=state) - assert gui.request('best_channels', [0, 1]) == [0] - def test_manual_clustering_edge_cases(manual_clustering): mc = manual_clustering diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index d5b05a508..536444164 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -44,7 +44,8 @@ def _test_view(view_name, model=None, tempdir=None): state.save() # Create the GUI. - plugins = ['ManualClusteringPlugin', + plugins = ['ClusterStatsPlugin', + 'ManualClusteringPlugin', view_name + 'Plugin'] gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) gui.show() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 1b1ef911d..5c1d75b4a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -199,7 +199,7 @@ def _best_channels(self, cluster_ids, n_channels_requested=None): # Number of channels to find on each axis. n = n_channels_requested or self.n_channels # Request the best channels to the GUI. - channels = (self.gui.request('best_channels', cluster_ids) + channels = (self.gui.request('cluster_stats', cluster_ids) if self.gui else None) # By default, select the first channels. if channels is None or not len(channels): diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 59bf032b6..fca2a6bbf 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -207,14 +207,14 @@ def unconnect_(self, *args, **kwargs): self._event.unconnect(*args, **kwargs) def register(self, func=None, name=None): - """Register a function for a given name.""" + """Register a object for a given name.""" if func is None: return lambda _: self.register(func=_, name=name) name = name or func.__name__ self._registered[name] = func def request(self, name, *args, **kwargs): - """Request the result of a possibly registered function.""" + """Request the result of a possibly registered object.""" if name in self._registered: obj = self._registered[name] if hasattr(obj, '__call__'): @@ -222,7 +222,7 @@ def request(self, name, *args, **kwargs): else: return obj else: - logger.debug("No registered function for `%s`.", name) + logger.debug("No registered object for `%s`.", name) return None def closeEvent(self, e): From bd915bc009bb5ea8c4002d9864d849f1784b354e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:42:45 +0100 Subject: [PATCH 0849/1059] WIP --- phy/cluster/manual/gui_component.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 1f0788386..f8a8d9bf4 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -374,6 +374,7 @@ def on_cluster(self, up): def attach(self, gui): self.gui = gui + gui.register(self, name='manual_clustering') # Create the actions. self._create_actions(gui) @@ -382,6 +383,14 @@ def attach(self, gui): gui.add_view(self.cluster_view, name='ClusterView') gui.add_view(self.similarity_view, name='SimilarityView') + # Add the quality column in the cluster view. + cs = gui.request('cluster_stats') + if cs: + self.cluster_view.add_column(cs.max_waveform_amplitude, + name='quality') + self.set_default_sort('quality') + self.set_similarity_func(cs.mean_masked_features_score) + # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) return self @@ -501,11 +510,3 @@ def attach_to_gui(self, gui, model=None, state=None): ) mc.attach(gui) gui.manual_clustering = mc - - # Add the quality column in the cluster view. - cs = gui.request('cluster_stats') - if not cs: - return - mc.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') - mc.set_default_sort('quality') - mc.set_similarity_func(cs.mean_masked_features_score) From f253f8f11fff653fe022c9d6312df713a009c0de Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 15:51:54 +0100 Subject: [PATCH 0850/1059] Fix --- phy/cluster/manual/store.py | 9 +++++---- phy/cluster/manual/views.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index fbfc38d7b..1e3a16197 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -11,6 +11,7 @@ import numpy as np +from phy.io.array import Selector from phy.stats.clusters import (mean, get_max_waveform_amplitude, get_mean_masked_features_distance, @@ -59,11 +60,11 @@ def attach(self, gui): class ClusterStatsPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - mc = gui.request('manual_clustering') - if not mc: - return ctx = gui.request('context') - cs = create_cluster_stats(model, selector=mc.selector, context=ctx) + selector = Selector(spike_clusters=model.spike_clusters, + spikes_per_cluster=model.spikes_per_cluster, + ) + cs = create_cluster_stats(model, selector=selector, context=ctx) cs.attach(gui) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 5c1d75b4a..0118c3dd3 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -199,8 +199,8 @@ def _best_channels(self, cluster_ids, n_channels_requested=None): # Number of channels to find on each axis. n = n_channels_requested or self.n_channels # Request the best channels to the GUI. - channels = (self.gui.request('cluster_stats', cluster_ids) - if self.gui else None) + cs = self.gui.request('cluster_stats') if self.gui else None + channels = cs.best_channels_multiple(cluster_ids) if cs else None # By default, select the first channels. if channels is None or not len(channels): return From 1916ad156b6dbffb7b794eff81cca66c4e83ca3e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 18:38:10 +0100 Subject: [PATCH 0851/1059] Rename cluster stats to cluster store --- phy/cluster/manual/__init__.py | 2 +- phy/cluster/manual/gui_component.py | 2 +- phy/cluster/manual/store.py | 12 ++++++------ phy/cluster/manual/tests/test_store.py | 10 +++++----- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index eeeadc5c8..eaf61a82e 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -6,5 +6,5 @@ from ._utils import ClusterMeta from .clustering import Clustering from .gui_component import ManualClustering -from .store import ClusterStats, create_cluster_stats +from .store import ClusterStore, create_cluster_store from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index f8a8d9bf4..736e60e32 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -384,7 +384,7 @@ def attach(self, gui): gui.add_view(self.similarity_view, name='SimilarityView') # Add the quality column in the cluster view. - cs = gui.request('cluster_stats') + cs = gui.request('cluster_store') if cs: self.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 1e3a16197..09057a1e1 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -27,7 +27,7 @@ # Cluster statistics # ----------------------------------------------------------------------------- -class ClusterStats(object): +class ClusterStore(object): def __init__(self, context=None): self.context = context self._stats = {} @@ -55,22 +55,22 @@ def add(self, f=None, name=None, cache=None): return f def attach(self, gui): - gui.register(self, name='cluster_stats') + gui.register(self, name='cluster_store') -class ClusterStatsPlugin(IPlugin): +class ClusterStorePlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): ctx = gui.request('context') selector = Selector(spike_clusters=model.spike_clusters, spikes_per_cluster=model.spikes_per_cluster, ) - cs = create_cluster_stats(model, selector=selector, context=ctx) + cs = create_cluster_store(model, selector=selector, context=ctx) cs.attach(gui) -def create_cluster_stats(model, selector=None, context=None, +def create_cluster_store(model, selector=None, context=None, max_n_spikes_per_cluster=1000): - cs = ClusterStats(context=context) + cs = ClusterStore(context=context) ns = max_n_spikes_per_cluster def select(cluster_id, n=None): diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index a3fc31e67..c76174bb0 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from ..store import create_cluster_stats, ClusterStats +from ..store import create_cluster_store, ClusterStore from phy.io import Context, Selector @@ -14,10 +14,10 @@ # Test cluster stats #------------------------------------------------------------------------------ -def test_create_cluster_stats(model): +def test_create_cluster_store(model): selector = Selector(spike_clusters=model.spike_clusters, spikes_per_cluster=model.spikes_per_cluster) - cs = create_cluster_stats(model, selector=selector) + cs = create_cluster_store(model, selector=selector) assert cs.mean_masks(1).shape == (model.n_channels,) assert cs.mean_features(1).shape == (model.n_channels, model.n_features_per_channel) @@ -28,9 +28,9 @@ def test_create_cluster_stats(model): assert cs.mean_masked_features_score(1, 2) > 0 -def test_cluster_stats(tempdir): +def test_cluster_store(tempdir): context = Context(tempdir) - cs = ClusterStats(context=context) + cs = ClusterStore(context=context) @cs.add(cache='memory') def f(x): diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 536444164..367a781a9 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -44,7 +44,7 @@ def _test_view(view_name, model=None, tempdir=None): state.save() # Create the GUI. - plugins = ['ClusterStatsPlugin', + plugins = ['ClusterStorePlugin', 'ManualClusteringPlugin', view_name + 'Plugin'] gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 0118c3dd3..e2655e43a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -199,7 +199,7 @@ def _best_channels(self, cluster_ids, n_channels_requested=None): # Number of channels to find on each axis. n = n_channels_requested or self.n_channels # Request the best channels to the GUI. - cs = self.gui.request('cluster_stats') if self.gui else None + cs = self.gui.request('cluster_store') if self.gui else None channels = cs.best_channels_multiple(cluster_ids) if cs else None # By default, select the first channels. if channels is None or not len(channels): From 6e94cd22d7f1f47c7bc046553b16db9a98fb970d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 18:46:55 +0100 Subject: [PATCH 0852/1059] Remove gui.manual_clustering --- phy/cluster/manual/gui_component.py | 1 - phy/cluster/manual/tests/test_views.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 736e60e32..4c8a3d300 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -509,4 +509,3 @@ def attach_to_gui(self, gui, model=None, state=None): cluster_groups=model.cluster_groups, ) mc.attach(gui) - gui.manual_clustering = mc diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 367a781a9..51c5ba25f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -50,9 +50,11 @@ def _test_view(view_name, model=None, tempdir=None): gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) gui.show() - gui.manual_clustering.select([]) - gui.manual_clustering.select([0]) - gui.manual_clustering.select([0, 2]) + mc = gui.request('manual_clustering') + assert mc + mc.select([]) + mc.select([0]) + mc.select([0, 2]) view = gui.list_views(view_name)[0] view.gui = gui From 5e193f9f123fdb121237a8b65e58d2dfc4a9efc0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 18:55:10 +0100 Subject: [PATCH 0853/1059] Add model data in cluster store --- phy/cluster/manual/store.py | 33 +++++++++++++++++++------- phy/cluster/manual/tests/test_store.py | 10 ++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 09057a1e1..10043da14 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -78,27 +78,44 @@ def select(cluster_id, n=None): n = n or ns return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) + # Model data. + # ------------------------------------------------------------------------- + @cs.add - def mean_masks(cluster_id): + def masks(cluster_id): spike_ids = select(cluster_id) masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 - return mean(masks) + return masks @cs.add - def mean_features(cluster_id): + def mean_masks(cluster_id): + return mean(cs.masks(cluster_id)) + + @cs.add + def features(cluster_id): spike_ids = select(cluster_id) features = np.atleast_2d(model.features[spike_ids]) assert features.ndim == 3 - return mean(features) + return features @cs.add - def mean_waveforms(cluster_id): + def mean_features(cluster_id): + return mean(cs.features(cluster_id)) + + @cs.add + def waveforms(cluster_id): spike_ids = select(cluster_id, ns // 10) waveforms = np.atleast_2d(model.waveforms[spike_ids]) assert waveforms.ndim == 3 - mw = mean(waveforms) - return mw + return waveforms + + @cs.add + def mean_waveforms(cluster_id): + return mean(cs.waveforms(cluster_id)) + + # Statistics. + # ------------------------------------------------------------------------- @cs.add(cache='memory') def best_channels(cluster_id): @@ -106,7 +123,7 @@ def best_channels(cluster_id): uch = get_unmasked_channels(mm) return get_sorted_main_channels(mm, uch) - @cs.add + @cs.add(cache='memory') def best_channels_multiple(cluster_ids): best_channels = [] for cluster in cluster_ids: diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index c76174bb0..76c94a409 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -18,11 +18,21 @@ def test_create_cluster_store(model): selector = Selector(spike_clusters=model.spike_clusters, spikes_per_cluster=model.spikes_per_cluster) cs = create_cluster_store(model, selector=selector) + + nspk = len(model.spikes_per_cluster[1]) + + assert cs.masks(1).shape == (nspk, model.n_channels) + assert cs.features(1).shape == (nspk, model.n_channels, + model.n_features_per_channel) + assert cs.waveforms(1).shape == (nspk, model.n_samples_waveforms, + model.n_channels) + assert cs.mean_masks(1).shape == (model.n_channels,) assert cs.mean_features(1).shape == (model.n_channels, model.n_features_per_channel) assert cs.mean_waveforms(1).shape == (model.n_samples_waveforms, model.n_channels) + assert 1 <= cs.best_channels(1).shape[0] <= model.n_channels assert 0 < cs.max_waveform_amplitude(1) < 1 assert cs.mean_masked_features_score(1, 2) > 0 From bdcbbf9be3d97dbbd3fc3d34594be043704e0092 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 20:46:07 +0100 Subject: [PATCH 0854/1059] WIP: selection data in cluster store --- phy/cluster/manual/store.py | 68 +++++++++++++++++++++----- phy/cluster/manual/tests/test_store.py | 36 +++++++++----- 2 files changed, 79 insertions(+), 25 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 10043da14..f52fabf2f 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -7,6 +7,7 @@ # Imports # ----------------------------------------------------------------------------- +from functools import wraps import logging import numpy as np @@ -68,51 +69,92 @@ def attach_to_gui(self, gui, model=None, state=None): cs.attach(gui) -def create_cluster_store(model, selector=None, context=None, - max_n_spikes_per_cluster=1000): +def create_cluster_store(model, selector=None, context=None): cs = ClusterStore(context=context) - ns = max_n_spikes_per_cluster + + # TODO: make this configurable. + max_n_spikes_per_cluster = { + 'masks': 1000, + 'features': 10000, + 'waveforms': 100, + } def select(cluster_id, n=None): assert cluster_id >= 0 - n = n or ns return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) # Model data. # ------------------------------------------------------------------------- + def concat(f): + """Take a function accepting a single cluster, and return a function + accepting multiple clusters.""" + @wraps(f) + def wrapped(cluster_ids): + # Single cluster. + if not hasattr(cluster_ids, '__len__'): + return f(cluster_ids) + # Concatenate the result of multiple clusters. + spike_ids_l, data_l = zip(*(f(c) for c in cluster_ids)) + return np.hstack(spike_ids_l), np.vstack(data_l) + return wrapped + @cs.add + @concat def masks(cluster_id): - spike_ids = select(cluster_id) + spike_ids = select(cluster_id, max_n_spikes_per_cluster['masks']) masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 - return masks + return spike_ids, masks + + @cs.add + @concat + def features_masks(cluster_id): + spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) + fm = np.atleast_3d(model.features_masks[spike_ids]) + assert fm.ndim == 3 + return spike_ids, fm @cs.add def mean_masks(cluster_id): - return mean(cs.masks(cluster_id)) + # We access [1] because we return spike_ids, masks. + return mean(cs.masks(cluster_id)[1]) @cs.add + @concat def features(cluster_id): - spike_ids = select(cluster_id) + spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) features = np.atleast_2d(model.features[spike_ids]) assert features.ndim == 3 - return features + return spike_ids, features @cs.add def mean_features(cluster_id): - return mean(cs.features(cluster_id)) + return mean(cs.features(cluster_id)[1]) @cs.add + @concat def waveforms(cluster_id): - spike_ids = select(cluster_id, ns // 10) + spike_ids = select(cluster_id, max_n_spikes_per_cluster['waveforms']) + waveforms = np.atleast_2d(model.waveforms[spike_ids]) + assert waveforms.ndim == 3 + return spike_ids, waveforms + + @cs.add + @concat + def waveforms_masks(cluster_id): + spike_ids = select(cluster_id, max_n_spikes_per_cluster['waveforms']) waveforms = np.atleast_2d(model.waveforms[spike_ids]) assert waveforms.ndim == 3 - return waveforms + masks = np.atleast_2d(model.masks[spike_ids]) + assert masks.ndim == 2 + # Ensure that both arrays have the same number of channels. + assert masks.shape[1] == waveforms.shape[2] + return spike_ids, waveforms, masks @cs.add def mean_waveforms(cluster_id): - return mean(cs.waveforms(cluster_id)) + return mean(cs.waveforms(cluster_id)[1]) # Statistics. # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index 76c94a409..e88656b68 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -19,21 +19,33 @@ def test_create_cluster_store(model): spikes_per_cluster=model.spikes_per_cluster) cs = create_cluster_store(model, selector=selector) - nspk = len(model.spikes_per_cluster[1]) + nc = model.n_channels + nfpc = model.n_features_per_channel + ns = len(model.spikes_per_cluster[1]) + nsw = model.n_samples_waveforms - assert cs.masks(1).shape == (nspk, model.n_channels) - assert cs.features(1).shape == (nspk, model.n_channels, - model.n_features_per_channel) - assert cs.waveforms(1).shape == (nspk, model.n_samples_waveforms, - model.n_channels) + def _check(out, *shape): + spikes, arr = out + assert spikes.shape[0] == shape[0] + assert arr.shape == shape - assert cs.mean_masks(1).shape == (model.n_channels,) - assert cs.mean_features(1).shape == (model.n_channels, - model.n_features_per_channel) - assert cs.mean_waveforms(1).shape == (model.n_samples_waveforms, - model.n_channels) + def _check2(arr, *shape): + assert arr.shape == shape - assert 1 <= cs.best_channels(1).shape[0] <= model.n_channels + _check(cs.masks(1), ns, nc) + _check(cs.features(1), ns, nc, nfpc) + _check(cs.waveforms(1), ns, nsw, nc) + + _check(cs.features_masks(1), ns, nc * nfpc, 2) + spike_ids, w, m = cs.waveforms_masks(1) + _check((spike_ids, w), ns, nsw, nc) + _check((spike_ids, m), ns, nc) + + assert cs.mean_masks(1).shape == (nc,) + assert cs.mean_features(1).shape == (nc, nfpc) + assert cs.mean_waveforms(1).shape == (nsw, nc) + + assert 1 <= cs.best_channels(1).shape[0] <= nc assert 0 < cs.max_waveform_amplitude(1) < 1 assert cs.mean_masked_features_score(1, 2) > 0 From 7f4a430c6f4cd09eb474a658efbe2cd3edcf7595 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 20:46:59 +0100 Subject: [PATCH 0855/1059] WIP: select event with just cluster_ids argument --- phy/cluster/manual/gui_component.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4c8a3d300..13bd6be84 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -15,7 +15,6 @@ from .clustering import Clustering from phy.gui.actions import Actions from phy.gui.widgets import Table -from phy.io.array import Selector from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -76,7 +75,7 @@ class ManualClustering(object): When this component is attached to a GUI, the GUI emits the following events: - select(cluster_ids, selector) + select(cluster_ids) when clusters are selected cluster(up) when a merge or split happens @@ -135,13 +134,6 @@ def __init__(self, self._global_history = GlobalHistory(process_ups=_process_ups) self._register_logging() - # Create the spike selector. - sc = self.clustering.spike_clusters - spc = self.clustering.spikes_per_cluster - self.selector = Selector(spike_clusters=sc, - spikes_per_cluster=spc, - ) - # Create the cluster views. self._create_cluster_views() self._add_default_columns() @@ -304,10 +296,7 @@ def _emit_select(self, cluster_ids): `select` event on the GUI.""" logger.debug("Select clusters: %s.", ', '.join(map(str, cluster_ids))) if self.gui: - self.gui.emit('select', - cluster_ids=cluster_ids, - selector=self.selector, - ) + self.gui.emit('select', cluster_ids) # Public methods # ------------------------------------------------------------------------- From 692584f8cfeb5d75a6d38529ed1029286188a3a9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 20:47:23 +0100 Subject: [PATCH 0856/1059] Add features_masks array in mock model --- phy/cluster/manual/tests/conftest.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 6ba47ef0c..de5c9d2ec 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +import numpy as np from pytest import yield_fixture from phy.electrode.mea import staggered_positions @@ -70,6 +71,12 @@ def model(): model.masks = artificial_masks(n_spikes, n_channels) model.traces = artificial_traces(n_samples_t, n_channels) model.features = artificial_features(n_spikes, n_channels, n_features) + + # features_masks array + f = model.features.reshape((n_spikes, -1)) + m = np.repeat(model.masks, n_features, axis=1) + model.features_masks = np.dstack((f, m)) + model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) model.n_features_per_channel = n_features model.n_samples_waveforms = n_samples_w From c4e42f194f209ad27463b5ddcc5286cbec957448 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:06:06 +0100 Subject: [PATCH 0857/1059] Increase cluster store coverage --- phy/cluster/manual/store.py | 40 ++++++++++++++++++-------- phy/cluster/manual/tests/test_store.py | 38 +++++++++++++++--------- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index f52fabf2f..6236811d0 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -77,15 +77,14 @@ def create_cluster_store(model, selector=None, context=None): 'masks': 1000, 'features': 10000, 'waveforms': 100, + 'waveform_lim': 1000, # used to compute the waveform bounds } + # TODO: add trace mean def select(cluster_id, n=None): assert cluster_id >= 0 return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) - # Model data. - # ------------------------------------------------------------------------- - def concat(f): """Take a function accepting a single cluster, and return a function accepting multiple clusters.""" @@ -99,6 +98,9 @@ def wrapped(cluster_ids): return np.hstack(spike_ids_l), np.vstack(data_l) return wrapped + # Model data. + # ------------------------------------------------------------------------- + @cs.add @concat def masks(cluster_id): @@ -115,11 +117,6 @@ def features_masks(cluster_id): assert fm.ndim == 3 return spike_ids, fm - @cs.add - def mean_masks(cluster_id): - # We access [1] because we return spike_ids, masks. - return mean(cs.masks(cluster_id)[1]) - @cs.add @concat def features(cluster_id): @@ -128,10 +125,6 @@ def features(cluster_id): assert features.ndim == 3 return spike_ids, features - @cs.add - def mean_features(cluster_id): - return mean(cs.features(cluster_id)[1]) - @cs.add @concat def waveforms(cluster_id): @@ -140,6 +133,17 @@ def waveforms(cluster_id): assert waveforms.ndim == 3 return spike_ids, waveforms + @cs.add + def waveform_lim(percentile=95): + """Return the 95% percentile of all waveform amplitudes.""" + k = max(1, model.n_spikes // max_n_spikes_per_cluster['waveform_lim']) + w = np.abs(model.waveforms[::k]) + n = w.shape[0] + w = w.reshape((n, -1)) + w = w.max(axis=1) + m = np.percentile(w, percentile) + return m + @cs.add @concat def waveforms_masks(cluster_id): @@ -152,6 +156,18 @@ def waveforms_masks(cluster_id): assert masks.shape[1] == waveforms.shape[2] return spike_ids, waveforms, masks + # Mean quantities. + # ------------------------------------------------------------------------- + + @cs.add + def mean_masks(cluster_id): + # We access [1] because we return spike_ids, masks. + return mean(cs.masks(cluster_id)[1]) + + @cs.add + def mean_features(cluster_id): + return mean(cs.features(cluster_id)[1]) + @cs.add def mean_waveforms(cluster_id): return mean(cs.waveforms(cluster_id)[1]) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index e88656b68..e4ec824cf 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -14,6 +14,18 @@ # Test cluster stats #------------------------------------------------------------------------------ +def test_cluster_store(tempdir): + context = Context(tempdir) + cs = ClusterStore(context=context) + + @cs.add(cache='memory') + def f(x): + return x * x + + assert cs.f(3) == 9 + assert cs.f(3) == 9 + + def test_create_cluster_store(model): selector = Selector(spike_clusters=model.spike_clusters, spikes_per_cluster=model.spikes_per_cluster) @@ -22,6 +34,7 @@ def test_create_cluster_store(model): nc = model.n_channels nfpc = model.n_features_per_channel ns = len(model.spikes_per_cluster[1]) + ns2 = len(model.spikes_per_cluster[2]) nsw = model.n_samples_waveforms def _check(out, *shape): @@ -32,6 +45,7 @@ def _check(out, *shape): def _check2(arr, *shape): assert arr.shape == shape + # Model data. _check(cs.masks(1), ns, nc) _check(cs.features(1), ns, nc, nfpc) _check(cs.waveforms(1), ns, nsw, nc) @@ -41,22 +55,20 @@ def _check2(arr, *shape): _check((spike_ids, w), ns, nsw, nc) _check((spike_ids, m), ns, nc) + # Test concat multiple clusters. + spike_ids, fm = cs.features_masks([1, 2]) + assert len(spike_ids) == ns + ns2 + assert fm.shape == (ns + ns2, nc * nfpc, 2) + + # Test means. assert cs.mean_masks(1).shape == (nc,) assert cs.mean_features(1).shape == (nc, nfpc) assert cs.mean_waveforms(1).shape == (nsw, nc) - assert 1 <= cs.best_channels(1).shape[0] <= nc + assert 0 < cs.waveform_lim() < 1 + + # Statistics. + assert 1 <= len(cs.best_channels(1)) <= nc + assert 1 <= len(cs.best_channels_multiple([1, 2])) <= nc assert 0 < cs.max_waveform_amplitude(1) < 1 assert cs.mean_masked_features_score(1, 2) > 0 - - -def test_cluster_store(tempdir): - context = Context(tempdir) - cs = ClusterStore(context=context) - - @cs.add(cache='memory') - def f(x): - return x * x - - assert cs.f(3) == 9 - assert cs.f(3) == 9 From 449da21371e671b353ea1de61240a594550414a3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:10:22 +0100 Subject: [PATCH 0858/1059] WIP: update views with cluster store --- phy/cluster/manual/store.py | 4 +- phy/cluster/manual/views.py | 79 ++++++++++++++++++------------------- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 6236811d0..d767acf65 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -94,8 +94,8 @@ def wrapped(cluster_ids): if not hasattr(cluster_ids, '__len__'): return f(cluster_ids) # Concatenate the result of multiple clusters. - spike_ids_l, data_l = zip(*(f(c) for c in cluster_ids)) - return np.hstack(spike_ids_l), np.vstack(data_l) + arrs = zip(*(f(c) for c in cluster_ids)) + return tuple(np.concatenate(_, axis=0) for _ in arrs) return wrapped # Model data. diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e2655e43a..21b4f8ca0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -73,6 +73,7 @@ def _extract_wave(traces, spk, mask, wave_len=None): def _get_data_bounds(arr, n_spikes=None, percentile=None): + # TODO: move to cluster store. n = arr.shape[0] k = max(1, n // n_spikes) if n_spikes else 1 w = np.abs(arr[::k]) @@ -181,21 +182,13 @@ def __init__(self, shortcuts=None, **kwargs): super(ManualClusteringView, self).__init__(**kwargs) self.events.add(status=StatusEvent) - def on_select(self, cluster_ids=None, selector=None, spike_ids=None): + def on_select(self, cluster_ids=None): cluster_ids = (cluster_ids if cluster_ids is not None else self.cluster_ids) - if spike_ids is None: - # Use the selector to select some or all of the spikes. - if selector: - ns = self.max_n_spikes_per_cluster - spike_ids = selector.select_spikes(cluster_ids, ns) - else: - spike_ids = self.spike_ids self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] - self.spike_ids = np.asarray(spike_ids if spike_ids is not None else []) def _best_channels(self, cluster_ids, n_channels_requested=None): - """Request best channels for a set of clusters.""" + """Return the best channels for a set of clusters.""" # Number of channels to find on each axis. n = n_channels_requested or self.n_channels # Request the best channels to the GUI. @@ -257,9 +250,9 @@ def __init__(self, type, channel_idx=None, key=None, button=None): class WaveformView(ManualClusteringView): - max_n_spikes_per_cluster = 100 - normalization_percentile = .95 - normalization_n_spikes = 1000 + # max_n_spikes_per_cluster = 100 + # normalization_percentile = .95 + # normalization_n_spikes = 1000 overlap = False scaling_coeff = 1.1 @@ -280,12 +273,13 @@ class WaveformView(ManualClusteringView): } def __init__(self, - waveforms=None, - masks=None, + waveforms_masks=None, spike_clusters=None, channel_positions=None, box_scaling=None, probe_scaling=None, + n_samples=None, + waveform_lim=None, **kwargs): """ @@ -295,6 +289,15 @@ def __init__(self, """ self._key_pressed = None + # Channel positions and n_channels. + assert channel_positions is not None + self.channel_positions = np.asarray(channel_positions) + self.n_channels = self.channel_positions.shape[0] + + # Number of samples per waveform. + assert n_samples > 0 + self.n_samples = n_samples + # Initialize the view. box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(layout='boxed', @@ -316,43 +319,39 @@ def __init__(self, self.box_size = np.array(self.boxed.box_size) self._update_boxes() - # Waveforms. - assert waveforms.ndim == 3 - self.n_spikes, self.n_samples, self.n_channels = waveforms.shape - self.waveforms = waveforms + # Data: functions cluster_id => waveforms. + self.waveforms_masks = waveforms_masks # Waveform normalization. - self.data_bounds = _get_data_bounds(waveforms, - self.normalization_n_spikes, - self.normalization_percentile) - - # Masks. - self.masks = masks + assert waveform_lim > 0 + self.data_bounds = [-1, -waveform_lim, +1, +waveform_lim] # Spike clusters. - assert spike_clusters.shape == (self.n_spikes,) self.spike_clusters = spike_clusters # Channel positions. assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - def on_select(self, cluster_ids=None, **kwargs): - super(WaveformView, self).on_select(cluster_ids=cluster_ids, - **kwargs) - cluster_ids, spike_ids = self.cluster_ids, self.spike_ids + def on_select(self, cluster_ids=None): + super(WaveformView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids n_clusters = len(cluster_ids) - n_spikes = len(spike_ids) - if n_spikes == 0: + if n_clusters == 0: return + # Load the waveform subset. + spike_ids, w, masks = self.waveforms_masks(cluster_ids) + n_spikes = len(spike_ids) + assert w.shape == (n_spikes, self.n_samples, self.n_channels) + assert masks.shape == (n_spikes, self.n_channels) + # Relative spike clusters. spike_clusters_rel = _get_spike_clusters_rel(self.spike_clusters, - spike_ids, - cluster_ids) + spike_ids, cluster_ids) + assert spike_clusters_rel.shape == (n_spikes,) # Fetch the waveforms. - w = self.waveforms[spike_ids] t = _get_linear_x(n_spikes, self.n_samples) # Overlap. if not self.overlap: @@ -361,9 +360,6 @@ def on_select(self, cluster_ids=None, **kwargs): # The total width should not depend on the number of clusters. t /= n_clusters - # Depth as a function of the cluster index and masks. - masks = self.masks[spike_ids] - # Plot all waveforms. # OPTIM: avoid the loop. with self.building(): @@ -509,12 +505,15 @@ def attach_to_gui(self, gui, model=None, state=None): 'probe_scaling', 'overlap', ) - view = WaveformView(waveforms=model.waveforms, - masks=model.masks, + cs = gui.request('cluster_store') + assert cs # We need the cluster store to retrieve the data. + view = WaveformView(waveforms_masks=cs.waveforms_masks, spike_clusters=model.spike_clusters, channel_positions=model.channel_positions, + n_samples=model.n_samples_waveforms, box_scaling=bs, probe_scaling=ps, + waveform_lim=cs.waveform_lim(), ) view.attach(gui) From 27c030524695984565f697d754af16b0ff905022 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:15:54 +0100 Subject: [PATCH 0859/1059] Add trace lim in cluster store --- phy/cluster/manual/store.py | 11 ++++++++++- phy/cluster/manual/tests/test_store.py | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index d767acf65..9fd119062 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -78,8 +78,8 @@ def create_cluster_store(model, selector=None, context=None): 'features': 10000, 'waveforms': 100, 'waveform_lim': 1000, # used to compute the waveform bounds + 'trace_lim': 10000, } - # TODO: add trace mean def select(cluster_id, n=None): assert cluster_id >= 0 @@ -210,4 +210,13 @@ def mean_masked_features_score(cluster_0, cluster_1): s = 1. / max(1e-10, d) return s + # Traces. + # ------------------------------------------------------------------------- + + @cs.add + def trace_lim(): + n = max_n_spikes_per_cluster['trace_lim'] + mt = model.traces[:n, model.channel_order].mean(axis=0) + return mt.astype(model.traces.dtype) + return cs diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index e4ec824cf..e8563a64a 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -65,7 +65,9 @@ def _check2(arr, *shape): assert cs.mean_features(1).shape == (nc, nfpc) assert cs.mean_waveforms(1).shape == (nsw, nc) + # Limits. assert 0 < cs.waveform_lim() < 1 + assert cs.trace_lim().shape == (1, nc) # Statistics. assert 1 <= len(cs.best_channels(1)) <= nc From d065376dd6bd0b0aba2b1370fdc925a166249c5e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:21:17 +0100 Subject: [PATCH 0860/1059] Rename trace_lim to mean_traces --- phy/cluster/manual/store.py | 6 +++--- phy/cluster/manual/tests/test_store.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 9fd119062..3521d3a37 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -78,7 +78,7 @@ def create_cluster_store(model, selector=None, context=None): 'features': 10000, 'waveforms': 100, 'waveform_lim': 1000, # used to compute the waveform bounds - 'trace_lim': 10000, + 'mean_traces': 10000, } def select(cluster_id, n=None): @@ -214,8 +214,8 @@ def mean_masked_features_score(cluster_0, cluster_1): # ------------------------------------------------------------------------- @cs.add - def trace_lim(): - n = max_n_spikes_per_cluster['trace_lim'] + def mean_traces(): + n = max_n_spikes_per_cluster['mean_traces'] mt = model.traces[:n, model.channel_order].mean(axis=0) return mt.astype(model.traces.dtype) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index e8563a64a..9d06fcff6 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -67,7 +67,7 @@ def _check2(arr, *shape): # Limits. assert 0 < cs.waveform_lim() < 1 - assert cs.trace_lim().shape == (1, nc) + assert cs.mean_traces().shape == (1, nc) # Statistics. assert 1 <= len(cs.best_channels(1)) <= nc From b5fac5905cb4ec6b51e40ddb308da91972bc025e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:24:57 +0100 Subject: [PATCH 0861/1059] WIP: update trace view --- phy/cluster/manual/tests/test_views.py | 3 ++- phy/cluster/manual/views.py | 21 ++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 51c5ba25f..79c275c3b 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -167,9 +167,10 @@ def test_trace_view_no_spikes(qtbot): sample_rate = 2000. traces = artificial_traces(n_samples, n_channels) + mt = np.atleast_2d(traces.mean(axis=0)) # Create the view. - v = TraceView(traces=traces, sample_rate=sample_rate) + v = TraceView(traces=traces, sample_rate=sample_rate, mean_traces=mt) _show(qtbot, v) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 21b4f8ca0..217c852e8 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -250,9 +250,6 @@ def __init__(self, type, channel_idx=None, key=None, button=None): class WaveformView(ManualClusteringView): - # max_n_spikes_per_cluster = 100 - # normalization_percentile = .95 - # normalization_n_spikes = 1000 overlap = False scaling_coeff = 1.1 @@ -535,7 +532,6 @@ def on_close(): # ----------------------------------------------------------------------------- class TraceView(ManualClusteringView): - n_samples_for_mean = 10000 interval_duration = .5 # default duration of the interval shift_amount = .1 scaling_coeff = 1.1 @@ -553,10 +549,11 @@ def __init__(self, sample_rate=None, spike_times=None, spike_clusters=None, - masks=None, + masks=None, # full array of masks channel_order=None, n_samples_per_spike=None, scaling=None, + mean_traces=None, **kwargs): # Sample rate. @@ -578,12 +575,9 @@ def __init__(self, self.channel_order = (channel_order if channel_order is not None else slice(None, None, None)) - # Compute the mean traces in order to detrend the traces. - k = max(1, self.n_samples // self.n_samples_for_mean) - # NOTE: the virtual memory mapped traces only works on contiguous - # data so we cannot load ::k here. - self.mean_traces = self.traces[:k, self.channel_order].mean(axis=0) - self.mean_traces = self.mean_traces.astype(traces.dtype) + # Used to detrend the traces. + assert mean_traces.shape == (1, self.n_channels) + self.mean_traces = mean_traces # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or @@ -848,6 +842,10 @@ def decrease(self): class TraceViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): s, = state.get_view_params('TraceView', 'scaling') + + cs = gui.request('cluster_store') + assert cs # We need the cluster store to retrieve the data. + view = TraceView(traces=model.traces, sample_rate=model.sample_rate, spike_times=model.spike_times, @@ -855,6 +853,7 @@ def attach_to_gui(self, gui, model=None, state=None): masks=model.masks, channel_order=model.channel_order, scaling=s, + mean_traces=cs.mean_traces(), ) view.attach(gui) From 5ee3269faa60b66c706e5bbfa390fa0babe842c6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:30:06 +0100 Subject: [PATCH 0862/1059] Add background features in store --- phy/cluster/manual/store.py | 12 ++++++++++++ phy/cluster/manual/tests/test_store.py | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 3521d3a37..dd2cb24d8 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -76,6 +76,7 @@ def create_cluster_store(model, selector=None, context=None): max_n_spikes_per_cluster = { 'masks': 1000, 'features': 10000, + 'background_features': 10000, 'waveforms': 100, 'waveform_lim': 1000, # used to compute the waveform bounds 'mean_traces': 10000, @@ -125,6 +126,17 @@ def features(cluster_id): assert features.ndim == 3 return spike_ids, features + @cs.add + def background_features(): + n = max_n_spikes_per_cluster['background_features'] + k = max(1, model.n_spikes // n) + features = model.features[::k] + spike_ids = np.arange(0, model.n_spikes, k) + assert spike_ids.shape == (features.shape[0],) + assert features.ndim == 3 + assert features.shape[0] <= n + return spike_ids, features + @cs.add @concat def waveforms(cluster_id): diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index 9d06fcff6..ffffb5878 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -55,6 +55,10 @@ def _check2(arr, *shape): _check((spike_ids, w), ns, nsw, nc) _check((spike_ids, m), ns, nc) + spike_ids, bgf = cs.background_features() + assert bgf.ndim == 3 + assert spike_ids.shape == (bgf.shape[0],) + # Test concat multiple clusters. spike_ids, fm = cs.features_masks([1, 2]) assert len(spike_ids) == ns + ns2 From b4d9ed3eb362411438a322777ab83ebf9932bc3d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 21:38:10 +0100 Subject: [PATCH 0863/1059] Background features and masks --- phy/cluster/manual/store.py | 11 +++++++---- phy/cluster/manual/tests/test_store.py | 7 +++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index dd2cb24d8..960f3c8f8 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -76,7 +76,7 @@ def create_cluster_store(model, selector=None, context=None): max_n_spikes_per_cluster = { 'masks': 1000, 'features': 10000, - 'background_features': 10000, + 'background_features_masks': 10000, 'waveforms': 100, 'waveform_lim': 1000, # used to compute the waveform bounds 'mean_traces': 10000, @@ -127,15 +127,18 @@ def features(cluster_id): return spike_ids, features @cs.add - def background_features(): - n = max_n_spikes_per_cluster['background_features'] + def background_features_masks(): + n = max_n_spikes_per_cluster['background_features_masks'] k = max(1, model.n_spikes // n) features = model.features[::k] + masks = model.masks[::k] spike_ids = np.arange(0, model.n_spikes, k) assert spike_ids.shape == (features.shape[0],) assert features.ndim == 3 assert features.shape[0] <= n - return spike_ids, features + assert masks.ndim == 2 + assert masks.shape[0] == features.shape[0] + return spike_ids, features, masks @cs.add @concat diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index ffffb5878..dcb1aa7ed 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -55,9 +55,12 @@ def _check2(arr, *shape): _check((spike_ids, w), ns, nsw, nc) _check((spike_ids, m), ns, nc) - spike_ids, bgf = cs.background_features() + spike_ids, bgf, bgm = cs.background_features_masks() assert bgf.ndim == 3 - assert spike_ids.shape == (bgf.shape[0],) + assert bgf.shape[1:] == (nc, nfpc) + assert bgm.ndim == 2 + assert bgm.shape[1] == nc + assert spike_ids.shape == (bgf.shape[0],) == (bgm.shape[0],) # Test concat multiple clusters. spike_ids, fm = cs.features_masks([1, 2]) From f6365f50d8817004e5220ec336183b6e2d63845a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 22:02:22 +0100 Subject: [PATCH 0864/1059] Add feature_lim in cluster store --- phy/cluster/manual/store.py | 13 +++++++++++++ phy/cluster/manual/tests/test_store.py | 1 + 2 files changed, 14 insertions(+) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 960f3c8f8..1b67a5ad5 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -79,6 +79,7 @@ def create_cluster_store(model, selector=None, context=None): 'background_features_masks': 10000, 'waveforms': 100, 'waveform_lim': 1000, # used to compute the waveform bounds + 'feature_lim': 1000, # used to compute the waveform bounds 'mean_traces': 10000, } @@ -126,6 +127,18 @@ def features(cluster_id): assert features.ndim == 3 return spike_ids, features + @cs.add + def feature_lim(percentile=95): + """Return the 95% percentile of all feature amplitudes.""" + # TODO: refactor with waveforms and _get_data_bounds + k = max(1, model.n_spikes // max_n_spikes_per_cluster['feature_lim']) + w = np.abs(model.features[::k]) + n = w.shape[0] + w = w.reshape((n, -1)) + w = w.max(axis=1) + m = np.percentile(w, percentile) + return m + @cs.add def background_features_masks(): n = max_n_spikes_per_cluster['background_features_masks'] diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index dcb1aa7ed..168f72923 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -74,6 +74,7 @@ def _check2(arr, *shape): # Limits. assert 0 < cs.waveform_lim() < 1 + assert 0 < cs.feature_lim() < 1 assert cs.mean_traces().shape == (1, nc) # Statistics. From 69484d6769e87d4887bf5dcd352daeefa2ed259d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 22:08:21 +0100 Subject: [PATCH 0865/1059] Update cluster_store.features_masks --- phy/cluster/manual/store.py | 7 ++++++- phy/cluster/manual/tests/test_store.py | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 1b67a5ad5..7bf22ade3 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -116,8 +116,13 @@ def masks(cluster_id): def features_masks(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) fm = np.atleast_3d(model.features_masks[spike_ids]) + ns = fm.shape[0] + nc = model.n_channels + nfpc = model.n_features_per_channel assert fm.ndim == 3 - return spike_ids, fm + f = fm[..., 0].reshape((ns, nc, nfpc)) + m = fm[:, ::nfpc, 1] + return spike_ids, f, m @cs.add @concat diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index 168f72923..902f275dc 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -50,11 +50,12 @@ def _check2(arr, *shape): _check(cs.features(1), ns, nc, nfpc) _check(cs.waveforms(1), ns, nsw, nc) - _check(cs.features_masks(1), ns, nc * nfpc, 2) + # Waveforms masks. spike_ids, w, m = cs.waveforms_masks(1) _check((spike_ids, w), ns, nsw, nc) _check((spike_ids, m), ns, nc) + # Background feature masks. spike_ids, bgf, bgm = cs.background_features_masks() assert bgf.ndim == 3 assert bgf.shape[1:] == (nc, nfpc) @@ -63,9 +64,10 @@ def _check2(arr, *shape): assert spike_ids.shape == (bgf.shape[0],) == (bgm.shape[0],) # Test concat multiple clusters. - spike_ids, fm = cs.features_masks([1, 2]) + spike_ids, f, m = cs.features_masks([1, 2]) assert len(spike_ids) == ns + ns2 - assert fm.shape == (ns + ns2, nc * nfpc, 2) + assert f.shape == (ns + ns2, nc, nfpc) + assert m.shape == (ns + ns2, nc) # Test means. assert cs.mean_masks(1).shape == (nc,) From 5d8b0970af1b66b2d88b06c589d8af68fce6f0b1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 22:09:37 +0100 Subject: [PATCH 0866/1059] WIP: update feature view --- phy/cluster/manual/views.py | 106 +++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 50 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 217c852e8..b0a554c9d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -177,7 +177,6 @@ def __init__(self, shortcuts=None, **kwargs): # Keep track of the selected clusters and spikes. self.cluster_ids = None - self.spike_ids = None super(ManualClusteringView, self).__init__(**kwargs) self.events.add(status=StatusEvent) @@ -757,9 +756,8 @@ def set_interval(self, interval, change_status=True): self.build() self.update() - def on_select(self, cluster_ids=None, **kwargs): - super(TraceView, self).on_select(cluster_ids=cluster_ids, - **kwargs) + def on_select(self, cluster_ids=None): + super(TraceView, self).on_select(cluster_ids) self.set_interval(self.interval, change_status=False) def attach(self, gui): @@ -915,10 +913,6 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(ManualClusteringView): - max_n_spikes_per_cluster = 100000 - normalization_percentile = .95 - normalization_n_spikes = 1000 - n_spikes_bg = 10000 _default_marker_size = 3. _feature_scaling = 1. @@ -928,17 +922,30 @@ class FeatureView(ManualClusteringView): } def __init__(self, - features=None, - masks=None, + features_masks=None, # function cluster_id => (spk, f, m) + background_features_masks=None, # (spk, f, m) spike_times=None, spike_clusters=None, + n_channels=None, + n_features_per_channel=None, + feature_lim=None, **kwargs): - assert len(features.shape) == 3 - self.n_spikes, self.n_channels, self.n_features = features.shape - self.n_cols = self.n_features + 1 + assert features_masks + self.features_masks = features_masks + + # This is a tuple (spikes, features, masks). + self.background_features_masks = background_features_masks + + self.n_features_per_channel = n_features_per_channel + assert n_channels > 0 + self.n_channels = n_channels + + self.n_spikes = spike_times.shape[0] + assert self.n_spikes >= 0 + + self.n_cols = self.n_features_per_channel + 1 self.shape = (self.n_cols, self.n_cols) - self.features = features # Initialize the view. super(FeatureView, self).__init__(layout='grid', @@ -946,17 +953,7 @@ def __init__(self, **kwargs) # Feature normalization. - self.data_bounds = _get_data_bounds(features, - self.normalization_n_spikes, - self.normalization_percentile) - - # Masks. - self.masks = masks - - # Background spikes. - k = max(1, self.n_spikes // self.n_spikes_bg) - self.spike_ids_bg = slice(None, None, k) - self.masks_bg = self.masks[self.spike_ids_bg] + self.data_bounds = [-1, -feature_lim, +1, +feature_lim] # Spike clusters. assert spike_clusters.shape == (self.n_spikes,) @@ -990,15 +987,12 @@ def add_attribute(self, name, values, top_left=True): if top_left: self.top_left_attribute = name - def _get_feature(self, dim, spike_ids=None): - f = self.features[spike_ids] - assert f.ndim == 3 - + def _get_feature(self, dim, spike_ids, f): if dim in self.attributes: # Extra features like time. values, _ = self.attributes[dim] values = values[spike_ids] - assert values.shape == (f.shape[0],) + # assert values.shape == (f.shape[0],) return values else: assert len(dim) == 2 @@ -1034,8 +1028,7 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, n_clusters = len(self.cluster_ids) # Retrieve the data bounds. - data_bounds = self._get_dim_bounds(x_dim[i, j], - y_dim[i, j]) + data_bounds = self._get_dim_bounds(x_dim[i, j], y_dim[i, j]) # Retrieve the masks and depth. mx, dx = _project_mask_depth(x_dim[i, j], masks, @@ -1080,20 +1073,26 @@ def clear_channels(self): self.x_channels = self.y_channels = None self.on_select() - def on_select(self, cluster_ids=None, **kwargs): - super(FeatureView, self).on_select(cluster_ids=cluster_ids, - **kwargs) - cluster_ids, spike_ids = self.cluster_ids, self.spike_ids - n_spikes = len(spike_ids) - if n_spikes == 0: + def on_select(self, cluster_ids=None): + super(FeatureView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids + n_clusters = len(cluster_ids) + if n_clusters == 0: return - # Get the masks for the selected spikes. - masks = self.masks[spike_ids] - sc = _get_spike_clusters_rel(self.spike_clusters, - spike_ids, + # Get the spikes, features, masks. + spike_ids, f, masks = self.features_masks(cluster_ids) + assert f.ndim == 3 + assert masks.ndim == 2 + assert spike_ids.shape[0] == f.shape[0] == masks.shape[0] + + # Get the spike clusters. + sc = _get_spike_clusters_rel(self.spike_clusters, spike_ids, cluster_ids) + # Get the background features. + spike_ids_bg, features_bg, masks_bg = self.background_features_masks + # Select the dimensions. # TODO: toggle automatic selection of the channels x_ch, y_ch = self._get_channel_dims(cluster_ids) @@ -1123,16 +1122,18 @@ def on_select(self, cluster_ids=None, **kwargs): for j in range(self.n_cols): # Retrieve the x and y values for the subplot. - x = self._get_feature(x_dim[i, j], self.spike_ids) - y = self._get_feature(y_dim[i, j], self.spike_ids) + x = self._get_feature(x_dim[i, j], spike_ids, f) + y = self._get_feature(y_dim[i, j], spike_ids, f) # Retrieve the x and y values for the background spikes. - x_bg = self._get_feature(x_dim[i, j], self.spike_ids_bg) - y_bg = self._get_feature(y_dim[i, j], self.spike_ids_bg) + x_bg = self._get_feature(x_dim[i, j], spike_ids_bg, + features_bg) + y_bg = self._get_feature(y_dim[i, j], spike_ids_bg, + features_bg) # Background features. self._plot_features(i, j, x_dim, y_dim, x_bg, y_bg, - masks=self.masks_bg) + masks=masks_bg) # Cluster features. self._plot_features(i, j, x_dim, y_dim, x, y, masks=masks, @@ -1192,11 +1193,16 @@ def feature_scaling(self, value): class FeatureViewPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - - view = FeatureView(features=model.features, - masks=model.masks, + cs = gui.request('cluster_store') + assert cs + bg = cs.background_features_masks() + view = FeatureView(features_masks=cs.features_masks, + background_features_masks=bg, spike_clusters=model.spike_clusters, spike_times=model.spike_times, + n_channels=model.n_channels, + n_features_per_channel=model.n_features_per_channel, + feature_lim=cs.feature_lim(), ) view.attach(gui) From 1c42b21de88455c5044cf62c3b7773fcefc52014 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 22:19:46 +0100 Subject: [PATCH 0867/1059] WIP: update correlogram view --- phy/cluster/manual/store.py | 15 +++++++++++++++ phy/cluster/manual/tests/test_store.py | 3 --- phy/cluster/manual/views.py | 22 ++++------------------ 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 7bf22ade3..958c560db 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -24,6 +24,21 @@ logger = logging.getLogger(__name__) +# ----------------------------------------------------------------------------- +# Utils +# ----------------------------------------------------------------------------- + +def _get_data_bounds(arr, n_spikes=None, percentile=None): + n = arr.shape[0] + k = max(1, n // n_spikes) if n_spikes else 1 + arr = np.abs(arr[::k]) + n = arr.shape[0] + arr = arr.reshape((n, -1)) + arr = arr.max(axis=1) + m = np.percentile(arr, percentile) + return m + + # ----------------------------------------------------------------------------- # Cluster statistics # ----------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index 902f275dc..c354a27ac 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -42,9 +42,6 @@ def _check(out, *shape): assert spikes.shape[0] == shape[0] assert arr.shape == shape - def _check2(arr, *shape): - assert arr.shape == shape - # Model data. _check(cs.masks(1), ns, nc) _check(cs.features(1), ns, nc, nfpc) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index b0a554c9d..f5b118eee 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -72,18 +72,6 @@ def _extract_wave(traces, spk, mask, wave_len=None): return data, channels -def _get_data_bounds(arr, n_spikes=None, percentile=None): - # TODO: move to cluster store. - n = arr.shape[0] - k = max(1, n // n_spikes) if n_spikes else 1 - w = np.abs(arr[::k]) - n = w.shape[0] - w = w.reshape((n, -1)) - w = w.max(axis=1) - m = np.percentile(w, percentile) - return [-1, -m, +1, +m] - - def _get_spike_clusters_rel(spike_clusters, spike_ids, cluster_ids): # Relative spike clusters. # NOTE: the order of the clusters in cluster_ids matters. @@ -1304,13 +1292,11 @@ def _compute_correlograms(self, cluster_ids): return ccg - def on_select(self, cluster_ids=None, **kwargs): - super(CorrelogramView, self).on_select(cluster_ids=cluster_ids, - **kwargs) - cluster_ids, spike_ids = self.cluster_ids, self.spike_ids + def on_select(self, cluster_ids=None): + super(CorrelogramView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids n_clusters = len(cluster_ids) - n_spikes = len(spike_ids) - if n_spikes == 0: + if n_clusters == 0: return ccg = self._compute_correlograms(cluster_ids) From 648bdefc695809f29834703592596bc3ee629780 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 22:25:18 +0100 Subject: [PATCH 0868/1059] Fixes --- phy/cluster/manual/store.py | 1 - phy/cluster/manual/views.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 958c560db..017605ea4 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -168,7 +168,6 @@ def background_features_masks(): spike_ids = np.arange(0, model.n_spikes, k) assert spike_ids.shape == (features.shape[0],) assert features.ndim == 3 - assert features.shape[0] <= n assert masks.ndim == 2 assert masks.shape[0] == features.shape[0] return spike_ids, features, masks diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f5b118eee..f47ec1166 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -563,8 +563,8 @@ def __init__(self, else slice(None, None, None)) # Used to detrend the traces. - assert mean_traces.shape == (1, self.n_channels) - self.mean_traces = mean_traces + self.mean_traces = np.atleast_2d(mean_traces) + assert self.mean_traces.shape == (1, self.n_channels) # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or From 10e9d19b49656090860a8ec92c1f728f77a1abfd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 20 Dec 2015 22:45:11 +0100 Subject: [PATCH 0869/1059] Fixes --- phy/cluster/manual/store.py | 6 ++++-- phy/io/context.py | 3 ++- phy/io/tests/test_context.py | 10 +--------- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 017605ea4..bda0653dd 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -90,8 +90,8 @@ def create_cluster_store(model, selector=None, context=None): # TODO: make this configurable. max_n_spikes_per_cluster = { 'masks': 1000, - 'features': 10000, - 'background_features_masks': 10000, + 'features': 1000, + 'background_features_masks': 1000, 'waveforms': 100, 'waveform_lim': 1000, # used to compute the waveform bounds 'feature_lim': 1000, # used to compute the waveform bounds @@ -252,6 +252,8 @@ def mean_masked_features_score(cluster_0, cluster_1): mm0 = cs.mean_masks(cluster_0) mm1 = cs.mean_masks(cluster_1) nfpc = model.n_features_per_channel + logger.debug("Computing the similarity of clusters %d, %d.", + cluster_0, cluster_1) d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, n_features_per_channel=nfpc) s = 1. / max(1e-10, d) diff --git a/phy/io/context.py b/phy/io/context.py index 943180008..897051ded 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -326,7 +326,8 @@ def __setstate__(self, state): class ContextPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): # Create the computing context. - gui.context = Context(op.join(op.dirname(model.path), '.phy/')) + gui.register(Context(op.join(op.dirname(model.path), '.phy/')), + name='context') #------------------------------------------------------------------------------ diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index dfdfe8646..e1a4cb724 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -15,11 +15,10 @@ from pytest import yield_fixture, mark, raises from six.moves import cPickle -from ..context import (Context, ContextPlugin, Task, +from ..context import (Context, Task, _iter_chunks_dask, write_array, read_array, _fullname, ) -from phy.utils import Bunch #------------------------------------------------------------------------------ @@ -177,13 +176,6 @@ def test_pickle_cache(tempdir, parallel_context): assert ctx.cache_dir == parallel_context.cache_dir -def test_context_plugin(tempdir): - gui = Bunch() - path = op.join(tempdir, 'model.ext') - ContextPlugin().attach_to_gui(gui, model=Bunch(path=path), state=Bunch()) - assert op.dirname(path) + '/.phy' in gui.context.cache_dir - - #------------------------------------------------------------------------------ # Test map #------------------------------------------------------------------------------ From f05e8da275d1e8b5076b249ec64dd09263ec9f55 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 09:09:40 +0100 Subject: [PATCH 0870/1059] Refactor _get_data_lim() --- phy/cluster/manual/store.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index bda0653dd..349c83cfd 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -28,15 +28,13 @@ # Utils # ----------------------------------------------------------------------------- -def _get_data_bounds(arr, n_spikes=None, percentile=None): +def _get_data_lim(arr, n_spikes=None, percentile=None): n = arr.shape[0] k = max(1, n // n_spikes) if n_spikes else 1 arr = np.abs(arr[::k]) n = arr.shape[0] arr = arr.reshape((n, -1)) - arr = arr.max(axis=1) - m = np.percentile(arr, percentile) - return m + return arr.max() # ----------------------------------------------------------------------------- @@ -148,16 +146,10 @@ def features(cluster_id): return spike_ids, features @cs.add - def feature_lim(percentile=95): - """Return the 95% percentile of all feature amplitudes.""" - # TODO: refactor with waveforms and _get_data_bounds - k = max(1, model.n_spikes // max_n_spikes_per_cluster['feature_lim']) - w = np.abs(model.features[::k]) - n = w.shape[0] - w = w.reshape((n, -1)) - w = w.max(axis=1) - m = np.percentile(w, percentile) - return m + def feature_lim(): + """Return the max of a subset of the feature amplitudes.""" + return _get_data_lim(model.features, + max_n_spikes_per_cluster['feature_lim']) @cs.add def background_features_masks(): @@ -181,15 +173,10 @@ def waveforms(cluster_id): return spike_ids, waveforms @cs.add - def waveform_lim(percentile=95): - """Return the 95% percentile of all waveform amplitudes.""" - k = max(1, model.n_spikes // max_n_spikes_per_cluster['waveform_lim']) - w = np.abs(model.waveforms[::k]) - n = w.shape[0] - w = w.reshape((n, -1)) - w = w.max(axis=1) - m = np.percentile(w, percentile) - return m + def waveform_lim(): + """Return the max of a subset of the waveform amplitudes.""" + return _get_data_lim(model.waveforms, + max_n_spikes_per_cluster['waveform_lim']) @cs.add @concat From d2bd99892bbba46fce90f6c9f95ed773f92b1181 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 09:12:12 +0100 Subject: [PATCH 0871/1059] Clear table with empty data --- phy/gui/static/table.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 87c1e32bc..f9818f27c 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -27,7 +27,7 @@ Table.prototype.setData = function(data) { data.cols: list of column names data.items: list of rows (each row is an object {col: value}) */ - if (data.items.length == 0) return; + // if (data.items.length == 0) return; // Reinitialize the state. this.selected = []; From 35ca61a2ff7ce0721e001ffdbdd21b5da26b12e0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 09:41:00 +0100 Subject: [PATCH 0872/1059] Use attribute instead of texture for plot colors --- phy/plot/glsl/plot.vert | 6 ++---- phy/plot/visuals.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/phy/plot/glsl/plot.vert b/phy/plot/glsl/plot.vert index a7a0dbf00..05583b8b1 100644 --- a/phy/plot/glsl/plot.vert +++ b/phy/plot/glsl/plot.vert @@ -1,11 +1,9 @@ #include "utils.glsl" attribute vec3 a_position; +attribute vec4 a_color; attribute float a_signal_index; // 0..n_signals-1 -uniform sampler2D u_plot_colors; -uniform float n_signals; - varying vec4 v_color; varying float v_signal_index; @@ -14,6 +12,6 @@ void main() { gl_Position = transform(xy); gl_Position.z = a_position.z; - v_color = fetch_texture(a_signal_index, u_plot_colors, n_signals); + v_color = a_color; v_signal_index = a_signal_index; } diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index c73fa81d9..f5dadf772 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -8,7 +8,6 @@ #------------------------------------------------------------------------------ import numpy as np -from vispy.gloo import Texture2D from .base import BaseVisual from .transform import Range, NDC @@ -208,6 +207,12 @@ def set_data(self, *args, **kwargs): pos[:, 1] = y.ravel() assert pos.shape == (n, 2) + # Generate the color attribute. + color = data.color + assert color.shape == (n_signals, 4) + color = np.repeat(color, n_samples, axis=0) + assert color.shape == (n, 4) + # Generate signal index. signal_index = np.repeat(np.arange(n_signals), n_samples) signal_index = _get_array(signal_index, (n, 1)).astype(np.float32) @@ -221,13 +226,9 @@ def set_data(self, *args, **kwargs): # Position and depth. depth = np.repeat(data.depth, n_samples, axis=0) self.program['a_position'] = np.c_[pos_tr, depth] - + self.program['a_color'] = color self.program['a_signal_index'] = signal_index - self.program['u_plot_colors'] = Texture2D(_get_texture(data.color, - PlotVisual._default_color, - n_signals, - [0, 1])) - self.program['n_signals'] = n_signals + # self.program['n_signals'] = n_signals class HistogramVisual(BaseVisual): From ea0a647a48d04883cf4dfe6ee9feaa479b535fbd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 09:49:12 +0100 Subject: [PATCH 0873/1059] Increase coverage --- phy/cluster/manual/tests/conftest.py | 7 +++++-- phy/cluster/manual/tests/test_views.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index de5c9d2ec..8cb476d8e 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op + import numpy as np from pytest import yield_fixture @@ -46,8 +48,8 @@ def similarity(): yield lambda c, d: c * 1.01 + d -@yield_fixture(scope='session') -def model(): +@yield_fixture +def model(tempdir): model = Bunch() n_spikes = 51 @@ -57,6 +59,7 @@ def model(): n_clusters = 3 n_features = 4 + model.path = op.join(tempdir, 'test') model.n_channels = n_channels # TODO: test with permutation and dead channels model.channel_order = None diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 79c275c3b..4752b9068 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -44,7 +44,8 @@ def _test_view(view_name, model=None, tempdir=None): state.save() # Create the GUI. - plugins = ['ClusterStorePlugin', + plugins = ['ContextPlugin', + 'ClusterStorePlugin', 'ManualClusteringPlugin', view_name + 'Plugin'] gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) From 71d1ad4273145676734c5860dc975ff6d0bd4044 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 09:55:13 +0100 Subject: [PATCH 0874/1059] Fixing travis --- phy/cluster/manual/tests/test_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index c354a27ac..df36d2bae 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -72,7 +72,7 @@ def _check(out, *shape): assert cs.mean_waveforms(1).shape == (nsw, nc) # Limits. - assert 0 < cs.waveform_lim() < 1 + assert 0 < cs.waveform_lim() < 3 assert 0 < cs.feature_lim() < 1 assert cs.mean_traces().shape == (1, nc) From d40b375f1c919c79e7a0b88f7e6bb9cb7d2cb0e9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 10:05:13 +0100 Subject: [PATCH 0875/1059] Fixing travis --- phy/cluster/manual/tests/test_store.py | 2 +- conftest.py => phy/conftest.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename conftest.py => phy/conftest.py (100%) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index df36d2bae..23483191e 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -73,7 +73,7 @@ def _check(out, *shape): # Limits. assert 0 < cs.waveform_lim() < 3 - assert 0 < cs.feature_lim() < 1 + assert 0 < cs.feature_lim() < 3 assert cs.mean_traces().shape == (1, nc) # Statistics. diff --git a/conftest.py b/phy/conftest.py similarity index 100% rename from conftest.py rename to phy/conftest.py From b690aadba47b6371c83c2a538f37991daf8d2fbc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 10:18:07 +0100 Subject: [PATCH 0876/1059] WIP: tweak manual_clustering API --- phy/cluster/manual/gui_component.py | 7 +++++-- phy/cluster/manual/tests/test_gui_component.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 13bd6be84..e0cc9f63d 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -301,11 +301,14 @@ def _emit_select(self, cluster_ids): # Public methods # ------------------------------------------------------------------------- - def add_column(self, func=None, name=None, show=True): + def add_column(self, func=None, name=None, show=True, default=False): if func is None: - return lambda f: self.add_column(f, name=name, show=show) + return lambda f: self.add_column(f, name=name, show=show, + default=default) self.cluster_view.add_column(func, name=name, show=show) self.similarity_view.add_column(func, name=name, show=show) + if default: + self.set_default_sort(name) def set_default_sort(self, name, sort_dir='desc'): logger.debug("Set default sort `%s` %s.", name, sort_dir) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index bc4206778..017d1ef61 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -149,7 +149,7 @@ def test_manual_clustering_split_2(gui, quality, similarity): mc = ManualClustering(spike_clusters) mc.attach(gui) - mc.add_column(quality, name='quality') + mc.add_column(quality, name='quality', default=True) mc.set_default_sort('quality', 'desc') mc.set_similarity_func(similarity) From 483d64e72988e80288a6ac27dad6a70f31d887c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 10:23:55 +0100 Subject: [PATCH 0877/1059] Bug fixes --- phy/cluster/manual/gui_component.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index e0cc9f63d..5fdfa2244 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -275,6 +275,7 @@ def on_request_undo_state(up): def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" + logger.log(5, "Update the cluster view.") self.cluster_view.set_rows(self.clustering.cluster_ids) def _update_similarity_view(self): @@ -285,6 +286,7 @@ def _update_similarity_view(self): selection = self.cluster_view.selected if not len(selection): return + logger.log(5, "Update the similarity view.") cluster_id = selection[0] self._best = cluster_id self.similarity_view.set_rows([c for c in self.clustering.cluster_ids @@ -305,12 +307,15 @@ def add_column(self, func=None, name=None, show=True, default=False): if func is None: return lambda f: self.add_column(f, name=name, show=show, default=default) + name = name or func.__name__ + assert name self.cluster_view.add_column(func, name=name, show=show) self.similarity_view.add_column(func, name=name, show=show) if default: self.set_default_sort(name) def set_default_sort(self, name, sort_dir='desc'): + assert name logger.debug("Set default sort `%s` %s.", name, sort_dir) # Set the default sort. self.cluster_view.set_default_sort(name, sort_dir) @@ -321,6 +326,7 @@ def set_default_sort(self, name, sort_dir='desc'): def set_similarity_func(self, f): """Set the similarity function.""" + logger.debug("Set similarity function `%s`.", f.__name__) self.similarity_func = f def on_cluster(self, up): From 3e0150c2bbb407c34507f700412ee15a91406a1a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 10:38:42 +0100 Subject: [PATCH 0878/1059] Make sure the similarity function returns a scalar --- phy/cluster/manual/gui_component.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 5fdfa2244..4bb309f16 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -10,6 +10,8 @@ from functools import partial import logging +import numpy as np + from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering @@ -327,7 +329,13 @@ def set_default_sort(self, name, sort_dir='desc'): def set_similarity_func(self, f): """Set the similarity function.""" logger.debug("Set similarity function `%s`.", f.__name__) - self.similarity_func = f + + # Make sure the function returns a scalar. + def wrapped(cluster_0, cluster_1): + out = f(cluster_0, cluster_1) + return np.asscalar(out) + + self.similarity_func = wrapped def on_cluster(self, up): """Update the cluster views after clustering actions.""" From 753323ad9dde05b6f0c4091755fd78bee672da00 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 18:18:30 +0100 Subject: [PATCH 0879/1059] WIP --- phy/cluster/manual/store.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 349c83cfd..2fdef27d8 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -72,16 +72,6 @@ def attach(self, gui): gui.register(self, name='cluster_store') -class ClusterStorePlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): - ctx = gui.request('context') - selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=model.spikes_per_cluster, - ) - cs = create_cluster_store(model, selector=selector, context=ctx) - cs.attach(gui) - - def create_cluster_store(model, selector=None, context=None): cs = ClusterStore(context=context) @@ -256,3 +246,14 @@ def mean_traces(): return mt.astype(model.traces.dtype) return cs + + +class ClusterStorePlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + ctx = gui.request('context') + assert ctx + selector = Selector(spike_clusters=model.spike_clusters, + spikes_per_cluster=model.spikes_per_cluster, + ) + cs = create_cluster_store(model, selector=selector, context=ctx) + cs.attach(gui) From 52bbc6309a82cab4f5e38583de8cb0c5a70ef1a4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 19:30:53 +0100 Subject: [PATCH 0880/1059] Update docs --- docs/cluster-manual.md | 11 ++++++++++- docs/gui.md | 23 ++++++++++++++++++++--- docs/io.md | 29 +++++++++++++++++++++++++++++ mkdocs.yml | 1 + 4 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 docs/io.md diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index 1cac448af..103053a50 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -150,6 +150,8 @@ The feature view shows the principal components of spikes across multiple dimens ### Trace view +The trace view shows the continuous traces from multiple channels with spikes superimposed. The spikes are in white except those belonging to the selected clusters, which are in the colors of the clusters. + ### Correlogram view The correlogram view computes and shows all pairwise correlograms of a set of clusters. @@ -174,7 +176,7 @@ The main objects are the following: `mc.similarity_view`: the similarity view (derives from `Table`) `mc.actions`: the clustering actions (instance of `Actions`) -In practice, you generally access this object from a GUI plugin, available in `session.manual_clustering`. +Use `gui.request('manual_clustering')` to get the `ManualClustering` instance inside the `attach_to_gui(gui, model=None, state=None)` method of a GUI plugin. ### Cluster and similarity view @@ -206,3 +208,10 @@ When the selection changes, the attached GUI raises the `select(cluster_ids, spi Other events are `cluster(up)` when a clustering action occurs, and `request_save(spike_clusters, cluster_groups)` when the user wants to save the results of the manual clustering session. +## Cluster store + +The **cluster store** contains a library of functions computing data and statistics for every cluster. These functions are cached on disk and possibly in memory. A `ClusterStore` instance is initialized with a `Context` which provides the caching facilities. You can add new functions with the `add(f)` method/decorator. + +The `create_cluster_store()` function creates a cluster store with a built-in library of functions that take the data from a model (for example, the `KwikModel` that works with the Kwik format). + +Use `gui.request('cluster_store')` to get the cluster store instance inside the `attach_to_gui(gui, model=None, state=None)` method of a GUI plugin. diff --git a/docs/gui.md b/docs/gui.md index d9b957540..daf77971c 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -160,7 +160,6 @@ The following method allows you to check how many views of each class there are: ```python >>> gui.view_count() -{'canvas': 1, 'figurecanvasqtagg': 1, 'htmlwidget': 2} ``` Use the following property to change the status bar: @@ -183,7 +182,7 @@ The object `gs` is a JSON-serializable Python dictionary. ## Adding actions -An **action** is a Python function that can be run by the user by clicking on a button or pressing a keyboard shortcut. You can create an `Actions` object to specify a list of actions attached to a GUI. +An **action** is a Python function that the user can run from the menu bar or with a keyboard shortcut. You can create an `Actions` object to specify a list of actions attached to a GUI. ```python >>> from phy.gui import Actions @@ -198,7 +197,7 @@ Now, if you press *Ctrl+H* in the GUI, you'll see `Hello world!` printed in the Once an action is added, you can call it with `actions.hello()` where `hello` is the name of the action. By default, this is the name of the associated function, but you can also specify the name explicitly with the `name=...` keyword argument in `actions.add()`. -You'll find more details about `actions.add()` in the API reference. +You'll find more details about `actions.add()` in the API reference. For example, use the `menu='MenuName'` keyword argument to add the action to a menu in the menu bar. Every GUI comes with a `default_actions` property which implements actions always available in GUIs: @@ -234,3 +233,21 @@ The GUI provides a convenient system to quickly execute actions without leaving Now, pressing `:c 3-6 hello` followed by the `Enter` keystroke displays `Select [3, 4, 5, 6] with hello` in the console. By convention, multiple arguments are separated by spaces, sequences of numbers are given either with `2,3,5,7` or `3-6` for consecutive numbers. If an alias is not specified when adding the action, you can always use the full action's name. + +## GUI plugins + +To create a specific GUI, you implement functionality in components and create GUI plugins to allow users to activate these components in their GUI. You can also specify the list of default plugins. + +You create a GUI with the function `gui = create_gui(name, model=model, plugins=plugins)`. The model provides the data while the list of plugins defines the functionality of the GUI. + +Plugins create views and actions and have full control on the GUI. Also, they can register objects or functions with `gui.register(obj, name='my_object')`. Other plugins can then retrieve these objects with `gui.request(name)`. This is how plugins can communicate. This function returns `None` if the object is not available. + +Note that the order of how plugins are attached matters, since you can only request components that have been registered in the previously-attached plugins. + +To create a GUI plugin, just define a class deriving from `IPlugin` and implementing `attach_to_gui(gui, model=None, state=None)`. + +## GUI state + +The **GUI state** is a special Python dictionary that holds info and parameters about a particular GUI session, like its position and size, the positions of the widgets, and other user preferences. This state is automatically persisted to disk (in JSON) in the config directory (passed as a parameter in the `create_gui()` function). By default, this is `~/.phy/gui_name/state.json`. + +Plugins can simply add fields to the GUI state and it will be persisted. There are special methods for GUI parameters: `state.save_gui_params()` and `state.load_gui_params()`. diff --git a/docs/io.md b/docs/io.md new file mode 100644 index 000000000..1bd8d3065 --- /dev/null +++ b/docs/io.md @@ -0,0 +1,29 @@ +## Data utilities + +The `phy.io` package contains utilities related to array manipulation, mock datasets, caching, and parallel computing context. + +### Array + +The `array` module contains functions to select subsets of large data arrays, and to obtain the spikes belonging to a set of clusters (notably the `Selector` class) + +### Context + +The `Context` provides facilities to accelerate computations through caching (with **joblib**) and parallel computing (with **ipyparallel**). + +A `Context` is initialized with a cache directory (typically a subdirectory `.phy` within the directory containing the data). You can also provide an `ipy_view` instance to use parallel computing with the ipyparallel package. + +#### Cache + +Use `f = context.cache(f)` to cache a function. By default, the decorated function will be cached on disk in the cache directory, using joblib. NumPy arrays are fully and efficiently supported. + +With the `memcache=True` argument, you can *also* use memory caching. This is interesting for example when caching functions returning a scalar for every cluster. This is the case with the functions computing the quality and similarity of clusters. These functions are called a lot during a manual clustering session. + +#### Parallel computing + +Use the `map()` and `map_async()` to call a function on multiple arguments in sequence or in parallel if an ipyparallel context is available (`ipy_view` keyword in the `Context`'s constructor). + +There is also an experimental `map_dask_array(f, da)` method to map in parallel a function that processes a single chunk of a **dask Array**. The result of every computation unit is saved in a `.npy` file in the cache directory, and the result is a new dask Array that is dynamically memory-mapped from the stack of `.npy` files. **The cache directory needs to be available from all computing units for this method to work** (using a network file system). Doing it this way should mitigate the performance issues with transferring large amounts of data over the network. + +#### Store + +You can store JSON-serializable Python dictionaries with `context.save()` and `context.load()`. The files are saved in the cache directory. NumPy array and Qt buffers are fully supported. You can save the GUI state and geometry there for example. diff --git a/mkdocs.yml b/mkdocs.yml index e23fd38e5..88ae426fa 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -7,6 +7,7 @@ pages: - Plotting: 'plot.md' - Configuration and plugin system: 'config.md' - Command-line interface: 'cli.md' +- Data utilities: 'io.md' - Manual clustering: 'cluster-manual.md' - Analysis functions: 'analysis.md' - API reference: 'api.md' From 4c3078e8b1a9d53fc2797fe13f58aaf401f8c8d8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 21:53:15 +0100 Subject: [PATCH 0881/1059] WIP: update GUI docs --- docs/cli.md | 92 ----------------------------------------------------- docs/gui.md | 90 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 92 deletions(-) diff --git a/docs/cli.md b/docs/cli.md index a4e8cdebc..2b42a372e 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -61,95 +61,3 @@ Hello world! ``` When the `phy` CLI is created, the `attach_to_cli(cli)` method of all discovered plugins are called. Refer to the click documentation to create subcommands with phy. - -## Creating a graphical application - -You can use this system to create a graphical application that is launched with `phy some_subcommand`. Moreover, your graphical application can itself accept user-defined plugins. - -Here is a complete example. Write the following in `~/.phy/plugins/mygui.py`: - -``` -import click -from phy import IPlugin -from phy.gui import GUI, HTMLWidget, create_app, run_app, load_gui_plugins -from phy.utils import Bunch - - -class MyGUI(GUI): - def __init__(self, name, plugins=None): - super(MyGUI, self).__init__() - - # We create a widget. - view = HTMLWidget() - view.set_body("Hello %s!" % name) - view.show() - self.add_view(view) - - # We load all plugins attached to that GUI. - session = Bunch(name=name) - load_gui_plugins(self, plugins, session) - - -class MyGUIPlugin(IPlugin): - def attach_to_cli(self, cli): - - @cli.command(name='mygui') - @click.argument('name') - def mygui(name): - - # Create the Qt application. - create_app() - - # Show the GUI. - gui = MyGUI(name) - gui.show() - - # Start the Qt event loop. - run_app() - - # Close the GUI. - gui.close() - del gui -``` - -Now, you can call `phy mygui world` to open a GUI showing `Hello world!`. - -## GUI plugins - -Your users can now create plugins for your graphical application, by creating a plugin with the `attach_to_gui(gui, session)` method. In this method, you can add actions, add views, and do anything provided by the GUI API. - -The `session` object is any Python object passed to the plugins by the GUI. Generally, it is a `Bunch` instance (just a Python dictionary with the additional `bunch.name` syntax) containing any data that you want to pass to the plugins. - -Here is a complete example. There are three steps. - -### Creating the plugin - -First, create a file in `~/.phy/plugins/mygui_plugin.py` with the following: - -``` -from phy import IPlugin -from phy.gui import Actions - - -class MyGUIPlugin(IPlugin): - def attach_to_gui(self, gui, session): - actions = Actions(gui) - - @actions.add(shortcut='a') - def myaction(): - print("Hello %s!" % session.name) -``` - -### Activating the plugin - -Next, add the following line in `~/.phy/phy_config.py`: - -``` -c.MyGUI.plugins = ['MyGUIPlugin'] -``` - -This is the list of the plugin names to activate automatically when creating a `MyGUI` instance. When you create a GUI from Python, you can also pass the list of plugins to activate as follows: `gui = MyGUI(name, plugins=[...])`. - -### Testing the plugin - -Finally, launch the GUI with `phy mygui world` and press `a` in the GUI. It should print `Hello world!` in the console. diff --git a/docs/gui.md b/docs/gui.md index daf77971c..3c2f858a9 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -250,4 +250,94 @@ To create a GUI plugin, just define a class deriving from `IPlugin` and implemen The **GUI state** is a special Python dictionary that holds info and parameters about a particular GUI session, like its position and size, the positions of the widgets, and other user preferences. This state is automatically persisted to disk (in JSON) in the config directory (passed as a parameter in the `create_gui()` function). By default, this is `~/.phy/gui_name/state.json`. +The GUI state is a `Bunch` instance, which derives from `dict` to support the additional `bunch.name` syntax. + Plugins can simply add fields to the GUI state and it will be persisted. There are special methods for GUI parameters: `state.save_gui_params()` and `state.load_gui_params()`. + + +## Example + +In this example we'll create a graphical application that is launched with `phy some_subcommand` and that can accept user-defined plugins. + +### GUI application and CLI plugin + +First, write the following in `~/.phy/plugins/mygui.py`: + +``` +import click +from phy import IPlugin +from phy.gui import GUI, HTMLWidget, create_app, run_app, load_gui_plugins +from phy.utils import Bunch + + +class MyGUI(GUI): + def __init__(self, name, plugins=None): + super(MyGUI, self).__init__() + + # We create a widget. + view = HTMLWidget() + view.set_body("Hello %s!" % name) + view.show() + self.add_view(view) + + # We load all plugins attached to that GUI. + session = Bunch(name=name) + load_gui_plugins(self, plugins, session) + + +class MyGUIPlugin(IPlugin): + def attach_to_cli(self, cli): + + @cli.command(name='mygui') + @click.argument('name') + def mygui(name): + + # Create the Qt application. + create_app() + + # Show the GUI. + gui = MyGUI(name) + gui.show() + + # Start the Qt event loop. + run_app() + + # Close the GUI. + gui.close() + del gui +``` + +Now, you can call `phy mygui world` to open a GUI showing `Hello world!`. + + +### Creating the plugin + +Now, let's create a plugin for the GUI. Create a file in `~/.phy/plugins/mygui_plugin.py` with the following: + +``` +from phy import IPlugin +from phy.gui import Actions + + +class MyGUIPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + actions = Actions(gui) + + @actions.add(shortcut='a') + def myaction(): + print("Hello %s!" % state.name) +``` + +### Activating the plugin + +Next, add the following line in `~/.phy/phy_config.py`: + +``` +c.MyGUI.plugins = ['MyGUIPlugin'] +``` + +This is the list of the plugin names to activate automatically when creating a `MyGUI` instance. When you create a GUI from Python, you can also pass the list of plugins to activate as follows: `gui = MyGUI(name, plugins=[...])`. + +### Testing the plugin + +Finally, launch the GUI with `phy mygui world` and press `a` in the GUI. It should print `Hello world!` in the console. From a9b19ab02bb60df75fd0b9b576f7efd4c38a4bc9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 22:14:29 +0100 Subject: [PATCH 0882/1059] WIP: update GUI docs --- docs/gui.md | 98 ++++++++++++----------------------------------------- 1 file changed, 22 insertions(+), 76 deletions(-) diff --git a/docs/gui.md b/docs/gui.md index 3c2f858a9..1a28d5aa1 100644 --- a/docs/gui.md +++ b/docs/gui.md @@ -126,7 +126,6 @@ The items [3] have been selected. The items [3, 5] have been selected. ``` - ### Interactivity with Javascript We can use Javascript in an HTML widget, and we can make Python and Javascript communicate. @@ -257,87 +256,34 @@ Plugins can simply add fields to the GUI state and it will be persisted. There a ## Example -In this example we'll create a graphical application that is launched with `phy some_subcommand` and that can accept user-defined plugins. - -### GUI application and CLI plugin - -First, write the following in `~/.phy/plugins/mygui.py`: - -``` -import click -from phy import IPlugin -from phy.gui import GUI, HTMLWidget, create_app, run_app, load_gui_plugins -from phy.utils import Bunch - - -class MyGUI(GUI): - def __init__(self, name, plugins=None): - super(MyGUI, self).__init__() - - # We create a widget. - view = HTMLWidget() - view.set_body("Hello %s!" % name) - view.show() - self.add_view(view) +In this example we'll create a GUI plugin and show how to activate it. - # We load all plugins attached to that GUI. - session = Bunch(name=name) - load_gui_plugins(self, plugins, session) - - -class MyGUIPlugin(IPlugin): - def attach_to_cli(self, cli): - - @cli.command(name='mygui') - @click.argument('name') - def mygui(name): - - # Create the Qt application. - create_app() - - # Show the GUI. - gui = MyGUI(name) - gui.show() - - # Start the Qt event loop. - run_app() - - # Close the GUI. - gui.close() - del gui +```python +>>> from phy import IPlugin +>>> from phy.gui import GUI, HTMLWidget, create_app, run_app, create_gui +>>> from phy.utils import Bunch ``` -Now, you can call `phy mygui world` to open a GUI showing `Hello world!`. - - -### Creating the plugin - -Now, let's create a plugin for the GUI. Create a file in `~/.phy/plugins/mygui_plugin.py` with the following: - +```python +>>> class MyComponent(IPlugin): +... def attach_to_gui(self, gui, model=None, state=None): +... # We create a widget. +... view = HTMLWidget() +... view.set_body("Hello %s!" % model.name) +... view.show() +... gui.add_view(view) +DEBUG:phy.utils.plugin:Register plugin `MyComponent`. ``` -from phy import IPlugin -from phy.gui import Actions - -class MyGUIPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): - actions = Actions(gui) - - @actions.add(shortcut='a') - def myaction(): - print("Hello %s!" % state.name) +```python +>>> gui = create_gui('MyGUI', model=Bunch(name='world'), plugins=['MyComponent']) +DEBUG:phy.gui.gui:The GUI state file `/Users/cyrille/.phy/MyGUI/state.json` doesn't exist. +DEBUG:phy.gui.gui:Attach plugin `MyComponent` to MyGUI. ``` -### Activating the plugin - -Next, add the following line in `~/.phy/phy_config.py`: - -``` -c.MyGUI.plugins = ['MyGUIPlugin'] +```python +>>> gui.show() +DEBUG:phy.gui.gui:Save the GUI state to `/Users/cyrille/.phy/MyGUI/state.json`. ``` -This is the list of the plugin names to activate automatically when creating a `MyGUI` instance. When you create a GUI from Python, you can also pass the list of plugins to activate as follows: `gui = MyGUI(name, plugins=[...])`. - -### Testing the plugin - -Finally, launch the GUI with `phy mygui world` and press `a` in the GUI. It should print `Hello world!` in the console. +This opens a GUI showing `Hello world!`. From 3b7f99396edb0e841830df40f3e825a75ccf62a2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 22:21:57 +0100 Subject: [PATCH 0883/1059] WIP: manual clustering docs --- docs/cluster-manual.md | 58 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index 103053a50..6f3f43135 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -215,3 +215,61 @@ The **cluster store** contains a library of functions computing data and statist The `create_cluster_store()` function creates a cluster store with a built-in library of functions that take the data from a model (for example, the `KwikModel` that works with the Kwik format). Use `gui.request('cluster_store')` to get the cluster store instance inside the `attach_to_gui(gui, model=None, state=None)` method of a GUI plugin. + +## GUI plugins + +You can create plugins to customize the manual clustering GUI. Here is a complete example showing how to change the quality and similarity measures. Put the following in `~/.phy/phy_config.py`: + +```python +import numpy as np +from phy import IPlugin + +# We write the plugin directly in the config file here, for simplicity. +# When dealing with more plugins it is a better practice to put them +# in separate files in ~/.phy/plugins/ or in your own repo that you can +# refer to in c.Plugins.dirs = ['/path/to/myplugindir']. +class MyPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + + # The model contains the data, the state contains the parameters + # for that session. + + # We can get GUI components with `gui.request(name)`. + # These are the two main components. There is also `context` to + # deal with the cache and parallel computing context. + mc = gui.request('manual_clustering') + cs = gui.request('cluster_store') + + # We add a column in the cluster view and set it as the default. + # This will be automatically cached. + @mc.add_column(default=True) + def mymeasure(cluster_id): + # This function takes a cluster id as input and returns a scalar. + + # We retrieve the spike_ids and waveforms for that cluster. + # spike_ids is a (n_spikes,) array. + # waveforms is a (n_spikes, n_samples, n_channels) array. + spike_ids, waveforms = cs.waveforms(cluster_id) + return waveforms.max() + + # We set the similarity function. + @mc.set_similarity_func + def mysim(cluster_0, cluster_1): + # This function returns a score for every pair of clusters. + + # Here we compute a distance between the mean masks. + m0 = cs.mean_masks(cluster_0) # (n_channels,) array + m1 = cs.mean_masks(cluster_1) # (n_channels,) array + distance = np.sum((m1 - m0) ** 2) + + # We need to convert the distance to a score: higher = better + # similarity. + score = -distance + return score + +# Now we set the config object. +c = get_config() + +# Here we say that we always want to load our plugin in the KwikGUI. +c.KwikGUI.plugins = ['MyPlugin'] +``` From 0bb056e8d699414c117d7e4b7964cc99af86f55f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 23:32:59 +0100 Subject: [PATCH 0884/1059] Cache objects on disk in cluster store by default --- phy/cluster/manual/store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 2fdef27d8..db6194b9e 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -46,7 +46,7 @@ def __init__(self, context=None): self.context = context self._stats = {} - def add(self, f=None, name=None, cache=None): + def add(self, f=None, name=None, cache='disk'): """Add a cluster statistic. Parameters From 38d9d80dfef3f12ac27b62af4ce4d232d85fe611 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 21 Dec 2015 23:48:20 +0100 Subject: [PATCH 0885/1059] Fixes --- phy/io/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/io/context.py b/phy/io/context.py index 897051ded..0ecfd5264 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -143,7 +143,7 @@ def _ensure_cache_dirs_exist(cache_dir, name): def _fullname(o): """Return the fully-qualified name of a function.""" - return o.__module__ + "." + o.__name__ + return o.__module__ + "." + o.__name__ if o.__module__ else o.__name__ class Context(object): @@ -194,6 +194,7 @@ def cache(self, f=None, memcache=False): if self._memory is None: # pragma: no cover logger.debug("Joblib is not installed: skipping cacheing.") return f + assert f disk_cached = self._memory.cache(f) name = _fullname(f) if memcache: From fdc1e171d35da22fe9e17435721bf7c41de19637 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 09:41:39 +0100 Subject: [PATCH 0886/1059] WIP: improve efficiency of similarity logic --- phy/cluster/manual/store.py | 17 ++++++++++++++++- phy/cluster/manual/tests/conftest.py | 1 + phy/cluster/manual/tests/test_store.py | 4 ++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index db6194b9e..1c6d77652 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -9,6 +9,7 @@ from functools import wraps import logging +from operator import itemgetter import numpy as np @@ -37,6 +38,15 @@ def _get_data_lim(arr, n_spikes=None, percentile=None): return arr.max() +def get_closest_clusters(cluster_id, cluster_ids, sim_func): + """Return a list of pairs `(cluster, similarity)` sorted by decreasing + similarity to a given cluster.""" + l = [(candidate, sim_func(cluster_id, candidate)) + for candidate in cluster_ids] + l = sorted(l, key=itemgetter(1), reverse=True) + return l + + # ----------------------------------------------------------------------------- # Cluster statistics # ----------------------------------------------------------------------------- @@ -222,7 +232,7 @@ def max_waveform_amplitude(cluster_id): logger.debug("Computing the quality of cluster %d.", cluster_id) return np.asscalar(get_max_waveform_amplitude(mm, mw)) - @cs.add(cache='memory') + @cs.add(cache=None) def mean_masked_features_score(cluster_0, cluster_1): mf0 = cs.mean_features(cluster_0) mf1 = cs.mean_features(cluster_1) @@ -236,6 +246,11 @@ def mean_masked_features_score(cluster_0, cluster_1): s = 1. / max(1e-10, d) return s + @cs.add(cache='memory') + def most_similar_clusters(cluster_id): + return get_closest_clusters(cluster_id, model.cluster_ids, + cs.mean_masked_features_score) + # Traces. # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 8cb476d8e..129f5fb24 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -69,6 +69,7 @@ def model(tempdir): model.spike_times = artificial_spike_samples(n_spikes) * 1. model.spike_times /= model.spike_times[-1] model.spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + model.cluster_ids = np.unique(model.spike_clusters) model.channel_positions = staggered_positions(n_channels) model.waveforms = artificial_waveforms(n_spikes, n_samples_w, n_channels) model.masks = artificial_masks(n_spikes, n_channels) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index 23483191e..d9da57892 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -6,6 +6,8 @@ # Imports #------------------------------------------------------------------------------ +import numpy as np + from ..store import create_cluster_store, ClusterStore from phy.io import Context, Selector @@ -81,3 +83,5 @@ def _check(out, *shape): assert 1 <= len(cs.best_channels_multiple([1, 2])) <= nc assert 0 < cs.max_waveform_amplitude(1) < 1 assert cs.mean_masked_features_score(1, 2) > 0 + + assert np.array(cs.most_similar_clusters(1)).shape == (3, 2) From 98ade77cb34781fa7b5828da04bdee93458e13e0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 10:19:02 +0100 Subject: [PATCH 0887/1059] The similarity function now returns a list of closest clusters. Much faster --- phy/cluster/manual/__init__.py | 2 +- phy/cluster/manual/gui_component.py | 33 ++++++++++++++++++---------- phy/cluster/manual/store.py | 17 ++++++++++---- phy/cluster/manual/tests/conftest.py | 6 +++-- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index eaf61a82e..8e27fa803 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -6,5 +6,5 @@ from ._utils import ClusterMeta from .clustering import Clustering from .gui_component import ManualClustering -from .store import ClusterStore, create_cluster_store +from .store import ClusterStore, create_cluster_store, get_closest_clusters from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4bb309f16..a38d07eb1 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -10,8 +10,6 @@ from functools import partial import logging -import numpy as np - from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering @@ -140,6 +138,7 @@ def __init__(self, self._create_cluster_views() self._add_default_columns() + self._similarity = {} self.similarity_func = None # Internal methods @@ -193,7 +192,13 @@ def good(cluster_id): self._best = None def similarity(cluster_id): - return self.similarity_func(cluster_id, self._best) + # NOTE: there is a dictionary with the similarity to the current + # best cluster. It is updated when the selection changes in the + # cluster view. This is a bit of a hack: the HTML table expects + # a function that returns a value for every row, but here we + # cache all similarity view rows in self._similarity for + # performance reasons. + return self._similarity.get(cluster_id, 0) self.similarity_view.add_column(similarity) def _create_actions(self, gui): @@ -288,9 +293,12 @@ def _update_similarity_view(self): selection = self.cluster_view.selected if not len(selection): return - logger.log(5, "Update the similarity view.") cluster_id = selection[0] + logger.log(5, "Update the similarity view.") + # This is a list of pairs (closest_cluster, similarity). self._best = cluster_id + self._similarity = {cl: s + for (cl, s) in self.similarity_func(cluster_id)} self.similarity_view.set_rows([c for c in self.clustering.cluster_ids if c not in selection]) self.similarity_view.sort_by('similarity', 'desc') @@ -327,15 +335,16 @@ def set_default_sort(self, name, sort_dir='desc'): self.cluster_view.sort_by(name, sort_dir) def set_similarity_func(self, f): - """Set the similarity function.""" - logger.debug("Set similarity function `%s`.", f.__name__) + """Set the similarity function. - # Make sure the function returns a scalar. - def wrapped(cluster_0, cluster_1): - out = f(cluster_0, cluster_1) - return np.asscalar(out) + This is a function that returns an ordered list of pairs + `(candidate, similarity)` for any given cluster. This list can have + a fixed number of elements for performance reasons (keeping the best + 20 candidates for example). - self.similarity_func = wrapped + """ + logger.debug("Set similarity function `%s`.", f.__name__) + self.similarity_func = f def on_cluster(self, up): """Update the cluster views after clustering actions.""" @@ -395,7 +404,7 @@ def attach(self, gui): self.cluster_view.add_column(cs.max_waveform_amplitude, name='quality') self.set_default_sort('quality') - self.set_similarity_func(cs.mean_masked_features_score) + self.set_similarity_func(cs.most_similar_clusters) # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 1c6d77652..8844bb55b 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -38,13 +38,20 @@ def _get_data_lim(arr, n_spikes=None, percentile=None): return arr.max() -def get_closest_clusters(cluster_id, cluster_ids, sim_func): +def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): """Return a list of pairs `(cluster, similarity)` sorted by decreasing similarity to a given cluster.""" - l = [(candidate, sim_func(cluster_id, candidate)) + + def _as_float(x): + """Ensure the value is a float.""" + if isinstance(x, np.generic): + return np.asscalar(x) + return float(x) + + l = [(candidate, _as_float(sim_func(cluster_id, candidate))) for candidate in cluster_ids] l = sorted(l, key=itemgetter(1), reverse=True) - return l + return l[:max_n] # ----------------------------------------------------------------------------- @@ -95,6 +102,7 @@ def create_cluster_store(model, selector=None, context=None): 'feature_lim': 1000, # used to compute the waveform bounds 'mean_traces': 10000, } + max_n_similar_clusters = 20 def select(cluster_id, n=None): assert cluster_id >= 0 @@ -249,7 +257,8 @@ def mean_masked_features_score(cluster_0, cluster_1): @cs.add(cache='memory') def most_similar_clusters(cluster_id): return get_closest_clusters(cluster_id, model.cluster_ids, - cs.mean_masked_features_score) + cs.mean_masked_features_score, + max_n_similar_clusters) # Traces. # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 129f5fb24..1c7423bde 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -21,6 +21,7 @@ artificial_traces, ) from phy.utils import Bunch +from phy.cluster.manual.store import get_closest_clusters #------------------------------------------------------------------------------ @@ -44,8 +45,9 @@ def quality(): @yield_fixture -def similarity(): - yield lambda c, d: c * 1.01 + d +def similarity(cluster_ids): + sim = lambda c, d: (c * 1.01 + d) + yield lambda c: get_closest_clusters(c, cluster_ids, sim) @yield_fixture From c4f489d5f6c15a2a56a6042ab23408b1fa195320 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 10:19:58 +0100 Subject: [PATCH 0888/1059] Update docs --- docs/cluster-manual.md | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md index 6f3f43135..e7cdf46d7 100644 --- a/docs/cluster-manual.md +++ b/docs/cluster-manual.md @@ -221,8 +221,11 @@ Use `gui.request('cluster_store')` to get the cluster store instance inside the You can create plugins to customize the manual clustering GUI. Here is a complete example showing how to change the quality and similarity measures. Put the following in `~/.phy/phy_config.py`: ```python + import numpy as np from phy import IPlugin +from phy.cluster.manual import get_closest_clusters + # We write the plugin directly in the config file here, for simplicity. # When dealing with more plugins it is a better practice to put them @@ -231,9 +234,6 @@ from phy import IPlugin class MyPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): - # The model contains the data, the state contains the parameters - # for that session. - # We can get GUI components with `gui.request(name)`. # These are the two main components. There is also `context` to # deal with the cache and parallel computing context. @@ -241,8 +241,8 @@ class MyPlugin(IPlugin): cs = gui.request('cluster_store') # We add a column in the cluster view and set it as the default. - # This will be automatically cached. @mc.add_column(default=True) + @cs.add(cache='memory') def mymeasure(cluster_id): # This function takes a cluster id as input and returns a scalar. @@ -252,8 +252,6 @@ class MyPlugin(IPlugin): spike_ids, waveforms = cs.waveforms(cluster_id) return waveforms.max() - # We set the similarity function. - @mc.set_similarity_func def mysim(cluster_0, cluster_1): # This function returns a score for every pair of clusters. @@ -267,9 +265,23 @@ class MyPlugin(IPlugin): score = -distance return score + # We set the similarity function. + @mc.set_similarity_func + @cs.add(cache='memory') + def myclosest(cluster_id): + """This function returns the list of closest clusters as + a list of `(cluster, sim)` pairs. + + By default, the 20 closest clusters are kept. + + """ + return get_closest_clusters(cluster_id, model.cluster_ids, mysim) + + # Now we set the config object. c = get_config() # Here we say that we always want to load our plugin in the KwikGUI. c.KwikGUI.plugins = ['MyPlugin'] + ``` From 2bd0295561c429245523f767ac58567e48b36561 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 10:24:43 +0100 Subject: [PATCH 0889/1059] Fix bug in automatic channel selection in feature view --- phy/cluster/manual/views.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f47ec1166..832ad444c 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1084,10 +1084,8 @@ def on_select(self, cluster_ids=None): # Select the dimensions. # TODO: toggle automatic selection of the channels x_ch, y_ch = self._get_channel_dims(cluster_ids) - if self.x_channels is None: - self.x_channels = x_ch - if self.y_channels is None: - self.y_channels = y_ch + self.x_channels = x_ch + self.y_channels = y_ch tla = self.top_left_attribute x_dim, y_dim = _dimensions_matrix(self.x_channels, self.y_channels, n_cols=self.n_cols, From ee540ec2439dbdbcffe1be686a0f60e9401198a3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 10:31:27 +0100 Subject: [PATCH 0890/1059] Toggle automatic selection of channels in the feature view --- phy/cluster/manual/tests/test_views.py | 1 + phy/cluster/manual/views.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 4752b9068..4839d3179 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -227,6 +227,7 @@ def test_feature_view(qtbot, model, tempdir): v.on_channel_click(channel_idx=3, button=1, key=2) v.clear_channels() + v.toggle_automatic_channel_selection() # qtbot.stop() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 832ad444c..d97b62678 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -951,6 +951,7 @@ def __init__(self, assert spike_times.shape == (self.n_spikes,) # Channels to show. + self.fixed_channels = False self.x_channels = None self.y_channels = None @@ -1082,11 +1083,14 @@ def on_select(self, cluster_ids=None): spike_ids_bg, features_bg, masks_bg = self.background_features_masks # Select the dimensions. - # TODO: toggle automatic selection of the channels - x_ch, y_ch = self._get_channel_dims(cluster_ids) - self.x_channels = x_ch - self.y_channels = y_ch + # Choose the channels automatically unless fixed_channels is set. + if (not self.fixed_channels or self.x_channels is None or + self.y_channels is None): + self.x_channels, self.y_channels = self._get_channel_dims( + cluster_ids) tla = self.top_left_attribute + assert self.x_channels + assert self.y_channels x_dim, y_dim = _dimensions_matrix(self.x_channels, self.y_channels, n_cols=self.n_cols, top_left_attribute=tla) @@ -1139,6 +1143,7 @@ def attach(self, gui): self.actions.add(self.increase) self.actions.add(self.decrease) self.actions.add(self.clear_channels) + self.actions.add(self.toggle_automatic_channel_selection) gui.connect_(self.on_channel_click) @@ -1156,8 +1161,14 @@ def on_channel_click(self, channel_idx=None, key=None, button=None): assert 0 <= channel_idx < self.n_channels # Update the channel. channels[key - 1] = channel_idx + self.fixed_channels = True self.on_select() + def toggle_automatic_channel_selection(self): + """Toggle the automatic selection of channels when the cluster + selection changes.""" + self.fixed_channels = not self.fixed_channels + def increase(self): """Increase the scaling of the features.""" self.feature_scaling *= 1.2 From 69f845ce5facb9c189ae785c936c01fc865564e2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 12:39:59 +0100 Subject: [PATCH 0891/1059] WIP: fixing joblib cache --- phy/cluster/manual/store.py | 75 +++++++++++++++----------- phy/cluster/manual/tests/test_store.py | 3 +- phy/io/array.py | 5 +- phy/io/context.py | 8 +-- phy/io/tests/test_array.py | 6 +-- 5 files changed, 57 insertions(+), 40 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 8844bb55b..ec3de625e 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -54,6 +54,28 @@ def _as_float(x): return l[:max_n] +def _log(f): + @wraps(f) + def wrapped(*args, **kwargs): + logger.log(10, "Compute %s(%s).", f.__name__, str(args)) + return f(*args, **kwargs) + return wrapped + + +def _concat(f): + """Take a function accepting a single cluster, and return a function + accepting multiple clusters.""" + @wraps(f) + def wrapped(cluster_ids): + # Single cluster. + if not hasattr(cluster_ids, '__len__'): + return f(cluster_ids) + # Concatenate the result of multiple clusters. + arrs = zip(*(f(c) for c in cluster_ids)) + return tuple(np.concatenate(_, axis=0) for _ in arrs) + return wrapped + + # ----------------------------------------------------------------------------- # Cluster statistics # ----------------------------------------------------------------------------- @@ -63,7 +85,7 @@ def __init__(self, context=None): self.context = context self._stats = {} - def add(self, f=None, name=None, cache='disk'): + def add(self, f=None, name=None, cache='disk', concat=None): """Add a cluster statistic. Parameters @@ -76,11 +98,14 @@ def add(self, f=None, name=None, cache='disk'): """ if f is None: - return lambda _: self.add(_, name=name, cache=cache) + return lambda _: self.add(_, name=name, cache=cache, concat=concat) name = name or f.__name__ if cache and self.context: + f = _log(f) f = self.context.cache(f, memcache=(cache == 'memory')) assert f + if concat: + f = _concat(f) self._stats[name] = f setattr(self, name, f) return f @@ -108,32 +133,18 @@ def select(cluster_id, n=None): assert cluster_id >= 0 return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) - def concat(f): - """Take a function accepting a single cluster, and return a function - accepting multiple clusters.""" - @wraps(f) - def wrapped(cluster_ids): - # Single cluster. - if not hasattr(cluster_ids, '__len__'): - return f(cluster_ids) - # Concatenate the result of multiple clusters. - arrs = zip(*(f(c) for c in cluster_ids)) - return tuple(np.concatenate(_, axis=0) for _ in arrs) - return wrapped - # Model data. # ------------------------------------------------------------------------- - @cs.add - @concat + @cs.add(concat=True) def masks(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['masks']) masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 + # print("m", cluster_id, spike_ids.shape, spike_ids[0], spike_ids[-1], masks.shape) return spike_ids, masks - @cs.add - @concat + @cs.add(concat=True) def features_masks(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) fm = np.atleast_3d(model.features_masks[spike_ids]) @@ -145,8 +156,7 @@ def features_masks(cluster_id): m = fm[:, ::nfpc, 1] return spike_ids, f, m - @cs.add - @concat + @cs.add(concat=True) def features(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) features = np.atleast_2d(model.features[spike_ids]) @@ -172,11 +182,12 @@ def background_features_masks(): assert masks.shape[0] == features.shape[0] return spike_ids, features, masks - @cs.add - @concat + @cs.add(concat=True) def waveforms(cluster_id): - spike_ids = select(cluster_id, max_n_spikes_per_cluster['waveforms']) + spike_ids = select(cluster_id, + max_n_spikes_per_cluster['waveforms']) waveforms = np.atleast_2d(model.waveforms[spike_ids]) + # print("w", cluster_id, spike_ids.shape, spike_ids[0], spike_ids[-1], waveforms.shape) assert waveforms.ndim == 3 return spike_ids, waveforms @@ -186,10 +197,10 @@ def waveform_lim(): return _get_data_lim(model.waveforms, max_n_spikes_per_cluster['waveform_lim']) - @cs.add - @concat + @cs.add(concat=True) def waveforms_masks(cluster_id): - spike_ids = select(cluster_id, max_n_spikes_per_cluster['waveforms']) + spike_ids = select(cluster_id, + max_n_spikes_per_cluster['waveforms']) waveforms = np.atleast_2d(model.waveforms[spike_ids]) assert waveforms.ndim == 3 masks = np.atleast_2d(model.masks[spike_ids]) @@ -237,7 +248,6 @@ def max_waveform_amplitude(cluster_id): mm = cs.mean_masks(cluster_id) mw = cs.mean_waveforms(cluster_id) assert mw.ndim == 2 - logger.debug("Computing the quality of cluster %d.", cluster_id) return np.asscalar(get_max_waveform_amplitude(mm, mw)) @cs.add(cache=None) @@ -247,8 +257,6 @@ def mean_masked_features_score(cluster_0, cluster_1): mm0 = cs.mean_masks(cluster_0) mm1 = cs.mean_masks(cluster_1) nfpc = model.n_features_per_channel - logger.debug("Computing the similarity of clusters %d, %d.", - cluster_0, cluster_1) d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, n_features_per_channel=nfpc) s = 1. / max(1e-10, d) @@ -275,9 +283,14 @@ def mean_traces(): class ClusterStorePlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): ctx = gui.request('context') + + def spikes_per_cluster(cluster_id): + mc = gui.request('manual_clustering') + return mc.clustering.spikes_per_cluster[cluster_id] + assert ctx selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=model.spikes_per_cluster, + spikes_per_cluster=spikes_per_cluster, ) cs = create_cluster_store(model, selector=selector, context=ctx) cs.attach(gui) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index d9da57892..c6e882bba 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -29,8 +29,9 @@ def f(x): def test_create_cluster_store(model): + spc = lambda c: model.spikes_per_cluster[c] selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=model.spikes_per_cluster) + spikes_per_cluster=spc) cs = create_cluster_store(model, selector=selector) nc = model.n_channels diff --git a/phy/io/array.py b/phy/io/array.py index 57f0a95f3..ccea62157 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -405,7 +405,7 @@ def select_spikes(cluster_ids=None, if not len(cluster_ids): return np.array([], dtype=np.int64) if max_n_spikes_per_cluster in (None, 0): - selection = {c: spikes_per_cluster[c] for c in cluster_ids} + selection = {c: spikes_per_cluster(c) for c in cluster_ids} else: assert max_n_spikes_per_cluster > 0 selection = {} @@ -415,7 +415,7 @@ def select_spikes(cluster_ids=None, # are more clusters. n = int(max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1))) n = max(1, n) - spikes = spikes_per_cluster[cluster] + spikes = spikes_per_cluster(cluster) selection[cluster] = regular_subset(spikes, n_spikes_max=n) return _flatten_per_cluster(selection) @@ -428,6 +428,7 @@ def __init__(self, spikes_per_cluster=None, spike_ids=None, ): + # NOTE: spikes_per_cluster is a function. self.spike_clusters = spike_clusters self.spikes_per_cluster = spikes_per_cluster self.n_spikes = len(spike_clusters) diff --git a/phy/io/context.py b/phy/io/context.py index 0ecfd5264..32e5ca984 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -148,7 +148,7 @@ def _fullname(o): class Context(object): """Handle function cacheing and parallel map with ipyparallel.""" - def __init__(self, cache_dir, ipy_view=None, verbose=0): + def __init__(self, cache_dir, ipy_view=None, verbose=100): self.verbose = verbose # Make sure the cache directory exists. self.cache_dir = op.realpath(op.expanduser(cache_dir)) @@ -210,9 +210,11 @@ def mem_cached(*args, **kwargs): """Cache the function in memory.""" h = hash((args, kwargs)) if h in c: + logger.debug("Get %s(%s) from memcache.", name, str(args)) # Retrieve the value from the memcache. return c[h] else: + logger.debug("Get %s(%s) from joblib.", name, str(args)) # Call and cache the function. out = disk_cached(*args, **kwargs) c[h] = out @@ -327,8 +329,8 @@ def __setstate__(self, state): class ContextPlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): # Create the computing context. - gui.register(Context(op.join(op.dirname(model.path), '.phy/')), - name='context') + ctx = Context(op.join(op.dirname(model.path), '.phy/')) + gui.register(ctx, name='context') #------------------------------------------------------------------------------ diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index dab2837d5..c57e1d548 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -357,10 +357,10 @@ def test_select_spikes(): select_spikes() spikes = [2, 3, 5, 7, 11] sc = [2, 3, 3, 2, 2] - spc = {2: [2, 7, 11], 3: [3, 5], 5: []} + spc = lambda c: {2: [2, 7, 11], 3: [3, 5], 5: []}.get(c, None) ae(select_spikes([], spikes_per_cluster=spc), []) ae(select_spikes([2, 3, 5], spikes_per_cluster=spc), spikes) - ae(select_spikes([2, 5], spikes_per_cluster=spc), spc[2]) + ae(select_spikes([2, 5], spikes_per_cluster=spc), spc(2)) ae(select_spikes([2, 3, 5], 0, spikes_per_cluster=spc), spikes) ae(select_spikes([2, 3, 5], None, spikes_per_cluster=spc), spikes) @@ -372,5 +372,5 @@ def test_select_spikes(): spike_ids=spikes, ) assert sel.select_spikes() is None - ae(sel.select_spikes([2, 5]), spc[2]) + ae(sel.select_spikes([2, 5]), spc(2)) ae(sel.select_spikes([2, 5], 2), [2]) From 58cf2819b674526307a3544137505f6ddd3d8c63 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 12:42:09 +0100 Subject: [PATCH 0892/1059] Flakify --- phy/cluster/manual/store.py | 4 +--- phy/io/context.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index ec3de625e..6bd1bfacb 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -141,7 +141,6 @@ def masks(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['masks']) masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 - # print("m", cluster_id, spike_ids.shape, spike_ids[0], spike_ids[-1], masks.shape) return spike_ids, masks @cs.add(concat=True) @@ -187,7 +186,6 @@ def waveforms(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['waveforms']) waveforms = np.atleast_2d(model.waveforms[spike_ids]) - # print("w", cluster_id, spike_ids.shape, spike_ids[0], spike_ids[-1], waveforms.shape) assert waveforms.ndim == 3 return spike_ids, waveforms @@ -200,7 +198,7 @@ def waveform_lim(): @cs.add(concat=True) def waveforms_masks(cluster_id): spike_ids = select(cluster_id, - max_n_spikes_per_cluster['waveforms']) + max_n_spikes_per_cluster['waveforms']) waveforms = np.atleast_2d(model.waveforms[spike_ids]) assert waveforms.ndim == 3 masks = np.atleast_2d(model.masks[spike_ids]) diff --git a/phy/io/context.py b/phy/io/context.py index 32e5ca984..6090d61af 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -148,7 +148,7 @@ def _fullname(o): class Context(object): """Handle function cacheing and parallel map with ipyparallel.""" - def __init__(self, cache_dir, ipy_view=None, verbose=100): + def __init__(self, cache_dir, ipy_view=None, verbose=0): self.verbose = verbose # Make sure the cache directory exists. self.cache_dir = op.realpath(op.expanduser(cache_dir)) From f93916088921e7e3a0c26b91b27cd252b462a0c9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 13:31:50 +0100 Subject: [PATCH 0893/1059] Add comment --- phy/cluster/manual/store.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 6bd1bfacb..3e65b54fb 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -282,6 +282,9 @@ class ClusterStorePlugin(IPlugin): def attach_to_gui(self, gui, model=None, state=None): ctx = gui.request('context') + # NOTE: we get the spikes_per_cluster from the Clustering instance. + # We need to access it from a function to avoid circular dependencies + # between the cluster store and manual clustering plugins. def spikes_per_cluster(cluster_id): mc = gui.request('manual_clustering') return mc.clustering.spikes_per_cluster[cluster_id] From 5c8837d24b8099bb0af96c013100a6ce33d91d2d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Dec 2015 13:54:54 +0100 Subject: [PATCH 0894/1059] WIP: ensure cluster_ids are ints to avoid cache issues --- phy/cluster/manual/gui_component.py | 8 +++++--- phy/cluster/manual/store.py | 17 ++++++----------- phy/cluster/manual/views.py | 1 + phy/gui/widgets.py | 3 +++ phy/utils/__init__.py | 1 + phy/utils/_types.py | 11 +++++++++++ 6 files changed, 27 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index a38d07eb1..60a420846 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -283,7 +283,8 @@ def on_request_undo_state(up): def _update_cluster_view(self): """Initialize the cluster view with cluster data.""" logger.log(5, "Update the cluster view.") - self.cluster_view.set_rows(self.clustering.cluster_ids) + cluster_ids = [int(c) for c in self.clustering.cluster_ids] + self.cluster_view.set_rows(cluster_ids) def _update_similarity_view(self): """Update the similarity view with matches for the specified @@ -297,9 +298,10 @@ def _update_similarity_view(self): logger.log(5, "Update the similarity view.") # This is a list of pairs (closest_cluster, similarity). self._best = cluster_id - self._similarity = {cl: s + self._similarity = {int(cl): s for (cl, s) in self.similarity_func(cluster_id)} - self.similarity_view.set_rows([c for c in self.clustering.cluster_ids + self.similarity_view.set_rows([int(c) + for c in self.clustering.cluster_ids if c not in selection]) self.similarity_view.sort_by('similarity', 'desc') diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 3e65b54fb..f1c66dac8 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -20,7 +20,7 @@ get_unmasked_channels, get_sorted_main_channels, ) -from phy.utils import IPlugin +from phy.utils import IPlugin, _as_scalar, _as_scalars logger = logging.getLogger(__name__) @@ -41,15 +41,8 @@ def _get_data_lim(arr, n_spikes=None, percentile=None): def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): """Return a list of pairs `(cluster, similarity)` sorted by decreasing similarity to a given cluster.""" - - def _as_float(x): - """Ensure the value is a float.""" - if isinstance(x, np.generic): - return np.asscalar(x) - return float(x) - - l = [(candidate, _as_float(sim_func(cluster_id, candidate))) - for candidate in cluster_ids] + l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) + for candidate in _as_scalars(cluster_ids)] l = sorted(l, key=itemgetter(1), reverse=True) return l[:max_n] @@ -57,7 +50,7 @@ def _as_float(x): def _log(f): @wraps(f) def wrapped(*args, **kwargs): - logger.log(10, "Compute %s(%s).", f.__name__, str(args)) + logger.log(5, "Compute %s(%s).", f.__name__, str(args)) return f(*args, **kwargs) return wrapped @@ -130,6 +123,7 @@ def create_cluster_store(model, selector=None, context=None): max_n_similar_clusters = 20 def select(cluster_id, n=None): + assert isinstance(cluster_id, int) assert cluster_id >= 0 return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) @@ -262,6 +256,7 @@ def mean_masked_features_score(cluster_0, cluster_1): @cs.add(cache='memory') def most_similar_clusters(cluster_id): + assert isinstance(cluster_id, int) return get_closest_clusters(cluster_id, model.cluster_ids, cs.mean_masked_features_score, max_n_similar_clusters) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d97b62678..e87c6aadd 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -173,6 +173,7 @@ def on_select(self, cluster_ids=None): cluster_ids = (cluster_ids if cluster_ids is not None else self.cluster_ids) self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] + self.cluster_ids = [int(c) for c in self.cluster_ids] def _best_channels(self, cluster_ids, n_channels_requested=None): """Return the best channels for a set of clusters.""" diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 947d83a8f..51dbd4588 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -255,6 +255,9 @@ def _get_row(self, id): def set_rows(self, ids): """Set the rows of the table.""" + # NOTE: make sure we have integers and not np.generic objects. + assert all(isinstance(i, int) for i in ids) + # Determine the sort column and dir to set after the rows. sort_col, sort_dir = self.current_sort default_sort_col, default_sort_dir = self.default_sort diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 484fc229a..45ced54a5 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -5,6 +5,7 @@ from ._misc import _load_json, _save_json from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, + _as_scalar, _as_scalars, Bunch, _is_list, _bunchify) from .event import EventEmitter, ProgressReporter from .plugin import IPlugin, get_plugin, get_all_plugins diff --git a/phy/utils/_types.py b/phy/utils/_types.py index 021b25e01..72394744a 100644 --- a/phy/utils/_types.py +++ b/phy/utils/_types.py @@ -45,6 +45,17 @@ def _is_list(obj): return isinstance(obj, list) +def _as_scalar(obj): + if isinstance(obj, np.generic): + return np.asscalar(obj) + assert isinstance(obj, (int, float)) + return obj + + +def _as_scalars(arr): + return [_as_scalar(x) for x in arr] + + def _is_integer(x): return isinstance(x, integer_types + (np.generic,)) From ba980bbe270e994ae7e39fe6695445758f041675 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:03:15 +0100 Subject: [PATCH 0895/1059] WIP: fixing appveyor --- appveyor.yml | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index a999f35e9..fa1b115c8 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,16 +9,19 @@ environment: PYTHON_ARCH: "64" install: - # Install app from latest script - - "powershell iex ((new-object net.webclient).DownloadString('http://phy.cortexlab.net/install/latest.ps1'))" - - "SET PATH=%HOMEPATH%\\miniconda3;%HOMEPATH%\\miniconda3\\Scripts;%PATH%" - - # We want to use the git version of phy, so let's uninstall the pip version - - "pip uninstall phy --yes" - - "pip install -r requirements-dev.txt" - - "python setup.py develop" + - pushd ~ + - (new-object System.Net.WebClient).DownloadFile('https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe', "$Home\miniconda3.exe") + - .\miniconda3.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null + - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" + - conda config --set ssl_verify false + - conda env create python=3.5 -y + - source activate phy + - conda install -c kwikteam klustakwik2 -y + - conda config --set ssl_verify true + - pip install -e . + - popd + - pip install -r requirements-dev.txt build: false # Not a C# project, build stuff at the test step instead. - test_script: # Run the project tests - - py.test phy -m "not long" + - py.test phy From 73eebb8d148c0087618fd4a41a2ab92df793482d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:07:42 +0100 Subject: [PATCH 0896/1059] WIP --- appveyor.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index fa1b115c8..32cdafd62 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,7 +9,6 @@ environment: PYTHON_ARCH: "64" install: - - pushd ~ - (new-object System.Net.WebClient).DownloadFile('https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe', "$Home\miniconda3.exe") - .\miniconda3.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" @@ -18,9 +17,8 @@ install: - source activate phy - conda install -c kwikteam klustakwik2 -y - conda config --set ssl_verify true - - pip install -e . - - popd - pip install -r requirements-dev.txt + - pip install -e . build: false # Not a C# project, build stuff at the test step instead. test_script: # Run the project tests From 0949f2cd99d6393ab5906e999a7a29e6791d3c1b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:10:55 +0100 Subject: [PATCH 0897/1059] WIP --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 32cdafd62..fa941d3d3 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,7 +9,7 @@ environment: PYTHON_ARCH: "64" install: - - (new-object System.Net.WebClient).DownloadFile('https://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe', "$Home\miniconda3.exe") + - (new-object System.Net.WebClient).DownloadFile('http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe', "$Home\miniconda3.exe") - .\miniconda3.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - conda config --set ssl_verify false From 2e12e80e281fd7da8992559a7aae6f13b059432d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:13:52 +0100 Subject: [PATCH 0898/1059] WIP --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index fa941d3d3..53da5b15a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,7 +9,7 @@ environment: PYTHON_ARCH: "64" install: - - (new-object System.Net.WebClient).DownloadFile('http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe', "$Home\miniconda3.exe") + - Invoke-WebRequest "http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe" -OutFile "$Home\miniconda3.exe" - .\miniconda3.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - conda config --set ssl_verify false From b9fe6d59b2776b7934a245eec0fcd3877c6b88c4 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:16:49 +0100 Subject: [PATCH 0899/1059] WIP --- appveyor.yml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 53da5b15a..0df8c0710 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -2,15 +2,13 @@ # This file was based on Olivier Grisel's python-appveyor-demo environment: - matrix: - - PYTHON: "C:\\Python34-conda64" - PYTHON_VERSION: "3.4" + - PYTHON: "C:\\Python35-conda64" + PYTHON_VERSION: "3.5" PYTHON_ARCH: "64" - install: - - Invoke-WebRequest "http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe" -OutFile "$Home\miniconda3.exe" - - .\miniconda3.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null + - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe' + - .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - conda config --set ssl_verify false - conda env create python=3.5 -y From f19ce84e708f8f2ca97c757906c398c1984b17c2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:19:25 +0100 Subject: [PATCH 0900/1059] WIP --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 0df8c0710..710815fe4 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -8,7 +8,7 @@ environment: PYTHON_ARCH: "64" install: - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe' - - .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null + - .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - conda config --set ssl_verify false - conda env create python=3.5 -y From 13dc308ae662f551da92ad019e302b040da0d26e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:22:29 +0100 Subject: [PATCH 0901/1059] WIP --- appveyor.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 710815fe4..6b2b3c23b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -8,15 +8,15 @@ environment: PYTHON_ARCH: "64" install: - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe' - - .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" - - $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - - conda config --set ssl_verify false - - conda env create python=3.5 -y - - source activate phy - - conda install -c kwikteam klustakwik2 -y - - conda config --set ssl_verify true - - pip install -r requirements-dev.txt - - pip install -e . + - ps: .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null + - ps: $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" + - ps: conda config --set ssl_verify false + - ps: conda env create python=3.5 -y + - ps: source activate phy + - ps: conda install -c kwikteam klustakwik2 -y + - ps: conda config --set ssl_verify true + - ps: pip install -r requirements-dev.txt + - ps: pip install -e . build: false # Not a C# project, build stuff at the test step instead. test_script: # Run the project tests From d59e99965040019301f4ecb47e4e8ea9ad411a8b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:24:01 +0100 Subject: [PATCH 0902/1059] WIP --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 6b2b3c23b..fef322810 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,7 +11,7 @@ install: - ps: .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - ps: $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - ps: conda config --set ssl_verify false - - ps: conda env create python=3.5 -y + - ps: conda env create python=3.5 - ps: source activate phy - ps: conda install -c kwikteam klustakwik2 -y - ps: conda config --set ssl_verify true From 177bef779376b992ee034120df51e9dafbefcb59 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:32:47 +0100 Subject: [PATCH 0903/1059] WIP: remove KK2 from pip in environment.yml --- .travis.yml | 1 + environment.yml | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 4844c62a8..5ed964668 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,6 +21,7 @@ install: # Create the environment. - conda env create python=$TRAVIS_PYTHON_VERSION - source activate phy + - pip install klustakwik2 # Dev requirements - pip install -r requirements-dev.txt - pip install -e . diff --git a/environment.yml b/environment.yml index 0b61d5aae..a2b8dec3a 100644 --- a/environment.yml +++ b/environment.yml @@ -18,5 +18,3 @@ dependencies: - dask - cython - click - - pip: - - klustakwik2 From e67dd33cf6e0b2f64a2492f71d4bc675f6e45cd2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:43:32 +0100 Subject: [PATCH 0904/1059] WIP --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index fef322810..a98c709cc 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -12,7 +12,7 @@ install: - ps: $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - ps: conda config --set ssl_verify false - ps: conda env create python=3.5 - - ps: source activate phy + - ps: activate phy - ps: conda install -c kwikteam klustakwik2 -y - ps: conda config --set ssl_verify true - ps: pip install -r requirements-dev.txt From 5d32aa21cb3ded36b09fc2862767423dda3053f1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:49:13 +0100 Subject: [PATCH 0905/1059] WIP: py3.4 in appveyor --- appveyor.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index a98c709cc..a14397ac3 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -3,15 +3,15 @@ environment: matrix: - - PYTHON: "C:\\Python35-conda64" - PYTHON_VERSION: "3.5" + - PYTHON: "C:\\Python34-conda64" + PYTHON_VERSION: "3.4" PYTHON_ARCH: "64" install: - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe' - ps: .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - ps: $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - ps: conda config --set ssl_verify false - - ps: conda env create python=3.5 + - ps: conda env create python=3.4 - ps: activate phy - ps: conda install -c kwikteam klustakwik2 -y - ps: conda config --set ssl_verify true From fd6e86f5643914c1d2cf41777d0e8a8a4487bee9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 25 Dec 2015 22:55:21 +0100 Subject: [PATCH 0906/1059] WIP --- appveyor.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index a14397ac3..6521ec916 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -7,8 +7,8 @@ environment: PYTHON_VERSION: "3.4" PYTHON_ARCH: "64" install: - - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda3-latest-Windows-x86_64.exe' - - ps: .\Miniconda3-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null + - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda-latest-Windows-x86_64.exe' + - ps: .\Miniconda-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - ps: $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - ps: conda config --set ssl_verify false - ps: conda env create python=3.4 From 7d0431e026ebfcaba77308461d8516a33dc41819 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 5 Jan 2016 19:29:46 +0100 Subject: [PATCH 0907/1059] WIP: bug fixes in views and store when data is missing --- phy/cluster/manual/store.py | 13 +++++++++++-- phy/cluster/manual/views.py | 13 ++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index f1c66dac8..2a7084b70 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -133,7 +133,10 @@ def select(cluster_id, n=None): @cs.add(concat=True) def masks(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['masks']) - masks = np.atleast_2d(model.masks[spike_ids]) + if model.masks is None: + masks = np.ones((len(spike_ids), len(model.channel_order))) + else: + masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 return spike_ids, masks @@ -152,7 +155,13 @@ def features_masks(cluster_id): @cs.add(concat=True) def features(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) - features = np.atleast_2d(model.features[spike_ids]) + if model.features is None: + features = np.zeros((len(spike_ids), + len(model.channel_order), + model.n_features_per_channel, + )) + else: + features = np.atleast_2d(model.features[spike_ids]) assert features.ndim == 3 return spike_ids, features diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e87c6aadd..e49d95402 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -585,7 +585,8 @@ def __init__(self, self.spike_clusters = spike_clusters # Masks. - assert masks.shape == (self.n_spikes, self.n_channels) + if masks is not None: + assert masks.shape == (self.n_spikes, self.n_channels) self.masks = masks else: self.spike_times = self.spike_clusters = self.masks = None @@ -622,7 +623,7 @@ def _load_traces(self, interval): assert traces.shape[1] == self.n_channels # Detrend the traces. - traces -= self.mean_traces + traces = traces - self.mean_traces # Create the plots. return traces @@ -632,7 +633,13 @@ def _load_spikes(self, interval): assert self.spike_times is not None # Keep the spikes in the interval. a, b = self.spike_times.searchsorted(interval) - return self.spike_times[a:b], self.spike_clusters[a:b], self.masks[a:b] + spike_times = self.spike_times[a:b] + spike_clusters = self.spike_clusters[a:b] + n_spikes = len(spike_times) + assert len(spike_clusters) == n_spikes + masks = (self.masks[a:b] if self.masks is not None + else np.ones(n_spikes)) + return spike_times, spike_clusters, masks def _plot_traces(self, traces, start=None, data_bounds=None): t = start + np.arange(traces.shape[0]) * self.dt From 0ec97c5cda5ded9a2204bad84b3e6a48feca19c1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 5 Jan 2016 19:48:55 +0100 Subject: [PATCH 0908/1059] Fix --- phy/cluster/manual/views.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e49d95402..a007f003d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -637,8 +637,12 @@ def _load_spikes(self, interval): spike_clusters = self.spike_clusters[a:b] n_spikes = len(spike_times) assert len(spike_clusters) == n_spikes + # TODO: make this cleaner + nc = (self.n_channels + if isinstance(self.channel_order, slice) + else len(self.channel_order)) masks = (self.masks[a:b] if self.masks is not None - else np.ones(n_spikes)) + else np.ones((n_spikes, nc))) return spike_times, spike_clusters, masks def _plot_traces(self, traces, start=None, data_bounds=None): From 7537e2c7dc00475dc6c1dc3a8b13fa8f4f959450 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 8 Jan 2016 14:46:48 +0100 Subject: [PATCH 0909/1059] Remove old fixture --- phy/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/phy/conftest.py b/phy/conftest.py index 1e56705d9..869c5fb27 100644 --- a/phy/conftest.py +++ b/phy/conftest.py @@ -40,9 +40,3 @@ def chdir_tempdir(): os.chdir(tempdir) yield tempdir os.chdir(curdir) - - -@yield_fixture -def tempdir_bis(): - with TemporaryDirectory() as tempdir: - yield tempdir From e95da596bc70b7afa8c2870aac4e1a0e2790eb68 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 8 Jan 2016 15:14:56 +0100 Subject: [PATCH 0910/1059] WIP: remove channel order in views --- phy/cluster/manual/store.py | 6 +++--- phy/cluster/manual/tests/test_store.py | 2 +- phy/cluster/manual/views.py | 22 ++++------------------ 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 2a7084b70..036c807bc 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -134,7 +134,7 @@ def select(cluster_id, n=None): def masks(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['masks']) if model.masks is None: - masks = np.ones((len(spike_ids), len(model.channel_order))) + masks = np.ones((len(spike_ids), model.n_channels)) else: masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 @@ -157,7 +157,7 @@ def features(cluster_id): spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) if model.features is None: features = np.zeros((len(spike_ids), - len(model.channel_order), + model.n_channels, model.n_features_per_channel, )) else: @@ -276,7 +276,7 @@ def most_similar_clusters(cluster_id): @cs.add def mean_traces(): n = max_n_spikes_per_cluster['mean_traces'] - mt = model.traces[:n, model.channel_order].mean(axis=0) + mt = model.traces[:n, :].mean(axis=0) return mt.astype(model.traces.dtype) return cs diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index c6e882bba..d85cd0a88 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -77,7 +77,7 @@ def _check(out, *shape): # Limits. assert 0 < cs.waveform_lim() < 3 assert 0 < cs.feature_lim() < 3 - assert cs.mean_traces().shape == (1, nc) + assert cs.mean_traces().shape == (nc,) # Statistics. assert 1 <= len(cs.best_channels(1)) <= nc diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a007f003d..e5ff865bc 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -538,7 +538,6 @@ def __init__(self, spike_times=None, spike_clusters=None, masks=None, # full array of masks - channel_order=None, n_samples_per_spike=None, scaling=None, mean_traces=None, @@ -551,18 +550,10 @@ def __init__(self, # Traces. assert len(traces.shape) == 2 - self.n_samples, self.n_channels_traces = traces.shape + self.n_samples, self.n_channels = traces.shape self.traces = traces self.duration = self.dt * self.n_samples - # Channel ordering and dead channels. - # We do traces[..., channel_order] whenever we load traces - # so that the channels match those in masks. - self.n_channels = (self.n_channels_traces if channel_order is None - else len(channel_order)) - self.channel_order = (channel_order if channel_order is not None - else slice(None, None, None)) - # Used to detrend the traces. self.mean_traces = np.atleast_2d(mean_traces) assert self.mean_traces.shape == (1, self.n_channels) @@ -618,8 +609,8 @@ def _load_traces(self, interval): i, j = int(i), int(j) # We load the traces and select the requested channels. - assert self.traces.shape[1] == self.n_channels_traces - traces = self.traces[i:j, self.channel_order] + assert self.traces.shape[1] == self.n_channels + traces = self.traces[i:j, :] assert traces.shape[1] == self.n_channels # Detrend the traces. @@ -637,12 +628,8 @@ def _load_spikes(self, interval): spike_clusters = self.spike_clusters[a:b] n_spikes = len(spike_times) assert len(spike_clusters) == n_spikes - # TODO: make this cleaner - nc = (self.n_channels - if isinstance(self.channel_order, slice) - else len(self.channel_order)) masks = (self.masks[a:b] if self.masks is not None - else np.ones((n_spikes, nc))) + else np.ones((n_spikes, self.n_channels))) return spike_times, spike_clusters, masks def _plot_traces(self, traces, start=None, data_bounds=None): @@ -849,7 +836,6 @@ def attach_to_gui(self, gui, model=None, state=None): spike_times=model.spike_times, spike_clusters=model.spike_clusters, masks=model.masks, - channel_order=model.channel_order, scaling=s, mean_traces=cs.mean_traces(), ) From 865b9ca4a4ce5e292ef9ea27fa28589896b66bc1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 8 Jan 2016 16:19:09 +0100 Subject: [PATCH 0911/1059] Refactor data functions of clustering views --- phy/cluster/manual/store.py | 43 +++++++++---- phy/cluster/manual/tests/test_store.py | 28 +++++---- phy/cluster/manual/views.py | 85 +++++++++++++------------- phy/plot/plot.py | 3 +- 4 files changed, 90 insertions(+), 69 deletions(-) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 036c807bc..341be38ac 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -14,13 +14,14 @@ import numpy as np from phy.io.array import Selector +from phy.plot.plot import _accumulate from phy.stats.clusters import (mean, get_max_waveform_amplitude, get_mean_masked_features_distance, get_unmasked_channels, get_sorted_main_channels, ) -from phy.utils import IPlugin, _as_scalar, _as_scalars +from phy.utils import Bunch, IPlugin, _as_scalar, _as_scalars logger = logging.getLogger(__name__) @@ -64,8 +65,7 @@ def wrapped(cluster_ids): if not hasattr(cluster_ids, '__len__'): return f(cluster_ids) # Concatenate the result of multiple clusters. - arrs = zip(*(f(c) for c in cluster_ids)) - return tuple(np.concatenate(_, axis=0) for _ in arrs) + return Bunch(_accumulate([f(c) for c in cluster_ids])) return wrapped @@ -127,6 +127,10 @@ def select(cluster_id, n=None): assert cluster_id >= 0 return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) + def _get_data(**kwargs): + kwargs['spike_clusters'] = model.spike_clusters[kwargs['spike_ids']] + return Bunch(**kwargs) + # Model data. # ------------------------------------------------------------------------- @@ -138,7 +142,9 @@ def masks(cluster_id): else: masks = np.atleast_2d(model.masks[spike_ids]) assert masks.ndim == 2 - return spike_ids, masks + return _get_data(spike_ids=spike_ids, + masks=masks, + ) @cs.add(concat=True) def features_masks(cluster_id): @@ -150,7 +156,10 @@ def features_masks(cluster_id): assert fm.ndim == 3 f = fm[..., 0].reshape((ns, nc, nfpc)) m = fm[:, ::nfpc, 1] - return spike_ids, f, m + return _get_data(spike_ids=spike_ids, + features=f, + masks=m, + ) @cs.add(concat=True) def features(cluster_id): @@ -163,7 +172,9 @@ def features(cluster_id): else: features = np.atleast_2d(model.features[spike_ids]) assert features.ndim == 3 - return spike_ids, features + return _get_data(spike_ids=spike_ids, + features=features, + ) @cs.add def feature_lim(): @@ -182,7 +193,10 @@ def background_features_masks(): assert features.ndim == 3 assert masks.ndim == 2 assert masks.shape[0] == features.shape[0] - return spike_ids, features, masks + return _get_data(spike_ids=spike_ids, + features=features, + masks=masks, + ) @cs.add(concat=True) def waveforms(cluster_id): @@ -190,7 +204,9 @@ def waveforms(cluster_id): max_n_spikes_per_cluster['waveforms']) waveforms = np.atleast_2d(model.waveforms[spike_ids]) assert waveforms.ndim == 3 - return spike_ids, waveforms + return _get_data(spike_ids=spike_ids, + waveforms=waveforms, + ) @cs.add def waveform_lim(): @@ -208,7 +224,10 @@ def waveforms_masks(cluster_id): assert masks.ndim == 2 # Ensure that both arrays have the same number of channels. assert masks.shape[1] == waveforms.shape[2] - return spike_ids, waveforms, masks + return _get_data(spike_ids=spike_ids, + waveforms=waveforms, + masks=masks, + ) # Mean quantities. # ------------------------------------------------------------------------- @@ -216,15 +235,15 @@ def waveforms_masks(cluster_id): @cs.add def mean_masks(cluster_id): # We access [1] because we return spike_ids, masks. - return mean(cs.masks(cluster_id)[1]) + return mean(cs.masks(cluster_id).masks) @cs.add def mean_features(cluster_id): - return mean(cs.features(cluster_id)[1]) + return mean(cs.features(cluster_id).features) @cs.add def mean_waveforms(cluster_id): - return mean(cs.waveforms(cluster_id)[1]) + return mean(cs.waveforms(cluster_id).waveforms) # Statistics. # ------------------------------------------------------------------------- diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index d85cd0a88..293695dc9 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -40,23 +40,24 @@ def test_create_cluster_store(model): ns2 = len(model.spikes_per_cluster[2]) nsw = model.n_samples_waveforms - def _check(out, *shape): - spikes, arr = out + def _check(out, name, *shape): + spikes = out.pop('spike_ids') + arr = out[name] assert spikes.shape[0] == shape[0] assert arr.shape == shape # Model data. - _check(cs.masks(1), ns, nc) - _check(cs.features(1), ns, nc, nfpc) - _check(cs.waveforms(1), ns, nsw, nc) - - # Waveforms masks. - spike_ids, w, m = cs.waveforms_masks(1) - _check((spike_ids, w), ns, nsw, nc) - _check((spike_ids, m), ns, nc) + _check(cs.masks(1), 'masks', ns, nc) + _check(cs.features(1), 'features', ns, nc, nfpc) + _check(cs.waveforms(1), 'waveforms', ns, nsw, nc) + _check(cs.waveforms_masks(1), 'waveforms', ns, nsw, nc) + _check(cs.waveforms_masks(1), 'masks', ns, nc) # Background feature masks. - spike_ids, bgf, bgm = cs.background_features_masks() + data = cs.background_features_masks() + spike_ids = data.spike_ids + bgf = data.features + bgm = data.masks assert bgf.ndim == 3 assert bgf.shape[1:] == (nc, nfpc) assert bgm.ndim == 2 @@ -64,7 +65,10 @@ def _check(out, *shape): assert spike_ids.shape == (bgf.shape[0],) == (bgm.shape[0],) # Test concat multiple clusters. - spike_ids, f, m = cs.features_masks([1, 2]) + data = cs.features_masks([1, 2]) + spike_ids = data.spike_ids + f = data.features + m = data.masks assert len(spike_ids) == ns + ns2 assert f.shape == (ns + ns2, nc, nfpc) assert m.shape == (ns + ns2, nc) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index e5ff865bc..816768a41 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -72,17 +72,6 @@ def _extract_wave(traces, spk, mask, wave_len=None): return data, channels -def _get_spike_clusters_rel(spike_clusters, spike_ids, cluster_ids): - # Relative spike clusters. - # NOTE: the order of the clusters in cluster_ids matters. - # It will influence the relative index of the clusters, which - # in return influence the depth. - spike_clusters = spike_clusters[spike_ids] - assert np.all(np.in1d(spike_clusters, cluster_ids)) - spike_clusters_rel = _index_of(spike_clusters, cluster_ids) - return spike_clusters_rel - - def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): """Return the OpenGL z-depth of vertices as a function of the mask and cluster index.""" @@ -147,7 +136,11 @@ def __init__(self, type, message=None): class ManualClusteringView(View): - max_n_spikes_per_cluster = None + """Base class for clustering views. + + The views take their data with functions `cluster_ids: spike_ids, data`. + + """ default_shortcuts = { } @@ -259,19 +252,12 @@ class WaveformView(ManualClusteringView): def __init__(self, waveforms_masks=None, - spike_clusters=None, channel_positions=None, box_scaling=None, probe_scaling=None, n_samples=None, waveform_lim=None, **kwargs): - """ - - The channel order in waveforms needs to correspond to the one - in channel_positions. - - """ self._key_pressed = None # Channel positions and n_channels. @@ -311,14 +297,11 @@ def __init__(self, assert waveform_lim > 0 self.data_bounds = [-1, -waveform_lim, +1, +waveform_lim] - # Spike clusters. - self.spike_clusters = spike_clusters - # Channel positions. assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - def on_select(self, cluster_ids=None): + def on_select(self, cluster_ids=None, zoom_on_channels=True): super(WaveformView, self).on_select(cluster_ids) cluster_ids = self.cluster_ids n_clusters = len(cluster_ids) @@ -326,14 +309,17 @@ def on_select(self, cluster_ids=None): return # Load the waveform subset. - spike_ids, w, masks = self.waveforms_masks(cluster_ids) + data = self.waveforms_masks(cluster_ids) + spike_ids = data.spike_ids + spike_clusters = data.spike_clusters + w = data.waveforms + masks = data.masks n_spikes = len(spike_ids) assert w.shape == (n_spikes, self.n_samples, self.n_channels) assert masks.shape == (n_spikes, self.n_channels) # Relative spike clusters. - spike_clusters_rel = _get_spike_clusters_rel(self.spike_clusters, - spike_ids, cluster_ids) + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) assert spike_clusters_rel.shape == (n_spikes,) # Fetch the waveforms. @@ -364,7 +350,7 @@ def on_select(self, cluster_ids=None): # Zoom on the best channels when selecting clusters. channels = self._best_channels(cluster_ids) - if channels is not None: + if channels is not None and zoom_on_channels: self.zoom_on_channels(channels) def attach(self, gui): @@ -396,7 +382,7 @@ def on_channel_click(e): def toggle_waveform_overlap(self): """Toggle the overlap of the waveforms.""" self.overlap = not self.overlap - self.on_select() + self.on_select(zoom_on_channels=False) # Box scaling # ------------------------------------------------------------------------- @@ -493,7 +479,6 @@ def attach_to_gui(self, gui, model=None, state=None): cs = gui.request('cluster_store') assert cs # We need the cluster store to retrieve the data. view = WaveformView(waveforms_masks=cs.waveforms_masks, - spike_clusters=model.spike_clusters, channel_positions=model.channel_positions, n_samples=model.n_samples_waveforms, box_scaling=bs, @@ -908,14 +893,23 @@ class FeatureView(ManualClusteringView): } def __init__(self, - features_masks=None, # function cluster_id => (spk, f, m) - background_features_masks=None, # (spk, f, m) + features_masks=None, + background_features_masks=None, spike_times=None, - spike_clusters=None, n_channels=None, n_features_per_channel=None, feature_lim=None, **kwargs): + """ + features_masks is a function : + `cluster_ids: Bunch(spike_ids, + features, + masks, + spike_clusters, + spike_times)` + background_features_masks is a Bunch(...) like above. + + """ assert features_masks self.features_masks = features_masks @@ -927,7 +921,9 @@ def __init__(self, assert n_channels > 0 self.n_channels = n_channels + # Spike times. self.n_spikes = spike_times.shape[0] + assert spike_times.shape == (self.n_spikes,) assert self.n_spikes >= 0 self.n_cols = self.n_features_per_channel + 1 @@ -941,15 +937,11 @@ def __init__(self, # Feature normalization. self.data_bounds = [-1, -feature_lim, +1, +feature_lim] - # Spike clusters. - assert spike_clusters.shape == (self.n_spikes,) - self.spike_clusters = spike_clusters - - # Spike times. - assert spike_times.shape == (self.n_spikes,) + # If this is True, the channels won't be automatically chosen + # when new clusters are selected. + self.fixed_channels = False # Channels to show. - self.fixed_channels = False self.x_channels = None self.y_channels = None @@ -1068,17 +1060,23 @@ def on_select(self, cluster_ids=None): return # Get the spikes, features, masks. - spike_ids, f, masks = self.features_masks(cluster_ids) + data = self.features_masks(cluster_ids) + spike_ids = data.spike_ids + spike_clusters = data.spike_clusters + f = data.features + masks = data.masks assert f.ndim == 3 assert masks.ndim == 2 assert spike_ids.shape[0] == f.shape[0] == masks.shape[0] # Get the spike clusters. - sc = _get_spike_clusters_rel(self.spike_clusters, spike_ids, - cluster_ids) + sc = _index_of(spike_clusters, cluster_ids) # Get the background features. - spike_ids_bg, features_bg, masks_bg = self.background_features_masks + data_bg = self.background_features_masks + spike_ids_bg = data_bg.spike_ids + features_bg = data_bg.features + masks_bg = data_bg.masks # Select the dimensions. # Choose the channels automatically unless fixed_channels is set. @@ -1193,7 +1191,6 @@ def attach_to_gui(self, gui, model=None, state=None): bg = cs.background_features_masks() view = FeatureView(features_masks=cs.features_masks, background_features_masks=bg, - spike_clusters=model.spike_clusters, spike_times=model.spike_times, n_channels=model.n_channels, n_features_per_channel=model.n_features_per_channel, diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 6899265de..b23cde259 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -48,7 +48,8 @@ def names(self): def __getitem__(self, name): """Concatenate all arrays with a given name.""" - return np.vstack(self._data[name]).astype(np.float32) + return np.concatenate(self._data[name], axis=0). \ + astype(self._data[name][0].dtype) def _accumulate(data_list, no_concat=()): From 5ca63a98f3235c6267e96b590a5a64cb5b13f712 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 8 Jan 2016 16:43:23 +0100 Subject: [PATCH 0912/1059] Change plugin interface --- phy/cluster/manual/gui_component.py | 4 ++-- phy/cluster/manual/store.py | 3 ++- .../manual/tests/test_gui_component.py | 3 ++- phy/cluster/manual/views.py | 16 ++++++++++---- phy/gui/gui.py | 22 ++++++++++++------- phy/gui/tests/test_gui.py | 12 +++++----- phy/io/context.py | 3 ++- 7 files changed, 40 insertions(+), 23 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 60a420846..1ac4abeb4 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -518,8 +518,8 @@ def save(self): class ManualClusteringPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): - + def attach_to_gui(self, gui): + model = gui.request('model') # Attach the manual clustering logic (wizard, merge, split, # undo stack) to the GUI. mc = ManualClustering(model.spike_clusters, diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index 341be38ac..f48d1626c 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -302,8 +302,9 @@ def mean_traces(): class ClusterStorePlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): + def attach_to_gui(self, gui): ctx = gui.request('context') + model = gui.request('model') # NOTE: we get the spikes_per_cluster from the Clustering instance. # We need to access it from a function to avoid circular dependencies diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 017d1ef61..ebacecb0b 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -65,7 +65,8 @@ def test_manual_clustering_plugin(qtbot, gui): masks=.75 * np.ones((3, 1)), ) state = Bunch() - ManualClusteringPlugin().attach_to_gui(gui, model=model, state=state) + gui.register(model=model, state=state) + ManualClusteringPlugin().attach_to_gui(gui) def test_manual_clustering_edge_cases(manual_clustering): diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 816768a41..bed90c062 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -470,7 +470,9 @@ def on_key_release(self, event): class WaveformViewPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): + def attach_to_gui(self, gui): + state = gui.state + model = gui.request('model') bs, ps, ov = state.get_view_params('WaveformView', 'box_scaling', 'probe_scaling', @@ -810,7 +812,9 @@ def decrease(self): class TraceViewPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): + def attach_to_gui(self, gui): + state = gui.state + model = gui.request('model') s, = state.get_view_params('TraceView', 'scaling') cs = gui.request('cluster_store') @@ -1185,8 +1189,10 @@ def feature_scaling(self, value): class FeatureViewPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): + def attach_to_gui(self, gui): + state = gui.state cs = gui.request('cluster_store') + model = gui.request('model') assert cs bg = cs.background_features_masks() view = FeatureView(features_masks=cs.features_masks, @@ -1344,7 +1350,9 @@ def set_window(self, window_size): class CorrelogramViewPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): + def attach_to_gui(self, gui): + state = gui.state + model = gui.request('model') bs, ws, es, ne, un = state.get_view_params('CorrelogramView', 'bin_size', 'window_size', diff --git a/phy/gui/gui.py b/phy/gui/gui.py index fca2a6bbf..e8bf4b4ea 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -206,12 +206,14 @@ def connect_(self, *args, **kwargs): def unconnect_(self, *args, **kwargs): self._event.unconnect(*args, **kwargs) - def register(self, func=None, name=None): + def register(self, obj=None, name=None, **kwargs): """Register a object for a given name.""" - if func is None: - return lambda _: self.register(func=_, name=name) - name = name or func.__name__ - self._registered[name] = func + for n, o in kwargs.items(): + self.register(o, n) + if obj is None: + return lambda _: self.register(obj=_, name=name) + name = name or obj.__name__ + self._registered[name] = obj def request(self, name, *args, **kwargs): """Request the result of a possibly registered object.""" @@ -416,7 +418,8 @@ def save(self): class SaveGeometryStatePlugin(IPlugin): - def attach_to_gui(self, gui, state=None, model=None): + def attach_to_gui(self, gui): + state = gui.state @gui.connect_ def on_close(): @@ -431,7 +434,7 @@ def on_show(): def create_gui(name=None, subtitle=None, model=None, plugins=None, config_dir=None): - """Create a GUI with a model and a list of plugins. + """Create a GUI with a list of plugins. By default, the list of plugins is taken from the `c.TheGUI.plugins` parameter, where `TheGUI` is the name of the GUI class. @@ -445,6 +448,9 @@ def create_gui(name=None, subtitle=None, model=None, state = GUIState(gui.name, config_dir=config_dir) gui.state = state + # Register the model. + gui.register(model=model) + # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. plugins_conf = load_master_config()[name].plugins @@ -454,7 +460,7 @@ def create_gui(name=None, subtitle=None, model=None, # Attach the plugins to the GUI. for plugin in plugins: logger.debug("Attach plugin `%s` to %s.", plugin, name) - get_plugin(plugin)().attach_to_gui(gui, state=state, model=model) + get_plugin(plugin)().attach_to_gui(gui) # Save the state to disk. @gui.connect_ diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 1fde6b571..744656379 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -195,8 +195,8 @@ def test_create_gui_1(qapp, tempdir): _tmp = [] class MyPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): - _tmp.append(state.hello) + def attach_to_gui(self, gui): + _tmp.append(gui.state.hello) gui = create_gui(plugins=['MyPlugin'], config_dir=tempdir) assert gui @@ -210,11 +210,11 @@ def attach_to_gui(self, gui, model=None, state=None): def test_save_geometry_state(gui): - state = Bunch() - SaveGeometryStatePlugin().attach_to_gui(gui, state=state) + gui.state = Bunch() + SaveGeometryStatePlugin().attach_to_gui(gui) gui.close() - assert state.geometry_state['geometry'] - assert state.geometry_state['state'] + assert gui.state.geometry_state['geometry'] + assert gui.state.geometry_state['state'] gui.show() diff --git a/phy/io/context.py b/phy/io/context.py index 6090d61af..38c0ff15f 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -327,7 +327,8 @@ def __setstate__(self, state): class ContextPlugin(IPlugin): - def attach_to_gui(self, gui, model=None, state=None): + def attach_to_gui(self, gui): + model = gui.request('model') # Create the computing context. ctx = Context(op.join(op.dirname(model.path), '.phy/')) gui.register(ctx, name='context') From 6ddcd38d19ac2f0a1b0a8bb98f75df2aebf96fa7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 8 Jan 2016 17:03:25 +0100 Subject: [PATCH 0913/1059] Quality and similarity functions are now defined in the GUI state --- phy/cluster/manual/gui_component.py | 12 ++++---- phy/cluster/manual/store.py | 3 ++ .../manual/tests/test_gui_component.py | 28 +++++++++---------- phy/cluster/manual/tests/test_store.py | 2 +- phy/cluster/manual/tests/test_views.py | 4 +++ 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 1ac4abeb4..c6df02e8f 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -402,11 +402,13 @@ def attach(self, gui): # Add the quality column in the cluster view. cs = gui.request('cluster_store') - if cs: - self.cluster_view.add_column(cs.max_waveform_amplitude, - name='quality') - self.set_default_sort('quality') - self.set_similarity_func(cs.most_similar_clusters) + if cs and 'ClusterView' in gui.state: + # Names of the quality and similarity functions. + quality = gui.state.ClusterView.quality + similarity = gui.state.ClusterView.similarity + self.cluster_view.add_column(cs.get(quality), name=quality) + self.set_default_sort(quality) + self.set_similarity_func(cs.get(similarity)) # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py index f48d1626c..6ac1a02e4 100644 --- a/phy/cluster/manual/store.py +++ b/phy/cluster/manual/store.py @@ -103,6 +103,9 @@ def add(self, f=None, name=None, cache='disk', concat=None): setattr(self, name, f) return f + def get(self, name): + return self._stats[name] + def attach(self, gui): gui.register(self, name='cluster_store') diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index ebacecb0b..a968942f6 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -21,6 +21,19 @@ # Fixtures #------------------------------------------------------------------------------ +@yield_fixture +def gui(qtbot): + gui = GUI(position=(200, 100), size=(500, 500)) + gui.state = Bunch() + gui.show() + qtbot.waitForWindowShown(gui) + yield gui + qtbot.wait(5) + gui.close() + del gui + qtbot.wait(5) + + @yield_fixture def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, quality, similarity): @@ -40,18 +53,6 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, del mc -@yield_fixture -def gui(qtbot): - gui = GUI(position=(200, 100), size=(500, 500)) - gui.show() - qtbot.waitForWindowShown(gui) - yield gui - qtbot.wait(5) - gui.close() - del gui - qtbot.wait(5) - - #------------------------------------------------------------------------------ # Test GUI component #------------------------------------------------------------------------------ @@ -64,8 +65,7 @@ def test_manual_clustering_plugin(qtbot, gui): features=np.zeros((3, 1, 2)), masks=.75 * np.ones((3, 1)), ) - state = Bunch() - gui.register(model=model, state=state) + gui.register(model=model) ManualClusteringPlugin().attach_to_gui(gui) diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py index 293695dc9..5041ddc79 100644 --- a/phy/cluster/manual/tests/test_store.py +++ b/phy/cluster/manual/tests/test_store.py @@ -25,7 +25,7 @@ def f(x): return x * x assert cs.f(3) == 9 - assert cs.f(3) == 9 + assert cs.get('f')(3) == 9 def test_create_cluster_store(model): diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 4839d3179..84161782a 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,6 +14,7 @@ from pytest import raises from vispy.util import keys +from phy.utils import Bunch from phy.gui import create_gui, GUIState from phy.io.mock import artificial_traces from ..views import (TraceView, _extract_wave, _selected_clusters_colors, @@ -41,6 +42,9 @@ def _test_view(view_name, model=None, tempdir=None): state.set_view_params('TraceView1', box_size=(1., .01)) state.set_view_params('FeatureView1', feature_scaling=.5) state.set_view_params('CorrelogramView1', uniform_normalization=True) + # quality and similarity functions for the cluster view. + state.ClusterView = Bunch(quality='max_waveform_amplitude', + similarity='most_similar_clusters') state.save() # Create the GUI. From bc0ec2599bf7a98d7f9916d95efae1514c7fc6f5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 8 Jan 2016 19:17:31 +0100 Subject: [PATCH 0914/1059] WIP: refactor views --- phy/cluster/manual/__init__.py | 1 - phy/cluster/manual/store.py | 324 ------------------------- phy/cluster/manual/tests/conftest.py | 56 +---- phy/cluster/manual/tests/test_store.py | 92 ------- phy/cluster/manual/tests/test_views.py | 224 ++++++++++++++++- phy/io/array.py | 53 ++++ phy/io/store.py | 103 ++++++++ phy/io/tests/test_store.py | 26 ++ phy/plot/plot.py | 51 +--- phy/traces/tests/test_spike_detect.py | 13 +- phy/utils/plugin.py | 3 +- phy/utils/tests/test_plugin.py | 2 +- 12 files changed, 405 insertions(+), 543 deletions(-) delete mode 100644 phy/cluster/manual/store.py delete mode 100644 phy/cluster/manual/tests/test_store.py create mode 100644 phy/io/store.py create mode 100644 phy/io/tests/test_store.py diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 8e27fa803..d6b871021 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -6,5 +6,4 @@ from ._utils import ClusterMeta from .clustering import Clustering from .gui_component import ManualClustering -from .store import ClusterStore, create_cluster_store, get_closest_clusters from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/store.py b/phy/cluster/manual/store.py deleted file mode 100644 index 6ac1a02e4..000000000 --- a/phy/cluster/manual/store.py +++ /dev/null @@ -1,324 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Manual clustering GUI component.""" - - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -from functools import wraps -import logging -from operator import itemgetter - -import numpy as np - -from phy.io.array import Selector -from phy.plot.plot import _accumulate -from phy.stats.clusters import (mean, - get_max_waveform_amplitude, - get_mean_masked_features_distance, - get_unmasked_channels, - get_sorted_main_channels, - ) -from phy.utils import Bunch, IPlugin, _as_scalar, _as_scalars - -logger = logging.getLogger(__name__) - - -# ----------------------------------------------------------------------------- -# Utils -# ----------------------------------------------------------------------------- - -def _get_data_lim(arr, n_spikes=None, percentile=None): - n = arr.shape[0] - k = max(1, n // n_spikes) if n_spikes else 1 - arr = np.abs(arr[::k]) - n = arr.shape[0] - arr = arr.reshape((n, -1)) - return arr.max() - - -def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): - """Return a list of pairs `(cluster, similarity)` sorted by decreasing - similarity to a given cluster.""" - l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) - for candidate in _as_scalars(cluster_ids)] - l = sorted(l, key=itemgetter(1), reverse=True) - return l[:max_n] - - -def _log(f): - @wraps(f) - def wrapped(*args, **kwargs): - logger.log(5, "Compute %s(%s).", f.__name__, str(args)) - return f(*args, **kwargs) - return wrapped - - -def _concat(f): - """Take a function accepting a single cluster, and return a function - accepting multiple clusters.""" - @wraps(f) - def wrapped(cluster_ids): - # Single cluster. - if not hasattr(cluster_ids, '__len__'): - return f(cluster_ids) - # Concatenate the result of multiple clusters. - return Bunch(_accumulate([f(c) for c in cluster_ids])) - return wrapped - - -# ----------------------------------------------------------------------------- -# Cluster statistics -# ----------------------------------------------------------------------------- - -class ClusterStore(object): - def __init__(self, context=None): - self.context = context - self._stats = {} - - def add(self, f=None, name=None, cache='disk', concat=None): - """Add a cluster statistic. - - Parameters - ---------- - f : function - name : str - cache : str - Can be `None` (no cache), `disk`, or `memory`. In the latter case - the function will also be cached on disk. - - """ - if f is None: - return lambda _: self.add(_, name=name, cache=cache, concat=concat) - name = name or f.__name__ - if cache and self.context: - f = _log(f) - f = self.context.cache(f, memcache=(cache == 'memory')) - assert f - if concat: - f = _concat(f) - self._stats[name] = f - setattr(self, name, f) - return f - - def get(self, name): - return self._stats[name] - - def attach(self, gui): - gui.register(self, name='cluster_store') - - -def create_cluster_store(model, selector=None, context=None): - cs = ClusterStore(context=context) - - # TODO: make this configurable. - max_n_spikes_per_cluster = { - 'masks': 1000, - 'features': 1000, - 'background_features_masks': 1000, - 'waveforms': 100, - 'waveform_lim': 1000, # used to compute the waveform bounds - 'feature_lim': 1000, # used to compute the waveform bounds - 'mean_traces': 10000, - } - max_n_similar_clusters = 20 - - def select(cluster_id, n=None): - assert isinstance(cluster_id, int) - assert cluster_id >= 0 - return selector.select_spikes([cluster_id], max_n_spikes_per_cluster=n) - - def _get_data(**kwargs): - kwargs['spike_clusters'] = model.spike_clusters[kwargs['spike_ids']] - return Bunch(**kwargs) - - # Model data. - # ------------------------------------------------------------------------- - - @cs.add(concat=True) - def masks(cluster_id): - spike_ids = select(cluster_id, max_n_spikes_per_cluster['masks']) - if model.masks is None: - masks = np.ones((len(spike_ids), model.n_channels)) - else: - masks = np.atleast_2d(model.masks[spike_ids]) - assert masks.ndim == 2 - return _get_data(spike_ids=spike_ids, - masks=masks, - ) - - @cs.add(concat=True) - def features_masks(cluster_id): - spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) - fm = np.atleast_3d(model.features_masks[spike_ids]) - ns = fm.shape[0] - nc = model.n_channels - nfpc = model.n_features_per_channel - assert fm.ndim == 3 - f = fm[..., 0].reshape((ns, nc, nfpc)) - m = fm[:, ::nfpc, 1] - return _get_data(spike_ids=spike_ids, - features=f, - masks=m, - ) - - @cs.add(concat=True) - def features(cluster_id): - spike_ids = select(cluster_id, max_n_spikes_per_cluster['features']) - if model.features is None: - features = np.zeros((len(spike_ids), - model.n_channels, - model.n_features_per_channel, - )) - else: - features = np.atleast_2d(model.features[spike_ids]) - assert features.ndim == 3 - return _get_data(spike_ids=spike_ids, - features=features, - ) - - @cs.add - def feature_lim(): - """Return the max of a subset of the feature amplitudes.""" - return _get_data_lim(model.features, - max_n_spikes_per_cluster['feature_lim']) - - @cs.add - def background_features_masks(): - n = max_n_spikes_per_cluster['background_features_masks'] - k = max(1, model.n_spikes // n) - features = model.features[::k] - masks = model.masks[::k] - spike_ids = np.arange(0, model.n_spikes, k) - assert spike_ids.shape == (features.shape[0],) - assert features.ndim == 3 - assert masks.ndim == 2 - assert masks.shape[0] == features.shape[0] - return _get_data(spike_ids=spike_ids, - features=features, - masks=masks, - ) - - @cs.add(concat=True) - def waveforms(cluster_id): - spike_ids = select(cluster_id, - max_n_spikes_per_cluster['waveforms']) - waveforms = np.atleast_2d(model.waveforms[spike_ids]) - assert waveforms.ndim == 3 - return _get_data(spike_ids=spike_ids, - waveforms=waveforms, - ) - - @cs.add - def waveform_lim(): - """Return the max of a subset of the waveform amplitudes.""" - return _get_data_lim(model.waveforms, - max_n_spikes_per_cluster['waveform_lim']) - - @cs.add(concat=True) - def waveforms_masks(cluster_id): - spike_ids = select(cluster_id, - max_n_spikes_per_cluster['waveforms']) - waveforms = np.atleast_2d(model.waveforms[spike_ids]) - assert waveforms.ndim == 3 - masks = np.atleast_2d(model.masks[spike_ids]) - assert masks.ndim == 2 - # Ensure that both arrays have the same number of channels. - assert masks.shape[1] == waveforms.shape[2] - return _get_data(spike_ids=spike_ids, - waveforms=waveforms, - masks=masks, - ) - - # Mean quantities. - # ------------------------------------------------------------------------- - - @cs.add - def mean_masks(cluster_id): - # We access [1] because we return spike_ids, masks. - return mean(cs.masks(cluster_id).masks) - - @cs.add - def mean_features(cluster_id): - return mean(cs.features(cluster_id).features) - - @cs.add - def mean_waveforms(cluster_id): - return mean(cs.waveforms(cluster_id).waveforms) - - # Statistics. - # ------------------------------------------------------------------------- - - @cs.add(cache='memory') - def best_channels(cluster_id): - mm = cs.mean_masks(cluster_id) - uch = get_unmasked_channels(mm) - return get_sorted_main_channels(mm, uch) - - @cs.add(cache='memory') - def best_channels_multiple(cluster_ids): - best_channels = [] - for cluster in cluster_ids: - channels = cs.best_channels(cluster) - best_channels.extend([ch for ch in channels - if ch not in best_channels]) - return best_channels - - @cs.add(cache='memory') - def max_waveform_amplitude(cluster_id): - mm = cs.mean_masks(cluster_id) - mw = cs.mean_waveforms(cluster_id) - assert mw.ndim == 2 - return np.asscalar(get_max_waveform_amplitude(mm, mw)) - - @cs.add(cache=None) - def mean_masked_features_score(cluster_0, cluster_1): - mf0 = cs.mean_features(cluster_0) - mf1 = cs.mean_features(cluster_1) - mm0 = cs.mean_masks(cluster_0) - mm1 = cs.mean_masks(cluster_1) - nfpc = model.n_features_per_channel - d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, - n_features_per_channel=nfpc) - s = 1. / max(1e-10, d) - return s - - @cs.add(cache='memory') - def most_similar_clusters(cluster_id): - assert isinstance(cluster_id, int) - return get_closest_clusters(cluster_id, model.cluster_ids, - cs.mean_masked_features_score, - max_n_similar_clusters) - - # Traces. - # ------------------------------------------------------------------------- - - @cs.add - def mean_traces(): - n = max_n_spikes_per_cluster['mean_traces'] - mt = model.traces[:n, :].mean(axis=0) - return mt.astype(model.traces.dtype) - - return cs - - -class ClusterStorePlugin(IPlugin): - def attach_to_gui(self, gui): - ctx = gui.request('context') - model = gui.request('model') - - # NOTE: we get the spikes_per_cluster from the Clustering instance. - # We need to access it from a function to avoid circular dependencies - # between the cluster store and manual clustering plugins. - def spikes_per_cluster(cluster_id): - mc = gui.request('manual_clustering') - return mc.clustering.spikes_per_cluster[cluster_id] - - assert ctx - selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=spikes_per_cluster, - ) - cs = create_cluster_store(model, selector=selector, context=ctx) - cs.attach(gui) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 1c7423bde..aaa648936 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -6,22 +6,9 @@ # Imports #------------------------------------------------------------------------------ -import os.path as op - -import numpy as np from pytest import yield_fixture -from phy.electrode.mea import staggered_positions -from phy.io.array import _spikes_per_cluster -from phy.io.mock import (artificial_waveforms, - artificial_features, - artificial_spike_clusters, - artificial_spike_samples, - artificial_masks, - artificial_traces, - ) -from phy.utils import Bunch -from phy.cluster.manual.store import get_closest_clusters +from phy.io.store import get_closest_clusters #------------------------------------------------------------------------------ @@ -48,44 +35,3 @@ def quality(): def similarity(cluster_ids): sim = lambda c, d: (c * 1.01 + d) yield lambda c: get_closest_clusters(c, cluster_ids, sim) - - -@yield_fixture -def model(tempdir): - model = Bunch() - - n_spikes = 51 - n_samples_w = 31 - n_samples_t = 20000 - n_channels = 11 - n_clusters = 3 - n_features = 4 - - model.path = op.join(tempdir, 'test') - model.n_channels = n_channels - # TODO: test with permutation and dead channels - model.channel_order = None - model.n_spikes = n_spikes - model.sample_rate = 20000. - model.duration = n_samples_t / float(model.sample_rate) - model.spike_times = artificial_spike_samples(n_spikes) * 1. - model.spike_times /= model.spike_times[-1] - model.spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - model.cluster_ids = np.unique(model.spike_clusters) - model.channel_positions = staggered_positions(n_channels) - model.waveforms = artificial_waveforms(n_spikes, n_samples_w, n_channels) - model.masks = artificial_masks(n_spikes, n_channels) - model.traces = artificial_traces(n_samples_t, n_channels) - model.features = artificial_features(n_spikes, n_channels, n_features) - - # features_masks array - f = model.features.reshape((n_spikes, -1)) - m = np.repeat(model.masks, n_features, axis=1) - model.features_masks = np.dstack((f, m)) - - model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) - model.n_features_per_channel = n_features - model.n_samples_waveforms = n_samples_w - model.cluster_groups = {c: None for c in range(n_clusters)} - - yield model diff --git a/phy/cluster/manual/tests/test_store.py b/phy/cluster/manual/tests/test_store.py deleted file mode 100644 index 5041ddc79..000000000 --- a/phy/cluster/manual/tests/test_store.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test GUI component.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..store import create_cluster_store, ClusterStore -from phy.io import Context, Selector - - -#------------------------------------------------------------------------------ -# Test cluster stats -#------------------------------------------------------------------------------ - -def test_cluster_store(tempdir): - context = Context(tempdir) - cs = ClusterStore(context=context) - - @cs.add(cache='memory') - def f(x): - return x * x - - assert cs.f(3) == 9 - assert cs.get('f')(3) == 9 - - -def test_create_cluster_store(model): - spc = lambda c: model.spikes_per_cluster[c] - selector = Selector(spike_clusters=model.spike_clusters, - spikes_per_cluster=spc) - cs = create_cluster_store(model, selector=selector) - - nc = model.n_channels - nfpc = model.n_features_per_channel - ns = len(model.spikes_per_cluster[1]) - ns2 = len(model.spikes_per_cluster[2]) - nsw = model.n_samples_waveforms - - def _check(out, name, *shape): - spikes = out.pop('spike_ids') - arr = out[name] - assert spikes.shape[0] == shape[0] - assert arr.shape == shape - - # Model data. - _check(cs.masks(1), 'masks', ns, nc) - _check(cs.features(1), 'features', ns, nc, nfpc) - _check(cs.waveforms(1), 'waveforms', ns, nsw, nc) - _check(cs.waveforms_masks(1), 'waveforms', ns, nsw, nc) - _check(cs.waveforms_masks(1), 'masks', ns, nc) - - # Background feature masks. - data = cs.background_features_masks() - spike_ids = data.spike_ids - bgf = data.features - bgm = data.masks - assert bgf.ndim == 3 - assert bgf.shape[1:] == (nc, nfpc) - assert bgm.ndim == 2 - assert bgm.shape[1] == nc - assert spike_ids.shape == (bgf.shape[0],) == (bgm.shape[0],) - - # Test concat multiple clusters. - data = cs.features_masks([1, 2]) - spike_ids = data.spike_ids - f = data.features - m = data.masks - assert len(spike_ids) == ns + ns2 - assert f.shape == (ns + ns2, nc, nfpc) - assert m.shape == (ns + ns2, nc) - - # Test means. - assert cs.mean_masks(1).shape == (nc,) - assert cs.mean_features(1).shape == (nc, nfpc) - assert cs.mean_waveforms(1).shape == (nsw, nc) - - # Limits. - assert 0 < cs.waveform_lim() < 3 - assert 0 < cs.feature_lim() < 3 - assert cs.mean_traces().shape == (nc,) - - # Statistics. - assert 1 <= len(cs.best_channels(1)) <= nc - assert 1 <= len(cs.best_channels_multiple([1, 2])) <= nc - assert 0 < cs.max_waveform_amplitude(1) < 1 - assert cs.mean_masked_features_score(1, 2) > 0 - - assert np.array(cs.most_similar_clusters(1)).shape == (3, 2) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 84161782a..64e2b7be0 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,13 +14,201 @@ from pytest import raises from vispy.util import keys -from phy.utils import Bunch +from phy.electrode.mea import staggered_positions from phy.gui import create_gui, GUIState -from phy.io.mock import artificial_traces +from phy.io.array import _spikes_per_cluster +from phy.io.mock import (artificial_waveforms, + artificial_features, + artificial_masks, + artificial_traces, + ) +from phy.io.store import ClusterStore, get_closest_clusters +from phy.stats.clusters import (mean, + get_max_waveform_amplitude, + get_mean_masked_features_distance, + get_unmasked_channels, + get_sorted_main_channels, + ) +from phy.utils import Bunch, IPlugin from ..views import (TraceView, _extract_wave, _selected_clusters_colors, _extend) +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +def create_model(): + model = Bunch() + + n_samples_waveforms = 31 + n_samples_t = 20000 + n_channels = 11 + n_clusters = 3 + model.n_spikes_per_cluster = 51 + n_spikes_total = n_clusters * model.n_spikes_per_cluster + n_features_per_channel = 4 + + model.path = '' + model.n_channels = n_channels + model.n_spikes = n_spikes_total + model.sample_rate = 20000. + model.duration = n_samples_t / float(model.sample_rate) + model.spike_times = np.linspace(0., model.duration, n_spikes_total) + model.spike_clusters = np.repeat(np.arange(n_clusters), + model.n_spikes_per_cluster) + model.cluster_ids = np.unique(model.spike_clusters) + model.channel_positions = staggered_positions(n_channels) + model.traces = artificial_traces(n_samples_t, n_channels) + model.masks = artificial_masks(n_spikes_total, n_channels) + + model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) + model.n_features_per_channel = n_features_per_channel + model.n_samples_waveforms = n_samples_waveforms + model.cluster_groups = {c: None for c in range(n_clusters)} + + return model + + +def create_cluster_store(model): + cs = ClusterStore() + + def get_waveforms(n): + return artificial_waveforms(n, + model.n_samples_waveforms, + model.n_channels) + + def get_masks(n): + return artificial_masks(n, model.n_channels) + + def get_features(n): + return artificial_features(n, + model.n_channels, + model.n_features_per_channel) + + def get_spike_ids(cluster_id): + n = model.n_spikes_per_cluster + return np.arange(n) + n * cluster_id + + def _get_data(**kwargs): + kwargs['spike_clusters'] = model.spike_clusters[kwargs['spike_ids']] + return Bunch(**kwargs) + + @cs.add(concat=True) + def masks(cluster_id): + return _get_data(spike_ids=get_spike_ids(cluster_id), + masks=get_masks(model.n_spikes_per_cluster)) + + @cs.add(concat=True) + def features(cluster_id): + return _get_data(spike_ids=get_spike_ids(cluster_id), + features=get_features(model.n_spikes_per_cluster)) + + @cs.add(concat=True) + def features_masks(cluster_id): + return _get_data(spike_ids=get_spike_ids(cluster_id), + features=get_features(model.n_spikes_per_cluster), + masks=get_masks(model.n_spikes_per_cluster)) + + @cs.add + def feature_lim(): + """Return the max of a subset of the feature amplitudes.""" + return 1 + + @cs.add + def background_features_masks(): + f = get_features(model.n_spikes) + m = model.masks + return _get_data(spike_ids=np.arange(model.n_spikes), + features=f, masks=m) + + @cs.add(concat=True) + def waveforms(cluster_id): + return _get_data(spike_ids=get_spike_ids(cluster_id), + waveforms=get_waveforms(model.n_spikes_per_cluster)) + + @cs.add + def waveform_lim(): + """Return the max of a subset of the waveform amplitudes.""" + return 1 + + @cs.add(concat=True) + def waveforms_masks(cluster_id): + return _get_data(spike_ids=get_spike_ids(cluster_id), + waveforms=get_waveforms(model.n_spikes_per_cluster), + masks=get_masks(model.n_spikes_per_cluster), + ) + + # Mean quantities. + # ------------------------------------------------------------------------- + + @cs.add + def mean_masks(cluster_id): + # We access [1] because we return spike_ids, masks. + return mean(cs.masks(cluster_id).masks) + + @cs.add + def mean_features(cluster_id): + return mean(cs.features(cluster_id).features) + + @cs.add + def mean_waveforms(cluster_id): + return mean(cs.waveforms(cluster_id).waveforms) + + # Statistics. + # ------------------------------------------------------------------------- + + @cs.add(cache='memory') + def best_channels(cluster_id): + mm = cs.mean_masks(cluster_id) + uch = get_unmasked_channels(mm) + return get_sorted_main_channels(mm, uch) + + @cs.add(cache='memory') + def best_channels_multiple(cluster_ids): + best_channels = [] + for cluster in cluster_ids: + channels = cs.best_channels(cluster) + best_channels.extend([ch for ch in channels + if ch not in best_channels]) + return best_channels + + @cs.add(cache='memory') + def max_waveform_amplitude(cluster_id): + mm = cs.mean_masks(cluster_id) + mw = cs.mean_waveforms(cluster_id) + assert mw.ndim == 2 + return np.asscalar(get_max_waveform_amplitude(mm, mw)) + + @cs.add(cache=None) + def mean_masked_features_score(cluster_0, cluster_1): + mf0 = cs.mean_features(cluster_0) + mf1 = cs.mean_features(cluster_1) + mm0 = cs.mean_masks(cluster_0) + mm1 = cs.mean_masks(cluster_1) + nfpc = model.n_features_per_channel + d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, + n_features_per_channel=nfpc) + s = 1. / max(1e-10, d) + return s + + @cs.add(cache='memory') + def most_similar_clusters(cluster_id): + assert isinstance(cluster_id, int) + return get_closest_clusters(cluster_id, model.cluster_ids, + cs.mean_masked_features_score) + + # Traces. + # ------------------------------------------------------------------------- + + @cs.add + def mean_traces(): + mt = model.traces[:, :].mean(axis=0) + return mt.astype(model.traces.dtype) + + return cs + + #------------------------------------------------------------------------------ # Utils #------------------------------------------------------------------------------ @@ -34,7 +222,14 @@ def _show(qtbot, view, stop=False): @contextmanager -def _test_view(view_name, model=None, tempdir=None): +def _test_view(view_name, tempdir=None): + + model = create_model() + + class ClusterStorePlugin(IPlugin): + def attach_to_gui(self, gui): + cs = create_cluster_store(model) + cs.attach(gui) # Save a test GUI state JSON file in the tempdir. state = GUIState(config_dir=tempdir) @@ -63,6 +258,7 @@ def _test_view(view_name, model=None, tempdir=None): view = gui.list_views(view_name)[0] view.gui = gui + view.model = model # HACK yield view gui.close() @@ -116,8 +312,8 @@ def test_selected_clusters_colors(): # Test waveform view #------------------------------------------------------------------------------ -def test_waveform_view(qtbot, model, tempdir): - with _test_view('WaveformView', model=model, tempdir=tempdir) as v: +def test_waveform_view(qtbot, tempdir): + with _test_view('WaveformView', tempdir=tempdir) as v: ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) v.toggle_waveform_overlap() v.toggle_waveform_overlap() @@ -179,8 +375,8 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def test_trace_view_spikes(qtbot, model, tempdir): - with _test_view('TraceView', model=model, tempdir=tempdir) as v: +def test_trace_view_spikes(qtbot, tempdir): + with _test_view('TraceView', tempdir=tempdir) as v: ac(v.stacked.box_size, (1., .08181), atol=1e-3) assert v.time == .25 @@ -205,7 +401,7 @@ def test_trace_view_spikes(qtbot, model, tempdir): ac(v.interval, (.25, .75)) # Widen the max interval. - v.set_interval((0, model.duration)) + v.set_interval((0, v.model.duration)) v.widen() # Change channel scaling. @@ -221,10 +417,11 @@ def test_trace_view_spikes(qtbot, model, tempdir): # Test feature view #------------------------------------------------------------------------------ -def test_feature_view(qtbot, model, tempdir): - with _test_view('FeatureView', model=model, tempdir=tempdir) as v: +def test_feature_view(qtbot, tempdir): + with _test_view('FeatureView', tempdir=tempdir) as v: assert v.feature_scaling == .5 - v.add_attribute('sine', np.sin(np.linspace(-10., 10., model.n_spikes))) + v.add_attribute('sine', + np.sin(np.linspace(-10., 10., v.model.n_spikes))) v.increase() v.decrease() @@ -240,10 +437,11 @@ def test_feature_view(qtbot, model, tempdir): # Test correlogram view #------------------------------------------------------------------------------ -def test_correlogram_view(qtbot, model, tempdir): - with _test_view('CorrelogramView', model=model, tempdir=tempdir) as v: +def test_correlogram_view(qtbot, tempdir): + with _test_view('CorrelogramView', tempdir=tempdir) as v: v.toggle_normalization() v.set_bin(1) v.set_window(100) + # qtbot.stop() diff --git a/phy/io/array.py b/phy/io/array.py index ccea62157..0a6ce9fae 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +from collections import defaultdict import logging import math from math import floor, exp @@ -450,3 +451,55 @@ def select_spikes(self, cluster_ids=None, return select_spikes(cluster_ids, spikes_per_cluster=self.spikes_per_cluster, max_n_spikes_per_cluster=ns) + + +# ----------------------------------------------------------------------------- +# Accumulator +# ----------------------------------------------------------------------------- + +def _flatten(l): + return [item for sublist in l for item in sublist] + + +class Accumulator(object): + """Accumulate arrays for concatenation.""" + def __init__(self): + self._data = defaultdict(list) + + def add(self, name, val): + """Add an array.""" + self._data[name].append(val) + + def get(self, name): + """Return the list of arrays for a given name.""" + return _flatten(self._data[name]) + + @property + def names(self): + """List of names.""" + return set(self._data) + + def __getitem__(self, name): + """Concatenate all arrays with a given name.""" + return np.concatenate(self._data[name], axis=0). \ + astype(self._data[name][0].dtype) + + +def _accumulate(data_list, no_concat=()): + """Concatenate a list of dicts `(name, array)`. + + You can specify some names which arrays should not be concatenated. + This is necessary with lists of plots with different sizes. + + """ + acc = Accumulator() + for data in data_list: + for name, val in data.items(): + acc.add(name, val) + out = {name: acc[name] for name in acc.names if name not in no_concat} + + # Some variables should not be concatenated but should be kept as lists. + # This is when there can be several arrays of variable length (NumPy + # doesn't support ragged arrays). + out.update({name: acc.get(name) for name in no_concat}) + return out diff --git a/phy/io/store.py b/phy/io/store.py new file mode 100644 index 000000000..fdf886461 --- /dev/null +++ b/phy/io/store.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +"""Cluster store.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +from functools import wraps +import logging +from operator import itemgetter + +import numpy as np + +from .array import _accumulate +from phy.utils import Bunch, _as_scalar, _as_scalars + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Utils +# ----------------------------------------------------------------------------- + +def _get_data_lim(arr, n_spikes=None, percentile=None): + n = arr.shape[0] + k = max(1, n // n_spikes) if n_spikes else 1 + arr = np.abs(arr[::k]) + n = arr.shape[0] + arr = arr.reshape((n, -1)) + return arr.max() + + +def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): + """Return a list of pairs `(cluster, similarity)` sorted by decreasing + similarity to a given cluster.""" + l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) + for candidate in _as_scalars(cluster_ids)] + l = sorted(l, key=itemgetter(1), reverse=True) + return l[:max_n] + + +def _log(f): + @wraps(f) + def wrapped(*args, **kwargs): + logger.log(5, "Compute %s(%s).", f.__name__, str(args)) + return f(*args, **kwargs) + return wrapped + + +def _concat(f): + """Take a function accepting a single cluster, and return a function + accepting multiple clusters.""" + @wraps(f) + def wrapped(cluster_ids): + # Single cluster. + if not hasattr(cluster_ids, '__len__'): + return f(cluster_ids) + # Concatenate the result of multiple clusters. + return Bunch(_accumulate([f(c) for c in cluster_ids])) + return wrapped + + +# ----------------------------------------------------------------------------- +# Cluster statistics +# ----------------------------------------------------------------------------- + +class ClusterStore(object): + def __init__(self, context=None): + self.context = context + self._stats = {} + + def add(self, f=None, name=None, cache='disk', concat=None): + """Add a cluster statistic. + + Parameters + ---------- + f : function + name : str + cache : str + Can be `None` (no cache), `disk`, or `memory`. In the latter case + the function will also be cached on disk. + + """ + if f is None: + return lambda _: self.add(_, name=name, cache=cache, concat=concat) + name = name or f.__name__ + if cache and self.context: + f = _log(f) + f = self.context.cache(f, memcache=(cache == 'memory')) + assert f + if concat: + f = _concat(f) + self._stats[name] = f + setattr(self, name, f) + return f + + def get(self, name): + return self._stats[name] + + def attach(self, gui): + gui.register(self, name='cluster_store') diff --git a/phy/io/tests/test_store.py b/phy/io/tests/test_store.py new file mode 100644 index 000000000..3b5560e43 --- /dev/null +++ b/phy/io/tests/test_store.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +"""Test cluster store.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from ..store import ClusterStore +from phy.io import Context + + +#------------------------------------------------------------------------------ +# Test cluster stats +#------------------------------------------------------------------------------ + +def test_cluster_store(tempdir): + context = Context(tempdir) + cs = ClusterStore(context=context) + + @cs.add(cache='memory') + def f(x): + return x * x + + assert cs.f(3) == 9 + assert cs.get('f')(3) == 9 diff --git a/phy/plot/plot.py b/phy/plot/plot.py index b23cde259..54417a915 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -7,11 +7,12 @@ # Imports #------------------------------------------------------------------------------ -from collections import defaultdict, OrderedDict +from collections import OrderedDict from contextlib import contextmanager import numpy as np +from phy.io.array import _accumulate from .base import BaseCanvas from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom @@ -24,54 +25,6 @@ # Utils #------------------------------------------------------------------------------ -def _flatten(l): - return [item for sublist in l for item in sublist] - - -class Accumulator(object): - """Accumulate arrays for concatenation.""" - def __init__(self): - self._data = defaultdict(list) - - def add(self, name, val): - """Add an array.""" - self._data[name].append(val) - - def get(self, name): - """Return the list of arrays for a given name.""" - return _flatten(self._data[name]) - - @property - def names(self): - """List of names.""" - return set(self._data) - - def __getitem__(self, name): - """Concatenate all arrays with a given name.""" - return np.concatenate(self._data[name], axis=0). \ - astype(self._data[name][0].dtype) - - -def _accumulate(data_list, no_concat=()): - """Concatenate a list of dicts `(name, array)`. - - You can specify some names which arrays should not be concatenated. - This is necessary with lists of plots with different sizes. - - """ - acc = Accumulator() - for data in data_list: - for name, val in data.items(): - acc.add(name, val) - out = {name: acc[name] for name in acc.names if name not in no_concat} - - # Some variables should not be concatenated but should be kept as lists. - # This is when there can be several arrays of variable length (NumPy - # doesn't support ragged arrays). - out.update({name: acc.get(name) for name in no_concat}) - return out - - # NOTE: we ensure that we only create every type *once*, so that # View._items has only one key for any class. _SCATTER_CLASSES = {} diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index 1424fe21b..e8cb7c3fa 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -8,7 +8,7 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import yield_fixture +from pytest import fixture from phy.io.datasets import download_test_data from phy.io.tests.test_context import (ipy_client, context, # noqa @@ -26,16 +26,16 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture +@fixture def traces(): path = download_test_data('test-32ch-10s.dat') traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) traces = traces[:20000] - yield traces + return traces -@yield_fixture(params=[(True,), (False,)]) +@fixture(params=[(True,), (False,)]) def spike_detector(request): remap = request.param[0] @@ -50,7 +50,7 @@ def spike_detector(request): site_label_to_traces_row=site_label_to_traces_row, sample_rate=sample_rate) - yield sd + return sd #------------------------------------------------------------------------------ @@ -158,7 +158,8 @@ def test_detect_simple(spike_detector, traces): # _plot(sd, traces, spike_samples, masks) -def test_detect_context(spike_detector, traces, parallel_context): # noqa +# NOTE: skip for now to accelerate the test suite... +def _test_detect_context(spike_detector, traces, parallel_context): # noqa sd = spike_detector sd.set_context(parallel_context) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 76d5612b1..0da72720a 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -49,9 +49,8 @@ class IPlugin(with_metaclass(IPluginRegistry)): def get_plugin(name): """Get a plugin class from its name.""" - name = name.lower() for (plugin,) in IPluginRegistry.plugins: - if name in plugin.__name__.lower(): + if name in plugin.__name__: return plugin raise ValueError("The plugin %s cannot be found." % name) diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index 8d81dd9f6..bd53fc8d5 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -67,7 +67,7 @@ class MyPlugin(IPlugin): pass assert IPluginRegistry.plugins == [(MyPlugin,)] - assert get_plugin('myplugin').__name__ == 'MyPlugin' + assert get_plugin('MyPlugin').__name__ == 'MyPlugin' with raises(ValueError): get_plugin('unknown') From dcd03696130ba95b9cdd3af2038914d75183f1b6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 10:04:36 +0100 Subject: [PATCH 0915/1059] Allow asymmetrical waveforms in trace view --- phy/cluster/manual/tests/test_views.py | 16 +++++--------- phy/cluster/manual/views.py | 30 +++++++++++++++++--------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 64e2b7be0..bf59e3cc7 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -11,7 +11,6 @@ import numpy as np from numpy.testing import assert_equal as ae from numpy.testing import assert_allclose as ac -from pytest import raises from vispy.util import keys from phy.electrode.mea import staggered_positions @@ -281,23 +280,18 @@ def test_extract_wave(): traces = np.arange(30).reshape((6, 5)) mask = np.array([0, 1, 1, .5, 0]) wave_len = 4 + hwl = wave_len // 2 - with raises(ValueError): - _extract_wave(traces, -1, mask, wave_len) - - with raises(ValueError): - _extract_wave(traces, 20, mask, wave_len) - - ae(_extract_wave(traces, 0, mask, wave_len)[0], + ae(_extract_wave(traces, 0 - hwl, 0 + hwl, mask, wave_len)[0], [[0, 0, 0], [0, 0, 0], [1, 2, 3], [6, 7, 8]]) - ae(_extract_wave(traces, 1, mask, wave_len)[0], + ae(_extract_wave(traces, 1 - hwl, 1 + hwl, mask, wave_len)[0], [[0, 0, 0], [1, 2, 3], [6, 7, 8], [11, 12, 13]]) - ae(_extract_wave(traces, 2, mask, wave_len)[0], + ae(_extract_wave(traces, 2 - hwl, 2 + hwl, mask, wave_len)[0], [[1, 2, 3], [6, 7, 8], [11, 12, 13], [16, 17, 18]]) - ae(_extract_wave(traces, 5, mask, wave_len)[0], + ae(_extract_wave(traces, 5 - hwl, 5 + hwl, mask, wave_len)[0], [[16, 17, 18], [21, 22, 23], [0, 0, 0], [0, 0, 0]]) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index bed90c062..4a3f7772a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -54,17 +54,14 @@ def _selected_clusters_colors(n_clusters=None): return colors[:n_clusters, ...] / 255. -def _extract_wave(traces, spk, mask, wave_len=None): +def _extract_wave(traces, start, end, mask, wave_len=None): n_samples, n_channels = traces.shape - if not (0 <= spk < n_samples): - raise ValueError() assert mask.shape == (n_channels,) channels = np.nonzero(mask > .1)[0] # There should be at least one non-masked channel. if not len(channels): return # pragma: no cover - i = spk - wave_len // 2 - j = spk + wave_len // 2 + i, j = start, end a, b = max(0, i), min(j, n_samples - 1) data = traces[a:b, channels] data = _get_padded(data, i - a, i - a + wave_len) @@ -631,17 +628,29 @@ def _plot_spike(self, spike_idx, start=None, traces=None, spike_times=None, spike_clusters=None, masks=None, data_bounds=None): - wave_len = self.n_samples_per_spike - dur_spike = wave_len * self.dt + # Can be a tuple or a scalar. + if isinstance(self.n_samples_per_spike, tuple): + wave_len = sum(self.n_samples_per_spike) # in samples + dur_spike = wave_len * self.dt # in seconds + wave_start = self.n_samples_per_spike[0] * self.dt # in seconds + else: + wave_len = self.n_samples_per_spike + dur_spike = wave_len * self.dt + wave_start = -dur_spike * .5 + trace_start = round(self.sample_rate * start) # Find the first x of the spike, relative to the start of # the interval - sample_rel = (round(spike_times[spike_idx] * self.sample_rate) - + spike_start = spike_times[spike_idx] + wave_start + spike_end = spike_times[spike_idx] + wave_start + dur_spike + sample_start = (round(spike_start * self.sample_rate) - + trace_start) + sample_end = (round(spike_end * self.sample_rate) - trace_start) # Extract the waveform from the traces. - w, ch = _extract_wave(traces, sample_rel, + w, ch = _extract_wave(traces, sample_start, sample_end, masks[spike_idx], wave_len) # Determine the color as a function of the spike's cluster. @@ -659,7 +668,7 @@ def _plot_spike(self, spike_idx, start=None, n_clusters=len(self.cluster_ids)) # Generate the x coordinates of the waveform. - t0 = spike_times[spike_idx] - dur_spike / 2. + t0 = spike_times[spike_idx] + wave_start t = t0 + self.dt * np.arange(wave_len) t = np.tile(t, (len(ch), 1)) @@ -824,6 +833,7 @@ def attach_to_gui(self, gui): sample_rate=model.sample_rate, spike_times=model.spike_times, spike_clusters=model.spike_clusters, + n_samples_per_spike=model.n_samples_waveforms, masks=model.masks, scaling=s, mean_traces=cs.mean_traces(), From 423d6ead8f31894b55d20bb28b77f6901abccfb8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 10:18:02 +0100 Subject: [PATCH 0916/1059] Allow asymmetrical waveforms in waveform view --- phy/cluster/manual/views.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 4a3f7772a..9e51918ff 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -263,6 +263,8 @@ def __init__(self, self.n_channels = self.channel_positions.shape[0] # Number of samples per waveform. + if isinstance(n_samples, tuple): + n_samples = sum(n_samples) assert n_samples > 0 self.n_samples = n_samples From bc19e3537e4691a7ad9d4b4cd8e95112ae72ca39 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 10:51:07 +0100 Subject: [PATCH 0917/1059] Minor updates --- phy/cluster/manual/gui_component.py | 12 +++++++----- phy/gui/gui.py | 9 +++++++-- phy/io/context.py | 4 +++- phy/io/store.py | 2 +- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index c6df02e8f..39bebc5e8 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -404,11 +404,13 @@ def attach(self, gui): cs = gui.request('cluster_store') if cs and 'ClusterView' in gui.state: # Names of the quality and similarity functions. - quality = gui.state.ClusterView.quality - similarity = gui.state.ClusterView.similarity - self.cluster_view.add_column(cs.get(quality), name=quality) - self.set_default_sort(quality) - self.set_similarity_func(cs.get(similarity)) + quality = gui.state.ClusterView.get('quality', None) + similarity = gui.state.ClusterView.get('similarity', None) + if quality: + self.cluster_view.add_column(cs.get(quality), name=quality) + self.set_default_sort(quality) + if similarity: + self.set_similarity_func(cs.get(similarity)) # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index e8bf4b4ea..e78473cd0 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -432,7 +432,7 @@ def on_show(): gui.restore_geometry_state(gs) -def create_gui(name=None, subtitle=None, model=None, +def create_gui(name=None, subtitle=None, model=None, state=None, plugins=None, config_dir=None): """Create a GUI with a list of plugins. @@ -445,7 +445,12 @@ def create_gui(name=None, subtitle=None, model=None, plugins = plugins or [] # Load the state. - state = GUIState(gui.name, config_dir=config_dir) + state = state or {} + # Ensure all dicts are Bunches. + for k in state.keys(): + if isinstance(state[k], dict) and not isinstance(state[k], Bunch): + state[k] = Bunch(state[k]) + state = GUIState(gui.name, config_dir=config_dir, **(state or {})) gui.state = state # Register the model. diff --git a/phy/io/context.py b/phy/io/context.py index 38c0ff15f..dcfb1e206 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -329,8 +329,10 @@ def __setstate__(self, state): class ContextPlugin(IPlugin): def attach_to_gui(self, gui): model = gui.request('model') + # Find the path. + path = getattr(model, 'path', '') # Create the computing context. - ctx = Context(op.join(op.dirname(model.path), '.phy/')) + ctx = Context(op.join(op.dirname(path), '.phy/')) gui.register(ctx, name='context') diff --git a/phy/io/store.py b/phy/io/store.py index fdf886461..1ead0fe1f 100644 --- a/phy/io/store.py +++ b/phy/io/store.py @@ -97,7 +97,7 @@ def add(self, f=None, name=None, cache='disk', concat=None): return f def get(self, name): - return self._stats[name] + return self._stats.get(name, None) def attach(self, gui): gui.register(self, name='cluster_store') From ac0619ad69a74aada2fc115cc3a4aeb419d13aca Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 11:30:56 +0100 Subject: [PATCH 0918/1059] Add stacked origin --- phy/cluster/manual/views.py | 5 ++++- phy/plot/interact.py | 5 +++-- phy/plot/plot.py | 4 ++-- phy/plot/tests/test_interact.py | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 9e51918ff..3d8521004 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -526,6 +526,7 @@ def __init__(self, masks=None, # full array of masks n_samples_per_spike=None, scaling=None, + origin=None, mean_traces=None, **kwargs): @@ -570,6 +571,7 @@ def __init__(self, # Initialize the view. super(TraceView, self).__init__(layout='stacked', + origin=origin, n_plots=self.n_channels, **kwargs) # Box and probe scaling. @@ -826,7 +828,7 @@ class TraceViewPlugin(IPlugin): def attach_to_gui(self, gui): state = gui.state model = gui.request('model') - s, = state.get_view_params('TraceView', 'scaling') + s, o = state.get_view_params('TraceView', 'scaling', 'origin') cs = gui.request('cluster_store') assert cs # We need the cluster store to retrieve the data. @@ -837,6 +839,7 @@ def attach_to_gui(self, gui): spike_clusters=model.spike_clusters, n_samples_per_spike=model.n_samples_waveforms, masks=model.masks, + origin=o, scaling=s, mean_traces=cs.mean_traces(), ) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index e7e0a43cd..465e4396f 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -253,7 +253,7 @@ class Stacked(Boxed): Name of the GLSL variable with the box index. """ - def __init__(self, n_boxes, margin=0, box_var=None): + def __init__(self, n_boxes, margin=0, box_var=None, origin=None): # The margin must be in [-1, 1] margin = np.clip(margin, -1, 1) @@ -266,7 +266,8 @@ def __init__(self, n_boxes, margin=0, box_var=None): b[:, 1] = np.linspace(-1, 1 - 2. / n_boxes + margin, n_boxes) b[:, 2] = 1 b[:, 3] = np.linspace(-1 + 2. / n_boxes - margin, 1., n_boxes) - b = b[::-1, :] + if origin == 'upper': + b = b[::-1, :] super(Stacked, self).__init__(b, box_var=box_var, keep_aspect_ratio=False, diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 54417a915..a88b960f1 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -47,7 +47,7 @@ class View(BaseCanvas): """High-level plotting canvas.""" _default_box_index = (0,) - def __init__(self, layout=None, shape=None, n_plots=None, + def __init__(self, layout=None, shape=None, n_plots=None, origin=None, box_bounds=None, box_pos=None, box_size=None, **kwargs): if not kwargs.get('keys', None): kwargs['keys'] = None @@ -68,7 +68,7 @@ def __init__(self, layout=None, shape=None, n_plots=None, elif layout == 'stacked': self.n_plots = n_plots - self.stacked = Stacked(n_plots, margin=.1) + self.stacked = Stacked(n_plots, margin=.1, origin=origin) self.stacked.attach(self) self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index b2d60ccd9..0d72117ee 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -150,7 +150,7 @@ def test_stacked_1(qtbot, canvas): n = 1000 box_index = np.repeat(np.arange(6), n, axis=0) - stacked = Stacked(n_boxes=6, margin=-10) + stacked = Stacked(n_boxes=6, margin=-10, origin='upper') _create_visual(qtbot, canvas, stacked, box_index) # qtbot.stop() From e1d4d5eb6d7238eef00e274e9ce193134300a54a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 13:37:59 +0100 Subject: [PATCH 0919/1059] Remove similarity view if there is no similarity function --- phy/cluster/manual/gui_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 39bebc5e8..6ff9309eb 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -398,7 +398,6 @@ def attach(self, gui): # Add the cluster views. gui.add_view(self.cluster_view, name='ClusterView') - gui.add_view(self.similarity_view, name='SimilarityView') # Add the quality column in the cluster view. cs = gui.request('cluster_store') @@ -411,6 +410,7 @@ def attach(self, gui): self.set_default_sort(quality) if similarity: self.set_similarity_func(cs.get(similarity)) + gui.add_view(self.similarity_view, name='SimilarityView') # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) From f1435938cff9ae4fe093d74eb7d9fd30d719db3a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 14:24:57 +0100 Subject: [PATCH 0920/1059] Toggle zoom on channels in waveform view --- phy/cluster/manual/tests/test_views.py | 4 +++ phy/cluster/manual/views.py | 36 ++++++++++++++++---------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index bf59e3cc7..9fad168c9 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -309,9 +309,13 @@ def test_selected_clusters_colors(): def test_waveform_view(qtbot, tempdir): with _test_view('WaveformView', tempdir=tempdir) as v: ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) + v.toggle_waveform_overlap() v.toggle_waveform_overlap() + v.toggle_zoom_on_channels() + v.toggle_zoom_on_channels() + # Box scaling. bs = v.boxed.box_size v.increase() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 3d8521004..ac68bb85d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -233,6 +233,7 @@ class WaveformView(ManualClusteringView): default_shortcuts = { 'toggle_waveform_overlap': 'o', + 'toggle_zoom_on_channels': 'z', # Box scaling. 'widen': 'ctrl+right', @@ -256,6 +257,7 @@ def __init__(self, waveform_lim=None, **kwargs): self._key_pressed = None + self.do_zoom_on_channels = True # Channel positions and n_channels. assert channel_positions is not None @@ -263,8 +265,8 @@ def __init__(self, self.n_channels = self.channel_positions.shape[0] # Number of samples per waveform. - if isinstance(n_samples, tuple): - n_samples = sum(n_samples) + n_samples = (sum(map(abs, n_samples)) if isinstance(n_samples, tuple) + else n_samples) assert n_samples > 0 self.n_samples = n_samples @@ -300,7 +302,7 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - def on_select(self, cluster_ids=None, zoom_on_channels=True): + def on_select(self, cluster_ids=None): super(WaveformView, self).on_select(cluster_ids) cluster_ids = self.cluster_ids n_clusters = len(cluster_ids) @@ -349,13 +351,14 @@ def on_select(self, cluster_ids=None, zoom_on_channels=True): # Zoom on the best channels when selecting clusters. channels = self._best_channels(cluster_ids) - if channels is not None and zoom_on_channels: + if channels is not None and self.do_zoom_on_channels: self.zoom_on_channels(channels) def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap) + self.actions.add(self.toggle_zoom_on_channels) # Box scaling. self.actions.add(self.widen) @@ -381,7 +384,12 @@ def on_channel_click(e): def toggle_waveform_overlap(self): """Toggle the overlap of the waveforms.""" self.overlap = not self.overlap - self.on_select(zoom_on_channels=False) + tmp = self.do_zoom_on_channels + self.on_select() + self.do_zoom_on_channels = tmp + + def toggle_zoom_on_channels(self): + self.do_zoom_on_channels = not self.do_zoom_on_channels # Box scaling # ------------------------------------------------------------------------- @@ -549,6 +557,12 @@ def __init__(self, self.n_samples_per_spike = (n_samples_per_spike or round(.002 * sample_rate)) + # Can be a tuple or a scalar. + if not isinstance(self.n_samples_per_spike, tuple): + ns = self.n_samples_per_spike + self.n_samples_per_spike = (-ns // 2, ns // 2) + # Now n_samples_per_spike is a tuple. + # Spike times. if spike_times is not None: spike_times = np.asarray(spike_times) @@ -632,15 +646,9 @@ def _plot_spike(self, spike_idx, start=None, traces=None, spike_times=None, spike_clusters=None, masks=None, data_bounds=None): - # Can be a tuple or a scalar. - if isinstance(self.n_samples_per_spike, tuple): - wave_len = sum(self.n_samples_per_spike) # in samples - dur_spike = wave_len * self.dt # in seconds - wave_start = self.n_samples_per_spike[0] * self.dt # in seconds - else: - wave_len = self.n_samples_per_spike - dur_spike = wave_len * self.dt - wave_start = -dur_spike * .5 + wave_len = sum(map(abs, self.n_samples_per_spike)) # in samples + dur_spike = wave_len * self.dt # in seconds + wave_start = self.n_samples_per_spike[0] * self.dt # in seconds trace_start = round(self.sample_rate * start) From ce0dce3e114dbc028e3584de8c00f8623f4dc9fc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sat, 9 Jan 2016 15:10:45 +0100 Subject: [PATCH 0921/1059] Show waveform means --- phy/cluster/manual/tests/test_views.py | 13 +++++++++-- phy/cluster/manual/views.py | 30 ++++++++++++++++++++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 9fad168c9..cc5c03f97 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -133,9 +133,15 @@ def waveform_lim(): @cs.add(concat=True) def waveforms_masks(cluster_id): + w = get_waveforms(model.n_spikes_per_cluster) + m = get_masks(model.n_spikes_per_cluster) + mw = cs.mean_waveforms(cluster_id)[np.newaxis, ...] + mm = cs.mean_masks(cluster_id)[np.newaxis, ...] return _get_data(spike_ids=get_spike_ids(cluster_id), - waveforms=get_waveforms(model.n_spikes_per_cluster), - masks=get_masks(model.n_spikes_per_cluster), + waveforms=w, + masks=m, + mean_waveforms=mw, + mean_masks=mm, ) # Mean quantities. @@ -313,6 +319,9 @@ def test_waveform_view(qtbot, tempdir): v.toggle_waveform_overlap() v.toggle_waveform_overlap() + v.toggle_show_means() + v.toggle_show_means() + v.toggle_zoom_on_channels() v.toggle_zoom_on_channels() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index ac68bb85d..cd83af110 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -85,13 +85,13 @@ def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): return depth -def _get_color(masks, spike_clusters_rel=None, n_clusters=None): +def _get_color(masks, spike_clusters_rel=None, n_clusters=None, alpha=.5): """Return the color of vertices as a function of the mask and cluster index.""" n_spikes = masks.shape[0] # The transparency depends on whether the spike clusters are specified. # For background spikes, we use a smaller alpha. - alpha = .5 if spike_clusters_rel is not None else .25 + alpha = alpha if spike_clusters_rel is not None else .25 assert masks.shape == (n_spikes,) # Generate the colors. colors = _selected_clusters_colors(n_clusters) @@ -234,6 +234,7 @@ class WaveformView(ManualClusteringView): default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_zoom_on_channels': 'z', + 'toggle_show_means': 'm', # Box scaling. 'widen': 'ctrl+right', @@ -257,6 +258,7 @@ def __init__(self, waveform_lim=None, **kwargs): self._key_pressed = None + self.do_show_means = False self.do_zoom_on_channels = True # Channel positions and n_channels. @@ -302,6 +304,18 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions + def _get_data(self, cluster_ids): + d = self.waveforms_masks(cluster_ids) + d.alpha = .5 + # Toggle waveform means. + if self.do_show_means: + d.waveforms = d.mean_waveforms + d.masks = d.mean_masks + d.spike_ids = np.arange(len(cluster_ids)) + d.spike_clusters = np.array(cluster_ids) + d.alpha = 1. + return d + def on_select(self, cluster_ids=None): super(WaveformView, self).on_select(cluster_ids) cluster_ids = self.cluster_ids @@ -310,7 +324,8 @@ def on_select(self, cluster_ids=None): return # Load the waveform subset. - data = self.waveforms_masks(cluster_ids) + data = self._get_data(cluster_ids) + alpha = data.alpha spike_ids = data.spike_ids spike_clusters = data.spike_clusters w = data.waveforms @@ -342,7 +357,9 @@ def on_select(self, cluster_ids=None): n_clusters=n_clusters) color = _get_color(m, spike_clusters_rel=spike_clusters_rel, - n_clusters=n_clusters) + n_clusters=n_clusters, + alpha=alpha, + ) self[ch].plot(x=t, y=w[:, :, ch], color=color, depth=depth, @@ -359,6 +376,7 @@ def attach(self, gui): super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap) self.actions.add(self.toggle_zoom_on_channels) + self.actions.add(self.toggle_show_means) # Box scaling. self.actions.add(self.widen) @@ -391,6 +409,10 @@ def toggle_waveform_overlap(self): def toggle_zoom_on_channels(self): self.do_zoom_on_channels = not self.do_zoom_on_channels + def toggle_show_means(self): + self.do_show_means = not self.do_show_means + self.on_select() + # Box scaling # ------------------------------------------------------------------------- From 1e1a72694102b962ff3d9967df133840ce819d28 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 14:34:45 +0100 Subject: [PATCH 0922/1059] View state in GUI --- phy/gui/gui.py | 41 ++++++++++++++++++++------------------- phy/gui/tests/test_gui.py | 10 +++++----- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index e78473cd0..89a00b0af 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -248,12 +248,11 @@ def show(self): # Views # ------------------------------------------------------------------------- - def _get_view_name(self, view): - """The view name is the class name followed by 1, 2, or n.""" + def _get_view_index(self, view): + """Index of a view in the GUI: 0 for the first view of a given + class, 1 for the next, and so on.""" name = view.__class__.__name__ - views = self.list_views(name) - n = len(views) + 1 - return '{:s}{:d}'.format(name, n) + return len(self.list_views(name)) def add_view(self, view, @@ -264,13 +263,16 @@ def add_view(self, floating=None): """Add a widget to the main window.""" - name = name or self._get_view_name(view) # Set the name in the view. - view.__name__ = name + view.view_index = self._get_view_index(view) + # The view name is ``, e.g. `MyView0`. + view.name = name or view.__class__.__name__ + str(view.view_index) + + # Get the Qt canvas for VisPy and matplotlib views. widget = _try_get_vispy_canvas(view) widget = _try_get_matplotlib_canvas(widget) - dock_widget = _create_dock_widget(widget, name, + dock_widget = _create_dock_widget(widget, view.name, closable=closable, floatable=floatable, ) @@ -286,7 +288,7 @@ def on_close_widget(): dock_widget.show() self.emit('add_view', view) - logger.log(5, "Add %s to GUI.", name) + logger.log(5, "Add %s to GUI.", view.name) return dock_widget def list_views(self, name='', is_visible=True): @@ -294,7 +296,7 @@ def list_views(self, name='', is_visible=True): children = self.findChildren(QWidget) return [child.view for child in children if isinstance(child, QDockWidget) and - child.view.__name__.startswith(name) and + child.view.name.startswith(name) and (child.isVisible() if is_visible else True) and child.width() >= 10 and child.height() >= 10 @@ -305,7 +307,7 @@ def view_count(self): views = self.list_views() counts = defaultdict(lambda: 0) for view in views: - counts[view.__name__] += 1 + counts[view.name] += 1 return dict(counts) # Menu bar @@ -385,16 +387,15 @@ def __init__(self, name='GUI', config_dir=None, **kwargs): _ensure_dir_exists(op.join(self.config_dir, self.name)) self.load() - def get_view_params(self, view_name, *names): - # TODO: how to choose view index - return [self.get(view_name + '1', Bunch()).get(name, None) - for name in names] + def get_view_state(self, view): + """Return the state of a view.""" + return self.get(view.name, Bunch()) - def set_view_params(self, view, **kwargs): - view_name = view if isinstance(view, string_types) else view.__name__ - if view_name not in self: - self[view_name] = Bunch() - self[view_name].update(kwargs) + def update_view_state(self, view, state): + """Update the state of a view.""" + if view.name not in self: + self[view.name] = Bunch() + self[view.name].update(state) @property def path(self): diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 744656379..166419e9b 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -178,12 +178,12 @@ def on_show(): #------------------------------------------------------------------------------ def test_gui_state_view(): - view = Bunch(__name__='myview1') + view = Bunch(name='MyView0') state = GUIState() - state.set_view_params(view, hello='world') - state.get_view_params('unknown', 'hello') == [None] - state.get_view_params('myview', 'unknown') == [None] - state.get_view_params('myview', 'hello') == ['world'] + state.update_view_state(view, dict(hello='world')) + assert not state.get_view_state(Bunch(name='MyView')) + assert not state.get_view_state(Bunch(name='MyView1')) + assert state.get_view_state(view) == Bunch(hello='world') def test_create_gui_1(qapp, tempdir): From 349746a73a8ff8c943336bf856ba30fb25a6753d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 16:04:05 +0100 Subject: [PATCH 0923/1059] WIP: refactor GUI plugins --- phy/gui/gui.py | 21 +++++++-------------- phy/io/context.py | 13 +------------ phy/io/store.py | 3 --- 3 files changed, 8 insertions(+), 29 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 89a00b0af..86e202c96 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -11,8 +11,6 @@ import logging import os.path as op -from six import string_types - from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, Qt, QSize, QMetaObject) from .actions import Actions, Snippets @@ -433,8 +431,8 @@ def on_show(): gui.restore_geometry_state(gs) -def create_gui(name=None, subtitle=None, model=None, state=None, - plugins=None, config_dir=None): +def create_gui(name=None, subtitle=None, model=None, + plugins=None, **state_kwargs): """Create a GUI with a list of plugins. By default, the list of plugins is taken from the `c.TheGUI.plugins` @@ -445,17 +443,12 @@ def create_gui(name=None, subtitle=None, model=None, state=None, name = gui.name plugins = plugins or [] - # Load the state. - state = state or {} - # Ensure all dicts are Bunches. - for k in state.keys(): - if isinstance(state[k], dict) and not isinstance(state[k], Bunch): - state[k] = Bunch(state[k]) - state = GUIState(gui.name, config_dir=config_dir, **(state or {})) - gui.state = state + # Create the state. + state = GUIState(gui.name, **state_kwargs) - # Register the model. - gui.register(model=model) + # Make the state and model accessible. + gui.state = state + gui.model = model # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. diff --git a/phy/io/context.py b/phy/io/context.py index dcfb1e206..60bc92be3 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -24,8 +24,7 @@ "Install it with `conda install dask`.") from .array import read_array, write_array -from phy.utils import (Bunch, _save_json, _load_json, _ensure_dir_exists, - IPlugin,) +from phy.utils import (Bunch, _save_json, _load_json, _ensure_dir_exists,) from phy.utils.config import phy_user_dir logger = logging.getLogger(__name__) @@ -326,16 +325,6 @@ def __setstate__(self, state): self._set_memory(state['cache_dir']) -class ContextPlugin(IPlugin): - def attach_to_gui(self, gui): - model = gui.request('model') - # Find the path. - path = getattr(model, 'path', '') - # Create the computing context. - ctx = Context(op.join(op.dirname(path), '.phy/')) - gui.register(ctx, name='context') - - #------------------------------------------------------------------------------ # Task #------------------------------------------------------------------------------ diff --git a/phy/io/store.py b/phy/io/store.py index 1ead0fe1f..f2e783065 100644 --- a/phy/io/store.py +++ b/phy/io/store.py @@ -98,6 +98,3 @@ def add(self, f=None, name=None, cache='disk', concat=None): def get(self, name): return self._stats.get(name, None) - - def attach(self, gui): - gui.register(self, name='cluster_store') From 61bbda51c5046cf395b73f43749706b18db708ba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 16:22:11 +0100 Subject: [PATCH 0924/1059] WIP: refactor view plugins --- phy/cluster/manual/gui_component.py | 15 +- phy/cluster/manual/tests/test_views.py | 161 ++++++++++---------- phy/cluster/manual/views.py | 200 ++++++++++++++++--------- 3 files changed, 214 insertions(+), 162 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 6ff9309eb..d78ef6f20 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -15,7 +15,6 @@ from .clustering import Clustering from phy.gui.actions import Actions from phy.gui.widgets import Table -from phy.utils import IPlugin logger = logging.getLogger(__name__) @@ -400,7 +399,8 @@ def attach(self, gui): gui.add_view(self.cluster_view, name='ClusterView') # Add the quality column in the cluster view. - cs = gui.request('cluster_store') + # TODO + cs = gui.model.store if cs and 'ClusterView' in gui.state: # Names of the quality and similarity functions. quality = gui.state.ClusterView.get('quality', None) @@ -519,14 +519,3 @@ def save(self): groups = {c: self.cluster_meta.get('group', c) or 'unsorted' for c in self.clustering.cluster_ids} self.gui.emit('request_save', spike_clusters, groups) - - -class ManualClusteringPlugin(IPlugin): - def attach_to_gui(self, gui): - model = gui.request('model') - # Attach the manual clustering logic (wizard, merge, split, - # undo stack) to the GUI. - mc = ManualClustering(model.spike_clusters, - cluster_groups=model.cluster_groups, - ) - mc.attach(gui) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index cc5c03f97..cfdf80890 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -6,15 +6,14 @@ # Imports #------------------------------------------------------------------------------ -from contextlib import contextmanager - import numpy as np from numpy.testing import assert_equal as ae from numpy.testing import assert_allclose as ac from vispy.util import keys +from pytest import fixture from phy.electrode.mea import staggered_positions -from phy.gui import create_gui, GUIState +from phy.gui import create_gui from phy.io.array import _spikes_per_cluster from phy.io.mock import (artificial_waveforms, artificial_features, @@ -28,9 +27,16 @@ get_unmasked_channels, get_sorted_main_channels, ) -from phy.utils import Bunch, IPlugin -from ..views import (TraceView, _extract_wave, _selected_clusters_colors, - _extend) +from phy.utils import Bunch +from ..gui_component import ManualClustering +from ..views import (WaveformView, + FeatureView, + CorrelogramView, + TraceView, + _extract_wave, + _selected_clusters_colors, + _extend, + ) #------------------------------------------------------------------------------ @@ -48,7 +54,6 @@ def create_model(): n_spikes_total = n_clusters * model.n_spikes_per_cluster n_features_per_channel = 4 - model.path = '' model.n_channels = n_channels model.n_spikes = n_spikes_total model.sample_rate = 20000. @@ -66,6 +71,9 @@ def create_model(): model.n_samples_waveforms = n_samples_waveforms model.cluster_groups = {c: None for c in range(n_clusters)} + # TODO: make this cleaner by abstracting the store away + model.store = create_cluster_store(model) + return model @@ -226,48 +234,40 @@ def _show(qtbot, view, stop=False): view.close() -@contextmanager -def _test_view(view_name, tempdir=None): - - model = create_model() - - class ClusterStorePlugin(IPlugin): - def attach_to_gui(self, gui): - cs = create_cluster_store(model) - cs.attach(gui) - +@fixture +def state(tempdir): # Save a test GUI state JSON file in the tempdir. - state = GUIState(config_dir=tempdir) - state.set_view_params('WaveformView1', overlap=False, box_size=(.1, .1)) - state.set_view_params('TraceView1', box_size=(1., .01)) - state.set_view_params('FeatureView1', feature_scaling=.5) - state.set_view_params('CorrelogramView1', uniform_normalization=True) + state = Bunch() + state.WaveformView0 = Bunch(overlap=False) + state.TraceView0 = Bunch(box_size=(1., .01)) + state.FeatureView0 = Bunch(feature_scaling=.5) + state.CorrelogramView0 = Bunch(uniform_normalization=True) + # quality and similarity functions for the cluster view. state.ClusterView = Bunch(quality='max_waveform_amplitude', similarity='most_similar_clusters') - state.save() - - # Create the GUI. - plugins = ['ContextPlugin', - 'ClusterStorePlugin', - 'ManualClusteringPlugin', - view_name + 'Plugin'] - gui = create_gui(model=model, plugins=plugins, config_dir=tempdir) - gui.show() + return state + + +@fixture +def gui(tempdir, state): + model = create_model() + gui = create_gui(model=model, config_dir=tempdir, **state) + mc = ManualClustering(model.spike_clusters, + cluster_groups=model.cluster_groups,) + mc.attach(gui) + gui.register(manual_clustering=mc) + return gui + +def _select_clusters(gui): + gui.show() mc = gui.request('manual_clustering') assert mc mc.select([]) mc.select([0]) mc.select([0, 2]) - view = gui.list_views(view_name)[0] - view.gui = gui - view.model = model # HACK - yield view - - gui.close() - #------------------------------------------------------------------------------ # Test utils @@ -312,57 +312,66 @@ def test_selected_clusters_colors(): # Test waveform view #------------------------------------------------------------------------------ -def test_waveform_view(qtbot, tempdir): - with _test_view('WaveformView', tempdir=tempdir) as v: - ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) +def test_waveform_view(qtbot, gui): + v = WaveformView(waveforms_masks=gui.model.store.waveforms_masks, + channel_positions=gui.model.channel_positions, + n_samples=gui.model.n_samples_waveforms, + waveform_lim=gui.model.store.waveform_lim(), + best_channels=(lambda clusters: [0, 1, 2]), + ) + v.attach(gui) - v.toggle_waveform_overlap() - v.toggle_waveform_overlap() + _select_clusters(gui) - v.toggle_show_means() - v.toggle_show_means() + ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) - v.toggle_zoom_on_channels() - v.toggle_zoom_on_channels() + v.toggle_waveform_overlap() + v.toggle_waveform_overlap() - # Box scaling. - bs = v.boxed.box_size - v.increase() - v.decrease() - ac(v.boxed.box_size, bs) + v.toggle_show_means() + v.toggle_show_means() - bs = v.boxed.box_size - v.widen() - v.narrow() - ac(v.boxed.box_size, bs) + v.toggle_zoom_on_channels() + v.toggle_zoom_on_channels() - # Probe scaling. - bp = v.boxed.box_pos - v.extend_horizontally() - v.shrink_horizontally() - ac(v.boxed.box_pos, bp) + # Box scaling. + bs = v.boxed.box_size + v.increase() + v.decrease() + ac(v.boxed.box_size, bs) - bp = v.boxed.box_pos - v.extend_vertically() - v.shrink_vertically() - ac(v.boxed.box_pos, bp) + bs = v.boxed.box_size + v.widen() + v.narrow() + ac(v.boxed.box_size, bs) - v.zoom_on_channels([0, 2, 4]) + # Probe scaling. + bp = v.boxed.box_pos + v.extend_horizontally() + v.shrink_horizontally() + ac(v.boxed.box_pos, bp) - # Simulate channel selection. - _clicked = [] + bp = v.boxed.box_pos + v.extend_vertically() + v.shrink_vertically() + ac(v.boxed.box_pos, bp) - @v.gui.connect_ - def on_channel_click(channel_idx=None, button=None, key=None): - _clicked.append((channel_idx, button, key)) + v.zoom_on_channels([0, 2, 4]) - v.events.key_press(key=keys.Key('2')) - v.events.mouse_press(pos=(0., 0.), button=1) - v.events.key_release(key=keys.Key('2')) + # Simulate channel selection. + _clicked = [] - assert _clicked == [(0, 1, 2)] + @v.gui.connect_ + def on_channel_click(channel_idx=None, button=None, key=None): + _clicked.append((channel_idx, button, key)) - # qtbot.stop() + v.events.key_press(key=keys.Key('2')) + v.events.mouse_press(pos=(0., 0.), button=1) + v.events.key_release(key=keys.Key('2')) + + assert _clicked == [(0, 1, 2)] + + # qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index cd83af110..fbc77a5b2 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -18,7 +18,7 @@ from phy.plot import View, _get_linear_x from phy.plot.utils import _get_boxes from phy.stats import correlograms -from phy.utils import IPlugin +from phy.utils import IPlugin, Bunch logger = logging.getLogger(__name__) @@ -165,20 +165,20 @@ def on_select(self, cluster_ids=None): self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] self.cluster_ids = [int(c) for c in self.cluster_ids] - def _best_channels(self, cluster_ids, n_channels_requested=None): - """Return the best channels for a set of clusters.""" - # Number of channels to find on each axis. - n = n_channels_requested or self.n_channels - # Request the best channels to the GUI. - cs = self.gui.request('cluster_store') if self.gui else None - channels = cs.best_channels_multiple(cluster_ids) if cs else None - # By default, select the first channels. - if channels is None or not len(channels): - return - assert len(channels) - # Repeat some channels if there aren't enough. - channels = _extend(channels, n) - return channels + # def _best_channels(self, cluster_ids, n_channels_requested=None): + # """Return the best channels for a set of clusters.""" + # # Number of channels to find on each axis. + # n = n_channels_requested or self.n_channels + # # Request the best channels to the GUI. + # cs = self.gui.request('cluster_store') if self.gui else None + # channels = cs.best_channels_multiple(cluster_ids) if cs else None + # # By default, select the first channels. + # if channels is None or not len(channels): + # return + # assert len(channels) + # # Repeat some channels if there aren't enough. + # channels = _extend(channels, n) + # return channels def attach(self, gui): """Attach the view to the GUI.""" @@ -189,6 +189,10 @@ def attach(self, gui): gui.add_view(self) self.gui = gui + + # Set the view state. + self.set_state(gui.state.get_view_state(self)) + gui.connect_(self.on_select) self.actions = Actions(gui, name=self.__class__.__name__, @@ -202,8 +206,36 @@ def attach(self, gui): def on_status(e): gui.status_message = e.message + # Save the view state in the GUI state. + @gui.connect_ + def on_close(): + gui.state.update_view_state(self, self.state) + self.show() + @property + def state(self): + """View state. + + This Bunch will be automatically persisted in the GUI state when the + GUI is closed. + + To be overriden. + + """ + return Bunch() + + def set_state(self, state): + """Set the view state. + + The passed object is the persisted `self.state` bunch. + + May be overriden. + + """ + for k, v in state.items(): + setattr(self, k, v) + def set_status(self, message=None): message = message or self.status if not message: @@ -228,7 +260,6 @@ def __init__(self, type, channel_idx=None, key=None, button=None): class WaveformView(ManualClusteringView): - overlap = False scaling_coeff = 1.1 default_shortcuts = { @@ -252,15 +283,17 @@ class WaveformView(ManualClusteringView): def __init__(self, waveforms_masks=None, channel_positions=None, - box_scaling=None, - probe_scaling=None, n_samples=None, waveform_lim=None, + best_channels=None, **kwargs): self._key_pressed = None - self.do_show_means = False + self._do_show_means = False + self._overlap = False self.do_zoom_on_channels = True + self.best_channels = best_channels or (lambda clusters: []) + # Channel positions and n_channels. assert channel_positions is not None self.channel_positions = np.asarray(channel_positions) @@ -281,11 +314,8 @@ def __init__(self, self.events.add(channel_click=ChannelClick) # Box and probe scaling. - self.box_scaling = np.array(box_scaling if box_scaling is not None - else (1., 1.)) - self.probe_scaling = np.array(probe_scaling - if probe_scaling is not None - else (1., 1.)) + self._box_scaling = np.ones(2) + self._probe_scaling = np.ones(2) # Make a copy of the initial box pos and size. We'll apply the scaling # to these quantities. @@ -367,10 +397,19 @@ def on_select(self, cluster_ids=None): ) # Zoom on the best channels when selecting clusters. - channels = self._best_channels(cluster_ids) + channels = self.best_channels(cluster_ids) if channels is not None and self.do_zoom_on_channels: self.zoom_on_channels(channels) + @property + def state(self): + return Bunch(box_scaling=tuple(self.box_scaling), + probe_scaling=tuple(self.probe_scaling), + overlap=self.overlap, + do_show_means=self.do_show_means, + do_zoom_on_channels=self.do_zoom_on_channels, + ) + def attach(self, gui): """Attach the view to the GUI.""" super(WaveformView, self).attach(gui) @@ -399,19 +438,41 @@ def on_channel_click(e): button=e.button, ) - def toggle_waveform_overlap(self): - """Toggle the overlap of the waveforms.""" - self.overlap = not self.overlap + # Overlap + # ------------------------------------------------------------------------- + + @property + def overlap(self): + return self._overlap + + @overlap.setter + def overlap(self, value): + self._overlap = value + # HACK: temporarily disable automatic zoom on channels when + # changing the overlap. tmp = self.do_zoom_on_channels + self.do_zoom_on_channels = False self.on_select() self.do_zoom_on_channels = tmp - def toggle_zoom_on_channels(self): - self.do_zoom_on_channels = not self.do_zoom_on_channels + def toggle_waveform_overlap(self): + """Toggle the overlap of the waveforms.""" + self.overlap = not self.overlap + + # Show means + # ------------------------------------------------------------------------- + + @property + def do_show_means(self): + return self._do_show_means + + @do_show_means.setter + def do_show_means(self, value): + self._do_show_means = value + self.on_select() def toggle_show_means(self): self.do_show_means = not self.do_show_means - self.on_select() # Box scaling # ------------------------------------------------------------------------- @@ -420,52 +481,75 @@ def _update_boxes(self): self.boxed.update_boxes(self.box_pos * self.probe_scaling, self.box_size * self.box_scaling) + @property + def box_scaling(self): + return self._box_scaling + + @box_scaling.setter + def box_scaling(self, value): + assert len(value) == 2 + self._box_scaling = np.array(value) + self._update_boxes() + def widen(self): """Increase the horizontal scaling of the waveforms.""" - self.box_scaling[0] *= self.scaling_coeff + self._box_scaling[0] *= self.scaling_coeff self._update_boxes() def narrow(self): """Decrease the horizontal scaling of the waveforms.""" - self.box_scaling[0] /= self.scaling_coeff + self._box_scaling[0] /= self.scaling_coeff self._update_boxes() def increase(self): """Increase the vertical scaling of the waveforms.""" - self.box_scaling[1] *= self.scaling_coeff + self._box_scaling[1] *= self.scaling_coeff self._update_boxes() def decrease(self): """Decrease the vertical scaling of the waveforms.""" - self.box_scaling[1] /= self.scaling_coeff + self._box_scaling[1] /= self.scaling_coeff self._update_boxes() # Probe scaling # ------------------------------------------------------------------------- + @property + def probe_scaling(self): + return self._probe_scaling + + @probe_scaling.setter + def probe_scaling(self, value): + assert len(value) == 2 + self._probe_scaling = np.array(value) + self._update_boxes() + def extend_horizontally(self): """Increase the horizontal scaling of the probe.""" - self.probe_scaling[0] *= self.scaling_coeff + self._probe_scaling[0] *= self.scaling_coeff self._update_boxes() def shrink_horizontally(self): """Decrease the horizontal scaling of the waveforms.""" - self.probe_scaling[0] /= self.scaling_coeff + self._probe_scaling[0] /= self.scaling_coeff self._update_boxes() def extend_vertically(self): """Increase the vertical scaling of the waveforms.""" - self.probe_scaling[1] *= self.scaling_coeff + self._probe_scaling[1] *= self.scaling_coeff self._update_boxes() def shrink_vertically(self): """Decrease the vertical scaling of the waveforms.""" - self.probe_scaling[1] /= self.scaling_coeff + self._probe_scaling[1] /= self.scaling_coeff self._update_boxes() # Navigation # ------------------------------------------------------------------------- + def toggle_zoom_on_channels(self): + self.do_zoom_on_channels = not self.do_zoom_on_channels + def zoom_on_channels(self, channels_rel): """Zoom on some channels.""" if channels_rel is None or not len(channels_rel): @@ -498,39 +582,6 @@ def on_key_release(self, event): self._key_pressed = None -class WaveformViewPlugin(IPlugin): - def attach_to_gui(self, gui): - state = gui.state - model = gui.request('model') - bs, ps, ov = state.get_view_params('WaveformView', - 'box_scaling', - 'probe_scaling', - 'overlap', - ) - cs = gui.request('cluster_store') - assert cs # We need the cluster store to retrieve the data. - view = WaveformView(waveforms_masks=cs.waveforms_masks, - channel_positions=model.channel_positions, - n_samples=model.n_samples_waveforms, - box_scaling=bs, - probe_scaling=ps, - waveform_lim=cs.waveform_lim(), - ) - view.attach(gui) - - if ov is not None: - view.overlap = ov - - @gui.connect_ - def on_close(): - # Save the box bounds. - state.set_view_params(view, - box_scaling=tuple(view.box_scaling), - probe_scaling=tuple(view.probe_scaling), - overlap=view.overlap, - ) - - # ----------------------------------------------------------------------------- # Trace view # ----------------------------------------------------------------------------- @@ -948,6 +999,7 @@ def __init__(self, n_channels=None, n_features_per_channel=None, feature_lim=None, + best_channels=None, **kwargs): """ features_masks is a function : @@ -960,6 +1012,8 @@ def __init__(self, """ + self.best_channels = best_channels or (lambda clusters: []) + assert features_masks self.features_masks = features_masks @@ -1089,7 +1143,7 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, def _get_channel_dims(self, cluster_ids): """Select the channels to show by default.""" n = self.n_cols - 1 - channels = self._best_channels(cluster_ids, 2 * n) + channels = self.best_channels(cluster_ids, 2 * n) channels = (channels if channels is not None else list(range(self.n_channels))) channels = _extend(channels, 2 * n) From 38672d80f08fd19caa07f372e60c7d65962f7f69 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 17:58:30 +0100 Subject: [PATCH 0925/1059] Refactor trace view --- phy/cluster/manual/tests/test_views.py | 71 +++++++++++++--------- phy/cluster/manual/views.py | 82 +++++++++++--------------- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index cfdf80890..26595956d 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -239,7 +239,7 @@ def state(tempdir): # Save a test GUI state JSON file in the tempdir. state = Bunch() state.WaveformView0 = Bunch(overlap=False) - state.TraceView0 = Bunch(box_size=(1., .01)) + state.TraceView0 = Bunch(scaling=1.) state.FeatureView0 = Bunch(feature_scaling=.5) state.CorrelogramView0 = Bunch(uniform_normalization=True) @@ -391,42 +391,55 @@ def test_trace_view_no_spikes(qtbot): _show(qtbot, v) -def test_trace_view_spikes(qtbot, tempdir): - with _test_view('TraceView', tempdir=tempdir) as v: - ac(v.stacked.box_size, (1., .08181), atol=1e-3) - assert v.time == .25 +def test_trace_view_spikes(qtbot, gui): + model = gui.model + cs = model.store + v = TraceView(traces=model.traces, + sample_rate=model.sample_rate, + spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + n_samples_per_spike=model.n_samples_waveforms, + masks=model.masks, + mean_traces=cs.mean_traces(), + ) + v.attach(gui) - v.go_to(.5) - assert v.time == .5 + _select_clusters(gui) - v.go_to(-.5) - assert v.time == .25 + ac(v.stacked.box_size, (1., .08181), atol=1e-3) + assert v.time == .25 - v.go_left() - assert v.time == .25 + v.go_to(.5) + assert v.time == .5 - v.go_right() - assert v.time == .35 + v.go_to(-.5) + assert v.time == .25 - # Change interval size. - v.set_interval((.25, .75)) - ac(v.interval, (.25, .75)) - v.widen() - ac(v.interval, (.225, .775)) - v.narrow() - ac(v.interval, (.25, .75)) + v.go_left() + assert v.time == .25 - # Widen the max interval. - v.set_interval((0, v.model.duration)) - v.widen() + v.go_right() + assert v.time == .35 - # Change channel scaling. - bs = v.stacked.box_size - v.increase() - v.decrease() - ac(v.stacked.box_size, bs, atol=1e-3) + # Change interval size. + v.set_interval((.25, .75)) + ac(v.interval, (.25, .75)) + v.widen() + ac(v.interval, (.225, .775)) + v.narrow() + ac(v.interval, (.25, .75)) - # qtbot.stop() + # Widen the max interval. + v.set_interval((0, model.duration)) + v.widen() + + # Change channel scaling. + bs = v.stacked.box_size + v.increase() + v.decrease() + ac(v.stacked.box_size, bs, atol=1e-3) + + # qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index fbc77a5b2..4473d28d1 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -165,21 +165,6 @@ def on_select(self, cluster_ids=None): self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] self.cluster_ids = [int(c) for c in self.cluster_ids] - # def _best_channels(self, cluster_ids, n_channels_requested=None): - # """Return the best channels for a set of clusters.""" - # # Number of channels to find on each axis. - # n = n_channels_requested or self.n_channels - # # Request the best channels to the GUI. - # cs = self.gui.request('cluster_store') if self.gui else None - # channels = cs.best_channels_multiple(cluster_ids) if cs else None - # # By default, select the first channels. - # if channels is None or not len(channels): - # return - # assert len(channels) - # # Repeat some channels if there aren't enough. - # channels = _extend(channels, n) - # return channels - def attach(self, gui): """Attach the view to the GUI.""" @@ -606,8 +591,6 @@ def __init__(self, spike_clusters=None, masks=None, # full array of masks n_samples_per_spike=None, - scaling=None, - origin=None, mean_traces=None, **kwargs): @@ -656,13 +639,15 @@ def __init__(self, else: self.spike_times = self.spike_clusters = self.masks = None + # Box and probe scaling. + self._scaling = 1. + self._origin = None + # Initialize the view. super(TraceView, self).__init__(layout='stacked', - origin=origin, + origin=self.origin, n_plots=self.n_channels, **kwargs) - # Box and probe scaling. - self.scaling = scaling or 1. # Make a copy of the initial box pos and size. We'll apply the scaling # to these quantities. @@ -840,6 +825,36 @@ def attach(self, gui): self.actions.add(self.widen) self.actions.add(self.narrow) + @property + def state(self): + return Bunch(scaling=self.scaling, + origin=self.origin, + ) + + # Scaling + # ------------------------------------------------------------------------- + + @property + def scaling(self): + return self._scaling + + @scaling.setter + def scaling(self, value): + self._scaling = value + self._update_boxes() + + # Origin + # ------------------------------------------------------------------------- + + @property + def origin(self): + return self._origin + + @origin.setter + def origin(self, value): + self._origin = value + self._update_boxes() + # Navigation # ------------------------------------------------------------------------- @@ -905,33 +920,6 @@ def decrease(self): self._update_boxes() -class TraceViewPlugin(IPlugin): - def attach_to_gui(self, gui): - state = gui.state - model = gui.request('model') - s, o = state.get_view_params('TraceView', 'scaling', 'origin') - - cs = gui.request('cluster_store') - assert cs # We need the cluster store to retrieve the data. - - view = TraceView(traces=model.traces, - sample_rate=model.sample_rate, - spike_times=model.spike_times, - spike_clusters=model.spike_clusters, - n_samples_per_spike=model.n_samples_waveforms, - masks=model.masks, - origin=o, - scaling=s, - mean_traces=cs.mean_traces(), - ) - view.attach(gui) - - @gui.connect_ - def on_close(): - # Save the box bounds. - state.set_view_params(view, scaling=view.scaling) - - # ----------------------------------------------------------------------------- # Feature view # ----------------------------------------------------------------------------- From c5cd67fb5cd1dbd82cb41abec6b5f20e7ec27425 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 18:05:26 +0100 Subject: [PATCH 0926/1059] Refactor feature view --- phy/cluster/manual/tests/test_views.py | 34 ++++++---- phy/cluster/manual/views.py | 91 +++++++++++--------------- 2 files changed, 62 insertions(+), 63 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 26595956d..02ddfcfc9 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -446,20 +446,32 @@ def test_trace_view_spikes(qtbot, gui): # Test feature view #------------------------------------------------------------------------------ -def test_feature_view(qtbot, tempdir): - with _test_view('FeatureView', tempdir=tempdir) as v: - assert v.feature_scaling == .5 - v.add_attribute('sine', - np.sin(np.linspace(-10., 10., v.model.n_spikes))) +def test_feature_view(qtbot, gui): + model = gui.model + cs = model.store + v = FeatureView(features_masks=cs.features_masks, + background_features_masks=cs.background_features_masks(), + spike_times=model.spike_times, + n_channels=model.n_channels, + n_features_per_channel=model.n_features_per_channel, + feature_lim=cs.feature_lim(), + ) + v.attach(gui) - v.increase() - v.decrease() + _select_clusters(gui) - v.on_channel_click(channel_idx=3, button=1, key=2) - v.clear_channels() - v.toggle_automatic_channel_selection() + assert v.feature_scaling == .5 + v.add_attribute('sine', + np.sin(np.linspace(-10., 10., model.n_spikes))) - # qtbot.stop() + v.increase() + v.decrease() + + v.on_channel_click(channel_idx=3, button=1, key=2) + v.clear_channels() + v.toggle_automatic_channel_selection() + + # qtbot.stop() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 4473d28d1..9e92a22ab 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -277,7 +277,7 @@ def __init__(self, self._overlap = False self.do_zoom_on_channels = True - self.best_channels = best_channels or (lambda clusters: []) + self.best_channels = best_channels or (lambda clusters, n=None: []) # Channel positions and n_channels. assert channel_positions is not None @@ -973,7 +973,6 @@ def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): class FeatureView(ManualClusteringView): _default_marker_size = 3. - _feature_scaling = 1. default_shortcuts = { 'increase': 'ctrl++', @@ -999,8 +998,9 @@ def __init__(self, background_features_masks is a Bunch(...) like above. """ + self._scaling = 1. - self.best_channels = best_channels or (lambda clusters: []) + self.best_channels = best_channels or (lambda clusters, n=None: []) assert features_masks self.features_masks = features_masks @@ -1042,20 +1042,8 @@ def __init__(self, self.attributes = {} self.add_attribute('time', spike_times) - def add_attribute(self, name, values, top_left=True): - """Add an attribute (aka extra feature). - - The values should be a 1D array with `n_spikes` elements. - - By default, there is the `time` attribute. - - """ - assert values.shape == (self.n_spikes,) - lim = values.min(), values.max() - self.attributes[name] = (values, lim) - # Register the attribute to use in the top-left subplot. - if top_left: - self.top_left_attribute = name + # Internal methods + # ------------------------------------------------------------------------- def _get_feature(self, dim, spike_ids, f): if dim in self.attributes: @@ -1067,7 +1055,7 @@ def _get_feature(self, dim, spike_ids, f): else: assert len(dim) == 2 ch, fet = dim - return f[:, ch, fet] * self._feature_scaling + return f[:, ch, fet] * self._scaling def _get_dim_bounds_single(self, dim): """Return the min and max of the bounds for a single dimension.""" @@ -1132,12 +1120,30 @@ def _get_channel_dims(self, cluster_ids): """Select the channels to show by default.""" n = self.n_cols - 1 channels = self.best_channels(cluster_ids, 2 * n) - channels = (channels if channels is not None + channels = (channels if channels else list(range(self.n_channels))) channels = _extend(channels, 2 * n) assert len(channels) == 2 * n return channels[:n], channels[n:] + # Public methods + # ------------------------------------------------------------------------- + + def add_attribute(self, name, values, top_left=True): + """Add an attribute (aka extra feature). + + The values should be a 1D array with `n_spikes` elements. + + By default, there is the `time` attribute. + + """ + assert values.shape == (self.n_spikes,) + lim = values.min(), values.max() + self.attributes[name] = (values, lim) + # Register the attribute to use in the top-left subplot. + if top_left: + self.top_left_attribute = name + def clear_channels(self): """Reset the dimensions.""" self.x_channels = self.y_channels = None @@ -1234,6 +1240,10 @@ def attach(self, gui): gui.connect_(self.on_channel_click) + @property + def state(self): + return Bunch(scaling=self.scaling) + def on_channel_click(self, channel_idx=None, key=None, button=None): """Respond to the click on a channel.""" if key is None or not (1 <= key <= (self.n_cols - 1)): @@ -1258,47 +1268,24 @@ def toggle_automatic_channel_selection(self): def increase(self): """Increase the scaling of the features.""" - self.feature_scaling *= 1.2 + self.scaling *= 1.2 self.on_select() def decrease(self): """Decrease the scaling of the features.""" - self.feature_scaling /= 1.2 + self.scaling /= 1.2 self.on_select() - @property - def feature_scaling(self): - return self._feature_scaling - - @feature_scaling.setter - def feature_scaling(self, value): - self._feature_scaling = value - - -class FeatureViewPlugin(IPlugin): - def attach_to_gui(self, gui): - state = gui.state - cs = gui.request('cluster_store') - model = gui.request('model') - assert cs - bg = cs.background_features_masks() - view = FeatureView(features_masks=cs.features_masks, - background_features_masks=bg, - spike_times=model.spike_times, - n_channels=model.n_channels, - n_features_per_channel=model.n_features_per_channel, - feature_lim=cs.feature_lim(), - ) - view.attach(gui) + # Feature scaling + # ------------------------------------------------------------------------- - fs, = state.get_view_params('FeatureView', 'feature_scaling') - if fs: - view.feature_scaling = fs + @property + def scaling(self): + return self._scaling - @gui.connect_ - def on_close(): - # Save the box bounds. - state.set_view_params(view, feature_scaling=view.feature_scaling) + @scaling.setter + def scaling(self, value): + self._scaling = value # ----------------------------------------------------------------------------- From 4663db88425fd7aad107de4625d0a936ac02b1e5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 18:12:46 +0100 Subject: [PATCH 0927/1059] Refactor trace view --- phy/cluster/manual/gui_component.py | 4 +- .../manual/tests/test_gui_component.py | 13 ----- phy/cluster/manual/tests/test_views.py | 23 ++++++-- phy/cluster/manual/views.py | 57 +++++-------------- 4 files changed, 32 insertions(+), 65 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index d78ef6f20..1050b0176 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -399,8 +399,8 @@ def attach(self, gui): gui.add_view(self.cluster_view, name='ClusterView') # Add the quality column in the cluster view. - # TODO - cs = gui.model.store + # TODO: access the model directly + cs = getattr(getattr(gui, 'model', None), 'store', None) if cs and 'ClusterView' in gui.state: # Names of the quality and similarity functions. quality = gui.state.ClusterView.get('quality', None) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index a968942f6..758027087 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -11,7 +11,6 @@ from numpy.testing import assert_array_equal as ae from ..gui_component import (ManualClustering, - ManualClusteringPlugin, ) from phy.gui import GUI from phy.utils import Bunch @@ -57,18 +56,6 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, # Test GUI component #------------------------------------------------------------------------------ -def test_manual_clustering_plugin(qtbot, gui): - model = Bunch(spike_clusters=[0, 1, 2], - cluster_groups=None, - n_features_per_channel=2, - waveforms=np.zeros((3, 4, 1)), - features=np.zeros((3, 1, 2)), - masks=.75 * np.ones((3, 1)), - ) - gui.register(model=model) - ManualClusteringPlugin().attach_to_gui(gui) - - def test_manual_clustering_edge_cases(manual_clustering): mc = manual_clustering diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 02ddfcfc9..05255c3d4 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -372,6 +372,7 @@ def on_channel_click(channel_idx=None, button=None, key=None): assert _clicked == [(0, 1, 2)] # qtbot.stop() + gui.close() #------------------------------------------------------------------------------ @@ -440,6 +441,7 @@ def test_trace_view_spikes(qtbot, gui): ac(v.stacked.box_size, bs, atol=1e-3) # qtbot.stop() + gui.close() #------------------------------------------------------------------------------ @@ -472,17 +474,26 @@ def test_feature_view(qtbot, gui): v.toggle_automatic_channel_selection() # qtbot.stop() + gui.close() #------------------------------------------------------------------------------ # Test correlogram view #------------------------------------------------------------------------------ -def test_correlogram_view(qtbot, tempdir): - with _test_view('CorrelogramView', tempdir=tempdir) as v: - v.toggle_normalization() +def test_correlogram_view(qtbot, gui): + model = gui.model + v = CorrelogramView(spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + sample_rate=model.sample_rate, + ) + v.attach(gui) + _select_clusters(gui) - v.set_bin(1) - v.set_window(100) + v.toggle_normalization() - # qtbot.stop() + v.set_bin(1) + v.set_window(100) + + # qtbot.stop() + gui.close() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 9e92a22ab..8cf86f271 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -18,7 +18,7 @@ from phy.plot import View, _get_linear_x from phy.plot.utils import _get_boxes from phy.stats import correlograms -from phy.utils import IPlugin, Bunch +from phy.utils import Bunch logger = logging.getLogger(__name__) @@ -1298,6 +1298,7 @@ class CorrelogramView(ManualClusteringView): bin_size = 1e-3 window_size = 50e-3 uniform_normalization = False + default_shortcuts = { 'go_left': 'alt+left', 'go_right': 'alt+right', @@ -1307,18 +1308,11 @@ def __init__(self, spike_times=None, spike_clusters=None, sample_rate=None, - bin_size=None, - window_size=None, - excerpt_size=None, - n_excerpts=None, **kwargs): assert sample_rate > 0 self.sample_rate = sample_rate - self.excerpt_size = excerpt_size or self.excerpt_size - self.n_excerpts = n_excerpts or self.n_excerpts - self.spike_times = np.asarray(spike_times) self.n_spikes, = self.spike_times.shape @@ -1332,7 +1326,8 @@ def __init__(self, self.spike_clusters = spike_clusters # Set the default bin and window size. - self.set_bin_window(bin_size=bin_size, window_size=window_size) + self.set_bin_window(bin_size=self.bin_size, + window_size=self.window_size) def set_bin_window(self, bin_size=None, window_size=None): """Set the bin and window sizes.""" @@ -1412,6 +1407,15 @@ def attach(self, gui): self.actions.add(self.set_bin, alias='cb') self.actions.add(self.set_window, alias='cw') + @property + def state(self): + return Bunch(bin_size=self.bin_size, + window_size=self.window_size, + excerpt_size=self.excerpt_size, + n_excerpts=self.n_excerpts, + uniform_normalization=self.uniform_normalization, + ) + def set_bin(self, bin_size): """Set the correlogram bin size (in milliseconds).""" self.set_bin_window(bin_size=bin_size * 1e-3) @@ -1421,38 +1425,3 @@ def set_window(self, window_size): """Set the correlogram window size (in milliseconds).""" self.set_bin_window(window_size=window_size * 1e-3) self.on_select() - - -class CorrelogramViewPlugin(IPlugin): - def attach_to_gui(self, gui): - state = gui.state - model = gui.request('model') - bs, ws, es, ne, un = state.get_view_params('CorrelogramView', - 'bin_size', - 'window_size', - 'excerpt_size', - 'n_excerpts', - 'uniform_normalization', - ) - - view = CorrelogramView(spike_times=model.spike_times, - spike_clusters=model.spike_clusters, - sample_rate=model.sample_rate, - bin_size=bs, - window_size=ws, - excerpt_size=es, - n_excerpts=ne, - ) - if un is not None: - view.uniform_normalization = un - view.attach(gui) - - @gui.connect_ - def on_close(): - # Save the normalization. - un = view.uniform_normalization - state.set_view_params(view, - uniform_normalization=un, - bin_size=view.bin_size, - window_size=view.window_size, - ) From d93af0a40840640600a49b3aea20f2be79a41b44 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 18:17:07 +0100 Subject: [PATCH 0928/1059] Increase coverage --- phy/cluster/manual/tests/test_views.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 05255c3d4..bf212ddff 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -317,7 +317,7 @@ def test_waveform_view(qtbot, gui): channel_positions=gui.model.channel_positions, n_samples=gui.model.n_samples_waveforms, waveform_lim=gui.model.store.waveform_lim(), - best_channels=(lambda clusters: [0, 1, 2]), + best_channels=gui.model.store.best_channels_multiple, ) v.attach(gui) @@ -356,6 +356,14 @@ def test_waveform_view(qtbot, gui): v.shrink_vertically() ac(v.boxed.box_pos, bp) + a, b = v.probe_scaling + v.probe_scaling = (a, b * 2) + ac(v.probe_scaling, (a, b * 2)) + + a, b = v.box_scaling + v.box_scaling = (a * 2, b) + ac(v.box_scaling, (a * 2, b)) + v.zoom_on_channels([0, 2, 4]) # Simulate channel selection. @@ -440,6 +448,9 @@ def test_trace_view_spikes(qtbot, gui): v.decrease() ac(v.stacked.box_size, bs, atol=1e-3) + v.origin = 'upper' + assert v.origin == 'upper' + # qtbot.stop() gui.close() From 23a215645f817a6efbe53c61fc414bdfb7fc630a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 18:39:31 +0100 Subject: [PATCH 0929/1059] Minor update --- phy/cluster/manual/tests/test_views.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index bf212ddff..96576dff8 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -313,11 +313,13 @@ def test_selected_clusters_colors(): #------------------------------------------------------------------------------ def test_waveform_view(qtbot, gui): - v = WaveformView(waveforms_masks=gui.model.store.waveforms_masks, - channel_positions=gui.model.channel_positions, - n_samples=gui.model.n_samples_waveforms, - waveform_lim=gui.model.store.waveform_lim(), - best_channels=gui.model.store.best_channels_multiple, + model = gui.model + cs = model.store + v = WaveformView(waveforms_masks=model.store.waveforms_masks, + channel_positions=model.channel_positions, + n_samples=model.n_samples_waveforms, + waveform_lim=cs.waveform_lim(), + best_channels=cs.best_channels_multiple, ) v.attach(gui) From 557c163ff8bbf9b5837b0246f803ae973970d7bf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 19:02:36 +0100 Subject: [PATCH 0930/1059] Fix save GUI state --- phy/cluster/manual/views.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 8cf86f271..c67c8c8d0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -195,6 +195,9 @@ def on_status(e): @gui.connect_ def on_close(): gui.state.update_view_state(self, self.state) + # NOTE: create_gui() already saves the state, but the event + # is registered *before* we add all views. + gui.state.save() self.show() From 16dde32ed3331608cb62c75bcba58a70f2149709 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 13 Jan 2016 20:22:40 +0100 Subject: [PATCH 0931/1059] Lint --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index b8de731d4..677afbc9a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,4 +5,4 @@ universal = 1 norecursedirs = experimental _* [flake8] -ignore=E265 +ignore=E265,E731 From a4ba95d8fda5a42402c9f2a8c8d0643fb285506f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 15 Jan 2016 15:44:00 +0100 Subject: [PATCH 0932/1059] Add scatter view --- phy/cluster/manual/tests/test_views.py | 22 ++++++++++ phy/cluster/manual/views.py | 59 ++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 96576dff8..ac0158d0f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -33,6 +33,7 @@ FeatureView, CorrelogramView, TraceView, + ScatterView, _extract_wave, _selected_clusters_colors, _extend, @@ -490,6 +491,27 @@ def test_feature_view(qtbot, gui): gui.close() +#------------------------------------------------------------------------------ +# Test scatter view +#------------------------------------------------------------------------------ + +def test_scatter_view(qtbot, gui): + n = 1000 + v = ScatterView(coords=lambda c: Bunch(x=np.random.randn(n), + y=np.random.randn(n), + spike_ids=np.arange(n), + spike_clusters=np.ones(n). + astype(np.int32) * c[0], + ), + bounds=[-3, -3, 3, 3], + ) + v.attach(gui) + + _select_clusters(gui) + + # qtbot.stop() + + #------------------------------------------------------------------------------ # Test correlogram view #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index c67c8c8d0..a0ef7ca4e 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1428,3 +1428,62 @@ def set_window(self, window_size): """Set the correlogram window size (in milliseconds).""" self.set_bin_window(window_size=window_size * 1e-3) self.on_select() + + +# ----------------------------------------------------------------------------- +# Scatter view +# ----------------------------------------------------------------------------- + +class ScatterView(ManualClusteringView): + _default_marker_size = 3. + + def __init__(self, + coords=None, # function clusters: Bunch(x, y) + bounds=None, + **kwargs): + + assert coords + self.coords = coords + + # Initialize the view. + super(ScatterView, self).__init__(**kwargs) + + # Feature normalization. + self.data_bounds = bounds + + def on_select(self, cluster_ids=None): + super(ScatterView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids + n_clusters = len(cluster_ids) + if n_clusters == 0: + return + + # Get the spike times and amplitudes + data = self.coords(cluster_ids) + spike_ids = data.spike_ids + spike_clusters = data.spike_clusters + x = data.x + y = data.y + n_spikes = len(spike_ids) + assert n_spikes > 0 + assert spike_clusters.shape == (n_spikes,) + assert x.shape == (n_spikes,) + assert y.shape == (n_spikes,) + + # Get the spike clusters. + sc = _index_of(spike_clusters, cluster_ids) + + # Plot the amplitudes. + with self.building(): + m = np.ones(n_spikes) + # Get the color of the markers. + color = _get_color(m, spike_clusters_rel=sc, n_clusters=n_clusters) + assert color.shape == (n_spikes, 4) + ms = (self._default_marker_size if sc is not None else 1.) + + self.scatter(x=x, + y=y, + color=color, + data_bounds=self.data_bounds, + size=ms * np.ones(n_spikes), + ) From 9e24600b90f59623bd8946e2289bc2ce8d8a2ce9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 15 Jan 2016 15:47:03 +0100 Subject: [PATCH 0933/1059] Rename bounds to data_bounds --- phy/cluster/manual/tests/test_views.py | 3 ++- phy/cluster/manual/views.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index ac0158d0f..0bd4b53b2 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -503,13 +503,14 @@ def test_scatter_view(qtbot, gui): spike_clusters=np.ones(n). astype(np.int32) * c[0], ), - bounds=[-3, -3, 3, 3], + data_bounds=[-3, -3, 3, 3], ) v.attach(gui) _select_clusters(gui) # qtbot.stop() + gui.close() #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a0ef7ca4e..361e37b03 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1439,7 +1439,7 @@ class ScatterView(ManualClusteringView): def __init__(self, coords=None, # function clusters: Bunch(x, y) - bounds=None, + data_bounds=None, **kwargs): assert coords @@ -1449,7 +1449,7 @@ def __init__(self, super(ScatterView, self).__init__(**kwargs) # Feature normalization. - self.data_bounds = bounds + self.data_bounds = data_bounds def on_select(self, cluster_ids=None): super(ScatterView, self).on_select(cluster_ids) From 8afdc0ae36a8f08abf41d14efcd6ac026c428afe Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 20 Jan 2016 11:05:00 +0100 Subject: [PATCH 0934/1059] Pass quality/similarity functions to manual clustering component --- phy/cluster/manual/gui_component.py | 23 ++++++++--------- phy/cluster/manual/tests/conftest.py | 25 +++++++++++-------- .../manual/tests/test_gui_component.py | 6 ++--- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 1050b0176..22bf02874 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -119,9 +119,13 @@ def __init__(self, spike_clusters, cluster_groups=None, shortcuts=None, + quality=None, + similarity=None, ): self.gui = None + self.quality = quality + self.similarity = similarity # Load default shortcuts, and override with any user shortcuts. self.shortcuts = self.default_shortcuts.copy() @@ -399,18 +403,13 @@ def attach(self, gui): gui.add_view(self.cluster_view, name='ClusterView') # Add the quality column in the cluster view. - # TODO: access the model directly - cs = getattr(getattr(gui, 'model', None), 'store', None) - if cs and 'ClusterView' in gui.state: - # Names of the quality and similarity functions. - quality = gui.state.ClusterView.get('quality', None) - similarity = gui.state.ClusterView.get('similarity', None) - if quality: - self.cluster_view.add_column(cs.get(quality), name=quality) - self.set_default_sort(quality) - if similarity: - self.set_similarity_func(cs.get(similarity)) - gui.add_view(self.similarity_view, name='SimilarityView') + if self.quality: + self.cluster_view.add_column(self.quality, + name=self.quality.__name__) + self.set_default_sort(self.quality.__name__) + if self.similarity: + self.set_similarity_func(self.similarity) + gui.add_view(self.similarity_view, name='SimilarityView') # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index aaa648936..902581a12 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from pytest import yield_fixture +from pytest import fixture from phy.io.store import get_closest_clusters @@ -15,23 +15,28 @@ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture +@fixture def cluster_ids(): - yield [0, 1, 2, 10, 11, 20, 30] - # i, g, N, i, g, N, N + return [0, 1, 2, 10, 11, 20, 30] + # i, g, N, i, g, N, N -@yield_fixture +@fixture def cluster_groups(): - yield {0: 'noise', 1: 'good', 10: 'mua', 11: 'good'} + return {0: 'noise', 1: 'good', 10: 'mua', 11: 'good'} -@yield_fixture +@fixture def quality(): - yield lambda c: c + def quality(c): + return c + return quality -@yield_fixture +@fixture def similarity(cluster_ids): sim = lambda c, d: (c * 1.01 + d) - yield lambda c: get_closest_clusters(c, cluster_ids, sim) + + def similarity(c): + return get_closest_clusters(c, cluster_ids, sim) + return similarity diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 758027087..60ea36017 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -41,13 +41,11 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, mc = ManualClustering(spike_clusters, cluster_groups=cluster_groups, shortcuts={'undo': 'ctrl+z'}, + quality=quality, + similarity=similarity, ) mc.attach(gui) - mc.add_column(quality, name='quality') - mc.set_default_sort('quality', 'desc') - mc.set_similarity_func(similarity) - yield mc del mc From 408d2eeacce93fb880f391973ca9ee3d77ce96a6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 20 Jan 2016 11:08:54 +0100 Subject: [PATCH 0935/1059] WIP: Remove store --- phy/cluster/manual/tests/conftest.py | 2 +- phy/cluster/manual/tests/test_views.py | 3 +- phy/io/array.py | 34 +++++++++ phy/io/store.py | 100 ------------------------- phy/io/tests/test_store.py | 26 ------- 5 files changed, 36 insertions(+), 129 deletions(-) delete mode 100644 phy/io/store.py delete mode 100644 phy/io/tests/test_store.py diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 902581a12..4b37e30d2 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,7 +8,7 @@ from pytest import fixture -from phy.io.store import get_closest_clusters +from phy.io.array import get_closest_clusters #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 0bd4b53b2..8e2f8eab0 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,13 +14,12 @@ from phy.electrode.mea import staggered_positions from phy.gui import create_gui -from phy.io.array import _spikes_per_cluster +from phy.io.array import _spikes_per_cluster, get_closest_clusters from phy.io.mock import (artificial_waveforms, artificial_features, artificial_masks, artificial_traces, ) -from phy.io.store import ClusterStore, get_closest_clusters from phy.stats.clusters import (mean, get_max_waveform_amplitude, get_mean_masked_features_distance, diff --git a/phy/io/array.py b/phy/io/array.py index 0a6ce9fae..77ef6695c 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -7,13 +7,16 @@ #------------------------------------------------------------------------------ from collections import defaultdict +from functools import wraps import logging import math from math import floor, exp +from operator import itemgetter import os.path as op import numpy as np +from phy.utils import Bunch, _as_scalar, _as_scalars from phy.utils._types import _as_array, _is_array_like logger = logging.getLogger(__name__) @@ -189,6 +192,37 @@ def _in_polygon(points, polygon): return path.contains_points(points) +def _get_data_lim(arr, n_spikes=None, percentile=None): + n = arr.shape[0] + k = max(1, n // n_spikes) if n_spikes else 1 + arr = np.abs(arr[::k]) + n = arr.shape[0] + arr = arr.reshape((n, -1)) + return arr.max() + + +def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): + """Return a list of pairs `(cluster, similarity)` sorted by decreasing + similarity to a given cluster.""" + l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) + for candidate in _as_scalars(cluster_ids)] + l = sorted(l, key=itemgetter(1), reverse=True) + return l[:max_n] + + +def _concat(f): + """Take a function accepting a single cluster, and return a function + accepting multiple clusters.""" + @wraps(f) + def wrapped(cluster_ids): + # Single cluster. + if not hasattr(cluster_ids, '__len__'): + return f(cluster_ids) + # Concatenate the result of multiple clusters. + return Bunch(_accumulate([f(c) for c in cluster_ids])) + return wrapped + + # ----------------------------------------------------------------------------- # I/O functions # ----------------------------------------------------------------------------- diff --git a/phy/io/store.py b/phy/io/store.py deleted file mode 100644 index f2e783065..000000000 --- a/phy/io/store.py +++ /dev/null @@ -1,100 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Cluster store.""" - - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -from functools import wraps -import logging -from operator import itemgetter - -import numpy as np - -from .array import _accumulate -from phy.utils import Bunch, _as_scalar, _as_scalars - -logger = logging.getLogger(__name__) - - -# ----------------------------------------------------------------------------- -# Utils -# ----------------------------------------------------------------------------- - -def _get_data_lim(arr, n_spikes=None, percentile=None): - n = arr.shape[0] - k = max(1, n // n_spikes) if n_spikes else 1 - arr = np.abs(arr[::k]) - n = arr.shape[0] - arr = arr.reshape((n, -1)) - return arr.max() - - -def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): - """Return a list of pairs `(cluster, similarity)` sorted by decreasing - similarity to a given cluster.""" - l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) - for candidate in _as_scalars(cluster_ids)] - l = sorted(l, key=itemgetter(1), reverse=True) - return l[:max_n] - - -def _log(f): - @wraps(f) - def wrapped(*args, **kwargs): - logger.log(5, "Compute %s(%s).", f.__name__, str(args)) - return f(*args, **kwargs) - return wrapped - - -def _concat(f): - """Take a function accepting a single cluster, and return a function - accepting multiple clusters.""" - @wraps(f) - def wrapped(cluster_ids): - # Single cluster. - if not hasattr(cluster_ids, '__len__'): - return f(cluster_ids) - # Concatenate the result of multiple clusters. - return Bunch(_accumulate([f(c) for c in cluster_ids])) - return wrapped - - -# ----------------------------------------------------------------------------- -# Cluster statistics -# ----------------------------------------------------------------------------- - -class ClusterStore(object): - def __init__(self, context=None): - self.context = context - self._stats = {} - - def add(self, f=None, name=None, cache='disk', concat=None): - """Add a cluster statistic. - - Parameters - ---------- - f : function - name : str - cache : str - Can be `None` (no cache), `disk`, or `memory`. In the latter case - the function will also be cached on disk. - - """ - if f is None: - return lambda _: self.add(_, name=name, cache=cache, concat=concat) - name = name or f.__name__ - if cache and self.context: - f = _log(f) - f = self.context.cache(f, memcache=(cache == 'memory')) - assert f - if concat: - f = _concat(f) - self._stats[name] = f - setattr(self, name, f) - return f - - def get(self, name): - return self._stats.get(name, None) diff --git a/phy/io/tests/test_store.py b/phy/io/tests/test_store.py deleted file mode 100644 index 3b5560e43..000000000 --- a/phy/io/tests/test_store.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test cluster store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ..store import ClusterStore -from phy.io import Context - - -#------------------------------------------------------------------------------ -# Test cluster stats -#------------------------------------------------------------------------------ - -def test_cluster_store(tempdir): - context = Context(tempdir) - cs = ClusterStore(context=context) - - @cs.add(cache='memory') - def f(x): - return x * x - - assert cs.f(3) == 9 - assert cs.get('f')(3) == 9 From b731ae4bf5f4ebe061be45d2bcc93722dd9c7f80 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 20 Jan 2016 11:17:26 +0100 Subject: [PATCH 0936/1059] Remove store in view tests --- phy/cluster/manual/tests/test_views.py | 106 ++++++++++++------------- 1 file changed, 50 insertions(+), 56 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 8e2f8eab0..3eebe3125 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,7 +14,7 @@ from phy.electrode.mea import staggered_positions from phy.gui import create_gui -from phy.io.array import _spikes_per_cluster, get_closest_clusters +from phy.io.array import _spikes_per_cluster, get_closest_clusters, _concat from phy.io.mock import (artificial_waveforms, artificial_features, artificial_masks, @@ -63,6 +63,8 @@ def create_model(): model.n_spikes_per_cluster) model.cluster_ids = np.unique(model.spike_clusters) model.channel_positions = staggered_positions(n_channels) + + # TODO: remove and replace by functions model.traces = artificial_traces(n_samples_t, n_channels) model.masks = artificial_masks(n_spikes_total, n_channels) @@ -71,15 +73,6 @@ def create_model(): model.n_samples_waveforms = n_samples_waveforms model.cluster_groups = {c: None for c in range(n_clusters)} - # TODO: make this cleaner by abstracting the store away - model.store = create_cluster_store(model) - - return model - - -def create_cluster_store(model): - cs = ClusterStore() - def get_waveforms(n): return artificial_waveforms(n, model.n_samples_waveforms, @@ -101,125 +94,128 @@ def _get_data(**kwargs): kwargs['spike_clusters'] = model.spike_clusters[kwargs['spike_ids']] return Bunch(**kwargs) - @cs.add(concat=True) + @_concat def masks(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), masks=get_masks(model.n_spikes_per_cluster)) + # model.masks = masks - @cs.add(concat=True) + @_concat def features(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), features=get_features(model.n_spikes_per_cluster)) + model.features = features - @cs.add(concat=True) + @_concat def features_masks(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), features=get_features(model.n_spikes_per_cluster), masks=get_masks(model.n_spikes_per_cluster)) + model.features_masks = features_masks - @cs.add def feature_lim(): """Return the max of a subset of the feature amplitudes.""" return 1 + model.feature_lim = feature_lim - @cs.add def background_features_masks(): f = get_features(model.n_spikes) m = model.masks return _get_data(spike_ids=np.arange(model.n_spikes), features=f, masks=m) + model.background_features_masks = background_features_masks - @cs.add(concat=True) def waveforms(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), waveforms=get_waveforms(model.n_spikes_per_cluster)) + model.waveforms = waveforms - @cs.add def waveform_lim(): """Return the max of a subset of the waveform amplitudes.""" return 1 + model.waveform_lim = waveform_lim - @cs.add(concat=True) + @_concat def waveforms_masks(cluster_id): w = get_waveforms(model.n_spikes_per_cluster) m = get_masks(model.n_spikes_per_cluster) - mw = cs.mean_waveforms(cluster_id)[np.newaxis, ...] - mm = cs.mean_masks(cluster_id)[np.newaxis, ...] + mw = mean_waveforms(cluster_id)[np.newaxis, ...] + mm = mean_masks(cluster_id)[np.newaxis, ...] return _get_data(spike_ids=get_spike_ids(cluster_id), waveforms=w, masks=m, mean_waveforms=mw, mean_masks=mm, ) + model.waveforms_masks = waveforms_masks # Mean quantities. # ------------------------------------------------------------------------- - @cs.add def mean_masks(cluster_id): # We access [1] because we return spike_ids, masks. - return mean(cs.masks(cluster_id).masks) + return mean(masks(cluster_id).masks) + model.mean_masks = mean_masks - @cs.add def mean_features(cluster_id): - return mean(cs.features(cluster_id).features) + return mean(features(cluster_id).features) + model.mean_features = mean_features - @cs.add def mean_waveforms(cluster_id): - return mean(cs.waveforms(cluster_id).waveforms) + return mean(waveforms(cluster_id).waveforms) + model.mean_waveforms = mean_waveforms # Statistics. # ------------------------------------------------------------------------- - @cs.add(cache='memory') def best_channels(cluster_id): - mm = cs.mean_masks(cluster_id) + mm = mean_masks(cluster_id) uch = get_unmasked_channels(mm) return get_sorted_main_channels(mm, uch) + model.best_channels = best_channels - @cs.add(cache='memory') def best_channels_multiple(cluster_ids): - best_channels = [] + bc = [] for cluster in cluster_ids: - channels = cs.best_channels(cluster) - best_channels.extend([ch for ch in channels - if ch not in best_channels]) - return best_channels + channels = best_channels(cluster) + bc.extend([ch for ch in channels if ch not in bc]) + return bc + model.best_channels_multiple = best_channels_multiple - @cs.add(cache='memory') def max_waveform_amplitude(cluster_id): - mm = cs.mean_masks(cluster_id) - mw = cs.mean_waveforms(cluster_id) + mm = mean_masks(cluster_id) + mw = mean_waveforms(cluster_id) assert mw.ndim == 2 return np.asscalar(get_max_waveform_amplitude(mm, mw)) + model.max_waveform_amplitude = max_waveform_amplitude - @cs.add(cache=None) def mean_masked_features_score(cluster_0, cluster_1): - mf0 = cs.mean_features(cluster_0) - mf1 = cs.mean_features(cluster_1) - mm0 = cs.mean_masks(cluster_0) - mm1 = cs.mean_masks(cluster_1) + mf0 = mean_features(cluster_0) + mf1 = mean_features(cluster_1) + mm0 = mean_masks(cluster_0) + mm1 = mean_masks(cluster_1) nfpc = model.n_features_per_channel d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, n_features_per_channel=nfpc) s = 1. / max(1e-10, d) return s + model.mean_masked_features_score = mean_masked_features_score - @cs.add(cache='memory') def most_similar_clusters(cluster_id): assert isinstance(cluster_id, int) return get_closest_clusters(cluster_id, model.cluster_ids, - cs.mean_masked_features_score) + mean_masked_features_score) + model.most_similar_clusters = most_similar_clusters # Traces. # ------------------------------------------------------------------------- - @cs.add def mean_traces(): mt = model.traces[:, :].mean(axis=0) return mt.astype(model.traces.dtype) + model.mean_traces = mean_traces - return cs + return model #------------------------------------------------------------------------------ @@ -314,12 +310,11 @@ def test_selected_clusters_colors(): def test_waveform_view(qtbot, gui): model = gui.model - cs = model.store - v = WaveformView(waveforms_masks=model.store.waveforms_masks, + v = WaveformView(waveforms_masks=model.waveforms_masks, channel_positions=model.channel_positions, n_samples=model.n_samples_waveforms, - waveform_lim=cs.waveform_lim(), - best_channels=cs.best_channels_multiple, + waveform_lim=model.waveform_lim(), + best_channels=model.best_channels_multiple, ) v.attach(gui) @@ -404,14 +399,13 @@ def test_trace_view_no_spikes(qtbot): def test_trace_view_spikes(qtbot, gui): model = gui.model - cs = model.store v = TraceView(traces=model.traces, sample_rate=model.sample_rate, spike_times=model.spike_times, spike_clusters=model.spike_clusters, n_samples_per_spike=model.n_samples_waveforms, masks=model.masks, - mean_traces=cs.mean_traces(), + mean_traces=model.mean_traces(), ) v.attach(gui) @@ -463,13 +457,13 @@ def test_trace_view_spikes(qtbot, gui): def test_feature_view(qtbot, gui): model = gui.model - cs = model.store - v = FeatureView(features_masks=cs.features_masks, - background_features_masks=cs.background_features_masks(), + bfm = model.background_features_masks() + v = FeatureView(features_masks=model.features_masks, + background_features_masks=bfm, spike_times=model.spike_times, n_channels=model.n_channels, n_features_per_channel=model.n_features_per_channel, - feature_lim=cs.feature_lim(), + feature_lim=model.feature_lim(), ) v.attach(gui) From f73c79ad334a0ee6cbd460b0b62fd4be20771433 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 21 Jan 2016 12:05:13 +0100 Subject: [PATCH 0937/1059] Update trace view --- phy/cluster/manual/tests/test_views.py | 49 +++++++------- phy/cluster/manual/views.py | 92 ++++++++------------------ 2 files changed, 54 insertions(+), 87 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 3eebe3125..6f55bbdad 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -33,6 +33,7 @@ CorrelogramView, TraceView, ScatterView, + select_traces, _extract_wave, _selected_clusters_colors, _extend, @@ -64,9 +65,25 @@ def create_model(): model.cluster_ids = np.unique(model.spike_clusters) model.channel_positions = staggered_positions(n_channels) - # TODO: remove and replace by functions - model.traces = artificial_traces(n_samples_t, n_channels) - model.masks = artificial_masks(n_spikes_total, n_channels) + all_traces = artificial_traces(n_samples_t, n_channels) + all_masks = artificial_masks(n_spikes_total, n_channels) + + def traces(interval): + """Load traces and spikes in an interval.""" + tr = select_traces(all_traces, interval, + sample_rate=model.sample_rate, + ) + # Find spikes. + a, b = model.spike_times.searchsorted(interval) + st = model.spike_times[a:b] + sc = model.spike_clusters[a:b] + m = all_masks[a:b, :] + return Bunch(traces=tr, + spike_times=st, + spike_clusters=sc, + masks=m, + ) + model.traces = traces model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) model.n_features_per_channel = n_features_per_channel @@ -120,7 +137,7 @@ def feature_lim(): def background_features_masks(): f = get_features(model.n_spikes) - m = model.masks + m = all_masks return _get_data(spike_ids=np.arange(model.n_spikes), features=f, masks=m) model.background_features_masks = background_features_masks @@ -211,8 +228,8 @@ def most_similar_clusters(cluster_id): # ------------------------------------------------------------------------- def mean_traces(): - mt = model.traces[:, :].mean(axis=0) - return mt.astype(model.traces.dtype) + mt = all_traces.mean(axis=0) + return mt.astype(all_traces.dtype) model.mean_traces = mean_traces return model @@ -384,27 +401,13 @@ def on_channel_click(channel_idx=None, button=None, key=None): # Test trace view #------------------------------------------------------------------------------ -def test_trace_view_no_spikes(qtbot): - n_samples = 1000 - n_channels = 12 - sample_rate = 2000. - - traces = artificial_traces(n_samples, n_channels) - mt = np.atleast_2d(traces.mean(axis=0)) - - # Create the view. - v = TraceView(traces=traces, sample_rate=sample_rate, mean_traces=mt) - _show(qtbot, v) - - -def test_trace_view_spikes(qtbot, gui): +def test_trace_view(qtbot, gui): model = gui.model v = TraceView(traces=model.traces, sample_rate=model.sample_rate, - spike_times=model.spike_times, - spike_clusters=model.spike_clusters, n_samples_per_spike=model.n_samples_waveforms, - masks=model.masks, + duration=model.duration, + n_channels=model.n_channels, mean_traces=model.mean_traces(), ) v.attach(gui) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 361e37b03..3e53b3666 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -574,6 +574,15 @@ def on_key_release(self, event): # Trace view # ----------------------------------------------------------------------------- +def select_traces(traces, interval, sample_rate=None): + """Load traces in an interval (in seconds).""" + start, end = interval + i, j = round(sample_rate * start), round(sample_rate * end) + i, j = int(i), int(j) + traces = traces[i:j, :] + return traces + + class TraceView(ManualClusteringView): interval_duration = .5 # default duration of the interval shift_amount = .1 @@ -590,23 +599,29 @@ class TraceView(ManualClusteringView): def __init__(self, traces=None, sample_rate=None, - spike_times=None, - spike_clusters=None, - masks=None, # full array of masks + duration=None, + n_channels=None, n_samples_per_spike=None, mean_traces=None, **kwargs): + # traces is a function interval => {traces, spike_times, + # spike_clusters, masks} + # Sample rate. assert sample_rate > 0 self.sample_rate = sample_rate self.dt = 1. / self.sample_rate # Traces. - assert len(traces.shape) == 2 - self.n_samples, self.n_channels = traces.shape + assert hasattr(traces, '__call__') self.traces = traces - self.duration = self.dt * self.n_samples + + assert duration >= 0 + self.duration = duration + + assert n_channels >= 0 + self.n_channels = n_channels # Used to detrend the traces. self.mean_traces = np.atleast_2d(mean_traces) @@ -622,26 +637,6 @@ def __init__(self, self.n_samples_per_spike = (-ns // 2, ns // 2) # Now n_samples_per_spike is a tuple. - # Spike times. - if spike_times is not None: - spike_times = np.asarray(spike_times) - self.n_spikes = len(spike_times) - assert spike_times.shape == (self.n_spikes,) - self.spike_times = spike_times - - # Spike clusters. - spike_clusters = (np.zeros(self.n_spikes) if spike_clusters is None - else spike_clusters) - assert spike_clusters.shape == (self.n_spikes,) - self.spike_clusters = spike_clusters - - # Masks. - if masks is not None: - assert masks.shape == (self.n_spikes, self.n_channels) - self.masks = masks - else: - self.spike_times = self.spike_clusters = self.masks = None - # Box and probe scaling. self._scaling = 1. self._origin = None @@ -663,38 +658,6 @@ def __init__(self, # Internal methods # ------------------------------------------------------------------------- - def _load_traces(self, interval): - """Load traces in an interval (in seconds).""" - - start, end = interval - - i, j = round(self.sample_rate * start), round(self.sample_rate * end) - i, j = int(i), int(j) - - # We load the traces and select the requested channels. - assert self.traces.shape[1] == self.n_channels - traces = self.traces[i:j, :] - assert traces.shape[1] == self.n_channels - - # Detrend the traces. - traces = traces - self.mean_traces - - # Create the plots. - return traces - - def _load_spikes(self, interval): - """Return spike times, spike clusters, masks.""" - assert self.spike_times is not None - # Keep the spikes in the interval. - a, b = self.spike_times.searchsorted(interval) - spike_times = self.spike_times[a:b] - spike_clusters = self.spike_clusters[a:b] - n_spikes = len(spike_times) - assert len(spike_clusters) == n_spikes - masks = (self.masks[a:b] if self.masks is not None - else np.ones((n_spikes, self.n_channels))) - return spike_times, spike_clusters, masks - def _plot_traces(self, traces, start=None, data_bounds=None): t = start + np.arange(traces.shape[0]) * self.dt gray = .4 @@ -775,7 +738,12 @@ def set_interval(self, interval, change_status=True): start, end = interval # Load traces. - traces = self._load_traces(interval) + d = self.traces(interval) + traces = d.traces - self.mean_traces + spike_times = d.spike_times + spike_clusters = d.spike_clusters + masks = d.masks + # NOTE: once loaded, the traces do not contain the dead channels # so there are `n_channels_order` channels here. assert traces.shape[1] == self.n_channels @@ -794,11 +762,7 @@ def set_interval(self, interval, change_status=True): self._plot_traces(traces, start=start, data_bounds=data_bounds) # Display the spikes. - if self.spike_times is not None: - # Load the spikes. - spike_times, spike_clusters, masks = self._load_spikes(interval) - - # Plot every spike. + if spike_times is not None: for i in range(len(spike_times)): self._plot_spike(i, start=start, From 7bc53f53bf6370392a57ebb9f1a3f9a06740be33 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 21 Jan 2016 14:42:53 +0100 Subject: [PATCH 0938/1059] WIP: refactor views --- phy/cluster/manual/tests/test_views.py | 39 ++++++-------------- phy/cluster/manual/views.py | 50 +++++++------------------- 2 files changed, 22 insertions(+), 67 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 6f55bbdad..08d751646 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -115,37 +115,25 @@ def _get_data(**kwargs): def masks(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), masks=get_masks(model.n_spikes_per_cluster)) - # model.masks = masks @_concat def features(cluster_id): - return _get_data(spike_ids=get_spike_ids(cluster_id), - features=get_features(model.n_spikes_per_cluster)) - model.features = features - - @_concat - def features_masks(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), features=get_features(model.n_spikes_per_cluster), masks=get_masks(model.n_spikes_per_cluster)) - model.features_masks = features_masks + model.features = features def feature_lim(): """Return the max of a subset of the feature amplitudes.""" return 1 model.feature_lim = feature_lim - def background_features_masks(): + def background_features(): f = get_features(model.n_spikes) m = all_masks return _get_data(spike_ids=np.arange(model.n_spikes), features=f, masks=m) - model.background_features_masks = background_features_masks - - def waveforms(cluster_id): - return _get_data(spike_ids=get_spike_ids(cluster_id), - waveforms=get_waveforms(model.n_spikes_per_cluster)) - model.waveforms = waveforms + model.background_features = background_features def waveform_lim(): """Return the max of a subset of the waveform amplitudes.""" @@ -153,18 +141,14 @@ def waveform_lim(): model.waveform_lim = waveform_lim @_concat - def waveforms_masks(cluster_id): + def waveforms(cluster_id): w = get_waveforms(model.n_spikes_per_cluster) m = get_masks(model.n_spikes_per_cluster) - mw = mean_waveforms(cluster_id)[np.newaxis, ...] - mm = mean_masks(cluster_id)[np.newaxis, ...] return _get_data(spike_ids=get_spike_ids(cluster_id), waveforms=w, masks=m, - mean_waveforms=mw, - mean_masks=mm, ) - model.waveforms_masks = waveforms_masks + model.waveforms = waveforms # Mean quantities. # ------------------------------------------------------------------------- @@ -327,7 +311,7 @@ def test_selected_clusters_colors(): def test_waveform_view(qtbot, gui): model = gui.model - v = WaveformView(waveforms_masks=model.waveforms_masks, + v = WaveformView(waveforms=model.waveforms, channel_positions=model.channel_positions, n_samples=model.n_samples_waveforms, waveform_lim=model.waveform_lim(), @@ -342,9 +326,6 @@ def test_waveform_view(qtbot, gui): v.toggle_waveform_overlap() v.toggle_waveform_overlap() - v.toggle_show_means() - v.toggle_show_means() - v.toggle_zoom_on_channels() v.toggle_zoom_on_channels() @@ -393,7 +374,7 @@ def on_channel_click(channel_idx=None, button=None, key=None): assert _clicked == [(0, 1, 2)] - # qtbot.stop() + qtbot.stop() gui.close() @@ -460,9 +441,9 @@ def test_trace_view(qtbot, gui): def test_feature_view(qtbot, gui): model = gui.model - bfm = model.background_features_masks() - v = FeatureView(features_masks=model.features_masks, - background_features_masks=bfm, + bfm = model.background_features() + v = FeatureView(features=model.features, + background_features=bfm, spike_times=model.spike_times, n_channels=model.n_channels, n_features_per_channel=model.n_features_per_channel, diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 3e53b3666..efc569424 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -253,7 +253,6 @@ class WaveformView(ManualClusteringView): default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_zoom_on_channels': 'z', - 'toggle_show_means': 'm', # Box scaling. 'widen': 'ctrl+right', @@ -269,14 +268,13 @@ class WaveformView(ManualClusteringView): } def __init__(self, - waveforms_masks=None, + waveforms=None, channel_positions=None, n_samples=None, waveform_lim=None, best_channels=None, **kwargs): self._key_pressed = None - self._do_show_means = False self._overlap = False self.do_zoom_on_channels = True @@ -312,7 +310,7 @@ def __init__(self, self._update_boxes() # Data: functions cluster_id => waveforms. - self.waveforms_masks = waveforms_masks + self.waveforms = waveforms # Waveform normalization. assert waveform_lim > 0 @@ -323,15 +321,8 @@ def __init__(self, self.channel_positions = channel_positions def _get_data(self, cluster_ids): - d = self.waveforms_masks(cluster_ids) + d = self.waveforms(cluster_ids) d.alpha = .5 - # Toggle waveform means. - if self.do_show_means: - d.waveforms = d.mean_waveforms - d.masks = d.mean_masks - d.spike_ids = np.arange(len(cluster_ids)) - d.spike_clusters = np.array(cluster_ids) - d.alpha = 1. return d def on_select(self, cluster_ids=None): @@ -394,7 +385,6 @@ def state(self): return Bunch(box_scaling=tuple(self.box_scaling), probe_scaling=tuple(self.probe_scaling), overlap=self.overlap, - do_show_means=self.do_show_means, do_zoom_on_channels=self.do_zoom_on_channels, ) @@ -403,7 +393,6 @@ def attach(self, gui): super(WaveformView, self).attach(gui) self.actions.add(self.toggle_waveform_overlap) self.actions.add(self.toggle_zoom_on_channels) - self.actions.add(self.toggle_show_means) # Box scaling. self.actions.add(self.widen) @@ -447,21 +436,6 @@ def toggle_waveform_overlap(self): """Toggle the overlap of the waveforms.""" self.overlap = not self.overlap - # Show means - # ------------------------------------------------------------------------- - - @property - def do_show_means(self): - return self._do_show_means - - @do_show_means.setter - def do_show_means(self, value): - self._do_show_means = value - self.on_select() - - def toggle_show_means(self): - self.do_show_means = not self.do_show_means - # Box scaling # ------------------------------------------------------------------------- @@ -947,8 +921,8 @@ class FeatureView(ManualClusteringView): } def __init__(self, - features_masks=None, - background_features_masks=None, + features=None, + background_features=None, spike_times=None, n_channels=None, n_features_per_channel=None, @@ -956,24 +930,24 @@ def __init__(self, best_channels=None, **kwargs): """ - features_masks is a function : + features is a function : `cluster_ids: Bunch(spike_ids, features, masks, spike_clusters, spike_times)` - background_features_masks is a Bunch(...) like above. + background_features is a Bunch(...) like above. """ self._scaling = 1. self.best_channels = best_channels or (lambda clusters, n=None: []) - assert features_masks - self.features_masks = features_masks + assert features + self.features = features # This is a tuple (spikes, features, masks). - self.background_features_masks = background_features_masks + self.background_features = background_features self.n_features_per_channel = n_features_per_channel assert n_channels > 0 @@ -1124,7 +1098,7 @@ def on_select(self, cluster_ids=None): return # Get the spikes, features, masks. - data = self.features_masks(cluster_ids) + data = self.features(cluster_ids) spike_ids = data.spike_ids spike_clusters = data.spike_clusters f = data.features @@ -1137,7 +1111,7 @@ def on_select(self, cluster_ids=None): sc = _index_of(spike_clusters, cluster_ids) # Get the background features. - data_bg = self.background_features_masks + data_bg = self.background_features spike_ids_bg = data_bg.spike_ids features_bg = data_bg.features masks_bg = data_bg.masks From b4b1e495e1d342a774069e9511b918e8f225d664 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 21 Jan 2016 14:43:01 +0100 Subject: [PATCH 0939/1059] Fix --- phy/cluster/manual/tests/test_views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 08d751646..59c911407 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -374,7 +374,7 @@ def on_channel_click(channel_idx=None, button=None, key=None): assert _clicked == [(0, 1, 2)] - qtbot.stop() + # qtbot.stop() gui.close() From 9967a716d969c54285a23d44a5ff606bbfeab28e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 22 Jan 2016 14:41:59 +0100 Subject: [PATCH 0940/1059] Fix --- phy/cluster/manual/gui_component.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 22bf02874..2ddf1cc5d 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -406,7 +406,8 @@ def attach(self, gui): if self.quality: self.cluster_view.add_column(self.quality, name=self.quality.__name__) - self.set_default_sort(self.quality.__name__) + self.set_default_sort(self.quality.__name__ + if self.quality else 'n_spikes') if self.similarity: self.set_similarity_func(self.similarity) gui.add_view(self.similarity_view, name='SimilarityView') From 4b34f4d4d68e75cb99c7f4cd9cffd14fe40f8b47 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 22 Jan 2016 17:14:40 +0100 Subject: [PATCH 0941/1059] Better trace detrending --- phy/cluster/manual/tests/test_views.py | 17 ++++------------ phy/cluster/manual/views.py | 27 +++++++++++--------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 59c911407..241d1c8d9 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -208,14 +208,6 @@ def most_similar_clusters(cluster_id): mean_masked_features_score) model.most_similar_clusters = most_similar_clusters - # Traces. - # ------------------------------------------------------------------------- - - def mean_traces(): - mt = all_traces.mean(axis=0) - return mt.astype(all_traces.dtype) - model.mean_traces = mean_traces - return model @@ -286,16 +278,16 @@ def test_extract_wave(): hwl = wave_len // 2 ae(_extract_wave(traces, 0 - hwl, 0 + hwl, mask, wave_len)[0], - [[0, 0, 0], [0, 0, 0], [1, 2, 3], [6, 7, 8]]) + [[0, 0], [0, 0], [1, 2], [6, 7]]) ae(_extract_wave(traces, 1 - hwl, 1 + hwl, mask, wave_len)[0], - [[0, 0, 0], [1, 2, 3], [6, 7, 8], [11, 12, 13]]) + [[0, 0], [1, 2], [6, 7], [11, 12]]) ae(_extract_wave(traces, 2 - hwl, 2 + hwl, mask, wave_len)[0], - [[1, 2, 3], [6, 7, 8], [11, 12, 13], [16, 17, 18]]) + [[1, 2], [6, 7], [11, 12], [16, 17]]) ae(_extract_wave(traces, 5 - hwl, 5 + hwl, mask, wave_len)[0], - [[16, 17, 18], [21, 22, 23], [0, 0, 0], [0, 0, 0]]) + [[16, 17], [21, 22], [0, 0], [0, 0]]) def test_selected_clusters_colors(): @@ -389,7 +381,6 @@ def test_trace_view(qtbot, gui): n_samples_per_spike=model.n_samples_waveforms, duration=model.duration, n_channels=model.n_channels, - mean_traces=model.mean_traces(), ) v.attach(gui) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index efc569424..7277c1dbd 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -55,9 +55,10 @@ def _selected_clusters_colors(n_clusters=None): def _extract_wave(traces, start, end, mask, wave_len=None): + mask_threshold = .5 n_samples, n_channels = traces.shape assert mask.shape == (n_channels,) - channels = np.nonzero(mask > .1)[0] + channels = np.nonzero(mask > mask_threshold)[0] # There should be at least one non-masked channel. if not len(channels): return # pragma: no cover @@ -101,7 +102,7 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None, alpha=.5): color = colors[spike_clusters_rel] else: # Fixed color when the spike clusters are not specified. - color = .5 * np.ones((n_spikes, 3)) + color = np.ones((n_spikes, 3)) hsv = rgb_to_hsv(color[:, :3]) # Change the saturation and value as a function of the mask. hsv[:, 1] *= masks @@ -576,7 +577,6 @@ def __init__(self, duration=None, n_channels=None, n_samples_per_spike=None, - mean_traces=None, **kwargs): # traces is a function interval => {traces, spike_times, @@ -597,10 +597,6 @@ def __init__(self, assert n_channels >= 0 self.n_channels = n_channels - # Used to detrend the traces. - self.mean_traces = np.atleast_2d(mean_traces) - assert self.mean_traces.shape == (1, self.n_channels) - # Number of samples per spike. self.n_samples_per_spike = (n_samples_per_spike or round(.002 * sample_rate)) @@ -634,7 +630,7 @@ def __init__(self, def _plot_traces(self, traces, start=None, data_bounds=None): t = start + np.arange(traces.shape[0]) * self.dt - gray = .4 + gray = .3 for ch in range(self.n_channels): self[ch].plot(t, traces[:, ch], color=(gray, gray, gray, 1), @@ -666,16 +662,15 @@ def _plot_spike(self, spike_idx, start=None, # Determine the color as a function of the spike's cluster. clu = spike_clusters[spike_idx] if self.cluster_ids is None or clu not in self.cluster_ids: - gray = .9 - color = (gray, gray, gray, 1) + sc = None + n_clusters = None else: clu_rel = self.cluster_ids.index(clu) - r, g, b = (_COLORMAP[clu_rel % len(_COLORMAP)] / 255.) - color = (r, g, b, 1.) sc = clu_rel * np.ones(len(ch), dtype=np.int32) - color = _get_color(masks[spike_idx, ch], - spike_clusters_rel=sc, - n_clusters=len(self.cluster_ids)) + n_clusters = len(self.cluster_ids) + color = _get_color(masks[spike_idx, ch], + spike_clusters_rel=sc, + n_clusters=n_clusters) # Generate the x coordinates of the waveform. t0 = spike_times[spike_idx] + wave_start @@ -713,7 +708,7 @@ def set_interval(self, interval, change_status=True): # Load traces. d = self.traces(interval) - traces = d.traces - self.mean_traces + traces = d.traces - np.mean(d.traces, axis=0) spike_times = d.spike_times spike_clusters = d.spike_clusters masks = d.masks From e90b53ab93e7ead4d43dda71e62edcd38c353bbe Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 22 Jan 2016 17:21:13 +0100 Subject: [PATCH 0942/1059] Remove unused test code --- phy/cluster/manual/tests/test_views.py | 33 -------------------------- 1 file changed, 33 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 241d1c8d9..9a1ef961c 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -183,31 +183,6 @@ def best_channels_multiple(cluster_ids): return bc model.best_channels_multiple = best_channels_multiple - def max_waveform_amplitude(cluster_id): - mm = mean_masks(cluster_id) - mw = mean_waveforms(cluster_id) - assert mw.ndim == 2 - return np.asscalar(get_max_waveform_amplitude(mm, mw)) - model.max_waveform_amplitude = max_waveform_amplitude - - def mean_masked_features_score(cluster_0, cluster_1): - mf0 = mean_features(cluster_0) - mf1 = mean_features(cluster_1) - mm0 = mean_masks(cluster_0) - mm1 = mean_masks(cluster_1) - nfpc = model.n_features_per_channel - d = get_mean_masked_features_distance(mf0, mf1, mm0, mm1, - n_features_per_channel=nfpc) - s = 1. / max(1e-10, d) - return s - model.mean_masked_features_score = mean_masked_features_score - - def most_similar_clusters(cluster_id): - assert isinstance(cluster_id, int) - return get_closest_clusters(cluster_id, model.cluster_ids, - mean_masked_features_score) - model.most_similar_clusters = most_similar_clusters - return model @@ -215,14 +190,6 @@ def most_similar_clusters(cluster_id): # Utils #------------------------------------------------------------------------------ -def _show(qtbot, view, stop=False): - view.show() - qtbot.waitForWindowShown(view.native) - if stop: # pragma: no cover - qtbot.stop() - view.close() - - @fixture def state(tempdir): # Save a test GUI state JSON file in the tempdir. From 5ff3cc7d2e3915ffe10be0044537864dc96faaff Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 22 Jan 2016 21:25:24 +0100 Subject: [PATCH 0943/1059] Flakify --- phy/cluster/manual/tests/test_views.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 9a1ef961c..6e766656f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,15 +14,13 @@ from phy.electrode.mea import staggered_positions from phy.gui import create_gui -from phy.io.array import _spikes_per_cluster, get_closest_clusters, _concat +from phy.io.array import _spikes_per_cluster, _concat from phy.io.mock import (artificial_waveforms, artificial_features, artificial_masks, artificial_traces, ) from phy.stats.clusters import (mean, - get_max_waveform_amplitude, - get_mean_masked_features_distance, get_unmasked_channels, get_sorted_main_channels, ) From f5b56faf55fbd056b01c5492136f30bf371a5bc6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 25 Jan 2016 11:02:25 +0100 Subject: [PATCH 0944/1059] Remove register/request in GUI --- phy/cluster/manual/gui_component.py | 1 - phy/cluster/manual/tests/test_views.py | 2 -- phy/gui/gui.py | 25 +------------------------ phy/gui/tests/test_gui.py | 9 --------- 4 files changed, 1 insertion(+), 36 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 2ddf1cc5d..5fd906827 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -394,7 +394,6 @@ def on_cluster(self, up): def attach(self, gui): self.gui = gui - gui.register(self, name='manual_clustering') # Create the actions. self._create_actions(gui) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 6e766656f..3adb3cf22 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -210,13 +210,11 @@ def gui(tempdir, state): mc = ManualClustering(model.spike_clusters, cluster_groups=model.cluster_groups,) mc.attach(gui) - gui.register(manual_clustering=mc) return gui def _select_clusters(gui): gui.show() - mc = gui.request('manual_clustering') assert mc mc.select([]) mc.select([0]) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 86e202c96..6f516c2c9 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -204,27 +204,6 @@ def connect_(self, *args, **kwargs): def unconnect_(self, *args, **kwargs): self._event.unconnect(*args, **kwargs) - def register(self, obj=None, name=None, **kwargs): - """Register a object for a given name.""" - for n, o in kwargs.items(): - self.register(o, n) - if obj is None: - return lambda _: self.register(obj=_, name=name) - name = name or obj.__name__ - self._registered[name] = obj - - def request(self, name, *args, **kwargs): - """Request the result of a possibly registered object.""" - if name in self._registered: - obj = self._registered[name] - if hasattr(obj, '__call__'): - return obj(*args, **kwargs) - else: - return obj - else: - logger.debug("No registered object for `%s`.", name) - return None - def closeEvent(self, e): """Qt slot when the window is closed.""" if self._closed: @@ -431,8 +410,7 @@ def on_show(): gui.restore_geometry_state(gs) -def create_gui(name=None, subtitle=None, model=None, - plugins=None, **state_kwargs): +def create_gui(name=None, subtitle=None, plugins=None, **state_kwargs): """Create a GUI with a list of plugins. By default, the list of plugins is taken from the `c.TheGUI.plugins` @@ -448,7 +426,6 @@ def create_gui(name=None, subtitle=None, model=None, # Make the state and model accessible. gui.state = state - gui.model = model # If no plugins are specified, load the master config and # get the list of user plugins to attach to the GUI. diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 166419e9b..071154721 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -102,15 +102,6 @@ def on_close_view(view): gui.default_actions.exit() -def test_gui_register(gui): - @gui.register - def hello(msg): - return 'hello ' + msg - - assert gui.request('hello', 'world') == 'hello world' - assert gui.request('unknown') is None - - def test_gui_status_message(gui): assert gui.status_message == '' gui.status_message = ':hello world!' From e1a71da15cda4c79b239ba984876a288094d3f5f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 25 Jan 2016 11:14:32 +0100 Subject: [PATCH 0945/1059] Fix --- phy/cluster/manual/tests/test_views.py | 5 ++++- phy/gui/gui.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 3adb3cf22..a1cb72210 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -206,15 +206,18 @@ def state(tempdir): @fixture def gui(tempdir, state): model = create_model() - gui = create_gui(model=model, config_dir=tempdir, **state) + gui = create_gui(config_dir=tempdir, **state) mc = ManualClustering(model.spike_clusters, cluster_groups=model.cluster_groups,) mc.attach(gui) + gui.model = model + gui.manual_clustering = mc return gui def _select_clusters(gui): gui.show() + mc = gui.manual_clustering assert mc mc.select([]) mc.select([0]) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 6f516c2c9..9d8639758 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -424,7 +424,7 @@ def create_gui(name=None, subtitle=None, plugins=None, **state_kwargs): # Create the state. state = GUIState(gui.name, **state_kwargs) - # Make the state and model accessible. + # Make the state. gui.state = state # If no plugins are specified, load the master config and From 2fad72495438b047dabce2ca42d6b2d21e011ac5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 25 Jan 2016 12:11:46 +0100 Subject: [PATCH 0946/1059] WIP: refactor spikes_per_cluster --- phy/cluster/manual/_utils.py | 2 - phy/cluster/manual/clustering.py | 60 ++----------------- phy/cluster/manual/gui_component.py | 6 +- phy/cluster/manual/tests/test_clustering.py | 53 ++-------------- .../manual/tests/test_gui_component.py | 6 +- phy/cluster/manual/tests/test_views.py | 6 +- phy/io/array.py | 15 +---- phy/io/tests/test_array.py | 6 +- 8 files changed, 27 insertions(+), 127 deletions(-) diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index fd6107f8c..b1ec99c86 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -61,8 +61,6 @@ def __init__(self, **kwargs): descendants=[], # pairs of (old_cluster, new_cluster) metadata_changed=[], # clusters with changed metadata metadata_value=None, # new metadata value - old_spikes_per_cluster={}, # only for the affected clusters - new_spikes_per_cluster={}, # only for the affected clusters undo_state=None, # returned during an undo: it contains # information about the undone action ) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index d5e01efc4..1948c0a8e 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -11,7 +11,6 @@ from phy.utils._types import _as_array, _is_array_like from phy.io.array import (_unique, _spikes_in_clusters, - _spikes_per_cluster, ) from ._utils import UpdateInfo from ._history import History @@ -84,9 +83,7 @@ def _extend_assignment(spike_ids, extended_spike_clusters)) -def _assign_update_info(spike_ids, - old_spike_clusters, old_spikes_per_cluster, - new_spike_clusters, new_spikes_per_cluster): +def _assign_update_info(spike_ids, old_spike_clusters, new_spike_clusters): old_clusters = _unique(old_spike_clusters) new_clusters = _unique(new_spike_clusters) descendants = list(set(zip(old_spike_clusters, @@ -96,8 +93,6 @@ def _assign_update_info(spike_ids, added=list(new_clusters), deleted=list(old_clusters), descendants=descendants, - old_spikes_per_cluster=old_spikes_per_cluster, - new_spikes_per_cluster=new_spikes_per_cluster, ) return update_info @@ -148,12 +143,6 @@ class Clustering(EventEmitter): information about the clusters. metadata_changed : list List of clusters with changed metadata (cluster group changes) - old_spikes_per_cluster : dict - Dictionary of `{cluster: spikes}` for the old clusters and - old clustering. - new_spikes_per_cluster : dict - Dictionary of `{cluster: spikes}` for the new clusters and - new clustering. """ @@ -164,8 +153,7 @@ def __init__(self, spike_clusters): self._spike_clusters = _as_array(spike_clusters) self._n_spikes = len(self._spike_clusters) self._spike_ids = np.arange(self._n_spikes).astype(np.int64) - # Create the spikes per cluster structure. - self._update_all_spikes_per_cluster() + self._new_cluster_id = self._spike_clusters.max() + 1 # Keep a copy of the original spike clusters assignment. self._spike_clusters_base = self._spike_clusters.copy() @@ -177,28 +165,17 @@ def reset(self): """ self._undo_stack.clear() self._spike_clusters = self._spike_clusters_base - self._update_all_spikes_per_cluster() + self._new_cluster_id = self._spike_clusters.max() + 1 @property def spike_clusters(self): """A n_spikes-long vector containing the cluster ids of all spikes.""" return self._spike_clusters - @property - def spikes_per_cluster(self): - """A dictionary `{cluster: spikes}`.""" - return self._spikes_per_cluster - @property def cluster_ids(self): """Ordered list of ids of all non-empty clusters.""" - return np.array(sorted(self._spikes_per_cluster)) - - @property - def spike_counts(self): - """Dictionary with the number of spikes in each cluster.""" - return {cluster: len(self._spikes_per_cluster[cluster]) - for cluster in self.cluster_ids} + return np.unique(self._spike_clusters) def new_cluster_id(self): """Generate a brand new cluster id. @@ -233,13 +210,6 @@ def spikes_in_clusters(self, clusters): # Actions #-------------------------------------------------------------------------- - def _update_all_spikes_per_cluster(self): - # Reset the new cluster id. - self._new_cluster_id = self._spike_clusters.max() + 1 - # Update the spikes_per_cluster dict. - self._spikes_per_cluster = _spikes_per_cluster(self._spike_clusters, - self._spike_ids) - def _do_assign(self, spike_ids, new_spike_clusters): """Make spike-cluster assignments after the spike selection has been extended to full clusters.""" @@ -265,19 +235,10 @@ def _do_assign(self, spike_ids, new_spike_clusters): if len(new_clusters) == 1: return self._do_merge(spike_ids, old_clusters, new_clusters[0]) - old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] - for cluster in old_clusters} - new_spikes_per_cluster = _spikes_per_cluster(new_spike_clusters, - spike_ids) - self._spikes_per_cluster.update(new_spikes_per_cluster) - # All old clusters are deleted. - for cluster in old_clusters: - del self._spikes_per_cluster[cluster] - # We return the UpdateInfo structure. up = _assign_update_info(spike_ids, - old_spike_clusters, old_spikes_per_cluster, - new_spike_clusters, new_spikes_per_cluster) + old_spike_clusters, + new_spike_clusters) # We update the new cluster id (strictly increasing during a session). self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1) @@ -290,22 +251,13 @@ def _do_merge(self, spike_ids, cluster_ids, to): # Create the UpdateInfo instance here. descendants = [(cluster, to) for cluster in cluster_ids] - old_spc = {k: self._spikes_per_cluster[k] for k in cluster_ids} - new_spc = {to: spike_ids} up = UpdateInfo(description='merge', spike_ids=spike_ids, added=[to], deleted=list(cluster_ids), descendants=descendants, - old_spikes_per_cluster=old_spc, - new_spikes_per_cluster=new_spc, ) - # Update the spikes_per_cluster structure directly. - self._spikes_per_cluster[to] = spike_ids - for cluster in cluster_ids: - del self._spikes_per_cluster[cluster] - # We update the new cluster id (strictly increasing during a session). self._new_cluster_id = max(max(up.added) + 1, self._new_cluster_id) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 5fd906827..1fcc48499 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -117,6 +117,7 @@ class ManualClustering(object): def __init__(self, spike_clusters, + spikes_per_cluster, cluster_groups=None, shortcuts=None, quality=None, @@ -127,6 +128,9 @@ def __init__(self, self.quality = quality self.similarity = similarity + assert hasattr(spikes_per_cluster, '__call__') + self.spikes_per_cluster = spikes_per_cluster + # Load default shortcuts, and override with any user shortcuts. self.shortcuts = self.default_shortcuts.copy() self.shortcuts.update(shortcuts or {}) @@ -179,7 +183,7 @@ def _add_default_columns(self): # Default columns. @self.add_column(name='n_spikes') def n_spikes(cluster_id): - return self.clustering.spike_counts[cluster_id] + return len(self.spikes_per_cluster(cluster_id)) def skip(cluster_id): """Whether to skip that cluster.""" diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 792fdecbe..86124dc67 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -9,7 +9,6 @@ import numpy as np from numpy.testing import assert_array_equal as ae from pytest import raises -from six import itervalues from phy.io.mock import artificial_spike_clusters from phy.io.array import (_spikes_in_clusters,) @@ -105,25 +104,6 @@ def test_extend_assignment(): # Test clustering #------------------------------------------------------------------------------ -def _flatten_spikes_per_cluster(spikes_per_cluster): - """Convert a dictionary {cluster: list_of_spikes} to a - spike_clusters array.""" - clusters = sorted(spikes_per_cluster) - clusters_arr = np.concatenate([(cluster * - np.ones(len(spikes_per_cluster[cluster]))) - for cluster in clusters]).astype(np.int64) - spikes_arr = np.concatenate([spikes_per_cluster[cluster] - for cluster in clusters]) - spike_clusters = np.vstack((spikes_arr, clusters_arr)) - ind = np.argsort(spike_clusters[0, :]) - return spike_clusters[1, ind] - - -def _check_spikes_per_cluster(clustering): - ae(_flatten_spikes_per_cluster(clustering.spikes_per_cluster), - clustering.spike_clusters) - - def test_clustering_split(): spike_clusters = np.array([2, 5, 3, 2, 7, 5, 2]) @@ -155,13 +135,11 @@ def test_clustering_split(): for to_split in splits: clustering.reset() clustering.split(to_split) - _check_spikes_per_cluster(clustering) # Test many splits, without reset this time. clustering.reset() for to_split in splits: clustering.split(to_split) - _check_spikes_per_cluster(clustering) def test_clustering_descendants_merge(): @@ -178,7 +156,6 @@ def test_clustering_descendants_merge(): new = up.added[0] assert new == 8 assert up.descendants == [(2, 8), (3, 8)] - _check_spikes_per_cluster(clustering) with raises(ValueError): up = clustering.merge([2, 8]) @@ -187,7 +164,6 @@ def test_clustering_descendants_merge(): new = up.added[0] assert new == 9 assert up.descendants == [(5, 9), (8, 9)] - _check_spikes_per_cluster(clustering) def test_clustering_descendants_split(): @@ -207,7 +183,6 @@ def test_clustering_descendants_split(): assert up.added == [8, 9] assert up.descendants == [(2, 8), (2, 9)] ae(clustering.spike_clusters, [8, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) # Undo. up = clustering.undo() @@ -215,7 +190,6 @@ def test_clustering_descendants_split(): assert up.added == [2] assert set(up.descendants) == set([(8, 2), (9, 2)]) ae(clustering.spike_clusters, spike_clusters) - _check_spikes_per_cluster(clustering) # Redo. up = clustering.redo() @@ -223,7 +197,6 @@ def test_clustering_descendants_split(): assert up.added == [8, 9] assert up.descendants == [(2, 8), (2, 9)] ae(clustering.spike_clusters, [8, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) # Second split: just replace cluster 8 by 10 (1 spike in it). up = clustering.split([0]) @@ -231,7 +204,6 @@ def test_clustering_descendants_split(): assert up.added == [10] assert up.descendants == [(8, 10)] ae(clustering.spike_clusters, [10, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) # Undo again. up = clustering.undo() @@ -239,7 +211,6 @@ def test_clustering_descendants_split(): assert up.added == [8] assert up.descendants == [(10, 8)] ae(clustering.spike_clusters, [8, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) def test_clustering_merge(): @@ -456,67 +427,53 @@ def test_clustering_long(): assert clustering.new_cluster_id() == n_clusters assert clustering.n_clusters == n_clusters - assert len(clustering.spike_counts) == n_clusters - assert sum(itervalues(clustering.spike_counts)) == n_spikes - _check_spikes_per_cluster(clustering) - # Updating a cluster, method 1. spike_clusters_new = spike_clusters.copy() spike_clusters_new[:10] = 100 clustering.spike_clusters[:] = spike_clusters_new[:] # Need to update explicitely. - clustering._update_all_spikes_per_cluster() + clustering._new_cluster_id = 101 ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100]) # Updating a cluster, method 2. clustering.spike_clusters[:] = spike_clusters_base[:] clustering.spike_clusters[:10] = 100 - # Need to update manually. - clustering._update_all_spikes_per_cluster() + # HACK: need to update manually here. + clustering._new_cluster_id = 101 ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100]) # Assign. new_cluster = 101 clustering.assign(np.arange(0, 10), new_cluster) assert new_cluster in clustering.cluster_ids - assert clustering.spike_counts[new_cluster] == 10 assert np.all(clustering.spike_clusters[:10] == new_cluster) - _check_spikes_per_cluster(clustering) # Merge. - count = clustering.spike_counts.copy() my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [2, 3]))[0] info = clustering.merge([2, 3]) my_spikes = info.spike_ids ae(my_spikes, my_spikes_0) assert (new_cluster + 1) in clustering.cluster_ids - assert clustering.spike_counts[new_cluster + 1] == count[2] + count[3] assert np.all(clustering.spike_clusters[my_spikes] == (new_cluster + 1)) - _check_spikes_per_cluster(clustering) # Merge to a given cluster. clustering.spike_clusters[:] = spike_clusters_base[:] - clustering._update_all_spikes_per_cluster() + clustering._new_cluster_id = 11 + my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [4, 6]))[0] - count = clustering.spike_counts - count4, count6 = count[4], count[6] info = clustering.merge([4, 6], 11) my_spikes = info.spike_ids ae(my_spikes, my_spikes_0) assert 11 in clustering.cluster_ids - assert clustering.spike_counts[11] == count4 + count6 assert np.all(clustering.spike_clusters[my_spikes] == 11) - _check_spikes_per_cluster(clustering) # Split. my_spikes = [1, 3, 5] clustering.split(my_spikes) assert np.all(clustering.spike_clusters[my_spikes] == 12) - _check_spikes_per_cluster(clustering) # Assign. clusters = [0, 1, 2] clustering.assign(my_spikes, clusters) clu = clustering.spike_clusters[my_spikes] ae(clu - clu[0], clusters) - _check_spikes_per_cluster(clustering) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 60ea36017..fffde618d 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -12,6 +12,7 @@ from ..gui_component import (ManualClustering, ) +from phy.io.array import _spikes_in_clusters from phy.gui import GUI from phy.utils import Bunch @@ -37,8 +38,10 @@ def gui(qtbot): def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, quality, similarity): spike_clusters = np.array(cluster_ids) + spikes_per_cluster = lambda c: [c] mc = ManualClustering(spike_clusters, + spikes_per_cluster, cluster_groups=cluster_groups, shortcuts={'undo': 'ctrl+z'}, quality=quality, @@ -132,7 +135,8 @@ def test_manual_clustering_split(manual_clustering): def test_manual_clustering_split_2(gui, quality, similarity): spike_clusters = np.array([0, 0, 1]) - mc = ManualClustering(spike_clusters) + mc = ManualClustering(spike_clusters, + lambda c: _spikes_in_clusters(spike_clusters, [c])) mc.attach(gui) mc.add_column(quality, name='quality', default=True) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index a1cb72210..e081afda1 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,7 +14,7 @@ from phy.electrode.mea import staggered_positions from phy.gui import create_gui -from phy.io.array import _spikes_per_cluster, _concat +from phy.io.array import _spikes_in_clusters, _concat from phy.io.mock import (artificial_waveforms, artificial_features, artificial_masks, @@ -83,7 +83,8 @@ def traces(interval): ) model.traces = traces - model.spikes_per_cluster = _spikes_per_cluster(model.spike_clusters) + sc = model.spike_clusters + model.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c]) model.n_features_per_channel = n_features_per_channel model.n_samples_waveforms = n_samples_waveforms model.cluster_groups = {c: None for c in range(n_clusters)} @@ -208,6 +209,7 @@ def gui(tempdir, state): model = create_model() gui = create_gui(config_dir=tempdir, **state) mc = ManualClustering(model.spike_clusters, + model.spikes_per_cluster, cluster_groups=model.cluster_groups,) mc.attach(gui) gui.model = model diff --git a/phy/io/array.py b/phy/io/array.py index 77ef6695c..04e683b38 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -458,17 +458,9 @@ def select_spikes(cluster_ids=None, class Selector(object): """This object is passed with the `select` event when clusters are selected. It allows to make selections of spikes.""" - def __init__(self, - spike_clusters=None, - spikes_per_cluster=None, - spike_ids=None, - ): + def __init__(self, spikes_per_cluster): # NOTE: spikes_per_cluster is a function. - self.spike_clusters = spike_clusters self.spikes_per_cluster = spikes_per_cluster - self.n_spikes = len(spike_clusters) - self.spike_ids = (np.asarray(spike_ids) if spike_ids is not None - else None) def select_spikes(self, cluster_ids=None, max_n_spikes_per_cluster=None): @@ -476,11 +468,6 @@ def select_spikes(self, cluster_ids=None, return None ns = max_n_spikes_per_cluster assert len(cluster_ids) >= 1 - # Select all spikes belonging to the cluster. - if ns is None: - spikes_rel = _spikes_in_clusters(self.spike_clusters, cluster_ids) - return (self.spike_ids[spikes_rel] - if self.spike_ids is not None else spikes_rel) # Select a subset of the spikes. return select_spikes(cluster_ids, spikes_per_cluster=self.spikes_per_cluster, diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index c57e1d548..e2dce63da 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -356,7 +356,6 @@ def test_select_spikes(): with raises(AssertionError): select_spikes() spikes = [2, 3, 5, 7, 11] - sc = [2, 3, 3, 2, 2] spc = lambda c: {2: [2, 7, 11], 3: [3, 5], 5: []}.get(c, None) ae(select_spikes([], spikes_per_cluster=spc), []) ae(select_spikes([2, 3, 5], spikes_per_cluster=spc), spikes) @@ -367,10 +366,7 @@ def test_select_spikes(): ae(select_spikes([2, 3, 5], 1, spikes_per_cluster=spc), [2, 3]) ae(select_spikes([2, 5], 2, spikes_per_cluster=spc), [2]) - sel = Selector(spike_clusters=sc, - spikes_per_cluster=spc, - spike_ids=spikes, - ) + sel = Selector(spc) assert sel.select_spikes() is None ae(sel.select_spikes([2, 5]), spc(2)) ae(sel.select_spikes([2, 5], 2), [2]) From ddea07c55db4dcad4035f3121f7525eb5358284f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 25 Jan 2016 16:30:42 +0100 Subject: [PATCH 0947/1059] WIP: rename _concat --- phy/cluster/manual/gui_component.py | 3 +++ phy/cluster/manual/tests/test_views.py | 8 ++++---- phy/io/array.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 1fcc48499..e9ebfc270 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -65,8 +65,11 @@ class ManualClustering(object): ---------- spike_clusters : ndarray + spikes_per_clusters : function `cluster_id -> spike_ids` cluster_groups : dictionary shortcuts : dict + quality: func + similarity: func GUI events ---------- diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index e081afda1..a35a1be6d 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -14,7 +14,7 @@ from phy.electrode.mea import staggered_positions from phy.gui import create_gui -from phy.io.array import _spikes_in_clusters, _concat +from phy.io.array import _spikes_in_clusters, concat_per_cluster from phy.io.mock import (artificial_waveforms, artificial_features, artificial_masks, @@ -110,12 +110,12 @@ def _get_data(**kwargs): kwargs['spike_clusters'] = model.spike_clusters[kwargs['spike_ids']] return Bunch(**kwargs) - @_concat + @concat_per_cluster def masks(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), masks=get_masks(model.n_spikes_per_cluster)) - @_concat + @concat_per_cluster def features(cluster_id): return _get_data(spike_ids=get_spike_ids(cluster_id), features=get_features(model.n_spikes_per_cluster), @@ -139,7 +139,7 @@ def waveform_lim(): return 1 model.waveform_lim = waveform_lim - @_concat + @concat_per_cluster def waveforms(cluster_id): w = get_waveforms(model.n_spikes_per_cluster) m = get_masks(model.n_spikes_per_cluster) diff --git a/phy/io/array.py b/phy/io/array.py index 04e683b38..4dc7502d7 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -210,7 +210,7 @@ def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): return l[:max_n] -def _concat(f): +def concat_per_cluster(f): """Take a function accepting a single cluster, and return a function accepting multiple clusters.""" @wraps(f) From f1e683fa1baf69eaf5beb8b1edcc565c445742fc Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 26 Jan 2016 12:01:27 +0100 Subject: [PATCH 0948/1059] WIP: refactor trace view --- phy/cluster/manual/tests/test_views.py | 8 +++--- phy/cluster/manual/views.py | 39 +++++++++++++------------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index a35a1be6d..c17e025fc 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -245,16 +245,16 @@ def test_extract_wave(): wave_len = 4 hwl = wave_len // 2 - ae(_extract_wave(traces, 0 - hwl, 0 + hwl, mask, wave_len)[0], + ae(_extract_wave(traces, 0 - hwl, mask, wave_len)[0], [[0, 0], [0, 0], [1, 2], [6, 7]]) - ae(_extract_wave(traces, 1 - hwl, 1 + hwl, mask, wave_len)[0], + ae(_extract_wave(traces, 1 - hwl, mask, wave_len)[0], [[0, 0], [1, 2], [6, 7], [11, 12]]) - ae(_extract_wave(traces, 2 - hwl, 2 + hwl, mask, wave_len)[0], + ae(_extract_wave(traces, 2 - hwl, mask, wave_len)[0], [[1, 2], [6, 7], [11, 12], [16, 17]]) - ae(_extract_wave(traces, 5 - hwl, 5 + hwl, mask, wave_len)[0], + ae(_extract_wave(traces, 5 - hwl, mask, wave_len)[0], [[16, 17], [21, 22], [0, 0], [0, 0]]) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 7277c1dbd..be0b3ad0b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -54,7 +54,7 @@ def _selected_clusters_colors(n_clusters=None): return colors[:n_clusters, ...] / 255. -def _extract_wave(traces, start, end, mask, wave_len=None): +def _extract_wave(traces, start, mask, wave_len=None): mask_threshold = .5 n_samples, n_channels = traces.shape assert mask.shape == (n_channels,) @@ -62,7 +62,7 @@ def _extract_wave(traces, start, end, mask, wave_len=None): # There should be at least one non-masked channel. if not len(channels): return # pragma: no cover - i, j = start, end + i, j = start, start + wave_len a, b = max(0, i), min(j, n_samples - 1) data = traces[a:b, channels] data = _get_padded(data, i - a, i - a + wave_len) @@ -604,7 +604,7 @@ def __init__(self, # Can be a tuple or a scalar. if not isinstance(self.n_samples_per_spike, tuple): ns = self.n_samples_per_spike - self.n_samples_per_spike = (-ns // 2, ns // 2) + self.n_samples_per_spike = (ns // 2, ns // 2) # Now n_samples_per_spike is a tuple. # Box and probe scaling. @@ -640,25 +640,27 @@ def _plot_spike(self, spike_idx, start=None, traces=None, spike_times=None, spike_clusters=None, masks=None, data_bounds=None): + sr = self.sample_rate wave_len = sum(map(abs, self.n_samples_per_spike)) # in samples - dur_spike = wave_len * self.dt # in seconds wave_start = self.n_samples_per_spike[0] * self.dt # in seconds - - trace_start = round(self.sample_rate * start) + trace_start = round(sr * start) # Find the first x of the spike, relative to the start of # the interval - spike_start = spike_times[spike_idx] + wave_start - spike_end = spike_times[spike_idx] + wave_start + dur_spike - sample_start = (round(spike_start * self.sample_rate) - - trace_start) - sample_end = (round(spike_end * self.sample_rate) - - trace_start) + spike_start = spike_times[spike_idx] - wave_start # in seconds + sample_start = round(spike_start * sr) - trace_start # Extract the waveform from the traces. - w, ch = _extract_wave(traces, sample_start, sample_end, + w, ch = _extract_wave(traces, sample_start, masks[spike_idx], wave_len) + # w: (n_samples, n_unmasked_channels) + # ch: (n_unmasked_channels,) with the channel indices + # spike_start (abs in sec) + # n_samples_per_spike (bef > 0, aft > 0) + # color: int (cluster rel) or (rgba) + # data_bounds, start (of the traces subset, in seconds) + # Determine the color as a function of the spike's cluster. clu = spike_clusters[spike_idx] if self.cluster_ids is None or clu not in self.cluster_ids: @@ -673,9 +675,8 @@ def _plot_spike(self, spike_idx, start=None, n_clusters=n_clusters) # Generate the x coordinates of the waveform. - t0 = spike_times[spike_idx] + wave_start - t = t0 + self.dt * np.arange(wave_len) - t = np.tile(t, (len(ch), 1)) + t = spike_start + self.dt * np.arange(wave_len) + t = np.tile(t, (len(ch), 1)) # (n_unmasked_channels, n_samples) # The box index depends on the channel. box_index = np.repeat(ch[:, np.newaxis], wave_len, axis=0) @@ -709,9 +710,9 @@ def set_interval(self, interval, change_status=True): # Load traces. d = self.traces(interval) traces = d.traces - np.mean(d.traces, axis=0) - spike_times = d.spike_times - spike_clusters = d.spike_clusters - masks = d.masks + spike_times = d.spike_times # (n,) + spike_clusters = d.spike_clusters # (n,) + masks = d.masks # (n, n_channels) # NOTE: once loaded, the traces do not contain the dead channels # so there are `n_channels_order` channels here. From 292a0728f6be2a2d17d6ef172869c7c5cf48aea7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 26 Jan 2016 13:08:32 +0100 Subject: [PATCH 0949/1059] WIP --- phy/cluster/manual/views.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index be0b3ad0b..2bdf911c9 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -651,8 +651,7 @@ def _plot_spike(self, spike_idx, start=None, sample_start = round(spike_start * sr) - trace_start # Extract the waveform from the traces. - w, ch = _extract_wave(traces, sample_start, - masks[spike_idx], wave_len) + w, ch = _extract_wave(traces, sample_start, masks[spike_idx], wave_len) # w: (n_samples, n_unmasked_channels) # ch: (n_unmasked_channels,) with the channel indices From ff4c15fe196fc436bff2370bd6e715a57c235983 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 26 Jan 2016 17:28:54 +0100 Subject: [PATCH 0950/1059] WIP: update trace view interface --- phy/cluster/manual/tests/test_views.py | 67 +++++++++--- phy/cluster/manual/views.py | 142 ++++++++++--------------- 2 files changed, 106 insertions(+), 103 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index c17e025fc..b38e885f1 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -48,8 +48,8 @@ def create_model(): n_samples_waveforms = 31 n_samples_t = 20000 n_channels = 11 - n_clusters = 3 - model.n_spikes_per_cluster = 51 + n_clusters = 4 + model.n_spikes_per_cluster = 50 n_spikes_total = n_clusters * model.n_spikes_per_cluster n_features_per_channel = 4 @@ -57,12 +57,19 @@ def create_model(): model.n_spikes = n_spikes_total model.sample_rate = 20000. model.duration = n_samples_t / float(model.sample_rate) - model.spike_times = np.linspace(0., model.duration, n_spikes_total) + model.spike_times = np.arange(0, model.duration, 100. / model.sample_rate) model.spike_clusters = np.repeat(np.arange(n_clusters), model.n_spikes_per_cluster) + assert len(model.spike_times) == len(model.spike_clusters) model.cluster_ids = np.unique(model.spike_clusters) model.channel_positions = staggered_positions(n_channels) + sc = model.spike_clusters + model.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c]) + model.n_features_per_channel = n_features_per_channel + model.n_samples_waveforms = n_samples_waveforms + model.cluster_groups = {c: None for c in range(n_clusters)} + all_traces = artificial_traces(n_samples_t, n_channels) all_masks = artificial_masks(n_spikes_total, n_channels) @@ -71,23 +78,50 @@ def traces(interval): tr = select_traces(all_traces, interval, sample_rate=model.sample_rate, ) + return [tr] + model.traces = traces + + def spikes_traces(interval): + # TODO OPTIM: we're loading the traces twice (model.traces and here) + traces = model.traces(interval)[0] + + sr = model.sample_rate + ns = model.n_samples_waveforms + if not isinstance(ns, tuple): + ns = (ns // 2, ns // 2) + offset_samples = ns[0] + wave_len = ns[0] + ns[1] + # Find spikes. a, b = model.spike_times.searchsorted(interval) st = model.spike_times[a:b] sc = model.spike_clusters[a:b] m = all_masks[a:b, :] - return Bunch(traces=tr, - spike_times=st, - spike_clusters=sc, - masks=m, - ) - model.traces = traces - - sc = model.spike_clusters - model.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c]) - model.n_features_per_channel = n_features_per_channel - model.n_samples_waveforms = n_samples_waveforms - model.cluster_groups = {c: None for c in range(n_clusters)} + n = len(st) + assert len(sc) == n + assert m.shape[0] == n + + # Extract waveforms. + spikes = [] + for i in range(n): + b = Bunch() + # Find the start of the waveform in the extracted traces. + sample_start = int(round((st[i] - interval[0]) * sr)) + sample_start -= offset_samples + b.waveforms, b.channels = _extract_wave(traces, + sample_start, + m[i], + wave_len) + # Masks on unmasked channels. + b.masks = m[i, b.channels] + b.spike_time = st[i] + b.spike_cluster = sc[i] + b.offset_samples = offset_samples + + spikes.append(b) + return spikes + + model.spikes_traces = spikes_traces def get_waveforms(n): return artificial_waveforms(n, @@ -344,9 +378,10 @@ def on_channel_click(channel_idx=None, button=None, key=None): def test_trace_view(qtbot, gui): model = gui.model + v = TraceView(traces=model.traces, + spikes=model.spikes_traces, sample_rate=model.sample_rate, - n_samples_per_spike=model.n_samples_waveforms, duration=model.duration, n_channels=model.n_channels, ) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 2bdf911c9..a5bbe051d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -54,8 +54,7 @@ def _selected_clusters_colors(n_clusters=None): return colors[:n_clusters, ...] / 255. -def _extract_wave(traces, start, mask, wave_len=None): - mask_threshold = .5 +def _extract_wave(traces, start, mask, wave_len=None, mask_threshold=.5): n_samples, n_channels = traces.shape assert mask.shape == (n_channels,) channels = np.nonzero(mask > mask_threshold)[0] @@ -573,23 +572,25 @@ class TraceView(ManualClusteringView): def __init__(self, traces=None, + spikes=None, sample_rate=None, duration=None, n_channels=None, - n_samples_per_spike=None, **kwargs): - # traces is a function interval => {traces, spike_times, - # spike_clusters, masks} + # traces is a function interval => [traces] + # spikes is a function interval => [Bunch(...)] # Sample rate. assert sample_rate > 0 self.sample_rate = sample_rate self.dt = 1. / self.sample_rate - # Traces. + # Traces and spikes. assert hasattr(traces, '__call__') self.traces = traces + assert hasattr(spikes, '__call__') + self.spikes = spikes assert duration >= 0 self.duration = duration @@ -597,20 +598,16 @@ def __init__(self, assert n_channels >= 0 self.n_channels = n_channels - # Number of samples per spike. - self.n_samples_per_spike = (n_samples_per_spike or - round(.002 * sample_rate)) - - # Can be a tuple or a scalar. - if not isinstance(self.n_samples_per_spike, tuple): - ns = self.n_samples_per_spike - self.n_samples_per_spike = (ns // 2, ns // 2) - # Now n_samples_per_spike is a tuple. - # Box and probe scaling. self._scaling = 1. self._origin = None + # Default data bounds. + # TODO: better way of finding data bounds for the traces. + tr = traces((0, 1))[0] + m, M = tr.min(), tr.max() + self.data_bounds = np.array([0, m, 1, M]) + # Initialize the view. super(TraceView, self).__init__(layout='stacked', origin=self.origin, @@ -628,59 +625,48 @@ def __init__(self, # Internal methods # ------------------------------------------------------------------------- - def _plot_traces(self, traces, start=None, data_bounds=None): - t = start + np.arange(traces.shape[0]) * self.dt + def _plot_traces(self, traces): + assert traces.shape[1] == self.n_channels + t = self.interval[0] + np.arange(traces.shape[0]) * self.dt gray = .3 for ch in range(self.n_channels): self[ch].plot(t, traces[:, ch], color=(gray, gray, gray, 1), - data_bounds=data_bounds) - - def _plot_spike(self, spike_idx, start=None, - traces=None, spike_times=None, spike_clusters=None, - masks=None, data_bounds=None): - - sr = self.sample_rate - wave_len = sum(map(abs, self.n_samples_per_spike)) # in samples - wave_start = self.n_samples_per_spike[0] * self.dt # in seconds - trace_start = round(sr * start) + data_bounds=self.data_bounds) - # Find the first x of the spike, relative to the start of - # the interval - spike_start = spike_times[spike_idx] - wave_start # in seconds - sample_start = round(spike_start * sr) - trace_start + def _plot_spike(self, waveforms=None, channels=None, masks=None, + spike_time=None, spike_cluster=None, offset_samples=0, + color=None): - # Extract the waveform from the traces. - w, ch = _extract_wave(traces, sample_start, masks[spike_idx], wave_len) + n_samples, n_channels = waveforms.shape + assert len(channels) == n_channels + assert len(masks) == n_channels + sr = float(self.sample_rate) - # w: (n_samples, n_unmasked_channels) - # ch: (n_unmasked_channels,) with the channel indices - # spike_start (abs in sec) - # n_samples_per_spike (bef > 0, aft > 0) - # color: int (cluster rel) or (rgba) - # data_bounds, start (of the traces subset, in seconds) + t0 = spike_time - offset_samples / sr # Determine the color as a function of the spike's cluster. - clu = spike_clusters[spike_idx] - if self.cluster_ids is None or clu not in self.cluster_ids: - sc = None - n_clusters = None - else: - clu_rel = self.cluster_ids.index(clu) - sc = clu_rel * np.ones(len(ch), dtype=np.int32) - n_clusters = len(self.cluster_ids) - color = _get_color(masks[spike_idx, ch], - spike_clusters_rel=sc, - n_clusters=n_clusters) + if color is None: + clu = spike_cluster + if self.cluster_ids is None or clu not in self.cluster_ids: + sc = None + n_clusters = None + else: + clu_rel = self.cluster_ids.index(clu) + sc = clu_rel * np.ones(n_channels, dtype=np.int32) + n_clusters = len(self.cluster_ids) + color = _get_color(masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) # Generate the x coordinates of the waveform. - t = spike_start + self.dt * np.arange(wave_len) - t = np.tile(t, (len(ch), 1)) # (n_unmasked_channels, n_samples) + t = t0 + self.dt * np.arange(n_samples) + t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) # The box index depends on the channel. - box_index = np.repeat(ch[:, np.newaxis], wave_len, axis=0) - self.plot(t, w.T, color=color, box_index=box_index, - data_bounds=data_bounds) + box_index = np.repeat(channels[:, np.newaxis], n_samples, axis=0) + self.plot(t, waveforms.T, color=color, box_index=box_index, + data_bounds=self.data_bounds) def _restrict_interval(self, interval): start, end = interval @@ -705,42 +691,24 @@ def set_interval(self, interval, change_status=True): interval = self._restrict_interval(interval) self.interval = interval start, end = interval - - # Load traces. - d = self.traces(interval) - traces = d.traces - np.mean(d.traces, axis=0) - spike_times = d.spike_times # (n,) - spike_clusters = d.spike_clusters # (n,) - masks = d.masks # (n, n_channels) - - # NOTE: once loaded, the traces do not contain the dead channels - # so there are `n_channels_order` channels here. - assert traces.shape[1] == self.n_channels - + # Update the data bounds on the x axis. + self.data_bounds[0] = start + self.data_bounds[2] = end # Set the status message. if change_status: self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) - # Determine the data bounds. - m, M = traces.min(), traces.max() - data_bounds = np.array([start, m, end, M]) - # Plot the traces. - # OPTIM: avoid the loop and generate all channel traces in - # one pass with NumPy (but need to set a_box_index manually too). - self._plot_traces(traces, start=start, data_bounds=data_bounds) - - # Display the spikes. - if spike_times is not None: - for i in range(len(spike_times)): - self._plot_spike(i, - start=start, - traces=traces, - spike_times=spike_times, - spike_clusters=spike_clusters, - masks=masks, - data_bounds=data_bounds, - ) + all_traces = self.traces(interval) + assert isinstance(all_traces, (tuple, list)) + for traces in all_traces: + self._plot_traces(traces) + + # Plot the spikes. + spikes = self.spikes(interval) + assert isinstance(spikes, (tuple, list)) + for spike in spikes: + self._plot_spike(**spike) self.build() self.update() From ed4a06f4c54c69d35a77485485ef0d2b3d01f1f0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 26 Jan 2016 17:37:07 +0100 Subject: [PATCH 0951/1059] Refactor function to extract spikes from traces --- phy/cluster/manual/tests/test_views.py | 47 +++++--------------------- phy/cluster/manual/views.py | 41 ++++++++++++++++++++++ 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index b38e885f1..7ee50b198 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -33,6 +33,7 @@ ScatterView, select_traces, _extract_wave, + extract_spikes, _selected_clusters_colors, _extend, ) @@ -82,45 +83,15 @@ def traces(interval): model.traces = traces def spikes_traces(interval): - # TODO OPTIM: we're loading the traces twice (model.traces and here) traces = model.traces(interval)[0] - - sr = model.sample_rate - ns = model.n_samples_waveforms - if not isinstance(ns, tuple): - ns = (ns // 2, ns // 2) - offset_samples = ns[0] - wave_len = ns[0] + ns[1] - - # Find spikes. - a, b = model.spike_times.searchsorted(interval) - st = model.spike_times[a:b] - sc = model.spike_clusters[a:b] - m = all_masks[a:b, :] - n = len(st) - assert len(sc) == n - assert m.shape[0] == n - - # Extract waveforms. - spikes = [] - for i in range(n): - b = Bunch() - # Find the start of the waveform in the extracted traces. - sample_start = int(round((st[i] - interval[0]) * sr)) - sample_start -= offset_samples - b.waveforms, b.channels = _extract_wave(traces, - sample_start, - m[i], - wave_len) - # Masks on unmasked channels. - b.masks = m[i, b.channels] - b.spike_time = st[i] - b.spike_cluster = sc[i] - b.offset_samples = offset_samples - - spikes.append(b) - return spikes - + b = extract_spikes(traces, interval, + sample_rate=model.sample_rate, + spike_times=model.spike_times, + spike_clusters=model.spike_clusters, + all_masks=all_masks, + n_samples_waveforms=model.n_samples_waveforms, + ) + return b model.spikes_traces = spikes_traces def get_waveforms(n): diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a5bbe051d..275c5b58d 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -557,6 +557,47 @@ def select_traces(traces, interval, sample_rate=None): return traces +def extract_spikes(traces, interval, sample_rate=None, + spike_times=None, spike_clusters=None, + all_masks=None, + n_samples_waveforms=None): + sr = sample_rate + ns = n_samples_waveforms + if not isinstance(ns, tuple): + ns = (ns // 2, ns // 2) + offset_samples = ns[0] + wave_len = ns[0] + ns[1] + + # Find spikes. + a, b = spike_times.searchsorted(interval) + st = spike_times[a:b] + sc = spike_clusters[a:b] + m = all_masks[a:b, :] + n = len(st) + assert len(sc) == n + assert m.shape[0] == n + + # Extract waveforms. + spikes = [] + for i in range(n): + b = Bunch() + # Find the start of the waveform in the extracted traces. + sample_start = int(round((st[i] - interval[0]) * sr)) + sample_start -= offset_samples + b.waveforms, b.channels = _extract_wave(traces, + sample_start, + m[i], + wave_len) + # Masks on unmasked channels. + b.masks = m[i, b.channels] + b.spike_time = st[i] + b.spike_cluster = sc[i] + b.offset_samples = offset_samples + + spikes.append(b) + return spikes + + class TraceView(ManualClusteringView): interval_duration = .5 # default duration of the interval shift_amount = .1 From ecaf3c8b17a6535e269072a181bfb0435a67d45f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 26 Jan 2016 18:05:17 +0100 Subject: [PATCH 0952/1059] Fix --- phy/cluster/manual/views.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 275c5b58d..202f10a52 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -572,7 +572,7 @@ def extract_spikes(traces, interval, sample_rate=None, a, b = spike_times.searchsorted(interval) st = spike_times[a:b] sc = spike_clusters[a:b] - m = all_masks[a:b, :] + m = all_masks[a:b] n = len(st) assert len(sc) == n assert m.shape[0] == n @@ -643,12 +643,6 @@ def __init__(self, self._scaling = 1. self._origin = None - # Default data bounds. - # TODO: better way of finding data bounds for the traces. - tr = traces((0, 1))[0] - m, M = tr.min(), tr.max() - self.data_bounds = np.array([0, m, 1, M]) - # Initialize the view. super(TraceView, self).__init__(layout='stacked', origin=self.origin, @@ -669,11 +663,14 @@ def __init__(self, def _plot_traces(self, traces): assert traces.shape[1] == self.n_channels t = self.interval[0] + np.arange(traces.shape[0]) * self.dt + t = np.tile(t, (self.n_channels, 1)) gray = .3 - for ch in range(self.n_channels): - self[ch].plot(t, traces[:, ch], + channels = np.arange(self.n_channels) + for ch in channels: + self[ch].plot(t[ch, :], traces[:, ch], color=(gray, gray, gray, 1), - data_bounds=self.data_bounds) + data_bounds=self.data_bounds, + ) def _plot_spike(self, waveforms=None, channels=None, masks=None, spike_time=None, spike_cluster=None, offset_samples=0, @@ -732,16 +729,18 @@ def set_interval(self, interval, change_status=True): interval = self._restrict_interval(interval) self.interval = interval start, end = interval - # Update the data bounds on the x axis. - self.data_bounds[0] = start - self.data_bounds[2] = end # Set the status message. if change_status: self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) - # Plot the traces. + # Load the traces. all_traces = self.traces(interval) assert isinstance(all_traces, (tuple, list)) + # Default data bounds. + m, M = all_traces[0].min(), all_traces[0].max() + self.data_bounds = np.array([start, m, end, M]) + + # Plot the traces. for traces in all_traces: self._plot_traces(traces) From cba8b0c5fc6268c8e21165ec2a3f2911b20050e1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 26 Jan 2016 19:13:53 +0100 Subject: [PATCH 0953/1059] Update --- phy/cluster/manual/tests/test_views.py | 4 ++-- phy/cluster/manual/views.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 7ee50b198..e25237279 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -82,8 +82,8 @@ def traces(interval): return [tr] model.traces = traces - def spikes_traces(interval): - traces = model.traces(interval)[0] + def spikes_traces(interval, traces): + traces = traces[0] b = extract_spikes(traces, interval, sample_rate=model.sample_rate, spike_times=model.spike_times, diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 202f10a52..fa80c0d10 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -745,7 +745,7 @@ def set_interval(self, interval, change_status=True): self._plot_traces(traces) # Plot the spikes. - spikes = self.spikes(interval) + spikes = self.spikes(interval, all_traces) assert isinstance(spikes, (tuple, list)) for spike in spikes: self._plot_spike(**spike) From 1fb782f295ba399a1aa8df074fc4603d0b6bf1be Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 15:38:07 +0100 Subject: [PATCH 0954/1059] WIP: cache() and memcache() methods in Context --- phy/io/context.py | 53 +++++++++++++++++------------------- phy/io/tests/test_context.py | 5 ++-- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index 60bc92be3..71762d65a 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -186,41 +186,38 @@ def ipy_view(self, value): # Dill is necessary because we need to serialize closures. value.use_dill() - def cache(self, f=None, memcache=False): + def cache(self, f): """Cache a function using the context's cache directory.""" - if f is None: - return lambda _: self.cache(_, memcache=memcache) if self._memory is None: # pragma: no cover logger.debug("Joblib is not installed: skipping cacheing.") return f assert f disk_cached = self._memory.cache(f) + return disk_cached + + def memcache(self, f): + from joblib import hash name = _fullname(f) - if memcache: - from joblib import hash - # Create the cache dictionary for the function. - if name not in self._memcache: - self._memcache[name] = {} - - c = self._memcache[name] - - @wraps(f) - def mem_cached(*args, **kwargs): - """Cache the function in memory.""" - h = hash((args, kwargs)) - if h in c: - logger.debug("Get %s(%s) from memcache.", name, str(args)) - # Retrieve the value from the memcache. - return c[h] - else: - logger.debug("Get %s(%s) from joblib.", name, str(args)) - # Call and cache the function. - out = disk_cached(*args, **kwargs) - c[h] = out - return out - return mem_cached - else: - return disk_cached + # Create the cache dictionary for the function. + if name not in self._memcache: + self._memcache[name] = {} + c = self._memcache[name] + + @wraps(f) + def memcached(*args, **kwargs): + """Cache the function in memory.""" + h = hash((args, kwargs)) + if h in c: + logger.debug("Get %s(%s) from memcache.", name, str(args)) + # Retrieve the value from the memcache. + return c[h] + else: + logger.debug("Get %s(%s) from joblib.", name, str(args)) + # Call and cache the function. + out = f(*args, **kwargs) + c[h] = out + return out + return memcached def map_dask_array(self, func, da, *args, **kwargs): """Map a function on the chunks of a dask array, and return a diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index e1a4cb724..c87bcaa43 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -131,11 +131,12 @@ def f(x): assert len(_res) == 2 -def test_context_memmap(tempdir, context): +def test_context_memcache(tempdir, context): _res = [] - @context.cache(memcache=True) + @context.memcache + @context.cache def f(x): _res.append(x) return x ** 2 From 20b4a018bdcb5bc27b18e6511bf232ebc9172147 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 15:39:09 +0100 Subject: [PATCH 0955/1059] Minor clean-up in gui_component --- phy/cluster/manual/gui_component.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index e9ebfc270..07beb22cf 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -149,7 +149,6 @@ def __init__(self, self._add_default_columns() self._similarity = {} - self.similarity_func = None # Internal methods # ------------------------------------------------------------------------- @@ -299,7 +298,7 @@ def _update_cluster_view(self): def _update_similarity_view(self): """Update the similarity view with matches for the specified clusters.""" - if not self.similarity_func: + if not self.similarity: return selection = self.cluster_view.selected if not len(selection): @@ -309,10 +308,11 @@ def _update_similarity_view(self): # This is a list of pairs (closest_cluster, similarity). self._best = cluster_id self._similarity = {int(cl): s - for (cl, s) in self.similarity_func(cluster_id)} - self.similarity_view.set_rows([int(c) - for c in self.clustering.cluster_ids - if c not in selection]) + for (cl, s) in self.similarity(cluster_id)} + clusters = list(map(int, self.clustering.cluster_ids)) + self.similarity_view.set_rows([c for c in clusters + if c not in selection and + c in self._similarity]) self.similarity_view.sort_by('similarity', 'desc') def _emit_select(self, cluster_ids): @@ -356,7 +356,7 @@ def set_similarity_func(self, f): """ logger.debug("Set similarity function `%s`.", f.__name__) - self.similarity_func = f + self.similarity = f def on_cluster(self, up): """Update the cluster views after clustering actions.""" From a99e5a07193ed820f4d4b3ad27d8357e0d601652 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 15:53:17 +0100 Subject: [PATCH 0956/1059] WIP: clean up similarity in gui_component --- phy/cluster/manual/gui_component.py | 47 ++++++++----------- .../manual/tests/test_gui_component.py | 5 +- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 07beb22cf..2cdc4a860 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -128,8 +128,8 @@ def __init__(self, ): self.gui = None - self.quality = quality - self.similarity = similarity + self.quality = quality # function cluster => quality + self.similarity = similarity # function cluster => [(cl, sim), ...] assert hasattr(spikes_per_cluster, '__call__') self.spikes_per_cluster = spikes_per_cluster @@ -148,7 +148,8 @@ def __init__(self, self._create_cluster_views() self._add_default_columns() - self._similarity = {} + self._best = None + self._current_similarity_values = {} # Internal methods # ------------------------------------------------------------------------- @@ -187,28 +188,25 @@ def _add_default_columns(self): def n_spikes(cluster_id): return len(self.spikes_per_cluster(cluster_id)) + @self.add_column(show=False) def skip(cluster_id): """Whether to skip that cluster.""" return (self.cluster_meta.get('group', cluster_id) in ('noise', 'mua')) - self.add_column(skip, show=False) + @self.add_column(show=False) def good(cluster_id): """Good column for color.""" return self.cluster_meta.get('group', cluster_id) == 'good' - self.add_column(good, show=False) - - self._best = None + @self.similarity_view.add_column def similarity(cluster_id): # NOTE: there is a dictionary with the similarity to the current # best cluster. It is updated when the selection changes in the # cluster view. This is a bit of a hack: the HTML table expects # a function that returns a value for every row, but here we - # cache all similarity view rows in self._similarity for - # performance reasons. - return self._similarity.get(cluster_id, 0) - self.similarity_view.add_column(similarity) + # cache all similarity view rows in self._current_similarity_values + return self._current_similarity_values.get(cluster_id, 0) def _create_actions(self, gui): self.actions = Actions(gui, @@ -304,15 +302,21 @@ def _update_similarity_view(self): if not len(selection): return cluster_id = selection[0] + self._best = cluster_id logger.log(5, "Update the similarity view.") # This is a list of pairs (closest_cluster, similarity). - self._best = cluster_id - self._similarity = {int(cl): s - for (cl, s) in self.similarity(cluster_id)} + similarities = self.similarity(cluster_id) + # We save the similarity values wrt the currently-selected clusters. + self._current_similarity_values = {int(cl): s + for (cl, s) in similarities} clusters = list(map(int, self.clustering.cluster_ids)) + # The similarity view will use these values. self.similarity_view.set_rows([c for c in clusters if c not in selection and - c in self._similarity]) + c in self._current_similarity_values]) + # The similarity name is always 'similarity' because we use + # a special function to retrieve the similarity values from the + # self._current_similarity_values dictionary. self.similarity_view.sort_by('similarity', 'desc') def _emit_select(self, cluster_ids): @@ -346,18 +350,6 @@ def set_default_sort(self, name, sort_dir='desc'): # Sort by the default sort. self.cluster_view.sort_by(name, sort_dir) - def set_similarity_func(self, f): - """Set the similarity function. - - This is a function that returns an ordered list of pairs - `(candidate, similarity)` for any given cluster. This list can have - a fixed number of elements for performance reasons (keeping the best - 20 candidates for example). - - """ - logger.debug("Set similarity function `%s`.", f.__name__) - self.similarity = f - def on_cluster(self, up): """Update the cluster views after clustering actions.""" @@ -415,7 +407,6 @@ def attach(self, gui): self.set_default_sort(self.quality.__name__ if self.quality else 'n_spikes') if self.similarity: - self.set_similarity_func(self.similarity) gui.add_view(self.similarity_view, name='SimilarityView') # Update the cluster views and selection when a cluster event occurs. diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index fffde618d..a43a99d90 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -136,12 +136,13 @@ def test_manual_clustering_split_2(gui, quality, similarity): spike_clusters = np.array([0, 0, 1]) mc = ManualClustering(spike_clusters, - lambda c: _spikes_in_clusters(spike_clusters, [c])) + lambda c: _spikes_in_clusters(spike_clusters, [c]), + similarity=similarity, + ) mc.attach(gui) mc.add_column(quality, name='quality', default=True) mc.set_default_sort('quality', 'desc') - mc.set_similarity_func(similarity) mc.split([0]) assert mc.selected == [2, 3] From 03fc791bcbf9730b1a6ef43b422eedba323ebe33 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 16:21:46 +0100 Subject: [PATCH 0957/1059] WIP: reorganize cluster stats --- phy/stats/clusters.py | 10 ++++++---- phy/stats/tests/test_clusters.py | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py index bff81e5ad..72bf3d25e 100644 --- a/phy/stats/clusters.py +++ b/phy/stats/clusters.py @@ -38,7 +38,8 @@ def get_sorted_main_channels(mean_masks, unmasked_channels): # Wizard measures #------------------------------------------------------------------------------ -def get_max_waveform_amplitude(mean_masks, mean_waveforms): +def get_waveform_amplitude(mean_masks, mean_waveforms): + """Return the amplitude of the waveforms on all channels.""" assert mean_waveforms.ndim == 2 n_samples, n_channels = mean_waveforms.shape @@ -46,11 +47,12 @@ def get_max_waveform_amplitude(mean_masks, mean_waveforms): assert mean_masks.ndim == 1 assert mean_masks.shape == (n_channels,) - mean_waveforms = mean_masks * mean_waveforms + mean_waveforms = mean_waveforms * mean_masks + assert mean_waveforms.shape == (n_samples, n_channels) # Amplitudes. - m, M = mean_waveforms.min(axis=1), mean_waveforms.max(axis=1) - return np.max(M - m) + m, M = mean_waveforms.min(axis=0), mean_waveforms.max(axis=0) + return M - m def get_mean_masked_features_distance(mean_features_0, diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py index 2fe7dac97..643f9f80d 100644 --- a/phy/stats/tests/test_clusters.py +++ b/phy/stats/tests/test_clusters.py @@ -16,7 +16,7 @@ get_mean_probe_position, get_sorted_main_channels, get_mean_masked_features_distance, - get_max_waveform_amplitude, + get_waveform_amplitude, ) from phy.electrode.mea import staggered_positions from phy.io.mock import (artificial_features, @@ -109,7 +109,7 @@ def test_sorted_main_channels(masks): assert np.all(np.in1d(channels, [5, 7])) -def test_max_waveform_amplitude(masks, waveforms): +def test_waveform_amplitude(masks, waveforms): waveforms *= .1 masks *= .1 @@ -119,8 +119,9 @@ def test_max_waveform_amplitude(masks, waveforms): mean_waveforms = mean(waveforms) mean_masks = mean(masks) - amplitude = get_max_waveform_amplitude(mean_masks, mean_waveforms) - assert amplitude > 0 + amplitude = get_waveform_amplitude(mean_masks, mean_waveforms) + assert np.all(amplitude >= 0) + assert amplitude.shape == (mean_waveforms.shape[1],) def test_mean_masked_features_distance(features, From 086ba8def12eab56f3d8e27e82b34788ef24de47 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 16:37:10 +0100 Subject: [PATCH 0958/1059] Increase coverage --- phy/io/array.py | 2 +- phy/io/tests/test_array.py | 7 +++++++ phy/traces/spike_detect.py | 2 +- phy/traces/tests/test_spike_detect.py | 26 +++++++++++++------------- phy/utils/tests/test_types.py | 9 ++++++++- 5 files changed, 30 insertions(+), 16 deletions(-) diff --git a/phy/io/array.py b/phy/io/array.py index 4dc7502d7..03df1608e 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -192,7 +192,7 @@ def _in_polygon(points, polygon): return path.contains_points(points) -def _get_data_lim(arr, n_spikes=None, percentile=None): +def _get_data_lim(arr, n_spikes=None): n = arr.shape[0] k = max(1, n // n_spikes) if n_spikes else 1 arr = np.abs(arr[::k]) diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index e2dce63da..81a3c53c0 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -18,6 +18,7 @@ _spikes_in_clusters, _spikes_per_cluster, _flatten_per_cluster, + _get_data_lim, select_spikes, Selector, chunk_bounds, @@ -127,6 +128,12 @@ def test_get_padded(): ae(_get_padded(arr, -2, 3).ravel(), [0, 0, 1, 2, 3]) +def test_get_data_lim(): + arr = np.random.rand(10, 5) + assert 0 < _get_data_lim(arr) < 1 + assert 0 < _get_data_lim(arr, 2) < 1 + + def test_unique(): """Test _unique() function""" _unique([]) diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py index 2929a4d99..45763aa2c 100644 --- a/phy/traces/spike_detect.py +++ b/phy/traces/spike_detect.py @@ -224,7 +224,7 @@ def detect(self, traces, thresholds=None): # Extract the spikes, masks, waveforms. if not self.ctx: return self.extract_spikes(traces, thresholds=thresholds) - else: + else: # pragma: no cover # skipped for now in the test suite import dask.array as da # Chunking parameters. diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py index e8cb7c3fa..526575de7 100644 --- a/phy/traces/tests/test_spike_detect.py +++ b/phy/traces/tests/test_spike_detect.py @@ -158,20 +158,20 @@ def test_detect_simple(spike_detector, traces): # _plot(sd, traces, spike_samples, masks) -# NOTE: skip for now to accelerate the test suite... -def _test_detect_context(spike_detector, traces, parallel_context): # noqa - sd = spike_detector - sd.set_context(parallel_context) +# # NOTE: skip for now to accelerate the test suite... +# def _test_detect_context(spike_detector, traces, parallel_context): # noqa +# sd = spike_detector +# sd.set_context(parallel_context) - spike_samples, masks, _ = sd.detect(traces) +# spike_samples, masks, _ = sd.detect(traces) - n_channels = sd.n_channels - n_spikes = len(spike_samples) +# n_channels = sd.n_channels +# n_spikes = len(spike_samples) - assert spike_samples.dtype == np.int64 - assert spike_samples.ndim == 1 +# assert spike_samples.dtype == np.int64 +# assert spike_samples.ndim == 1 - assert masks.dtype == np.float32 - assert masks.ndim == 2 - assert masks.shape == (n_spikes, n_channels) - # _plot(sd, traces, spike_samples.compute(), masks.compute()) +# assert masks.dtype == np.float32 +# assert masks.ndim == 2 +# assert masks.shape == (n_spikes, n_channels) +# # _plot(sd, traces, spike_samples.compute(), masks.compute()) diff --git a/phy/utils/tests/test_types.py b/phy/utils/tests/test_types.py index 115785e4b..94ee65c05 100644 --- a/phy/utils/tests/test_types.py +++ b/phy/utils/tests/test_types.py @@ -9,7 +9,7 @@ import numpy as np from pytest import raises -from .._types import (Bunch, _is_integer, _is_list, _is_float, +from .._types import (Bunch, _bunchify, _is_integer, _is_list, _is_float, _as_list, _is_array_like, _as_array, _as_tuple, ) @@ -27,6 +27,13 @@ def test_bunch(): assert obj.copy() == obj +def test_bunchify(): + d = {'a': {'b': 0}} + b = _bunchify(d) + assert isinstance(b, Bunch) + assert isinstance(b['a'], Bunch) + + def test_number(): assert not _is_integer(None) assert not _is_integer(3.) From d0c973e4e7d6b79bca88ace65ce8031ccc934bf1 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 17:02:46 +0100 Subject: [PATCH 0959/1059] WIP: update similarity logic in gui_component --- phy/cluster/manual/gui_component.py | 25 +++++++++++-------- .../manual/tests/test_gui_component.py | 7 +++--- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 2cdc4a860..4226243a6 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -7,6 +7,7 @@ # Imports # ----------------------------------------------------------------------------- +from collections import OrderedDict from functools import partial import logging @@ -199,7 +200,6 @@ def good(cluster_id): """Good column for color.""" return self.cluster_meta.get('group', cluster_id) == 'good' - @self.similarity_view.add_column def similarity(cluster_id): # NOTE: there is a dictionary with the similarity to the current # best cluster. It is updated when the selection changes in the @@ -207,6 +207,9 @@ def similarity(cluster_id): # a function that returns a value for every row, but here we # cache all similarity view rows in self._current_similarity_values return self._current_similarity_values.get(cluster_id, 0) + if self.similarity: + self.similarity_view.add_column(similarity, + name=self.similarity.__name__) def _create_actions(self, gui): self.actions = Actions(gui, @@ -307,17 +310,19 @@ def _update_similarity_view(self): # This is a list of pairs (closest_cluster, similarity). similarities = self.similarity(cluster_id) # We save the similarity values wrt the currently-selected clusters. - self._current_similarity_values = {int(cl): s - for (cl, s) in similarities} - clusters = list(map(int, self.clustering.cluster_ids)) + # Note that we keep the order of the output of the self.similary() + # function. + clusters_sim = OrderedDict([(int(cl), s) for (cl, s) in similarities]) + # List of similar clusters, remove non-existing ones. + clusters = [c for c in clusters_sim.keys() + if c in self.clustering.cluster_ids] # The similarity view will use these values. + self._current_similarity_values = clusters_sim + # Set the rows of the similarity view. + # TODO: instead of the self._current_similarity_values hack, + # give the possibility to specify the values here (?). self.similarity_view.set_rows([c for c in clusters - if c not in selection and - c in self._current_similarity_values]) - # The similarity name is always 'similarity' because we use - # a special function to retrieve the similarity values from the - # self._current_similarity_values dictionary. - self.similarity_view.sort_by('similarity', 'desc') + if c not in selection]) def _emit_select(self, cluster_ids): """Choose spikes from the specified clusters and emit the diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index a43a99d90..0db27c19e 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -6,7 +6,7 @@ # Imports #------------------------------------------------------------------------------ -from pytest import yield_fixture +from pytest import yield_fixture, fixture import numpy as np from numpy.testing import assert_array_equal as ae @@ -34,7 +34,7 @@ def gui(qtbot): qtbot.wait(5) -@yield_fixture +@fixture def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, quality, similarity): spike_clusters = np.array(cluster_ids) @@ -49,8 +49,7 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, ) mc.attach(gui) - yield mc - del mc + return mc #------------------------------------------------------------------------------ From bbc7a931994e0e6824fd1d7465be4948bd1c95a9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 18:09:56 +0100 Subject: [PATCH 0960/1059] Optimization --- phy/cluster/manual/gui_component.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4226243a6..bf1e4a080 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -305,6 +305,7 @@ def _update_similarity_view(self): if not len(selection): return cluster_id = selection[0] + cluster_ids = self.clustering.cluster_ids self._best = cluster_id logger.log(5, "Update the similarity view.") # This is a list of pairs (closest_cluster, similarity). @@ -315,7 +316,7 @@ def _update_similarity_view(self): clusters_sim = OrderedDict([(int(cl), s) for (cl, s) in similarities]) # List of similar clusters, remove non-existing ones. clusters = [c for c in clusters_sim.keys() - if c in self.clustering.cluster_ids] + if c in cluster_ids] # The similarity view will use these values. self._current_similarity_values = clusters_sim # Set the rows of the similarity view. From 45de0fe158d8695e22d4526010b102bce68dd29a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 18:21:35 +0100 Subject: [PATCH 0961/1059] Color for traces --- phy/cluster/manual/tests/test_views.py | 4 ++-- phy/cluster/manual/views.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index e25237279..b03f987f9 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -79,11 +79,11 @@ def traces(interval): tr = select_traces(all_traces, interval, sample_rate=model.sample_rate, ) - return [tr] + return [Bunch(traces=tr)] model.traces = traces def spikes_traces(interval, traces): - traces = traces[0] + traces = traces[0].traces b = extract_spikes(traces, interval, sample_rate=model.sample_rate, spike_times=model.spike_times, diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index fa80c0d10..b5569d7a5 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -602,6 +602,7 @@ class TraceView(ManualClusteringView): interval_duration = .5 # default duration of the interval shift_amount = .1 scaling_coeff = 1.1 + default_trace_color = (.3, .3, .3, 1.) default_shortcuts = { 'go_left': 'alt+left', 'go_right': 'alt+right', @@ -660,15 +661,15 @@ def __init__(self, # Internal methods # ------------------------------------------------------------------------- - def _plot_traces(self, traces): + def _plot_traces(self, traces=None, color=None): assert traces.shape[1] == self.n_channels t = self.interval[0] + np.arange(traces.shape[0]) * self.dt t = np.tile(t, (self.n_channels, 1)) - gray = .3 + color = color or self.default_trace_color channels = np.arange(self.n_channels) for ch in channels: self[ch].plot(t[ch, :], traces[:, ch], - color=(gray, gray, gray, 1), + color=color, data_bounds=self.data_bounds, ) @@ -737,12 +738,12 @@ def set_interval(self, interval, change_status=True): all_traces = self.traces(interval) assert isinstance(all_traces, (tuple, list)) # Default data bounds. - m, M = all_traces[0].min(), all_traces[0].max() + m, M = all_traces[0].traces.min(), all_traces[0].traces.max() self.data_bounds = np.array([start, m, end, M]) # Plot the traces. for traces in all_traces: - self._plot_traces(traces) + self._plot_traces(**traces) # Plot the spikes. spikes = self.spikes(interval, all_traces) From cf84d83c56dc12b61f92257a245c97de43590253 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 20:02:33 +0100 Subject: [PATCH 0962/1059] Default trace view in the middle --- phy/cluster/manual/tests/test_views.py | 6 +++--- phy/cluster/manual/views.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index b03f987f9..5a30b99ac 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -361,11 +361,11 @@ def test_trace_view(qtbot, gui): _select_clusters(gui) ac(v.stacked.box_size, (1., .08181), atol=1e-3) - assert v.time == .25 - - v.go_to(.5) assert v.time == .5 + v.go_to(.25) + assert v.time == .25 + v.go_to(-.5) assert v.time == .25 diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index b5569d7a5..45a45bf0a 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -656,7 +656,8 @@ def __init__(self, self._update_boxes() # Initial interval. - self.set_interval((0., self.interval_duration)) + self.interval = None + self.go_to(duration / 2.) # Internal methods # ------------------------------------------------------------------------- @@ -811,12 +812,14 @@ def time(self): @property def half_duration(self): """Half of the duration of the current interval.""" - a, b = self.interval - return (b - a) * .5 + if self.interval is not None: + a, b = self.interval + return (b - a) * .5 + else: + return self.interval_duration * .5 def go_to(self, time): """Go to a specific time (in seconds).""" - start, end = self.interval half_dur = self.half_duration self.set_interval((time - half_dur, time + half_dur)) From 104324e6a51c56bad07cc1da032e3dde6eb1a8fd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 21:08:46 +0100 Subject: [PATCH 0963/1059] Fix floating-point precision issue in plot: convert to float32 at the last moment --- phy/plot/interact.py | 4 +-- phy/plot/plot.py | 2 +- phy/plot/tests/test_base.py | 5 ++-- phy/plot/tests/test_transform.py | 15 ++++++----- phy/plot/tests/test_utils.py | 4 +-- phy/plot/transform.py | 8 +++--- phy/plot/utils.py | 13 +++++----- phy/plot/visuals.py | 44 +++++++++++++++----------------- 8 files changed, 47 insertions(+), 48 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 465e4396f..a439c9d17 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -78,7 +78,6 @@ def add_boxes(self, canvas, shape=None): box_index.append([i, j]) box_index = np.vstack(box_index) box_index = np.repeat(box_index, 8, axis=0) - box_index = box_index.astype(np.float32) boxes = LineVisual() @@ -88,7 +87,7 @@ def _remove_clip(tc): canvas.add_visual(boxes) boxes.set_data(pos=pos) - boxes.program['a_box_index'] = box_index + boxes.program['a_box_index'] = box_index.astype(np.float32) def update_program(self, program): program[self.shape_var] = self._shape @@ -178,6 +177,7 @@ def attach(self, canvas): def update_program(self, program): # Signal bounds (positions). box_bounds = _get_texture(self._box_bounds, NDC, self.n_boxes, [-1, 1]) + box_bounds = box_bounds.astype(np.float32) # TODO OPTIM: set the texture at initialization and update the data program['u_box_bounds'] = Texture2D(box_bounds, internalformat='rgba32f') diff --git a/phy/plot/plot.py b/phy/plot/plot.py index a88b960f1..153c04fc8 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -139,7 +139,7 @@ def build(self): # so we can replace this with `if 'a_box_index' in visual.program` # after the next VisPy release. if 'a_box_index' in visual.program._code_variables: - visual.program['a_box_index'] = box_index + visual.program['a_box_index'] = box_index.astype(np.float32) self.update() @contextmanager diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py index 2921e62d9..8f2984ac0 100644 --- a/phy/plot/tests/test_base.py +++ b/phy/plot/tests/test_base.py @@ -119,8 +119,9 @@ def __init__(self): self.inserter.insert_vert(s, 'after_transforms') def set_data(self): - data = np.random.uniform(0, 20, (1000, 2)).astype(np.float32) - self.program['a_position'] = self.transforms.apply(data) + data = np.random.uniform(0, 20, (1000, 2)) + pos = self.transforms.apply(data).astype(np.float32) + self.program['a_position'] = pos bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) canvas.transforms.add_on_gpu([Subplot((2, 3), (1, 2)), diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index c7f57e01d..c8ea810ae 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -26,12 +26,11 @@ def _check_forward(transform, array, expected): transformed = transform.apply(array) if array is None or not len(array): - assert transformed == array + assert transformed is None or not len(transformed) return array = np.atleast_2d(array) if isinstance(array, np.ndarray): assert transformed.shape[1] == array.shape[1] - assert transformed.dtype == np.float32 if not len(transformed): assert not len(expected) else: @@ -39,6 +38,8 @@ def _check_forward(transform, array, expected): def _check(transform, array, expected): + array = np.array(array, dtype=np.float64) + expected = np.array(expected, dtype=np.float64) _check_forward(transform, array, expected) # Test the inverse transform if it is implemented. try: @@ -176,7 +177,7 @@ def test_subplot_glsl(): @yield_fixture def array(): - yield np.array([[-1, 0], [1, 2]]) + yield np.array([[-1., 0.], [1., 2.]]) def test_transform_chain_empty(array): @@ -234,14 +235,14 @@ def test_transform_chain_add(): tc.add_on_cpu([Scale(.5)]) tc_2 = TransformChain() - tc_2.add_on_cpu([Scale(2)]) + tc_2.add_on_cpu([Scale(2.)]) - ae((tc + tc_2).apply([3]), [[3]]) + ae((tc + tc_2).apply([3.]), [[3.]]) def test_transform_chain_inverse(): tc = TransformChain() tc.add_on_cpu([Scale(.5), Translate((1, 0)), Scale(2)]) tci = tc.inverse() - ae(tc.apply([[1, 0]]), [[3, 0]]) - ae(tci.apply([[3, 0]]), [[1, 0]]) + ae(tc.apply([[1., 0.]]), [[3., 0.]]) + ae(tci.apply([[3., 0.]]), [[1., 0.]]) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index a8138f432..8cef15ab7 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -128,8 +128,8 @@ def test_get_boxes(): positions = staggered_positions(8) boxes = _get_boxes(positions) - ae(boxes[:, 1], np.arange(.75, -1.1, -.25)) - ae(boxes[:, 3], np.arange(1, -.76, -.25)) + ac(boxes[:, 1], np.arange(.75, -1.1, -.25), atol=1e-6) + ac(boxes[:, 3], np.arange(1, -.76, -.25), atol=1e-7) def test_get_box_pos_size(): diff --git a/phy/plot/transform.py b/phy/plot/transform.py index 252bcd9bb..d7b94d2f7 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -26,10 +26,10 @@ def wrapped(arr, **kwargs): if arr is None or not len(arr): return arr arr = np.atleast_2d(arr) - arr = arr.astype(np.float32) assert arr.ndim == 2 + assert arr.dtype == np.float64 out = f(arr, **kwargs) - out = out.astype(np.float32) + assert out.dtype == np.float64 out = np.atleast_2d(out) assert out.ndim == 2 assert out.shape[1] == arr.shape[1] @@ -98,8 +98,8 @@ def subplot_bounds_glsl(shape=None, index=None): def pixels_to_ndc(pos, size=None): """Convert from pixels to normalized device coordinates (in [-1, 1]).""" - pos = np.asarray(pos, dtype=np.float32) - size = np.asarray(size, dtype=np.float32) + pos = np.asarray(pos, dtype=np.float64) + size = np.asarray(size, dtype=np.float64) pos = pos / (size / 2.) - 1 # Flip y, because the origin in pixels is at the top left corner of the # window. diff --git a/phy/plot/utils.py b/phy/plot/utils.py index 33a5f85f8..ff5b3e6c4 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -76,7 +76,7 @@ def _get_boxes(pos, size=None, margin=0, keep_aspect_ratio=True): """Generate non-overlapping boxes in NDC from a set of positions.""" # Get x, y. - pos = np.asarray(pos, dtype=np.float32) + pos = np.asarray(pos, dtype=np.float64) x, y = pos.T x = x[:, np.newaxis] y = y[:, np.newaxis] @@ -131,7 +131,7 @@ def _get_texture(arr, default, n_items, from_bounds): arr = np.tile(default, (n_items, 1)) assert arr.shape == (n_items, n_cols) # Convert to 3D texture. - arr = arr[np.newaxis, ...].astype(np.float32) + arr = arr[np.newaxis, ...].astype(np.float64) assert arr.shape == (1, n_items, n_cols) # NOTE: we need to cast the texture to [0., 1.] (float texture). # This is easy as soon as we assume that the signal bounds are in @@ -143,7 +143,6 @@ def _get_texture(arr, default, n_items, from_bounds): arr = (arr - m) / (M - m) assert np.all(arr >= 0) assert np.all(arr <= 1.) - arr = arr.astype(np.float32) return arr @@ -152,7 +151,7 @@ def _get_array(val, shape, default=None): assert val is not None or default is not None if hasattr(val, '__len__') and len(val) == 0: # pragma: no cover val = None - out = np.zeros(shape, dtype=np.float32) + out = np.zeros(shape, dtype=np.float64) # This solves `ValueError: could not broadcast input array from shape (n) # into shape (n, 1)`. if val is not None and isinstance(val, np.ndarray): @@ -206,8 +205,8 @@ def _get_pos(x, y): assert x is not None assert y is not None - x = np.asarray(x, dtype=np.float32) - y = np.asarray(y, dtype=np.float32) + x = np.asarray(x, dtype=np.float64) + y = np.asarray(y, dtype=np.float64) # Validate the position. assert x.ndim == y.ndim == 1 @@ -220,7 +219,7 @@ def _get_index(n_items, item_size, n): """Prepare an index attribute for GPU uploading.""" index = np.arange(n_items) index = np.repeat(index, item_size) - index = index.astype(np.float32) + index = index.astype(np.float64) assert index.shape == (n,) return index diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index f5dadf772..23617784a 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -107,9 +107,10 @@ def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) self.data_range.from_bounds = data.data_bounds pos_tr = self.transforms.apply(data.pos) - self.program['a_position'] = np.c_[pos_tr, data.depth] - self.program['a_size'] = data.size - self.program['a_color'] = data.color + pos_tr = np.c_[pos_tr, data.depth] + self.program['a_position'] = pos_tr.astype(np.float32) + self.program['a_size'] = data.size.astype(np.float32) + self.program['a_color'] = data.color.astype(np.float32) def _as_list(arr): @@ -179,7 +180,7 @@ def validate(x=None, assert depth.shape == (n_signals, 1) data_bounds = _get_data_bounds(data_bounds, length=n_signals) - data_bounds = data_bounds.astype(np.float32) + data_bounds = data_bounds.astype(np.float64) assert data_bounds.shape == (n_signals, 4) return Bunch(x=x, y=y, @@ -202,7 +203,7 @@ def set_data(self, *args, **kwargs): y = np.concatenate(data.y) if len(data.y) else np.array([]) # Generate the position array. - pos = np.empty((n, 2), dtype=np.float32) + pos = np.empty((n, 2), dtype=np.float64) pos[:, 0] = x.ravel() pos[:, 1] = y.ravel() assert pos.shape == (n, 2) @@ -215,7 +216,7 @@ def set_data(self, *args, **kwargs): # Generate signal index. signal_index = np.repeat(np.arange(n_signals), n_samples) - signal_index = _get_array(signal_index, (n, 1)).astype(np.float32) + signal_index = _get_array(signal_index, (n, 1)) assert signal_index.shape == (n, 1) # Transform the positions. @@ -225,10 +226,9 @@ def set_data(self, *args, **kwargs): # Position and depth. depth = np.repeat(data.depth, n_samples, axis=0) - self.program['a_position'] = np.c_[pos_tr, depth] - self.program['a_color'] = color - self.program['a_signal_index'] = signal_index - # self.program['n_signals'] = n_signals + self.program['a_position'] = np.c_[pos_tr, depth].astype(np.float32) + self.program['a_color'] = color.astype(np.float32) + self.program['a_signal_index'] = signal_index.astype(np.float32) class HistogramVisual(BaseVisual): @@ -248,7 +248,7 @@ def validate(hist=None, color=None, ylim=None): assert hist is not None - hist = np.asarray(hist, np.float32) + hist = np.asarray(hist, np.float64) if hist.ndim == 1: hist = hist[None, :] assert hist.ndim == 2 @@ -295,18 +295,17 @@ def set_data(self, *args, **kwargs): # Set the transformed position. pos = np.vstack(_tesselate_histogram(row) for row in hist) - pos = pos.astype(np.float32) pos_tr = self.transforms.apply(pos) assert pos_tr.shape == (n, 2) - self.program['a_position'] = pos_tr + self.program['a_position'] = pos_tr.astype(np.float32) # Generate the hist index. - self.program['a_hist_index'] = _get_index(n_hists, n_bins * 6, n) + hist_index = _get_index(n_hists, n_bins * 6, n) + self.program['a_hist_index'] = hist_index.astype(np.float32) # Hist colors. - self.program['u_color'] = _get_texture(data.color, - self._default_color, - n_hists, [0, 1]) + tex = _get_texture(data.color, self._default_color, n_hists, [0, 1]) + self.program['u_color'] = tex.astype(np.float32) self.program['n_hists'] = n_hists @@ -343,7 +342,6 @@ def validate(pos=None, color=None, data_bounds=None): assert pos.ndim == 2 n_lines = pos.shape[0] assert pos.shape[1] == 4 - pos = pos.astype(np.float32) # Color. color = _get_array(color, (n_lines, 4), LineVisual._default_color) @@ -352,7 +350,7 @@ def validate(pos=None, color=None, data_bounds=None): if data_bounds is None: data_bounds = NDC data_bounds = _get_data_bounds(data_bounds, length=n_lines) - data_bounds = data_bounds.astype(np.float32) + data_bounds = data_bounds.astype(np.float64) assert data_bounds.shape == (n_lines, 4) return Bunch(pos=pos, color=color, data_bounds=data_bounds) @@ -367,7 +365,7 @@ def set_data(self, *args, **kwargs): pos = data.pos assert pos.ndim == 2 assert pos.shape[1] == 4 - assert pos.dtype == np.float32 + assert pos.dtype == np.float64 n_lines = pos.shape[0] n_vertices = 2 * n_lines pos = pos.reshape((-1, 2)) @@ -379,8 +377,8 @@ def set_data(self, *args, **kwargs): # Position. assert pos_tr.shape == (n_vertices, 2) - self.program['a_position'] = pos_tr + self.program['a_position'] = pos_tr.astype(np.float32) # Color. - color = np.repeat(data.color, 2, axis=0).astype(np.float32) - self.program['a_color'] = color + color = np.repeat(data.color, 2, axis=0) + self.program['a_color'] = color.astype(np.float32) From 8a8946b53ae04c5db0039ac72edd1efab1bf2d0c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 27 Jan 2016 21:15:04 +0100 Subject: [PATCH 0964/1059] Fix --- phy/cluster/manual/views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 45a45bf0a..1967457e0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1171,8 +1171,8 @@ def on_select(self, cluster_ids=None): spike_clusters_rel=sc) # Add axes. - self[i, j].lines(pos=[[-1, 0, +1, 0], - [0, -1, 0, +1]], + self[i, j].lines(pos=[[-1., 0., +1., 0.], + [0., -1., 0., +1.]], color=(.25, .25, .25, .5)) # Add the boxes. From 08f0891413b75d602f6f8307a09a2498ba9785a0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 28 Jan 2016 16:20:01 +0100 Subject: [PATCH 0965/1059] Minor fixes in views --- phy/cluster/manual/views.py | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 1967457e0..61be457fe 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -115,6 +115,8 @@ def _extend(channels, n=None): channels = list(channels) if n is None: return channels + if not len(channels): + channels = [0] if len(channels) < n: channels.extend([channels[-1]] * (n - len(channels))) channels = channels[:n] @@ -278,7 +280,7 @@ def __init__(self, self._overlap = False self.do_zoom_on_channels = True - self.best_channels = best_channels or (lambda clusters, n=None: []) + self.best_channels = best_channels or (lambda clusters: []) # Channel positions and n_channels. assert channel_positions is not None @@ -948,7 +950,7 @@ def __init__(self, """ self._scaling = 1. - self.best_channels = best_channels or (lambda clusters, n=None: []) + self.best_channels = best_channels or (lambda clusters=None: []) assert features self.features = features @@ -1067,8 +1069,8 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, def _get_channel_dims(self, cluster_ids): """Select the channels to show by default.""" n = self.n_cols - 1 - channels = self.best_channels(cluster_ids, 2 * n) - channels = (channels if channels + channels = self.best_channels(cluster_ids) + channels = (channels if channels is not None else list(range(self.n_channels))) channels = _extend(channels, 2 * n) assert len(channels) == 2 * n @@ -1119,9 +1121,10 @@ def on_select(self, cluster_ids=None): # Get the background features. data_bg = self.background_features - spike_ids_bg = data_bg.spike_ids - features_bg = data_bg.features - masks_bg = data_bg.masks + if data_bg is not None: + spike_ids_bg = data_bg.spike_ids + features_bg = data_bg.features + masks_bg = data_bg.masks # Select the dimensions. # Choose the channels automatically unless fixed_channels is set. @@ -1156,15 +1159,18 @@ def on_select(self, cluster_ids=None): x = self._get_feature(x_dim[i, j], spike_ids, f) y = self._get_feature(y_dim[i, j], spike_ids, f) - # Retrieve the x and y values for the background spikes. - x_bg = self._get_feature(x_dim[i, j], spike_ids_bg, - features_bg) - y_bg = self._get_feature(y_dim[i, j], spike_ids_bg, - features_bg) + if data_bg is not None: + # Retrieve the x and y values for the background + # spikes. + x_bg = self._get_feature(x_dim[i, j], spike_ids_bg, + features_bg) + y_bg = self._get_feature(y_dim[i, j], spike_ids_bg, + features_bg) + + # Background features. + self._plot_features(i, j, x_dim, y_dim, x_bg, y_bg, + masks=masks_bg) - # Background features. - self._plot_features(i, j, x_dim, y_dim, x_bg, y_bg, - masks=masks_bg) # Cluster features. self._plot_features(i, j, x_dim, y_dim, x, y, masks=masks, From b3816dd45e9459be78bfdc8a9502f2418c0ad5b6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 28 Jan 2016 18:33:45 +0100 Subject: [PATCH 0966/1059] Minor update in views --- phy/cluster/manual/views.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 61be457fe..476b90f51 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1411,18 +1411,22 @@ def on_select(self, cluster_ids=None): # Get the spike times and amplitudes data = self.coords(cluster_ids) + if data is None: + return spike_ids = data.spike_ids - spike_clusters = data.spike_clusters x = data.x y = data.y n_spikes = len(spike_ids) assert n_spikes > 0 - assert spike_clusters.shape == (n_spikes,) assert x.shape == (n_spikes,) assert y.shape == (n_spikes,) - # Get the spike clusters. - sc = _index_of(spike_clusters, cluster_ids) + spike_clusters = data.spike_clusters + if spike_clusters is not None: + assert spike_clusters.shape == (n_spikes,) + sc = _index_of(spike_clusters, cluster_ids) + else: # pragma: no cover + sc = None # Plot the amplitudes. with self.building(): @@ -1430,7 +1434,8 @@ def on_select(self, cluster_ids=None): # Get the color of the markers. color = _get_color(m, spike_clusters_rel=sc, n_clusters=n_clusters) assert color.shape == (n_spikes, 4) - ms = (self._default_marker_size if sc is not None else 1.) + # ms = (self._default_marker_size if sc is not None else 1.) + ms = self._default_marker_size self.scatter(x=x, y=y, From aecb0ae60afc2f897601987e1c6e7e674d275a3b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 28 Jan 2016 18:34:07 +0100 Subject: [PATCH 0967/1059] Flakify --- phy/cluster/manual/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 476b90f51..ed55928f9 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1425,7 +1425,7 @@ def on_select(self, cluster_ids=None): if spike_clusters is not None: assert spike_clusters.shape == (n_spikes,) sc = _index_of(spike_clusters, cluster_ids) - else: # pragma: no cover + else: # pragma: no cover sc = None # Plot the amplitudes. From 0d1dfac26148363803d95246bd7f3697813b4a21 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 28 Jan 2016 21:21:44 +0100 Subject: [PATCH 0968/1059] Revert unnecessary change in feature view --- phy/cluster/manual/views.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index ed55928f9..b597c9c08 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1414,19 +1414,17 @@ def on_select(self, cluster_ids=None): if data is None: return spike_ids = data.spike_ids + spike_clusters = data.spike_clusters x = data.x y = data.y n_spikes = len(spike_ids) assert n_spikes > 0 + assert spike_clusters.shape == (n_spikes,) assert x.shape == (n_spikes,) assert y.shape == (n_spikes,) - spike_clusters = data.spike_clusters - if spike_clusters is not None: - assert spike_clusters.shape == (n_spikes,) - sc = _index_of(spike_clusters, cluster_ids) - else: # pragma: no cover - sc = None + # Get the spike clusters. + sc = _index_of(spike_clusters, cluster_ids) # Plot the amplitudes. with self.building(): @@ -1434,8 +1432,7 @@ def on_select(self, cluster_ids=None): # Get the color of the markers. color = _get_color(m, spike_clusters_rel=sc, n_clusters=n_clusters) assert color.shape == (n_spikes, 4) - # ms = (self._default_marker_size if sc is not None else 1.) - ms = self._default_marker_size + ms = (self._default_marker_size if sc is not None else 1.) self.scatter(x=x, y=y, From 5a43a99ba716fed9eb5fe15047d6c54b84f5a957 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 29 Jan 2016 14:47:55 +0100 Subject: [PATCH 0969/1059] Disable slow debug log --- phy/io/context.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index 71762d65a..6bedb9a3a 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -208,12 +208,10 @@ def memcached(*args, **kwargs): """Cache the function in memory.""" h = hash((args, kwargs)) if h in c: - logger.debug("Get %s(%s) from memcache.", name, str(args)) - # Retrieve the value from the memcache. + # logger.debug("Get %s(%s) from memcache.", name, str(args)) return c[h] else: - logger.debug("Get %s(%s) from joblib.", name, str(args)) - # Call and cache the function. + # logger.debug("Compute %s(%s).", name, str(args)) out = f(*args, **kwargs) c[h] = out return out From bf39dbc3e11ea247e8bcf51510b3e2d96c7ca88a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 1 Feb 2016 16:46:43 +0100 Subject: [PATCH 0970/1059] Detrend traces in select_traces() --- phy/cluster/manual/views.py | 1 + 1 file changed, 1 insertion(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index b597c9c08..96c81bd83 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -556,6 +556,7 @@ def select_traces(traces, interval, sample_rate=None): i, j = round(sample_rate * start), round(sample_rate * end) i, j = int(i), int(j) traces = traces[i:j, :] + traces = traces - np.mean(traces, axis=0) return traces From 0069e0dba39217e1ae79a1185cf6fc54c1cdbf32 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 1 Feb 2016 20:51:16 +0100 Subject: [PATCH 0971/1059] WIP: add Controller --- phy/cluster/manual/controller.py | 300 ++++++++++++++++++++ phy/cluster/manual/tests/conftest.py | 55 +++- phy/cluster/manual/tests/test_controller.py | 17 ++ phy/cluster/manual/tests/test_views.py | 232 +-------------- phy/cluster/manual/views.py | 8 +- phy/io/array.py | 6 +- 6 files changed, 390 insertions(+), 228 deletions(-) create mode 100644 phy/cluster/manual/controller.py create mode 100644 phy/cluster/manual/tests/test_controller.py diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py new file mode 100644 index 000000000..c77653ae3 --- /dev/null +++ b/phy/cluster/manual/controller.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- + +"""Controller: model -> views.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging + +import numpy as np + +from phy.cluster.manual.gui_component import ManualClustering +from phy.cluster.manual.views import (WaveformView, + TraceView, + FeatureView, + CorrelogramView, + select_traces, + extract_spikes, + ) +from phy.io.array import _get_data_lim, concat_per_cluster +from phy.io import Context, Selector +from phy.stats.clusters import (mean, + get_waveform_amplitude, + ) +from phy.utils import Bunch + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Kwik GUI +#------------------------------------------------------------------------------ + +class Controller(object): + """Take data out of the model and feeds it to views.""" + # responsible for the cache + def __init__(self): + self._init_data() + self._init_selector() + self._init_context() + + # Internal methods + # ------------------------------------------------------------------------- + + def _init_data(self): # pragma: no cover + self.cache_dir = None + # Child classes must set these variables. + self.spike_times = None # (n_spikes,) array + + # TODO: make sure these structures are updated during a session + self.spike_clusters = None # (n_spikes,) array + self.cluster_groups = None # dict {cluster_id: None/'noise'/'mua'} + self.cluster_ids = None + + self.channel_positions = None # (n_channels, 2) array + self.n_samples_waveforms = None # int > 0 + self.n_channels = None # int > 0 + self.n_features_per_channel = None # int > 0 + self.sample_rate = None # float + self.duration = None # float + + self.all_masks = None + self.all_waveforms = None + self.all_features = None + + def _init_selector(self): + self.selector = Selector(self.spikes_per_cluster) + + def _init_context(self): + assert self.cache_dir + self.context = Context(self.cache_dir) + ctx = self.context + + self.get_masks = ctx.cache(self.get_masks) + self.get_features = ctx.cache(self.get_features) + self.get_waveforms = ctx.cache(self.get_waveforms) + self.get_background_features = ctx.cache(self.get_background_features) + + self.get_mean_masks = ctx.memcache(self.get_mean_masks) + self.get_mean_features = ctx.memcache(self.get_mean_features) + self.get_mean_waveforms = ctx.memcache(self.get_mean_waveforms) + + self.get_waveform_lim = ctx.cache(self.get_waveform_lim) + self.get_feature_lim = ctx.cache(self.get_feature_lim) + + self.get_waveform_amplitude = ctx.memcache(ctx.cache( + self.get_waveforms_amplitude)) + self.get_best_channel_position = ctx.memcache( + self.get_best_channel_position) + self.get_close_clusters = ctx.memcache(ctx.cache( + self.get_close_clusters)) + + self.spikes_per_cluster = ctx.memcache(self.spikes_per_cluster) + + def _select_spikes(self, cluster_id, n_max=None): + assert isinstance(cluster_id, int) + assert cluster_id >= 0 + return self.selector.select_spikes([cluster_id], n_max) + + def _select_data(self, cluster_id, arr, n_max=None): + spike_ids = self._select_spikes(cluster_id, n_max) + b = Bunch() + b.data = arr[spike_ids] + b.spike_ids = spike_ids + b.spike_clusters = self.spike_clusters[spike_ids] + b.masks = self.all_masks[spike_ids] + return b + + def _data_lim(self, arr, n_max): + return _get_data_lim(arr, n_spikes=n_max) + + # Masks + # ------------------------------------------------------------------------- + + # Is cached in _init_context() + @concat_per_cluster + def get_masks(self, cluster_id): + return self._select_data(cluster_id, + self.all_masks, + 100, # TODO + ) + + def get_mean_masks(self, cluster_id): + return mean(self.get_masks(cluster_id).data) + + # Waveforms + # ------------------------------------------------------------------------- + + # Is cached in _init_context() + @concat_per_cluster + def get_waveforms(self, cluster_id): + return self._select_data(cluster_id, + self.all_waveforms, + 100, # TODO + ) + + def get_mean_waveforms(self, cluster_id): + return mean(self.get_waveforms(cluster_id).data) + + def get_waveform_lim(self): + return 1 + + def get_waveforms_amplitude(self, cluster_id): + mm = self.get_mean_masks(cluster_id) + mw = self.get_mean_waveforms(cluster_id) + assert mw.ndim == 2 + return get_waveform_amplitude(mm, mw) + + # Features + # ------------------------------------------------------------------------- + + # Is cached in _init_context() + @concat_per_cluster + def get_features(self, cluster_id): + return self._select_data(cluster_id, + self.all_features, + 1000, # TODO + ) + + def get_background_features(self): + k = max(1, int(self.n_spikes // 1000)) + spike_ids = slice(None, None, k) + b = Bunch() + b.data = self.all_features[spike_ids] + b.spike_ids = spike_ids + b.spike_clusters = self.spike_clusters[spike_ids] + b.masks = self.all_masks[spike_ids] + return b + + def get_mean_features(self, cluster_id): + return mean(self.get_features(cluster_id).data) + + def get_feature_lim(self): + return 1 + + # Traces + # ------------------------------------------------------------------------- + + def get_traces(self, interval): + tr = select_traces(self.all_traces, interval, + sample_rate=self.sample_rate, + ) + return [Bunch(traces=tr)] + + def get_spikes_traces(self, interval, traces): + # NOTE: we extract the spikes from the first traces array. + traces = traces[0].traces + b = extract_spikes(traces, interval, + sample_rate=self.sample_rate, + spike_times=self.spike_times, + spike_clusters=self.spike_clusters, + all_masks=self.all_masks, + n_samples_waveforms=self.n_samples_waveforms, + ) + return b + + # Cluster statistics + # ------------------------------------------------------------------------- + + def get_best_channel(self, cluster_id): + wa = self.get_waveforms_amplitude(cluster_id) + return int(wa.argmax()) + + def get_best_channels(self, cluster_ids): + channels = [self.get_best_channel(cluster_id) + for cluster_id in cluster_ids] + return list(set(channels)) + + # def get_channels_by_amplitude(self, cluster_ids): + # wa = self.get_waveforms_amplitude(cluster_ids[0]) + # return np.argsort(wa)[::-1].tolist() + + def get_best_channel_position(self, cluster_id): + cha = self.get_best_channel(cluster_id) + return tuple(self.channel_positions[cha]) + + def get_probe_depth(self, cluster_id): + return self.get_best_channel_position(cluster_id)[1] + + def get_close_clusters(self, cluster_id): + assert isinstance(cluster_id, int) + # Position of the cluster's best channel. + pos0 = self.get_best_channel_position(cluster_id) + n = len(pos0) + assert n in (2, 3) + # Positions of all clusters' best channels. + clusters = self.cluster_ids + pos = np.vstack([self.get_best_channel_position(int(clu)) + for clu in clusters]) + assert pos.shape == (len(clusters), n) + # Distance of all clusters to the current cluster. + dist = (pos - pos0) ** 2 + assert dist.shape == (len(clusters), n) + dist = np.sum(dist, axis=1) ** .5 + assert dist.shape == (len(clusters),) + # Closest clusters. + ind = np.argsort(dist) + ind = ind[:100] # TODO + return [(int(clusters[i]), float(dist[i])) for i in ind] + + def spikes_per_cluster(self, cluster_id): + return np.nonzero(self.spike_clusters == cluster_id)[0] + + # View methods + # ------------------------------------------------------------------------- + + def add_waveform_view(self, gui): + v = WaveformView(waveforms=self.get_waveforms, + channel_positions=self.channel_positions, + n_samples=self.n_samples_waveforms, + waveform_lim=self.get_waveform_lim(), + best_channels=self.get_best_channels, + ) + v.attach(gui) + return v + + def add_trace_view(self, gui): + v = TraceView(traces=self.get_traces, + spikes=self.get_spikes_traces, + sample_rate=self.sample_rate, + duration=self.duration, + n_channels=self.n_channels, + ) + v.attach(gui) + return v + + def add_feature_view(self, gui): + v = FeatureView(features=self.get_features, + background_features=self.get_background_features(), + spike_times=self.spike_times, + n_channels=self.n_channels, + n_features_per_channel=self.n_features_per_channel, + feature_lim=self.get_feature_lim(), + best_channels=self.get_best_channels, + ) + v.attach(gui) + return v + + def add_correlogram_view(self, gui): + v = CorrelogramView(spike_times=self.spike_times, + spike_clusters=self.spike_clusters, + sample_rate=self.sample_rate, + ) + v.attach(gui) + return v + + # GUI methods + # ------------------------------------------------------------------------- + + def set_manual_clustering(self, gui): + mc = ManualClustering(self.spike_clusters, + self.spikes_per_cluster, + similarity=self.get_close_clusters, + cluster_groups=self.cluster_groups, + ) + self.manual_clustering = mc + mc.add_column(self.get_probe_depth) + mc.attach(gui) diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index 4b37e30d2..a22d6bd9e 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -8,7 +8,18 @@ from pytest import fixture -from phy.io.array import get_closest_clusters +import numpy as np + +from phy.cluster.manual.controller import Controller +from phy.electrode.mea import staggered_positions +from phy.io.array import (get_closest_clusters, + _spikes_in_clusters, + ) +from phy.io.mock import (artificial_waveforms, + artificial_features, + artificial_masks, + artificial_traces, + ) #------------------------------------------------------------------------------ @@ -40,3 +51,45 @@ def similarity(cluster_ids): def similarity(c): return get_closest_clusters(c, cluster_ids, sim) return similarity + + +class MockController(Controller): + def __init__(self, tempdir): + self.tempdir = tempdir + super(MockController, self).__init__() + + def _init_data(self): + self.cache_dir = self.tempdir + self.n_samples_waveforms = 31 + self.n_samples_t = 20000 + self.n_channels = 11 + self.n_clusters = 4 + self.n_spikes_per_cluster = 50 + n_spikes_total = self.n_clusters * self.n_spikes_per_cluster + n_features_per_channel = 4 + + self.n_channels = self.n_channels + self.n_spikes = n_spikes_total + self.sample_rate = 20000. + self.duration = self.n_samples_t / float(self.sample_rate) + self.spike_times = np.arange(0, self.duration, 100. / self.sample_rate) + self.spike_clusters = np.repeat(np.arange(self.n_clusters), + self.n_spikes_per_cluster) + assert len(self.spike_times) == len(self.spike_clusters) + self.cluster_ids = np.unique(self.spike_clusters) + self.channel_positions = staggered_positions(self.n_channels) + + sc = self.spike_clusters + self.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c]) + self.spike_count = lambda c: len(self.spikes_per_cluster(c)) + self.n_features_per_channel = n_features_per_channel + self.cluster_groups = {c: None for c in range(self.n_clusters)} + + self.all_traces = artificial_traces(self.n_samples_t, self.n_channels) + self.all_masks = artificial_masks(n_spikes_total, self.n_channels) + self.all_waveforms = artificial_waveforms(n_spikes_total, + self.n_samples_waveforms, + self.n_channels) + self.all_features = artificial_features(n_spikes_total, + self.n_channels, + self.n_features_per_channel) diff --git a/phy/cluster/manual/tests/test_controller.py b/phy/cluster/manual/tests/test_controller.py new file mode 100644 index 000000000..2ecc0e430 --- /dev/null +++ b/phy/cluster/manual/tests/test_controller.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +"""Test controller.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 5a30b99ac..ce1eb3856 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -12,184 +12,16 @@ from vispy.util import keys from pytest import fixture -from phy.electrode.mea import staggered_positions from phy.gui import create_gui -from phy.io.array import _spikes_in_clusters, concat_per_cluster -from phy.io.mock import (artificial_waveforms, - artificial_features, - artificial_masks, - artificial_traces, - ) -from phy.stats.clusters import (mean, - get_unmasked_channels, - get_sorted_main_channels, - ) from phy.utils import Bunch -from ..gui_component import ManualClustering -from ..views import (WaveformView, - FeatureView, - CorrelogramView, - TraceView, - ScatterView, - select_traces, +from .conftest import MockController +from ..views import (ScatterView, _extract_wave, - extract_spikes, _selected_clusters_colors, _extend, ) -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -def create_model(): - model = Bunch() - - n_samples_waveforms = 31 - n_samples_t = 20000 - n_channels = 11 - n_clusters = 4 - model.n_spikes_per_cluster = 50 - n_spikes_total = n_clusters * model.n_spikes_per_cluster - n_features_per_channel = 4 - - model.n_channels = n_channels - model.n_spikes = n_spikes_total - model.sample_rate = 20000. - model.duration = n_samples_t / float(model.sample_rate) - model.spike_times = np.arange(0, model.duration, 100. / model.sample_rate) - model.spike_clusters = np.repeat(np.arange(n_clusters), - model.n_spikes_per_cluster) - assert len(model.spike_times) == len(model.spike_clusters) - model.cluster_ids = np.unique(model.spike_clusters) - model.channel_positions = staggered_positions(n_channels) - - sc = model.spike_clusters - model.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c]) - model.n_features_per_channel = n_features_per_channel - model.n_samples_waveforms = n_samples_waveforms - model.cluster_groups = {c: None for c in range(n_clusters)} - - all_traces = artificial_traces(n_samples_t, n_channels) - all_masks = artificial_masks(n_spikes_total, n_channels) - - def traces(interval): - """Load traces and spikes in an interval.""" - tr = select_traces(all_traces, interval, - sample_rate=model.sample_rate, - ) - return [Bunch(traces=tr)] - model.traces = traces - - def spikes_traces(interval, traces): - traces = traces[0].traces - b = extract_spikes(traces, interval, - sample_rate=model.sample_rate, - spike_times=model.spike_times, - spike_clusters=model.spike_clusters, - all_masks=all_masks, - n_samples_waveforms=model.n_samples_waveforms, - ) - return b - model.spikes_traces = spikes_traces - - def get_waveforms(n): - return artificial_waveforms(n, - model.n_samples_waveforms, - model.n_channels) - - def get_masks(n): - return artificial_masks(n, model.n_channels) - - def get_features(n): - return artificial_features(n, - model.n_channels, - model.n_features_per_channel) - - def get_spike_ids(cluster_id): - n = model.n_spikes_per_cluster - return np.arange(n) + n * cluster_id - - def _get_data(**kwargs): - kwargs['spike_clusters'] = model.spike_clusters[kwargs['spike_ids']] - return Bunch(**kwargs) - - @concat_per_cluster - def masks(cluster_id): - return _get_data(spike_ids=get_spike_ids(cluster_id), - masks=get_masks(model.n_spikes_per_cluster)) - - @concat_per_cluster - def features(cluster_id): - return _get_data(spike_ids=get_spike_ids(cluster_id), - features=get_features(model.n_spikes_per_cluster), - masks=get_masks(model.n_spikes_per_cluster)) - model.features = features - - def feature_lim(): - """Return the max of a subset of the feature amplitudes.""" - return 1 - model.feature_lim = feature_lim - - def background_features(): - f = get_features(model.n_spikes) - m = all_masks - return _get_data(spike_ids=np.arange(model.n_spikes), - features=f, masks=m) - model.background_features = background_features - - def waveform_lim(): - """Return the max of a subset of the waveform amplitudes.""" - return 1 - model.waveform_lim = waveform_lim - - @concat_per_cluster - def waveforms(cluster_id): - w = get_waveforms(model.n_spikes_per_cluster) - m = get_masks(model.n_spikes_per_cluster) - return _get_data(spike_ids=get_spike_ids(cluster_id), - waveforms=w, - masks=m, - ) - model.waveforms = waveforms - - # Mean quantities. - # ------------------------------------------------------------------------- - - def mean_masks(cluster_id): - # We access [1] because we return spike_ids, masks. - return mean(masks(cluster_id).masks) - model.mean_masks = mean_masks - - def mean_features(cluster_id): - return mean(features(cluster_id).features) - model.mean_features = mean_features - - def mean_waveforms(cluster_id): - return mean(waveforms(cluster_id).waveforms) - model.mean_waveforms = mean_waveforms - - # Statistics. - # ------------------------------------------------------------------------- - - def best_channels(cluster_id): - mm = mean_masks(cluster_id) - uch = get_unmasked_channels(mm) - return get_sorted_main_channels(mm, uch) - model.best_channels = best_channels - - def best_channels_multiple(cluster_ids): - bc = [] - for cluster in cluster_ids: - channels = best_channels(cluster) - bc.extend([ch for ch in channels if ch not in bc]) - return bc - model.best_channels_multiple = best_channels_multiple - - return model - - #------------------------------------------------------------------------------ # Utils #------------------------------------------------------------------------------ @@ -202,29 +34,20 @@ def state(tempdir): state.TraceView0 = Bunch(scaling=1.) state.FeatureView0 = Bunch(feature_scaling=.5) state.CorrelogramView0 = Bunch(uniform_normalization=True) - - # quality and similarity functions for the cluster view. - state.ClusterView = Bunch(quality='max_waveform_amplitude', - similarity='most_similar_clusters') return state @fixture def gui(tempdir, state): - model = create_model() gui = create_gui(config_dir=tempdir, **state) - mc = ManualClustering(model.spike_clusters, - model.spikes_per_cluster, - cluster_groups=model.cluster_groups,) - mc.attach(gui) - gui.model = model - gui.manual_clustering = mc + gui.controller = MockController(tempdir) + gui.controller.set_manual_clustering(gui) return gui def _select_clusters(gui): gui.show() - mc = gui.manual_clustering + mc = gui.controller.manual_clustering assert mc mc.select([]) mc.select([0]) @@ -275,15 +98,7 @@ def test_selected_clusters_colors(): #------------------------------------------------------------------------------ def test_waveform_view(qtbot, gui): - model = gui.model - v = WaveformView(waveforms=model.waveforms, - channel_positions=model.channel_positions, - n_samples=model.n_samples_waveforms, - waveform_lim=model.waveform_lim(), - best_channels=model.best_channels_multiple, - ) - v.attach(gui) - + v = gui.controller.add_waveform_view(gui) _select_clusters(gui) ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) @@ -339,7 +154,7 @@ def on_channel_click(channel_idx=None, button=None, key=None): assert _clicked == [(0, 1, 2)] - # qtbot.stop() + qtbot.stop() gui.close() @@ -348,15 +163,7 @@ def on_channel_click(channel_idx=None, button=None, key=None): #------------------------------------------------------------------------------ def test_trace_view(qtbot, gui): - model = gui.model - - v = TraceView(traces=model.traces, - spikes=model.spikes_traces, - sample_rate=model.sample_rate, - duration=model.duration, - n_channels=model.n_channels, - ) - v.attach(gui) + v = gui.controller.add_trace_view(gui) _select_clusters(gui) @@ -384,7 +191,7 @@ def test_trace_view(qtbot, gui): ac(v.interval, (.25, .75)) # Widen the max interval. - v.set_interval((0, model.duration)) + v.set_interval((0, gui.controller.duration)) v.widen() # Change channel scaling. @@ -405,22 +212,12 @@ def test_trace_view(qtbot, gui): #------------------------------------------------------------------------------ def test_feature_view(qtbot, gui): - model = gui.model - bfm = model.background_features() - v = FeatureView(features=model.features, - background_features=bfm, - spike_times=model.spike_times, - n_channels=model.n_channels, - n_features_per_channel=model.n_features_per_channel, - feature_lim=model.feature_lim(), - ) - v.attach(gui) - + v = gui.controller.add_feature_view(gui) _select_clusters(gui) assert v.feature_scaling == .5 v.add_attribute('sine', - np.sin(np.linspace(-10., 10., model.n_spikes))) + np.sin(np.linspace(-10., 10., gui.controller.n_spikes))) v.increase() v.decrease() @@ -460,12 +257,7 @@ def test_scatter_view(qtbot, gui): #------------------------------------------------------------------------------ def test_correlogram_view(qtbot, gui): - model = gui.model - v = CorrelogramView(spike_times=model.spike_times, - spike_clusters=model.spike_clusters, - sample_rate=model.sample_rate, - ) - v.attach(gui) + v = gui.controller.add_correlogram_view(gui) _select_clusters(gui) v.toggle_normalization() diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 96c81bd83..83fcc006b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -115,7 +115,7 @@ def _extend(channels, n=None): channels = list(channels) if n is None: return channels - if not len(channels): + if not len(channels): # pragma: no cover channels = [0] if len(channels) < n: channels.extend([channels[-1]] * (n - len(channels))) @@ -339,7 +339,7 @@ def on_select(self, cluster_ids=None): alpha = data.alpha spike_ids = data.spike_ids spike_clusters = data.spike_clusters - w = data.waveforms + w = data.data masks = data.masks n_spikes = len(spike_ids) assert w.shape == (n_spikes, self.n_samples, self.n_channels) @@ -1111,7 +1111,7 @@ def on_select(self, cluster_ids=None): data = self.features(cluster_ids) spike_ids = data.spike_ids spike_clusters = data.spike_clusters - f = data.features + f = data.data masks = data.masks assert f.ndim == 3 assert masks.ndim == 2 @@ -1124,7 +1124,7 @@ def on_select(self, cluster_ids=None): data_bg = self.background_features if data_bg is not None: spike_ids_bg = data_bg.spike_ids - features_bg = data_bg.features + features_bg = data_bg.data masks_bg = data_bg.masks # Select the dimensions. diff --git a/phy/io/array.py b/phy/io/array.py index 03df1608e..bc581667e 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -214,12 +214,12 @@ def concat_per_cluster(f): """Take a function accepting a single cluster, and return a function accepting multiple clusters.""" @wraps(f) - def wrapped(cluster_ids): + def wrapped(self, cluster_ids): # Single cluster. if not hasattr(cluster_ids, '__len__'): - return f(cluster_ids) + return f(self, cluster_ids) # Concatenate the result of multiple clusters. - return Bunch(_accumulate([f(c) for c in cluster_ids])) + return Bunch(_accumulate([f(self, c) for c in cluster_ids])) return wrapped From 753f10e18573fdb8a40372d8d17b964d11bd04c7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 1 Feb 2016 20:51:40 +0100 Subject: [PATCH 0972/1059] Fix --- phy/cluster/manual/tests/test_views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index ce1eb3856..529ec22d4 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -154,7 +154,7 @@ def on_channel_click(channel_idx=None, button=None, key=None): assert _clicked == [(0, 1, 2)] - qtbot.stop() + # qtbot.stop() gui.close() From 9309316d137b51376469c944735d6613503fb606 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 1 Feb 2016 21:36:48 +0100 Subject: [PATCH 0973/1059] Update --- phy/cluster/manual/controller.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index c77653ae3..b9a583311 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -41,6 +41,8 @@ def __init__(self): self._init_selector() self._init_context() + self.n_spikes = len(self.spike_times) + # Internal methods # ------------------------------------------------------------------------- @@ -61,9 +63,10 @@ def _init_data(self): # pragma: no cover self.sample_rate = None # float self.duration = None # float - self.all_masks = None - self.all_waveforms = None - self.all_features = None + self.all_masks = None # (n_spikes, n_channels) + self.all_waveforms = None # (n_spikes, n_samples, n_channels) + self.all_features = None # (n_spikes, n_channels, n_features) + self.all_traces = None # (n_samples_traces, n_channels) def _init_selector(self): self.selector = Selector(self.spikes_per_cluster) From 8bbca6d2f78b54caabbc28f6e224d716de67d28b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 2 Feb 2016 14:38:13 +0100 Subject: [PATCH 0974/1059] Fix joblib --- phy/io/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/io/context.py b/phy/io/context.py index 6bedb9a3a..3efce6055 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -192,7 +192,7 @@ def cache(self, f): logger.debug("Joblib is not installed: skipping cacheing.") return f assert f - disk_cached = self._memory.cache(f) + disk_cached = self._memory.cache(f, ignore=['self']) return disk_cached def memcache(self, f): From 92555880f1c4a2e412c4b0fbd8e9d5011cb84f4a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 2 Feb 2016 14:49:52 +0100 Subject: [PATCH 0975/1059] Fix --- phy/io/context.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/phy/io/context.py b/phy/io/context.py index 3efce6055..603dab81e 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -7,6 +7,7 @@ #------------------------------------------------------------------------------ from functools import wraps +import inspect import logging import os import os.path as op @@ -192,7 +193,12 @@ def cache(self, f): logger.debug("Joblib is not installed: skipping cacheing.") return f assert f - disk_cached = self._memory.cache(f, ignore=['self']) + # NOTE: discard self in instance methods. + if 'self' in inspect.getargspec(f).args: + ignore = ['self'] + else: + ignore = None + disk_cached = self._memory.cache(f, ignore=ignore) return disk_cached def memcache(self, f): From c9cb4fb77935a7e512ceedbd64d9fa40c07a1f6c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 2 Feb 2016 15:08:25 +0100 Subject: [PATCH 0976/1059] Better best channels in feature view --- phy/cluster/manual/controller.py | 8 ++-- phy/cluster/manual/views.py | 69 +++++++++++++++++--------------- 2 files changed, 40 insertions(+), 37 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index b9a583311..aa6c45417 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -211,9 +211,9 @@ def get_best_channels(self, cluster_ids): for cluster_id in cluster_ids] return list(set(channels)) - # def get_channels_by_amplitude(self, cluster_ids): - # wa = self.get_waveforms_amplitude(cluster_ids[0]) - # return np.argsort(wa)[::-1].tolist() + def get_channels_by_amplitude(self, cluster_ids): + wa = self.get_waveforms_amplitude(cluster_ids[0]) + return np.argsort(wa)[::-1].tolist() def get_best_channel_position(self, cluster_id): cha = self.get_best_channel(cluster_id) @@ -276,7 +276,7 @@ def add_feature_view(self, gui): n_channels=self.n_channels, n_features_per_channel=self.n_features_per_channel, feature_lim=self.get_feature_lim(), - best_channels=self.get_best_channels, + best_channels=self.get_channels_by_amplitude, ) v.attach(gui) return v diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 83fcc006b..d946c4a8c 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -875,34 +875,41 @@ def decrease(self): # Feature view # ----------------------------------------------------------------------------- -def _dimensions_matrix(x_channels, y_channels, n_cols=None, - top_left_attribute=None): - """Dimensions matrix.""" - # time, attr time, (x, 0) time, (y, 0) time, (z, 0) - # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) - # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) - # time, (z', 0) (z', 0), (x, 2) (z', 1), (y, 2) (z', 2), (z, 2) +def _dimensions_matrix(channels, n_cols=None, top_left_attribute=None): + """Dimension matrix.""" + # time, attr time, (x, 0) time, (y, 0) time, (z, 0) + # time, (x, 1) (x, 0), (x, 1) (x, 0), (y, 0) (x, 0), (z, 0) + # time, (y, 1) (x, 1), (y, 1) (y, 0), (y, 1) (y, 0), (z, 0) + # time, (z, 1) (x, 1), (z, 1) (y, 1), (z, 1) (z, 0), (z, 1) assert n_cols > 0 - assert len(x_channels) >= n_cols - 1 - assert len(y_channels) >= n_cols - 1 + assert len(channels) >= n_cols - 1 y_dim = {} x_dim = {} x_dim[0, 0] = 'time' y_dim[0, 0] = top_left_attribute or 'time' - # Time in first column and first row. for i in range(1, n_cols): + # First line. x_dim[0, i] = 'time' - y_dim[0, i] = (x_channels[i - 1], 0) + y_dim[0, i] = (channels[i - 1], 0) + # First column. x_dim[i, 0] = 'time' - y_dim[i, 0] = (y_channels[i - 1], 0) + y_dim[i, 0] = (channels[i - 1], 1) + # Diagonal. + x_dim[i, i] = (channels[i - 1], 0) + y_dim[i, i] = (channels[i - 1], 1) for i in range(1, n_cols): - for j in range(1, n_cols): - x_dim[i, j] = (x_channels[i - 1], j - 1) - y_dim[i, j] = (y_channels[j - 1], i - 1) + for j in range(i + 1, n_cols): + assert j > i + # Above the diagonal. + x_dim[i, j] = (channels[i - 1], 0) + y_dim[i, j] = (channels[j - 1], 0) + # Below the diagonal. + x_dim[j, i] = (channels[i - 1], 1) + y_dim[j, i] = (channels[j - 1], 1) return x_dim, y_dim @@ -984,8 +991,7 @@ def __init__(self, self.fixed_channels = False # Channels to show. - self.x_channels = None - self.y_channels = None + self.channels = None # Attributes: extra features. This is a dictionary # {name: (array, data_bounds)} @@ -1006,6 +1012,7 @@ def _get_feature(self, dim, spike_ids, f): else: assert len(dim) == 2 ch, fet = dim + assert fet < f.shape[2] return f[:, ch, fet] * self._scaling def _get_dim_bounds_single(self, dim): @@ -1073,9 +1080,9 @@ def _get_channel_dims(self, cluster_ids): channels = self.best_channels(cluster_ids) channels = (channels if channels is not None else list(range(self.n_channels))) - channels = _extend(channels, 2 * n) - assert len(channels) == 2 * n - return channels[:n], channels[n:] + channels = _extend(channels, n) + assert len(channels) == n + return channels # Public methods # ------------------------------------------------------------------------- @@ -1097,7 +1104,7 @@ def add_attribute(self, name, values, top_left=True): def clear_channels(self): """Reset the dimensions.""" - self.x_channels = self.y_channels = None + self.channels = None self.on_select() def on_select(self, cluster_ids=None): @@ -1129,21 +1136,17 @@ def on_select(self, cluster_ids=None): # Select the dimensions. # Choose the channels automatically unless fixed_channels is set. - if (not self.fixed_channels or self.x_channels is None or - self.y_channels is None): - self.x_channels, self.y_channels = self._get_channel_dims( - cluster_ids) + if (not self.fixed_channels or self.channels is None): + self.channels = self._get_channel_dims(cluster_ids) tla = self.top_left_attribute - assert self.x_channels - assert self.y_channels - x_dim, y_dim = _dimensions_matrix(self.x_channels, self.y_channels, + assert self.channels + x_dim, y_dim = _dimensions_matrix(self.channels, n_cols=self.n_cols, top_left_attribute=tla) # Set the status message. - ch_i = ', '.join(map(str, self.x_channels)) - ch_j = ', '.join(map(str, self.y_channels)) - self.set_status('Channels: {} - {}'.format(ch_i, ch_j)) + ch = ', '.join(map(str, self.channels)) + self.set_status('Channels: {}'.format(ch)) # Set a non-time attribute as y coordinate in the top-left subplot. attrs = sorted(self.attributes) @@ -1204,9 +1207,9 @@ def on_channel_click(self, channel_idx=None, key=None, button=None): if key is None or not (1 <= key <= (self.n_cols - 1)): return # Get the axis from the pressed button (1, 2, etc.) - axis = 'x' if button == 1 else 'y' + # axis = 'x' if button == 1 else 'y' # Get the existing channels. - channels = self.x_channels if axis == 'x' else self.y_channels + channels = self.channels if channels is None: return assert len(channels) == self.n_cols - 1 From 37ab1184394b75108873c053d32a0495fcf9f83d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 2 Feb 2016 15:15:02 +0100 Subject: [PATCH 0977/1059] Waveform and feature data lim --- phy/cluster/manual/controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index aa6c45417..47147cca1 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -143,7 +143,7 @@ def get_mean_waveforms(self, cluster_id): return mean(self.get_waveforms(cluster_id).data) def get_waveform_lim(self): - return 1 + return self._data_lim(self.all_waveforms, 100) # TODO def get_waveforms_amplitude(self, cluster_id): mm = self.get_mean_masks(cluster_id) @@ -176,7 +176,7 @@ def get_mean_features(self, cluster_id): return mean(self.get_features(cluster_id).data) def get_feature_lim(self): - return 1 + return self._data_lim(self.all_features, 100) # TODO # Traces # ------------------------------------------------------------------------- From d1e89f1ca85cf293b52cac3f90aa0502f125674c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 3 Feb 2016 13:18:39 +0100 Subject: [PATCH 0978/1059] Clear the scatter view when there is no data --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 1 + phy/plot/plot.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 529ec22d4..09e684eca 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -241,7 +241,7 @@ def test_scatter_view(qtbot, gui): spike_ids=np.arange(n), spike_clusters=np.ones(n). astype(np.int32) * c[0], - ), + ) if 2 not in c else None, data_bounds=[-3, -3, 3, 3], ) v.attach(gui) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d946c4a8c..4abba43c9 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1416,6 +1416,7 @@ def on_select(self, cluster_ids=None): # Get the spike times and amplitudes data = self.coords(cluster_ids) if data is None: + self.clear() return spike_ids = data.spike_ids spike_clusters = data.spike_clusters diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 153c04fc8..3c02aff79 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -80,6 +80,7 @@ def clear(self): """Reset the view.""" self._items = OrderedDict() self.visuals = [] + self.update() def _add_item(self, cls, *args, **kwargs): """Add a plot item.""" From cb31494193a9ccdf5b850ed285fc9a93562a29bd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 3 Feb 2016 13:26:07 +0100 Subject: [PATCH 0979/1059] Resiliency to failing waveform extraction in trace view --- phy/cluster/manual/views.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 4abba43c9..318a2d448 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -587,10 +587,11 @@ def extract_spikes(traces, interval, sample_rate=None, # Find the start of the waveform in the extracted traces. sample_start = int(round((st[i] - interval[0]) * sr)) sample_start -= offset_samples - b.waveforms, b.channels = _extract_wave(traces, - sample_start, - m[i], - wave_len) + o = _extract_wave(traces, sample_start, m[i], wave_len) + if o is None: # pragma: no cover + logger.debug("Unable to extract spike %d.", i) + continue + b.waveforms, b.channels = o # Masks on unmasked channels. b.masks = m[i, b.channels] b.spike_time = st[i] From 7471785e63d14596fde0303056fe9ed4d0acdb8d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 3 Feb 2016 15:54:09 +0100 Subject: [PATCH 0980/1059] WIP: support multiple sets of waveforms in waveform view --- phy/cluster/manual/controller.py | 11 ++++----- phy/cluster/manual/tests/test_views.py | 1 + phy/cluster/manual/views.py | 33 +++++++++++++------------- phy/io/array.py | 13 +++++++--- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 47147cca1..3895e12b5 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -134,13 +134,13 @@ def get_mean_masks(self, cluster_id): # Is cached in _init_context() @concat_per_cluster def get_waveforms(self, cluster_id): - return self._select_data(cluster_id, - self.all_waveforms, - 100, # TODO - ) + return [self._select_data(cluster_id, + self.all_waveforms, + 100, # TODO + )] def get_mean_waveforms(self, cluster_id): - return mean(self.get_waveforms(cluster_id).data) + return mean(self.get_waveforms(cluster_id)[0].data) def get_waveform_lim(self): return self._data_lim(self.all_waveforms, 100) # TODO @@ -252,7 +252,6 @@ def spikes_per_cluster(self, cluster_id): def add_waveform_view(self, gui): v = WaveformView(waveforms=self.get_waveforms, channel_positions=self.channel_positions, - n_samples=self.n_samples_waveforms, waveform_lim=self.get_waveform_lim(), best_channels=self.get_best_channels, ) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 09e684eca..233f00e13 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -140,6 +140,7 @@ def test_waveform_view(qtbot, gui): ac(v.box_scaling, (a * 2, b)) v.zoom_on_channels([0, 2, 4]) + v.next_data() # Simulate channel selection. _clicked = [] diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 318a2d448..2d025d350 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -255,6 +255,7 @@ class WaveformView(ManualClusteringView): default_shortcuts = { 'toggle_waveform_overlap': 'o', 'toggle_zoom_on_channels': 'z', + 'next_data': 'w', # Box scaling. 'widen': 'ctrl+right', @@ -272,13 +273,13 @@ class WaveformView(ManualClusteringView): def __init__(self, waveforms=None, channel_positions=None, - n_samples=None, waveform_lim=None, best_channels=None, **kwargs): self._key_pressed = None self._overlap = False self.do_zoom_on_channels = True + self.data_index = 0 self.best_channels = best_channels or (lambda clusters: []) @@ -287,12 +288,6 @@ def __init__(self, self.channel_positions = np.asarray(channel_positions) self.n_channels = self.channel_positions.shape[0] - # Number of samples per waveform. - n_samples = (sum(map(abs, n_samples)) if isinstance(n_samples, tuple) - else n_samples) - assert n_samples > 0 - self.n_samples = n_samples - # Initialize the view. box_bounds = _get_boxes(channel_positions) super(WaveformView, self).__init__(layout='boxed', @@ -322,11 +317,6 @@ def __init__(self, assert channel_positions.shape == (self.n_channels, 2) self.channel_positions = channel_positions - def _get_data(self, cluster_ids): - d = self.waveforms(cluster_ids) - d.alpha = .5 - return d - def on_select(self, cluster_ids=None): super(WaveformView, self).on_select(cluster_ids) cluster_ids = self.cluster_ids @@ -335,14 +325,18 @@ def on_select(self, cluster_ids=None): return # Load the waveform subset. - data = self._get_data(cluster_ids) - alpha = data.alpha + data = self.waveforms(cluster_ids) + # Take one element in the list. + data = data[self.data_index % len(data)] + alpha = data.get('alpha', .5) spike_ids = data.spike_ids spike_clusters = data.spike_clusters w = data.data masks = data.masks n_spikes = len(spike_ids) - assert w.shape == (n_spikes, self.n_samples, self.n_channels) + assert w.ndim == 3 + n_samples = w.shape[1] + assert w.shape == (n_spikes, n_samples, self.n_channels) assert masks.shape == (n_spikes, self.n_channels) # Relative spike clusters. @@ -350,7 +344,7 @@ def on_select(self, cluster_ids=None): assert spike_clusters_rel.shape == (n_spikes,) # Fetch the waveforms. - t = _get_linear_x(n_spikes, self.n_samples) + t = _get_linear_x(n_spikes, n_samples) # Overlap. if not self.overlap: t = t + 2.5 * (spike_clusters_rel[:, np.newaxis] - @@ -408,6 +402,8 @@ def attach(self, gui): self.actions.add(self.extend_vertically) self.actions.add(self.shrink_vertically) + self.actions.add(self.next_data) + # We forward the event from VisPy to the phy GUI. @self.connect def on_channel_click(e): @@ -511,6 +507,11 @@ def shrink_vertically(self): # Navigation # ------------------------------------------------------------------------- + def next_data(self): + """Show the next set of waveforms (if any).""" + self.data_index += 1 + self.on_select() + def toggle_zoom_on_channels(self): self.do_zoom_on_channels = not self.do_zoom_on_channels diff --git a/phy/io/array.py b/phy/io/array.py index bc581667e..80fb4af6b 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -219,7 +219,15 @@ def wrapped(self, cluster_ids): if not hasattr(cluster_ids, '__len__'): return f(self, cluster_ids) # Concatenate the result of multiple clusters. - return Bunch(_accumulate([f(self, c) for c in cluster_ids])) + l = [f(self, c) for c in cluster_ids] + # Handle the case where every function returns a list of Bunch. + if l and isinstance(l[0], list): + # We assume that all items have the same length. + n = len(l[0]) + return [Bunch(_accumulate([item[i] for item in l])) + for i in range(n)] + else: + return Bunch(_accumulate(l)) return wrapped @@ -502,8 +510,7 @@ def names(self): def __getitem__(self, name): """Concatenate all arrays with a given name.""" - return np.concatenate(self._data[name], axis=0). \ - astype(self._data[name][0].dtype) + return np.concatenate(self._data[name], axis=0) def _accumulate(data_list, no_concat=()): From 0d0f8ed276e10c9be3015e92059bf7324e3e3d2e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 3 Feb 2016 17:21:44 +0100 Subject: [PATCH 0981/1059] Fix waveform normalization --- phy/cluster/manual/controller.py | 21 +++++++++++++++++---- phy/cluster/manual/views.py | 10 +++++----- phy/io/array.py | 6 +++++- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 3895e12b5..998db5060 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -85,7 +85,7 @@ def _init_context(self): self.get_mean_features = ctx.memcache(self.get_mean_features) self.get_mean_waveforms = ctx.memcache(self.get_mean_waveforms) - self.get_waveform_lim = ctx.cache(self.get_waveform_lim) + self.get_waveform_lims = ctx.cache(self.get_waveform_lims) self.get_feature_lim = ctx.cache(self.get_feature_lim) self.get_waveform_amplitude = ctx.memcache(ctx.cache( @@ -142,8 +142,21 @@ def get_waveforms(self, cluster_id): def get_mean_waveforms(self, cluster_id): return mean(self.get_waveforms(cluster_id)[0].data) - def get_waveform_lim(self): - return self._data_lim(self.all_waveforms, 100) # TODO + def get_waveform_lims(self): + n_spikes = 100 # TODO + arr = self.all_waveforms + n = arr.shape[0] + k = max(1, n // n_spikes) + # Extract waveforms. + arr = arr[::k] + # Take the corresponding masks. + masks = self.all_masks[::k].copy() + arr = arr * masks[:, np.newaxis, :] + # NOTE: on some datasets, there are a few outliers that screw up + # the normalization. These parameters should be customizable. + m = np.percentile(arr, .05) + M = np.percentile(arr, 99.95) + return m, M def get_waveforms_amplitude(self, cluster_id): mm = self.get_mean_masks(cluster_id) @@ -252,7 +265,7 @@ def spikes_per_cluster(self, cluster_id): def add_waveform_view(self, gui): v = WaveformView(waveforms=self.get_waveforms, channel_positions=self.channel_positions, - waveform_lim=self.get_waveform_lim(), + waveform_lims=self.get_waveform_lims(), best_channels=self.get_best_channels, ) v.attach(gui) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 2d025d350..a6b2a9957 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -273,7 +273,7 @@ class WaveformView(ManualClusteringView): def __init__(self, waveforms=None, channel_positions=None, - waveform_lim=None, + waveform_lims=None, best_channels=None, **kwargs): self._key_pressed = None @@ -310,8 +310,8 @@ def __init__(self, self.waveforms = waveforms # Waveform normalization. - assert waveform_lim > 0 - self.data_bounds = [-1, -waveform_lim, +1, +waveform_lim] + assert len(waveform_lims) == 2 + self.data_bounds = [-1, waveform_lims[0], +1, waveform_lims[1]] # Channel positions. assert channel_positions.shape == (self.n_channels, 2) @@ -1098,8 +1098,8 @@ def add_attribute(self, name, values, top_left=True): """ assert values.shape == (self.n_spikes,) - lim = values.min(), values.max() - self.attributes[name] = (values, lim) + lims = values.min(), values.max() + self.attributes[name] = (values, lims) # Register the attribute to use in the top-left subplot. if top_left: self.top_left_attribute = name diff --git a/phy/io/array.py b/phy/io/array.py index 80fb4af6b..60b6ca8fc 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -510,7 +510,11 @@ def names(self): def __getitem__(self, name): """Concatenate all arrays with a given name.""" - return np.concatenate(self._data[name], axis=0) + l = self._data[name] + # Process scalars: only return the first one and don't concatenate. + if len(l) and not hasattr(l[0], '__len__'): + return l[0] + return np.concatenate(l, axis=0) def _accumulate(data_list, no_concat=()): From be324cdb5ad238fdf86ecd55c39a148ab7c31cec Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 3 Feb 2016 17:43:29 +0100 Subject: [PATCH 0982/1059] Fix test --- phy/cluster/manual/tests/test_views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 233f00e13..eddea5c3e 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -140,7 +140,6 @@ def test_waveform_view(qtbot, gui): ac(v.box_scaling, (a * 2, b)) v.zoom_on_channels([0, 2, 4]) - v.next_data() # Simulate channel selection. _clicked = [] @@ -155,6 +154,8 @@ def on_channel_click(channel_idx=None, button=None, key=None): assert _clicked == [(0, 1, 2)] + v.next_data() + # qtbot.stop() gui.close() From 4918bfeecde6995b1bf7d752ec0ca861e67b8898 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 5 Feb 2016 12:10:58 +0100 Subject: [PATCH 0983/1059] WIP: add TextVisual --- phy/plot/glsl/text.frag | 8 +++ phy/plot/glsl/text.vert | 43 +++++++++++++ phy/plot/static/chars.txt | 1 + phy/plot/tests/test_visuals.py | 21 ++++++- phy/plot/visuals.py | 109 +++++++++++++++++++++++++++++++-- 5 files changed, 176 insertions(+), 6 deletions(-) create mode 100644 phy/plot/glsl/text.frag create mode 100644 phy/plot/glsl/text.vert create mode 100644 phy/plot/static/chars.txt diff --git a/phy/plot/glsl/text.frag b/phy/plot/glsl/text.frag new file mode 100644 index 000000000..ef8644e2d --- /dev/null +++ b/phy/plot/glsl/text.frag @@ -0,0 +1,8 @@ + +uniform sampler2D u_tex; + +varying vec2 v_tex_coords; + +void main() { + gl_FragColor = texture2D(u_tex, v_tex_coords); +} diff --git a/phy/plot/glsl/text.vert b/phy/plot/glsl/text.vert new file mode 100644 index 000000000..fd3bd9941 --- /dev/null +++ b/phy/plot/glsl/text.vert @@ -0,0 +1,43 @@ + +attribute vec2 a_position; // text position +attribute float a_glyph_index; // glyph index in the text +attribute float a_quad_index; // quad index in the glyph +attribute float a_char_index; // index of the glyph in the texture + +uniform vec2 u_glyph_size; // (w, h) + +varying vec2 v_tex_coords; + +const float rows = 6; +const float cols = 16; + +void main() { + float w = u_glyph_size.x / u_window_size.x; + float h = u_glyph_size.y / u_window_size.y; + + float dx = mod(a_quad_index, 2.); + float dy = 0.; + if ((2. <= a_quad_index) && (a_quad_index <= 4.)) { + dy = 1.; + } + + // Position of the glyph. + gl_Position = transform(a_position); + gl_Position.xy = gl_Position.xy + vec2(a_glyph_index * w + dx * w, dy * h); + + // Index in the texture + float i = floor(a_char_index / cols); + float j = mod(a_char_index, cols); + + // uv position in the texture for the glyph. + vec2 uv = vec2(j, rows - 1. - i); + uv /= vec2(cols, rows); + + // Little margin to avoid edge effects between glyphs. + dx = .01 + .98 * dx; + dy = .01 + .98 * dy; + // Texture coordinates for the fragment shader. + vec2 duv = vec2(dx / cols, dy /rows); + + v_tex_coords = uv + duv; +} diff --git a/phy/plot/static/chars.txt b/phy/plot/static/chars.txt new file mode 100644 index 000000000..8c24215d1 --- /dev/null +++ b/phy/plot/static/chars.txt @@ -0,0 +1 @@ + !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ \ No newline at end of file diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index e677a1529..83e218b73 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,7 @@ import numpy as np from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, - LineVisual, + LineVisual, TextVisual, ) @@ -175,3 +175,22 @@ def test_line_0(qtbot, canvas_pz): color = np.random.uniform(.5, .9, (n, 4)) _test_visual(qtbot, canvas_pz, LineVisual(), pos=pos, color=color, data_bounds=[-1, -1, 1, 1]) + + +#------------------------------------------------------------------------------ +# Test text visual +#------------------------------------------------------------------------------ + +def test_text_empty(qtbot, canvas): + pos = np.zeros((0, 2)) + _test_visual(qtbot, canvas, TextVisual(), pos=pos, text=[]) + + +def test_text_0(qtbot, canvas_pz): + text = '0123456789' + text = [text[:n] for n in range(1, 11)] + + pos = np.c_[np.linspace(-.5, .5, 10), np.linspace(-.5, .5, 10)] + + _test_visual(qtbot, canvas_pz, TextVisual(), + pos=pos, text=text) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 23617784a..37e053123 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -7,7 +7,10 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op + import numpy as np +from vispy.gloo import Texture2D from .base import BaseVisual from .transform import Range, NDC @@ -310,14 +313,110 @@ def set_data(self, *args, **kwargs): class TextVisual(BaseVisual): - def __init__(self): # pragma: no cover - # TODO: this text visual + _default_color = (1., 1., 1., 1.) + + def __init__(self, color=None): super(TextVisual, self).__init__() self.set_shader('text') - self.set_primitive_type('points') + self.set_primitive_type('triangles') + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + # Load the font. + # TODO: compress the npy file with gzip + curdir = op.realpath(op.dirname(__file__)) + font_name = 'SourceCodePro-Regular' + font_size = 48 + fn = '%s-%d.npy' % (font_name, font_size) + self._tex = np.load(op.join(curdir, 'static', fn)) + with open(op.join(curdir, 'static', 'chars.txt'), 'r') as f: + self._chars = f.read() + + def _get_glyph_indices(self, s): + return [self._chars.index(char) for char in s] + + @staticmethod + def validate(pos=None, text=None, color=None, data_bounds=None): + assert pos is not None + pos = np.atleast_2d(pos) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + n_text = pos.shape[0] + + assert len(text) == n_text + + # Color. + color = color if color is not None else TextVisual._default_color + assert len(color) == 4 + + # By default, we assume that the coordinates are in NDC. + if data_bounds is None: + data_bounds = NDC + data_bounds = _get_data_bounds(data_bounds) + data_bounds = data_bounds.astype(np.float64) + assert data_bounds.shape == (1, 4) + + return Bunch(pos=pos, text=text, color=color, data_bounds=data_bounds) + + @staticmethod + def vertex_count(pos=None, **kwargs): + """Take the output of validate() as input.""" + # Total number of glyphs * 6 (6 vertices per glyph). + return sum(map(len, kwargs['text'])) * 6 + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + pos = data.pos + assert pos.ndim == 2 + assert pos.shape[1] == 2 + assert pos.dtype == np.float64 + + # TODO: color + + # Concatenate all strings. + text = data.text + lengths = list(map(len, text)) + text = ''.join(text) + a_char_index = self._get_glyph_indices(text) + n_glyphs = len(a_char_index) + + tex = self._tex + glyph_height = tex.shape[0] // 6 + glyph_width = tex.shape[1] // 16 + glyph_size = (glyph_width, glyph_height) + + # Position of all glyphs. + a_position = np.repeat(pos, lengths, axis=0) + if not len(lengths): + a_glyph_index = np.zeros((0,)) + else: + a_glyph_index = np.concatenate([np.arange(n) for n in lengths]) + a_quad_index = np.arange(6) + + a_position = np.repeat(a_position, 6, axis=0) + a_glyph_index = np.repeat(a_glyph_index, 6) + a_quad_index = np.tile(a_quad_index, n_glyphs) + a_char_index = np.repeat(a_char_index, 6) + + n_vertices = n_glyphs * 6 + assert a_position.shape == (n_vertices, 2) + assert a_glyph_index.shape == (n_vertices,) + assert a_quad_index.shape == (n_vertices,) + assert a_char_index.shape == (n_vertices,) + + # Transform the positions. + self.data_range.from_bounds = data.data_bounds + pos_tr = self.transforms.apply(a_position) + assert pos_tr.shape == (n_vertices, 2) + + self.program['a_position'] = a_position.astype(np.float32) + self.program['a_glyph_index'] = a_glyph_index.astype(np.float32) + self.program['a_quad_index'] = a_quad_index.astype(np.float32) + self.program['a_char_index'] = a_char_index.astype(np.float32) + + self.program['u_glyph_size'] = glyph_size - def set_data(self): - pass + self.program['u_tex'] = Texture2D(tex[::-1, :]) class LineVisual(BaseVisual): From 43bb5ed73ed65a1e25f7dafe930d5d713ad72262 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 5 Feb 2016 12:23:53 +0100 Subject: [PATCH 0984/1059] Add gzipped texture font --- phy/plot/static/SourceCodePro-Regular-48.npy.gz | Bin 0 -> 20536 bytes phy/plot/visuals.py | 13 ++++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) create mode 100644 phy/plot/static/SourceCodePro-Regular-48.npy.gz diff --git a/phy/plot/static/SourceCodePro-Regular-48.npy.gz b/phy/plot/static/SourceCodePro-Regular-48.npy.gz new file mode 100644 index 0000000000000000000000000000000000000000..992f5cce3d3ccfc7297bec0959529af2e6e7e639 GIT binary patch literal 20536 zcmYhCV{m5Cm-XYMW7{@5wv&!+bexXuz-Zv z)T+HMQ6vn^n@@u}2)K%!i=(lrsGW(aqN5#yvZ=X?wV@*e8yBOky$8sp9<(=xI9hkA z^e?{|UNHy=NHGZ2nq18qKR+1F6Jl!Gxxa9%U{bu=Z=De)7JEDBYee=cNAa(1j%abu6 z6(u?XJ~S*okTvY|2%ri3=86axu!miXZ*utCblh>aG}ImxG`GSP*~pKiQY*?-NmEI@ zi}Vm_m?i|TAg_nV@P4=t!{{KZ2BR+?qf{1<7C_yKE1ARN$Qv{6dA_r$w)^_KpUMXM z=&2l;gOhfwIEv|!b?Bmfz_fhU;CybZZUwnV35#u^}{(nC*k5&>sLv<)$Z~t z4(T41e}p^MGw6ixm=$30uf03>R7)zlO2`+Q6&{kid-Ii&@0VGY{^6LH`Qm!hSGif& zJ5@h^@HNC0IOaile8~0;O)Ov>t;czj!J+)vKh`}-!l!XX+SzLwdH0#}C3#`f-}Pn) z9U3vfl^^6=Y~>!E^Obtd>2;BrW1t9Ekdx#-nWE;lEFUk@4)6V*xANDZbcX8^z|So? z+hzZ_@jz{bLCXh?BhC_h@ljsc6N;Gi&#bP1(CbSn7hnxt4}xiyC z`-Va_7giy+w*o?;u8#V-F3pDq&% zDp-D*r?GIs;Zfcp_9F*M$;IB{n^D_1iuj_w1MVbx3b^tDO>mcGs~`v`u@HX7cT>{u zM#}R>cb=+=#}sjD_m` zEw^;@FK^n10nMB$jUdD^Uw@ElWVf^|e#?^HH?kDUgiY+;q|@I2Jus&pD=s=?fQ@j% z7qn?364HnYBlN`0bO=bYB~6hS3zaSZ#rP$)0b@_4LrNfom9zzwMtoh-zy!5OVP#U~ zNJsj`tfe!pc$gFQ0Wbc#pxW6PNWOjbQaUd!Vp6If?W;iuqAB{NnSMaFk7KRO|*i%e1>zpc}po>bbN11#&8NH^uYTHzSttTtpf>gGT;tsCG0uGOd`TAh)=FMgGb& ztWX{jSTSL9U{rJ=5D7Zods$T>PBAuEs{I7Av%?W%A|Im!K>&8o#5VSB~ zk{_7=`Y&-!u1)*JYNC@08;L8+m?a1@JNy(otTBSgz`#SFzoP@j5Qcs;gat*hS2AcV-wV0k`U5V_# zku4W~d@nzQNMZsvZz}>nQgmaLt}6PZU<n9jeM+DRuv<XbHQgVM z^uD)HbzlD$XLo1VQDD@1!^(E3qVmL{fu$V;q*UXiS{wsWxY4qSokMaR=TdTVkE?n8 zfW@kSb{CpW=CqjnAYtrGei+P!{ySeO6-;k$n@sq zJ}f6P%5G_-v8XXKcHYXW)f5eqLGgpm3OBl4@%QPt4HFoYlG)P=P11pn*u>vg#!n%Qi{ z++WztW&Y5m731If<+;LPs>+m`f|XZ=sA2k6@fT?)UTsJeh?o&jcPS?lj5N8Ib zxI8of=+V%{`hfOHH;AuBp* z+rlLXlTx82u}ADiQn|Yc2N{Qed5v8qHzx;(W}xqLyyNj)Qj7S(acyJ+`khA?s#r7N zq4_{p4z|MiuO|S~9nM$mvxFyt&I=GSdBGt>D(Xb`Ig7JHDL@z8-p3s?ixNiMnnmb9 z7-&RR3KDbT3vzZ!gW*fkeeVfS__D$m1KxayHQtI@z>iG*z5;oXTrRb|Ny(MMa7}HyKxuHwHMPl!MlGCzZ0LkLJk@kN-I< zT7>YZIxjXrjFe7af7J|i$fpWhi2(c(xWQp{qD03H_h_>iMFRqv@jBC&$ox(R%%w^zh0_>V){BKJxJ%crJXiKN%}t!k%VT z);sk6(P_S>*bNrEf!xFu(>d|BGo-#~T;`T8d4w?jBy%oNb;Sv=`{^=>8}bp^wB%JO zYK^p(EL3LSb{N{^ml+)A$}=jTdXpKS$9H!vJ~{p3@-Xuvda zUcDQ(Zt42S@ssRg#|ale^{b0c4&e)obo=8^3>Auz*BsF@-rO&|e8wKPqOnD~#HGn> z+IL-5u>-^29sdF^si$#aHY}J;6TX9+o${*%tE)JZWTDrE2mbjBVyjD*<3I*^tM)0u z7i(KgKGSUGL~#c@g!?^hCyCe5B;l>x-zH3bLeqoWJ9cbh0h?}Np?@v=mfv#kAa8Pa z2gE@+F|p;?37w>F)FK};f~T@jX@FV{i?ZL4S`myK)+PKNZ_9%!{SJ|*d9J%_?GsOylwE2lTY-)MBMLLQyR81dE3N#ryuAZ1HxH!8*k>RYN5B3qZx5YRwCG_m)5`JhAPhO6W#W@flSW{C`qU|2p0_OTHd3h2QPG=9o+{W_xjv z$$g^@J^d?*6v|{cK?fwv9EYN5;?eGqTyiLpz+db4eapK4md^dX>+LUzq5OsTJDh~v z;kO0%#y3}nW)pm#R(G+nlrhXk6q5ufttX}y z%EfbUIuMe(9ZMng<=j$9gYW)1uW#eCd$N;jBlN4X`18)r_GoGFnFFaz+j^UmgOG-= z<90GL8#k|bJ+f5t-u2Vrn|t@nghn~rjlrue9}C|Vsb6;4jf+Ks=Z&grYp z+QSLIrAy~yUQaNY*ZO?HT3d~s* z;HY*llT?eY#(l(0y~STX@zJL&FL_Y3e}~q*qHLDs^9UuYjyjw5cfJ%5_9>!EKf34oTS|ww(9}b_+nJ@oeX1-SPnY2iB{m z`^&~Xg!BvA%P6w=^Sqj)E31f0@D1OV1Q&8E>Fz!bcL-8$#W1~gL=D@m&+2Iz^y>~% zsOhp}ov^z{JcPo)>+bAG6J$jj$=az4u|w&S;6p^l6y?%|*r1rg51fCo)yGYT>RO4b zm!+sd=})AE|BHYmNSe*a zg1X;l7PN}6_C>kdU~XKVRIMNLT2ey_(P+W9oh#HSH6=VtSPsYqm~r8k9ny8L>fvvj z{wAJ!C_%Ds=mxyTtQes(l=po=VM>>KP5@Udvs!$&FBh)$yN~+Qvkiuu(p(8vh?8!guF4K-NcU}2CK73d?pRXvkb+bYyRBqro#O$eaaDQ{qiA9u z7*2jkLVj_c(r4B9NAfwe>;qh7af{|B?9XS2?~e*NWkKq|slipzU+p>GEZvghucaao zvD}RCoe?t@2jYy7@rvs!sR+GL;(Zb-D|IbPG055VW?0I_}x<=jC{I1bJ-6M4o-i z>!aT`s3`lf_CgNM%BD*yb@{&5UuO0=&bYr&w~)0bmb56JQJp3)qW08x>4pDExfM4u z&%0@!Tuiz~`C@-lO!Bo&ktX|k<&Jy5fNTrz0CIOs!$fpI9?N7xK^_DZ(b36QC{s1k zyL#_eI#ARuK2St)#h`LL-|o;dv{9YN*T6DqrcS`d7UVZ<<$*x_yPb7UDzCpL#IVGoWK_pY{GG6K+d}1RUMglI&W83>T_Ff z>_j-l1$n+_Q8?5=fx}aVG5n;xA5jH<)^ZC&UJ4fi%p%q)8KWDXl6li`XUZP;)aRbx zOX=XGa9A6{Wue<*!|e#elj!P?w=(71^eAwJeHb>7^y8}EfE;1{K6>aYJ!{APDmUNI zd^c?k0rLR%cd6IbRnNY?=hZNAn87f@d^HIsSth=w*X+zgOH#n^XD;Piy^u|}{1JC? z(2Hh=j~t5AZC1$hTt8LfCMgn?`vZ>UvWXRFhoKEcRMAVs7ZG%RYM&Vgaw}+a^9W@y zp7A*wGBJm;5TcfX{NI3lp`9Qb=JW(s4aMm8xaO-^P4W3YUN1i!1n-QgUQMta!cxle_O65Kvf z{|}v}en-0WXV*JsikHh=Y9Y4`hlo4_y%;oX-w&~9n{E|kRw#ZoRct)`8B-FEc^LOP zqo+FDKG^sF8bB++5K5S?UX5nL!oe@I=>oiDx3LEdx4~zdjY9WGX=jW-+G3 z5lAsN!Lce}ZtvX7aeCrL8QaPmtk8c7;&yqG8HrLvZSV)9wu}f+v_tvSIOS|L^Sz@! z{bCwIPj1(Cerq?yj#u;o^YjWTov(p%OYuEl+xaH{mfm#foR#?l6`3SIze5{hYw>MY zw%7Ry|CH3zsOF>VYkB)oAe}qbwO1yU85tDK!mON5-$2N@SAF$F(uO#eME=J~THw|N zQ^q|5m;SXt%fu{g_C!5{m~F~Rp9~Iw1ByLER`LJ4@^YPO3JR!4{5AK zvLCJfRF4}?Z1S`bm+Jgxu|`hh5N*iEcYWyrBnlwL$h>IcpO>(bpFsyERE)T`^FEW5 z#CP>77#jCZ^ZU-g=n5KAxiy>;kD;KUye7Q_aXa|zaid_c9i0sUe)i=o;@7f#vI-AEh_=m zjPy7#utTs{uP7uG0w0EKCsQ6~?S%cx1^*Ut2OX(PYLaGGv zFl9!Jt=e`Un}Y2^F&1#|`CU4vDBq$5CS=(oaG4Vqe%2yUX|`)y%0R&Q(YbonGbNmPc`-CrSCn z_rtz638>V<6h{)>fYbQ`0?n(4nAkIo>Z*h9RNjr1j*lib27Yg+u zmYbX<(FVJ>$GW~ndLqe(=gB3c0Lc1?As@m}Ve$C$PZ6D%{o*UYj5MX@#9PA-5@Uo7 zkgF8}yM~EyXlmd66ENfx;H0M_Wf8%y|B=;cwSrBOQGe*?Br-1w8Ux=f!TAR20MBA@ z=b*Az9y&(g^)n#o!v3M`$4@}LPu&A?lT~(|y6#u{BY)E%wmDwb?=|^s;7@DA&q`6{ zHc-VgA7*T?N&PDdRpR*{*iKBous9zBEbbFuOzj@$$lBJu6ezK(jv=EMd&RXMdC=?m zVAmFp*!)*&)F5Djr(audQfwi=?4NR7L?GHCS&F8LVWnY#TT|?rIxtan`(*z^@)8*eI*ht0k|#vE^>kt zNyPKc0RO(q%o~glvMg1EpL6wJj-zBEl>GUA8*+3U_{&aaG;b6)^Y^)3hJXiPGOsgH z5XQ%;(9g|^zz236Mef9sF3oXn6%YNHb_eVJF$VNb^sfe!=%Ke}H~US>PyDMI_!zNh zhKkcOh6J&s0~$qEMGaqQY0J1%%5xil_2I3Eo;&4a9nSmjm+7?-y-nCdn_n99cAh%nXJ za_&KRE?CLmmi@<#?<4lBWA9I%ue}kdI%&$8)rg0w41#*XMf}I`**{I)hlS{m19D3^ zIAhXzPMgXxX8AWQo4RguEYP4_6lr!as26Kqxo2btDlZT@t< z+Zoq$HAY|Qqs)sSND)m3DQyrcBKtcJ83c?G!bIUQ{lQKnz7q(O?TOwP4x?4KC?HZy zzI%L%d&VeIjQU7N*=FPOeMss|tB7yJbMX7L+aO3D@+>e^h>9Gl9EFGCM=V}S=xi>D z$f1m=!FcY?@Mwghhm!6x3@6z?*ZBVpSYW~qNFoz;Yuvk z5AjAK9*?`4xR?sM{n>XO$vV?E2cvB}rpFX2SGg&va0Oy8n$Ykup_)`iTRtLzv7I@x;N8>rKXPGu0Vsyq1w=}`|-_INGoDym^%QeqvtHhRS zKz)em#uF^6weEe^UznDtvv?_ZVnKaM?;IQOfMZy%%gX&EK%5LzsLA#b-#9Nxg0A_s zcN-o2?s~pdIb6x=o7iE43=)2kh=B-WIno{e$&kJ3c`#vg>99 zQWc8}hA^|8L&a+TlOAwE1lyFj6>IYR-taFvPH=po{WP-#!t`vL2Yv!rN;e+L4hx?; ztc$wrptOh_{e`2LKuu>H2*LD|MGy{wuScp2jpb}0)NAOZGuM-0#_tm6+T(2>9^_Cw!p@=xubm9N%KMzALaqk!M?V48iKTr@}oDFWnH@#l$; zR2u7-vJa{+vWaNFS#TKw>vhE>Qrqx7km2|EJ@mD1>S1#N07niq6XrggSUp)ZqhkCM z!?*XJE+7hHJWH_Vni7r*!cM;DXcck-Tiu#^WR~nbrx($7q)xx&Y?!$^ykDj`AThBI z9UWcz{JThe$Ie>9ccEyz2b4vsNT}%fR3eI#B=AVnJnA`wxJBFQ?S9Y9d zd^D(lOKiqid`b$aUo}aY8ey>gfs*iOBEgA2Jh{`6N7h+aFGtV7BeC8XN_4@=6~gO! zwy;f`tn@;#>-S%A zVb}>46Wxy^!#pFmiWTk&Y2{WBu+BW(Nq#MuSF~7%(D*~ouu;}Xb&ev4& zuG3NNJMHh$sA@=2ELW3(I3=4WyYNFhoCEhx+(uz0YHrT5V4zX7&AVt`L_KB>{*jl5HBhEAK`1@)6HV%%P-G!n1 zdt`zXG_}E)PB_F8Gck}^VM}>3r=Wg^2k0dWc9vm3^*a-g3e|r@HcSIh0CKSY<=Kh* zn2dif2wXDoAGhM8;#UEX+h17s^~VE&SW;yzI%&YMd99=jpNXd;Gl6L54Nn)%2L4Ox z7p{mYL1vjQ`+jHg>jmp>O8C`lH1X?(W9)NJWkyaQ;ZO=&%f?$ z9b1f3AzJy`pk}7Z(SAV{q}1lRCi5}lUl?Vs)girssBT=(m4n)@LvN>nZ^&N%xjiqu zyLbF4;r00;K$BWXg8s+w=NY2lA4CA^X!D#BczWPzTDZ`pqFTa1q31Qy)6@jy_AH}^ z8@&mz)F=j$D?SanM&zPG+-nP`&w6-YA{wn<=rsxbupBk>a!kR4(ra3pquUtcv2SF0 z-NDNn^tD-q&**M^I-PnqW@b>=A4t76#SF&DLeTHi%;T-XvUF1jc7iO$5}P1^@yXZa z1h&i=En_xSFmi*QS5JQ^U_>*|)dah2xT-Q|&x^5gf&7v_yis2nXkIv~U~=3+%~poF z8~ky1?~D%79ehDP*{e{Yafsc6J{iVTXr5WNKQkq^2-Jf`a);J& zcEs#xrA>e`B(tENL1{q8c*2%fIN)xEYzNE48NKEAquLq?~w6u^4^Y^5TEU zBJky{V||6A(;}j`Vrt)+AYN9{cbo=L*ok-bJX&o-C-oIM!J>2*%KItt0M&HUiDt+s zt_17mOoQ}1&4C{K$ZQ*!C0Z_+tYIrgh43cU{{N=Ry>goF+TYFxievX+cCbAw;pY7s z0?mTzQP@!6PnFp{=a+U4)trpE@r#e`KFIWhlN}lJ2}xfTZYOeCUcV%rJCKa3YL3)h zSC=+2-)9zH=DA^$U_cJXz%x8e8M9t)F7_W}#NdsD zb{~W?LZ+=9n^dhO&BtN<&smzn`0rdPVIX1&s`yDp?`23yfCrQKKQ6VY9Ou{sqjzoT z`Y|YVucx=vkF_H1hs$ILTmfqzR$gv6HN%uf=C^#@ z(XZK5gK|#ld<4a%z5h`{pBF5qouB_ti3h<=9`vgjvX75%K5Z)5aW?u#gCwFi-SYO7n601QVjeyR3v z!NdL6l#rPP^)5>Hl^WV-jv6)IV;cCWAl zo)*9r<7pZ~LuC`z)reEV+!#Y7Bo)tCCgix+2389c-91Fo@i@;cFV3{@BD4a6g9`MX zv~*uf2hXc;wjLIBoP|ucv#7=S8GA2NTwIbC`RSTZs4KkM6M90{d$fMaJx|=FeK=ny~yEd$YBT9HD!|0;U z&`j|IYA;-du3DZ~D=xvM0ks#I(iiqZJ=~Bb+s*TL>BFi zusyq$b9twrvO8#r%=q=O=^T&dd~N%~lGRDxM;KHYgl4ODG9Qp|h?SmJC207~#3E@I2Nwc+B2b@D^?7zfltIRo)|lgY7E| zsOKG)Dzt>Jige`O>gLote2NQP7l+~DVnk!f=`2OIdJ%x~4nxw`Pgyve+@GiDuTES* zDZUJIROG^4#$Cc^&|oBq*k`GX#F@LIWC?_5LufBtf@U~)Rwp*Ebp}G!0?Fvs_{%x1 zqF4s{IpBS$KM({GYQt(Uq! zTzVLQJY_c~?4~5)@2UdiE^?2y>ru44P31u!1H3wtO@x$HOx(-py>b$x5lSZ!*eGB^ zrd4gI2{8wv^4&dKgTEXcBIoApywMWqK4V5w_oC2LXMg7rgRJcWG#dk}%cDSc^*fWQ z>c4_BH_&ho+{w`QGd}R2h9sa7`-=eLb%G?8$zy^;VigW;T`%oivDNRJ3qxS;uQT%2 zHpvq9Z?iR44%AFInpY{Cw7g~qz2wd7q)?MA)nM*y*uu(^x#`x>c7 z8@=D)12nrcU`ATfS$};00SP$@e|=X)2U;$G<_6{`rM@Ainghh=fL;6%gfedk{8=O> zp_Jc1^zQ`*n9u}U9!fT0vL&P_tP0XS{^P$TmRk{4hw|s9*75>lBAJ~Gn7xG1_Q`=R zb7GXchRDZE7id18BL&C2v2SJmhtB@`n%tDkzZ&;JR#WVD3-_QSQdTkMsBu`9w z!Sp02F|)q15vWnqgYNCu%vv#>CW0j2f=#$@wQk>;3Gm||k5D5!+ou^r%pNC{9aE2v*H`NrM*P=TVbtT&~oEfmyl<>Rndm+YW^r0z}VCMJ?3V+oeCHs}c zL1o7Qn68>rsOo;0u+M_l!}lShwiSGFok|BQR|G5GBDcOQSccq-X!QVcA z_&(Ffv_=)&FcvzWKAlga1cMmwl<77jMiN<9>%y|UDgw$VNmYqhnPvti zZgJmR8~NhpdNe|UC!`v6e^x=Jn=k5;I_qrf6~Y^Paz}f!mA4Ew@dI*;KH0%A^s7#O zjw#{MaV0E%{EPGHSxH_mnqIVSkRGQF+}77Fw}FMi#0DAo>8Y}x$6LXrqjmN{rm1)T zmylYqnA=qWb{MfgElhjk4;f$emMrcFOc4sQJ5tj9!fst2**>t@Gvb3n)#B#yzcKi2 zW$q$)rGa8fWXKq*`t7_u4&F}Mi8@QE2jzJCL(@>184w+kRYldx`)cgdl*`Z)aA)<@ zI}LB`{}jpe6x!4zyYOlvTg^finThyVMwhzbG=ri93&VO^(1;@jAy*=x_`42kPoIbh z6r#r}_x9gW#V^DBf=pii!rd`_GiL7|dLeQ6U<6fxlq9n2O-5QSKvHFw1PiR1iO}}@ z+R%2&Lw*S-%t)CKS2~+4Va}{jxFx9R<1@Jd1EHnVG14cV55txTk?wXTgUn{|MR!6F zspV2T!rbqh22oYCxO^Kt3J{0#(bzFO9=IjDPDG@VWKnLr3udwIr&W_f$FRnL1ZK;| zZ0qB^R33)wiYNc zXu|!yTOZLyU#mf1IAx9xu>L?*3XWdPFlxJN^@0)^jO!SX^(c*Y92l}$)n{h`+S`+! z>b7n~^_$9$_Hb(ci4?Qe+xc!-+H!)%@tERU94zKxq}%?*A1oqEr*DY1FEDYeVIOW{ z+N`fy8@ML>0b-N%Ke;kk`cDr|I~r-@6@)6Bo!p=VDXdb#W{;3RK?gs)q8x@#gJKPD zATDP#C#8xLw~9O)ibt~YXEGfO^_ZTqo%f4SdB+)UrQ*u75oPBz$I@@4T&E{lndS_DFCfMR&~^symh1!3}_g zuDUQ12k{s+jnODPK{qXN%Y#brA!N!$xziU2jgCo%+4k+Nzj`e>0dE?@rj3H;VDWcyH#Djwc=y37zm%Kmj%Y(QeTL{fC?zGw>FIgO zKk!P3B1V!ZmRi{e(_yX-VF*);N<6IFE*h&@$TmlOS4j$j2NZOR9s?(C53tB=2p0Ww zpg^O<6qSOkd{iL*aLYWmuTk?WdkbG5L)Y=?sc4tqdAS}I|Hk;2qlImE5Doo=tqp5q ziR{2TW@$hA*C%BW9j~Q$FAnojNLI4AqPzmV)bUtjQIv(ITfV1%s!wB0t(;)f-@iTtEF75acgbpR-{edaM$kaz_=y>4~P zKA)Z>J3|UA|5$C0M`@2~^*EB`M@g>ffCBvNB)GI8j9@i9Wj);M>G|JQ_Y9X2)*f($Ibd*dO^V(W!c z=}VHiFVi;ALOLB=R5!}>?^)au>SNmI)SS2YvbP&XL&b|!w7O5q>oGzL&Y%9_dv?0G z>VB?N;j=WI3E#@?aHr<8B_%GLwb=%T%5h|YLN^e{vO^yDz3>c9Lj?5mq43Xm136eX~#qx31o^*O}B3i8tN4LSD@#{o4+23EkFX{|eHGNUoozS)J{Y*_}|!xY7k zxG~%1%Os)|_9Iv8)^vGuna)0rRZc739OSVB$p#DpuRjM=z>y&C4nOd9R{dV&R>w88 z;SU>8Q!-Y=i1(y}l|`tTq3?(L&dXs~IjpwpM@10yUCs6FdWc{QB6v-AUK{0FVG*u~ zsE?|3&$~F&TdnFK~X>ofnZ}v&wd? zy|LDP6ko-zQD~{Y;it#K>x`>_F16;-Akf3geMvjkR&F_6=XlE<7O9MVh#v?h6$~LK z#2&AmQ^yk3}|ePuxH2gD&6oB(qLTVO6i)8AgOL1dN~4HvIk{%1ro zP@Ls)Gh@KVuWvHQK9ojSTzo--vP4<`J4tqR&4ix49m}uZibD8GWx2h#ti+&`Ocb35 zth8p($J5`w!W$z~%IK)7Gm5P!gn)$CJ<~REh)gp314P&`U;wQ4)(5$2cfaXz1WF?5 z-`W*ceEn#l-;T7@8pZTHIR7(R`%gAAI{`g72LUO`jVYi|LIv}BbW*U!S>J9Ju=|~p z&zLAXHy29OV{eX z9#5eYot>0NS_^BWduv8G4_U8Ag|Gop= z0%M>jCQF-0Oah)XCdXfNV~ZY+PD3=9ypG=RM7lr@{y4)-GKCHp{kPL#FoPAqv&+Fx z@-^@7;Mf3vmG)~K(hCmPO_v>hZbE2{0SRE4%o(WBkmgkQ?}cGuxC8&q{5iT%#3>ng z%KN{E<}<1}9~@>KIp@9S(%uIy{)GS|nra1{wXM$r1(1|s|GO&?1q zIhKW@@3c{}wsHTxjTf9v!k;21x}9IrVHuY1}(Wu1Ra{WYh} zEJ82#HBOhQ@qCfLe>%S31F9Q6Z~AY$ z$EK%c)}NR*7gv4nEM@Ai!}gQt?ikOItsnfzf*J{Im!#mA^{&Cds3F06r*IzELH+k8 zJEv+F-X&wJRFJnr5pblFjuFlFTdvQ>g9C^Qw1^Sigb&W=`Fr?W>s(mr&=S z6~)x%^A;-$E7u*e!aF4IHJ%I)gHP>G9Uh19&5}gD3z*;1&nIxo-!yUeP9RdxpY4vB zfM=RAbpkRXYC*J{T+47#U$vY2%Spww2wPj@-xLb%F|K8i$?p*y=Cgq6+AGOEr?;*C=7KSTlS#D+Ps=4S2#*z4UU6q@fy4u% zx*ET#-qpe#mLN6Ri`Vw8xW^v$*=BklN7%0_cHH7nQ4ZR_UH&1|i#rYQ)lnbQZJ7@o zr#aSsQClJ_g_7R(2l_#s(uA?#37HekmW>m!{#KO3xHoBP{#j-A}$BM295Mbm!SxiK(HuQ==h@SfrHOng%Q3VStd zIs^F$s1%_j1Yu{(I_E32^3T@AQ!~1)_OH(M;bu#x9iIFPzR0nF+M;?mq8H(Ru3A`#;JX~@_jZZ?+88YY6gF>rGnb#jR5nj{;y9GRJnz8;1D&rVbE4h2nttWND3&_>| z%v|G^TY;9Ku2R53FZs()?jNgLg6opCL&U;_T%8eKD*}DO-M$&Tp}Svg>Qc1%XnBHq zmkp+;Kc>)TtTp^0+}j;n?qy{tE-23Ex9)b9BA(I$L;luFq+={VxtE>CZv7>V#9-3pE7-U2q+uO{VxKQ zmRqAO#86BlwqSog50hL>Vz8i634WkYKQk^oxB5mT_NQCEj+KW0AQZw7&F3mnxLpBfXBVS4mQcil_bFKp~2B%X0LcA_*GwIoyr z*CQRps5io=_krf8+y~4gW^{%4qg=(;Rcxh10d;rF}p88?;!1X%P)%7ECuijRZlmAl(;OU)j&Q^c5X(G?Ke|n#DOU zy0d?zi)@D%#(B*?Aqfj@zXgLMN-I^k%@-**NLfPyVI|ctH)wny@qAKnQ6F0ObusoM zfSi4ptN+r^3kTlJr}VDAuq)F=)^HeIAFy|Gt5Xd_ql3NzWwz*%6goJrkL{y2KDGfK z{v6d0;RM8VF<4y657V4rSEE0s>i#arc0r0TX-zDPw~du_E&4epWpP(er=3qv@vS~a zHW`!(43622?A66Vj0)7=>2;XZ;HQ%z{MA8~&wHGS1x>J>D_3I6v8$X`RI^u!XzoZd zl3=-jHxo^LeZ}u==R}0JP{#He-mVlo>mnz%ld_$u0G9%MiSlk?j1tksA9>dLks2&$ zJvf^Tqz=5~WU_}-z{L?)zKD*W*&eqvTs6lGA~4(nxN!66ISRj?Y}}Qj zJbmh3x@JA+<^mycx&_IJsZ$sQX=LxX#eR>afO3YU0%0gd$s+eKqiFjmB*W0hf-2mO zsca{-YROUJZMyqr@HvA85C2I;ZqUcT*VsbJ%yERGP2KFP43B#`Ww)3xHUw!ie((iWXbFs)=VMmUM3f}|iPheL*)?EO!-3TG zP1%f4`qoGa#G#b~ApT@x?kge>W2kb^XGY->JjeWjW=luZUU?+}6F6;fapH!sj1 z^9G5`zLwv5P?rWryS360Awd$^`zoV5;!0F7V)n{{uNLL$PyNE8n(@)FRw=SYO}LPQ zinuOB!r%;;OViV_@fSoiR~l7s2vSc>*v@G0yERtu8=1zxUySmAdmV<*kRF-`_O;WO z(s>Tt(yNpVw%~BAS)8l&xC$bQm!iRAYCX(o3;<{(vj3P>msbpo1p&3a?xYWS4P01| zdo2ykn;||bBF&3rCex{ZfKwlE@6Ah}>RX%5VvHnFCXQ&R)oAQc^^|3n#l_ykdP~0i zvAH~%5;|jUJKfk}+(=^n=g^F0xnCOTY-1a}@)Wm=YudJvn=cI}nPC^vytXcUNC6#= z3rfl>cQXvj${58~#h4dwdK@2rL!rMTLf(&O4F@ihgRBJkWGb-O1RKkhq6D2e8AsJ5 zXIQ_#w1tv#?GQ}u^cCV2I2(Ktj*qvgd9iGCs@EC4$9(ay@S2El&X>=he&8%Isud8h zkyy=p>gILsZ|}BkPcXxB$F+GDU!k{1>`Sqn5e!a-$xYiAH=6b+yvcLt{_4)_Ty1xB zif<&*LIqfgmV5@gJ)dPS2;{9W)Ido_r(G`WtI*D%86Qaz%6krvJ#`YxyRZx1t~U1m zzW|&MWAHnlS=(y=VQ1{PLYu>PzrvIGELXB07hibQs6P%Oc}=;mtoWI{6P$+~HP+cZ z+9uWufuGuW+w2BmFTb1XO7T++o22~@;)Hz@&(9I$oA{o76yC)bNNTry#W2%@^w)S} zM9(d2^^5~E67)#woPt*OVY9=Nk%(ED z^mw-3ny+YOsEdGOil?WCGFP8f585;d&Z4Wp)x z>DNwDyS%^3c-xqqJ*w@ zOr1EieHgJ1YxTT+>gXFBA$0#UbfJ$`f++O_>7jL~RG&1L6~5*oaL@12Z#zGvK<;+y^Q>V2O+@T1SFCiqv{)c;aY$XkMCFN&} z7H6K@edbP*b+n$dijv~@$*BaLkC1~N2lz{Zfo71)`k8>F}wY&A+F@4>mW=FLdFk#B@?t&s!q5YtHrrpt7 z9!YD@e~F(NCF18BLEMS*Q+pd=Jv#nX*p1K)3_r=QshSoyR(Dj9pb{s|qW)gs`{QyQ zy^HA>xq2G^)_KhIYhYb*x_)N5q5S(a3$WrYZ;Q{)7^Ne=<$R%g@9AE|C@jlN@9qH| z2PqMsZwixvb_a&L>L~Lu7@7sP&uk#t!LAbV^B$MFBI*2WmZHg~rdAy<~FX$E7A&xPBymuK*h& ztL0m7y=6`bh*L&I`_{PtMYfO4xkSj{XSh(tlSj8%;ftx?*c4 z{gpq2eY%kNNk1wcC}1+^_pC%7A=B_YMklzjM*em(Iq1JI;@M79FQ~YgykcF(FAx8q(ftH5i6MXdY^2JQkipL?hw-x}1F|Dd<_l4%oQ zIXa`jQYiR|!0A z0QyUaVvU7Ph2RD`_-s|lf_aRc<7$${bro|xlT?#Elor2;Q+G{0Ng=;e-Dsr?~S6cyEIDw%6Uiz+3jCfDmRr)f#ZM3F6gm=tz)$v8h)Vpm@!J)6*t~?Bq`iC% zCvNGwU{GlML>Dg`JNF@SWm>?Km7<8QG4o-m=Y)aIO@xNun6@yY!3^zMaGX#zF0!9m zqXh=Y`usY{IP9>gZZK%M;{|blCXUV{_fDI3Hiv(5DS4Rt8&n;HR{7aa-f$a{4$=+h zIg6jzOYpVFpDz5o1nlD*SKwjuQh|I^3!!+@Cn)iqLgZ%*dBP0y(c=z~^~20TZwKnE zvYX7izv)PV_|ZPfv{j9$A;rDOdz~smml6HkXK7S~?T2Mf>W`0WZjY+kBjS2w_y z77D%@s_+O8jh`A>8ZC4dM02okns8!4Eu|p}J!JmFX$rSLI3L2!x~+R&D@~v?eRil; zl!hLY^lU0V&lN)AEZX3k?qR3g`Kfs>Uj*D!5)mNPIIbD014-vup%qMijs!T~I$zh; zT=+>*iH*0{6?oWhKV!xiF_2Q{UgDI0L**v|$(h)L8=08T?F2M>*$NBAvy!@{?le8Cs+(@UZ`WYVkLJtnVd2+8l?8Oo;rXl7jA6 zk56HtjA`-t`qsgVkQb+&3Gf#Ay%b<4-#UuHa8NRtBIc7W2ck5zz*(R#s^U)MPft?g zYo5<0oW4l%EUIVX=(M9mjGf=VRWhfwC{NG~em+1RBT1{9O9P#`@Utc}-?{=1`|T&C zxXT~c(V_C@2ufC{{Omw3pQpcGAd8MI>mu{8Vk$~No+Ge>yqiQmf1hjo3?0r~79sm* zb74lEwudA8D*T4-b0q83V50;?%H^3jB9h~_=Bi95>HMshqwo4YN!VK1jh}77#+Kg+ zF2KWn`)N^=8FWYg7aqYbRDNoI5jg-Iy((t%fs+sc@@lQBkM7HV1yGMZ3kRKrp=XZb z`^5$N(Fu2eDH2y5}ZTtOC+IR6Ef}OI0PaJ_Pd(=NzRvGF0@9eY9V<{xWwz=v+1WI23XdyG6QJ5Q-WW(>;0)Ri5{L{$2S@%8ek>_#mF2_MGHFFxBtPD?D#h0}s3QGy3D}xa*eqYVcYM z#>$kygJYNk8yDI00yy|ths;n<;zX0p??25aQfA@_W9w%T5p&b=Bee)|e||YT zIs-%vC0}N}<{+Z&YgJJXPjz?pE&)HYgXbsV%O~b_H0j+Qo}byKQXYkGuGc>G#drWp z;9=K(-Vg8*gVq&C1D*0k5Gpp;>AHajZ9emWH~A}Y@)H?%Fn)(aOvRc5%yor)#Mc8C zvIzHRjT4_1M9eL6M7512!!-M0<{(?AL?VE}EuHm+R^O<;hB1LbEUEbIIBHa+TH#a;+>P307ZY9dKtCoPD+d|dzqG<;3aKZ3^ zzt4^a9BR&eWHJ(BYhAvK!f zp&L?HMf&#bwx5JNK(>8=_YBGWS=34rX)JYN7U#@nEE1ok1pIV2ZE%pz&t02s$`JY4 z;5ECWukxn%b(h&$spkhF-^L6|Kgnv=cjA~njunN{Y5ne=HmYNk-;Oyc zs!iW}rcWBxLv8O=yZ0SaCywYMwt42jannX$RzX^&6e%`CaggnoO2E$rVKqYzjzdB6 z^Hd|;dXgSw`=x9#1R_VS!}e=fF_I(vv?+Qcz&>dNY`?UX3p^x}+pt9)Fhc}{g(Cd4 z3B~s@{)@?oDDGcL4tWi78@7=|1`X|=LEoyRrB5+xp963(M2NmK3IP5d)UvYt*itcU zo`go%yyJx;RUqR~D8f&ZI-$_)3_cbh^{@g(&vIr;hs`XX=dgJw2}SrxRDk>YUO)O< zJn#7|^hRn<|MNuf{LEkXh`ZxnuoU5^rgb$xmG6QJP2-^ChRuV)t2yJeT4BeMittmT z;kF26O;JlUC8IJNwg^8pZCq+*LE6s~OYgz{o5S}1OA&tlx2TNx`9A<+Ig>CI0s#L< BDct}7 literal 0 HcmV?d00001 diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 37e053123..bcd60158b 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -7,6 +7,7 @@ # Imports #------------------------------------------------------------------------------ +import gzip import os.path as op import numpy as np @@ -313,6 +314,11 @@ def set_data(self, *args, **kwargs): class TextVisual(BaseVisual): + """Display strings at multiple locations. + + Currently, the color, font family, and font size is not customizable. + + """ _default_color = (1., 1., 1., 1.) def __init__(self, color=None): @@ -323,12 +329,13 @@ def __init__(self, color=None): self.transforms.add_on_cpu(self.data_range) # Load the font. - # TODO: compress the npy file with gzip curdir = op.realpath(op.dirname(__file__)) font_name = 'SourceCodePro-Regular' font_size = 48 - fn = '%s-%d.npy' % (font_name, font_size) - self._tex = np.load(op.join(curdir, 'static', fn)) + # The font texture is gzipped. + fn = '%s-%d.npy.gz' % (font_name, font_size) + with gzip.open(op.join(curdir, 'static', fn), 'rb') as f: + self._tex = np.load(f) with open(op.join(curdir, 'static', 'chars.txt'), 'r') as f: self._chars = f.read() From d75902b63383c8d7c15a54492b44efdb858e879a Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 5 Feb 2016 14:28:40 +0100 Subject: [PATCH 0985/1059] Add text() function in plotting interface --- MANIFEST.in | 1 + phy/plot/plot.py | 7 ++++++- phy/plot/tests/test_plot.py | 9 +++++++++ phy/plot/visuals.py | 25 +++++++++++++++---------- setup.py | 2 +- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index f42e27c56..35bff6167 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,7 @@ include README.md recursive-include tests * recursive-include phy/electrode/probes *.prb recursive-include phy/plot/glsl *.vert *.frag *.glsl +recursive-include phy/plot/static *.npy *.gz *.txt recursive-include phy/cluster/manual/static *.html *.css recursive-include phy/gui/static *.html *.css *.js recursive-exclude * __pycache__ diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 3c02aff79..64b17760e 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -18,7 +18,8 @@ from .panzoom import PanZoom from .transform import NDC from .utils import _get_array -from .visuals import ScatterVisual, PlotVisual, HistogramVisual, LineVisual +from .visuals import (ScatterVisual, PlotVisual, HistogramVisual, + LineVisual, TextVisual) #------------------------------------------------------------------------------ @@ -113,6 +114,10 @@ def hist(self, *args, **kwargs): """Add some histograms.""" return self._add_item(HistogramVisual, *args, **kwargs) + def text(self, *args, **kwargs): + """Add text.""" + return self._add_item(TextVisual, *args, **kwargs) + def lines(self, *args, **kwargs): """Add some lines.""" return self._add_item(LineVisual, *args, **kwargs) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 62f5cbef3..6346ae202 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -118,6 +118,15 @@ def test_grid_lines(qtbot): _show(qtbot, view) +def test_grid_text(qtbot): + view = View(layout='grid', shape=(2, 1)) + + view[0, 0].text(pos=(0, 0), text='Hello world!') + view[1, 0].text(pos=[[-.5, 0], [+.5, 0]], text=['|', ':)']) + + _show(qtbot, view) + + def test_grid_complete(qtbot): view = View(layout='grid', shape=(2, 2)) t = _get_linear_x(1, 1000).ravel() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index bcd60158b..686d060ac 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -11,6 +11,7 @@ import os.path as op import numpy as np +from six import string_types from vispy.gloo import Texture2D from .base import BaseVisual @@ -338,6 +339,8 @@ def __init__(self, color=None): self._tex = np.load(f) with open(op.join(curdir, 'static', 'chars.txt'), 'r') as f: self._chars = f.read() + self.color = color if color is not None else self._default_color + assert len(self.color) == 4 def _get_glyph_indices(self, s): return [self._chars.index(char) for char in s] @@ -350,20 +353,19 @@ def validate(pos=None, text=None, color=None, data_bounds=None): assert pos.shape[1] == 2 n_text = pos.shape[0] + if isinstance(text, string_types): + text = [text] assert len(text) == n_text - # Color. - color = color if color is not None else TextVisual._default_color - assert len(color) == 4 - # By default, we assume that the coordinates are in NDC. if data_bounds is None: data_bounds = NDC - data_bounds = _get_data_bounds(data_bounds) + data_bounds = _get_data_bounds(data_bounds, pos) + assert data_bounds.shape[0] == n_text data_bounds = data_bounds.astype(np.float64) - assert data_bounds.shape == (1, 4) + assert data_bounds.shape == (n_text, 4) - return Bunch(pos=pos, text=text, color=color, data_bounds=data_bounds) + return Bunch(pos=pos, text=text, data_bounds=data_bounds) @staticmethod def vertex_count(pos=None, **kwargs): @@ -378,8 +380,6 @@ def set_data(self, *args, **kwargs): assert pos.shape[1] == 2 assert pos.dtype == np.float64 - # TODO: color - # Concatenate all strings. text = data.text lengths = list(map(len, text)) @@ -412,7 +412,11 @@ def set_data(self, *args, **kwargs): assert a_char_index.shape == (n_vertices,) # Transform the positions. - self.data_range.from_bounds = data.data_bounds + data_bounds = data.data_bounds + data_bounds = np.repeat(data_bounds, lengths, axis=0) + data_bounds = np.repeat(data_bounds, 6, axis=0) + assert data_bounds.shape == (n_vertices, 4) + self.data_range.from_bounds = data_bounds pos_tr = self.transforms.apply(a_position) assert pos_tr.shape == (n_vertices, 2) @@ -422,6 +426,7 @@ def set_data(self, *args, **kwargs): self.program['a_char_index'] = a_char_index.astype(np.float32) self.program['u_glyph_size'] = glyph_size + # TODO: color self.program['u_tex'] = Texture2D(tex[::-1, :]) diff --git a/setup.py b/setup.py index be69360bb..2554c9f38 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ def _package_tree(pkgroot): packages=_package_tree('phy'), package_dir={'phy': 'phy'}, package_data={ - 'phy': ['*.vert', '*.frag', '*.glsl', + 'phy': ['*.vert', '*.frag', '*.glsl', '*.npy', '*.gz', '*.txt', '*.html', '*.css', '*.js', '*.prb'], }, entry_points={ From e180d040333d3edf103c9aecfdfd960b4250edaf Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 5 Feb 2016 14:41:11 +0100 Subject: [PATCH 0986/1059] Add text anchor --- phy/plot/glsl/text.vert | 3 +++ phy/plot/visuals.py | 24 +++++++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/phy/plot/glsl/text.vert b/phy/plot/glsl/text.vert index fd3bd9941..1a78c28c7 100644 --- a/phy/plot/glsl/text.vert +++ b/phy/plot/glsl/text.vert @@ -3,6 +3,8 @@ attribute vec2 a_position; // text position attribute float a_glyph_index; // glyph index in the text attribute float a_quad_index; // quad index in the glyph attribute float a_char_index; // index of the glyph in the texture +attribute float a_lengths; +attribute float a_anchor; uniform vec2 u_glyph_size; // (w, h) @@ -24,6 +26,7 @@ void main() { // Position of the glyph. gl_Position = transform(a_position); gl_Position.xy = gl_Position.xy + vec2(a_glyph_index * w + dx * w, dy * h); + gl_Position.x += (a_anchor - .5) * a_lengths * w; // Index in the texture float i = floor(a_char_index / cols); diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 686d060ac..cdf49c09e 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -346,7 +346,8 @@ def _get_glyph_indices(self, s): return [self._chars.index(char) for char in s] @staticmethod - def validate(pos=None, text=None, color=None, data_bounds=None): + def validate(pos=None, text=None, anchor=None, + data_bounds=None): assert pos is not None pos = np.atleast_2d(pos) assert pos.ndim == 2 @@ -357,6 +358,11 @@ def validate(pos=None, text=None, color=None, data_bounds=None): text = [text] assert len(text) == n_text + anchor = anchor if anchor is not None else 0. + if not hasattr(anchor, '__len__'): + anchor = [anchor] * n_text + assert len(anchor) == n_text + # By default, we assume that the coordinates are in NDC. if data_bounds is None: data_bounds = NDC @@ -365,7 +371,8 @@ def validate(pos=None, text=None, color=None, data_bounds=None): data_bounds = data_bounds.astype(np.float64) assert data_bounds.shape == (n_text, 4) - return Bunch(pos=pos, text=text, data_bounds=data_bounds) + return Bunch(pos=pos, text=text, anchor=anchor, + data_bounds=data_bounds) @staticmethod def vertex_count(pos=None, **kwargs): @@ -400,16 +407,25 @@ def set_data(self, *args, **kwargs): a_glyph_index = np.concatenate([np.arange(n) for n in lengths]) a_quad_index = np.arange(6) + a_anchor = data.anchor + a_position = np.repeat(a_position, 6, axis=0) a_glyph_index = np.repeat(a_glyph_index, 6) a_quad_index = np.tile(a_quad_index, n_glyphs) a_char_index = np.repeat(a_char_index, 6) + a_anchor = np.repeat(a_anchor, lengths) + a_anchor = np.repeat(a_anchor, 6) + + a_lengths = np.repeat(lengths, lengths) + a_lengths = np.repeat(a_lengths, 6) + n_vertices = n_glyphs * 6 assert a_position.shape == (n_vertices, 2) assert a_glyph_index.shape == (n_vertices,) assert a_quad_index.shape == (n_vertices,) - assert a_char_index.shape == (n_vertices,) + assert a_anchor.shape == (n_vertices,) + assert a_lengths.shape == (n_vertices,) # Transform the positions. data_bounds = data.data_bounds @@ -424,6 +440,8 @@ def set_data(self, *args, **kwargs): self.program['a_glyph_index'] = a_glyph_index.astype(np.float32) self.program['a_quad_index'] = a_quad_index.astype(np.float32) self.program['a_char_index'] = a_char_index.astype(np.float32) + self.program['a_anchor'] = a_anchor.astype(np.float32) + self.program['a_lengths'] = a_lengths.astype(np.float32) self.program['u_glyph_size'] = glyph_size # TODO: color From e98990fa51c23a30c173246d26c5a41f62edd54c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 5 Feb 2016 18:00:45 +0100 Subject: [PATCH 0987/1059] Update text test --- phy/plot/tests/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index 6346ae202..e747e1e37 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -121,7 +121,7 @@ def test_grid_lines(qtbot): def test_grid_text(qtbot): view = View(layout='grid', shape=(2, 1)) - view[0, 0].text(pos=(0, 0), text='Hello world!') + view[0, 0].text(pos=(0, 0), text='Hello world!', anchor=0.) view[1, 0].text(pos=[[-.5, 0], [+.5, 0]], text=['|', ':)']) _show(qtbot, view) From 7b78ac91ae863756814ba3ecbbae6875d0ac0b7c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 7 Feb 2016 11:38:45 +0100 Subject: [PATCH 0988/1059] Add polygon visual --- phy/plot/tests/test_visuals.py | 19 +++++++++++- phy/plot/visuals.py | 57 +++++++++++++++++++++++++++++++--- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 83e218b73..f02c78ee1 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -10,7 +10,7 @@ import numpy as np from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, - LineVisual, TextVisual, + LineVisual, PolygonVisual, TextVisual, ) @@ -177,6 +177,23 @@ def test_line_0(qtbot, canvas_pz): pos=pos, color=color, data_bounds=[-1, -1, 1, 1]) +#------------------------------------------------------------------------------ +# Test polygon visual +#------------------------------------------------------------------------------ + +def test_polygon_empty(qtbot, canvas): + pos = np.zeros((0, 2)) + _test_visual(qtbot, canvas, PolygonVisual(), pos=pos) + + +def test_polygon_0(qtbot, canvas_pz): + n = 9 + x = .5 * np.cos(np.linspace(0., 2 * np.pi, n)) + y = .5 * np.sin(np.linspace(0., 2 * np.pi, n)) + pos = np.c_[x, y] + _test_visual(qtbot, canvas_pz, PolygonVisual(), pos=pos) + + #------------------------------------------------------------------------------ # Test text visual #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index cdf49c09e..63bcb1f58 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -450,11 +450,7 @@ def set_data(self, *args, **kwargs): class LineVisual(BaseVisual): - """Lines. - - Note: currently, all lines shall have the same color. - - """ + """Lines.""" _default_color = (.3, .3, .3, 1.) def __init__(self, color=None): @@ -511,3 +507,54 @@ def set_data(self, *args, **kwargs): # Color. color = np.repeat(data.color, 2, axis=0) self.program['a_color'] = color.astype(np.float32) + + +class PolygonVisual(BaseVisual): + """Polygon.""" + _default_color = (1., 1., 1., 1.) + + def __init__(self): + super(PolygonVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('line_loop') + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def validate(pos=None, data_bounds=None): + assert pos is not None + pos = np.atleast_2d(pos) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + + # By default, we assume that the coordinates are in NDC. + if data_bounds is None: + data_bounds = NDC + data_bounds = _get_data_bounds(data_bounds) + data_bounds = data_bounds.astype(np.float64) + assert data_bounds.shape == (1, 4) + + return Bunch(pos=pos, data_bounds=data_bounds) + + @staticmethod + def vertex_count(pos=None, **kwargs): + """Take the output of validate() as input.""" + return pos.shape[0] + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + pos = data.pos + assert pos.ndim == 2 + assert pos.shape[1] == 2 + assert pos.dtype == np.float64 + n_vertices = pos.shape[0] + + # Transform the positions. + self.data_range.from_bounds = data.data_bounds + pos_tr = self.transforms.apply(pos) + + # Position. + assert pos_tr.shape == (n_vertices, 2) + self.program['a_position'] = pos_tr.astype(np.float32) + + self.program['u_color'] = self._default_color From c5947d845264f3661b4ce8d853195cc0470f9c3d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 7 Feb 2016 12:13:01 +0100 Subject: [PATCH 0989/1059] WIP: lasso tool --- phy/plot/utils.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/phy/plot/utils.py b/phy/plot/utils.py index ff5b3e6c4..f5d3ef70f 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -13,6 +13,7 @@ import numpy as np from vispy import gloo + from .transform import Range, NDC logger = logging.getLogger(__name__) @@ -228,6 +229,49 @@ def _get_linear_x(n_signals, n_samples): return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) +#------------------------------------------------------------------------------ +# Interactive tools +#------------------------------------------------------------------------------ + +class Lasso(object): + def __init__(self): + self._points = [] + + def add(self, pos): + self._points.append(pos) + + @property + def points(self): + l = self._points + # Close the loop. + if l: + l.append(l[0]) + out = np.array(l) + assert out.ndim == 2 + assert out.shape[1] == 2 + return l + + def clear(self): + self._points = [] + + def in_polygon(self, points): + pass + + def attach(self, canvas): + canvas.connect(self.on_mouse_press) + from .visuals import PolygonVisual + self.visual = PolygonVisual() + canvas.add_visual(self.visual) + + def on_mouse_press(self, e): + if 'Control' in e.modifiers: + if e.button == 1: + self.add(e.pos) + else: + self.clear() + self.visual.set_data(pos=self.points) + + #------------------------------------------------------------------------------ # Misc #------------------------------------------------------------------------------ From 92fbaf2f48ff965da1540d53dde3403d683e8dba Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 7 Feb 2016 16:34:23 +0100 Subject: [PATCH 0990/1059] Add lasso --- phy/plot/tests/test_visuals.py | 30 ++++++++++++++++ phy/plot/utils.py | 44 ----------------------- phy/plot/visuals.py | 65 +++++++++++++++++++++++++++++++--- 3 files changed, 91 insertions(+), 48 deletions(-) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index f02c78ee1..5cf0b77fa 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -8,9 +8,12 @@ #------------------------------------------------------------------------------ import numpy as np +from numpy.testing import assert_array_equal as ae +from vispy.util import keys from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, LineVisual, PolygonVisual, TextVisual, + Lasso, ) @@ -194,6 +197,33 @@ def test_polygon_0(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, PolygonVisual(), pos=pos) +def test_lasso(qtbot, canvas_pz): + v = TextVisual() + canvas_pz.add_visual(v) + v.set_data(text="Hello") + + l = Lasso() + l.attach(canvas_pz) + canvas_pz.show() + qtbot.waitForWindowShown(canvas_pz.native) + + ev = canvas_pz.events + ev.mouse_press(pos=(0, 0), button=1, modifiers=(keys.CONTROL,)) + l.add((+1, -1)) + l.add((+1, +1)) + l.add((-1, +1)) + assert l.count == 4 + assert l.polygon.shape == (4, 2) + b = [[-1, -1], [+1, -1], [+1, +1], [-1, +1]] + ae(l.in_polygon(b), [False, False, True, True]) + + ev.mouse_press(pos=(0, 0), button=2, modifiers=(keys.CONTROL,)) + assert l.count == 0 + + # qtbot.stop() + canvas_pz.close() + + #------------------------------------------------------------------------------ # Test text visual #------------------------------------------------------------------------------ diff --git a/phy/plot/utils.py b/phy/plot/utils.py index f5d3ef70f..ff5b3e6c4 100644 --- a/phy/plot/utils.py +++ b/phy/plot/utils.py @@ -13,7 +13,6 @@ import numpy as np from vispy import gloo - from .transform import Range, NDC logger = logging.getLogger(__name__) @@ -229,49 +228,6 @@ def _get_linear_x(n_signals, n_samples): return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) -#------------------------------------------------------------------------------ -# Interactive tools -#------------------------------------------------------------------------------ - -class Lasso(object): - def __init__(self): - self._points = [] - - def add(self, pos): - self._points.append(pos) - - @property - def points(self): - l = self._points - # Close the loop. - if l: - l.append(l[0]) - out = np.array(l) - assert out.ndim == 2 - assert out.shape[1] == 2 - return l - - def clear(self): - self._points = [] - - def in_polygon(self, points): - pass - - def attach(self, canvas): - canvas.connect(self.on_mouse_press) - from .visuals import PolygonVisual - self.visual = PolygonVisual() - canvas.add_visual(self.visual) - - def on_mouse_press(self, e): - if 'Control' in e.modifiers: - if e.button == 1: - self.add(e.pos) - else: - self.clear() - self.visual.set_data(pos=self.points) - - #------------------------------------------------------------------------------ # Misc #------------------------------------------------------------------------------ diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 63bcb1f58..e78117dc8 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -15,7 +15,7 @@ from vispy.gloo import Texture2D from .base import BaseVisual -from .transform import Range, NDC +from .transform import Range, NDC, pixels_to_ndc from .utils import (_tesselate_histogram, _get_texture, _get_array, @@ -23,6 +23,7 @@ _get_pos, _get_index, ) +from phy.io.array import _in_polygon from phy.utils import Bunch @@ -348,14 +349,18 @@ def _get_glyph_indices(self, s): @staticmethod def validate(pos=None, text=None, anchor=None, data_bounds=None): + + if isinstance(text, string_types): + text = [text] + + if pos is None: + pos = np.zeros((len(text), 2)) + assert pos is not None pos = np.atleast_2d(pos) assert pos.ndim == 2 assert pos.shape[1] == 2 n_text = pos.shape[0] - - if isinstance(text, string_types): - text = [text] assert len(text) == n_text anchor = anchor if anchor is not None else 0. @@ -558,3 +563,55 @@ def set_data(self, *args, **kwargs): self.program['a_position'] = pos_tr.astype(np.float32) self.program['u_color'] = self._default_color + + +#------------------------------------------------------------------------------ +# Interactive tools +#------------------------------------------------------------------------------ + +class Lasso(object): + def __init__(self): + self._points = [] + + def add(self, pos): + self._points.append(pos) + + @property + def polygon(self): + l = self._points + # # Close the loop. + # if l: + # l.append(l[0]) + out = np.array(l, dtype=np.float64) + out = np.reshape(out, (out.size // 2, 2)) + assert out.ndim == 2 + assert out.shape[1] == 2 + return out + + def clear(self): + self._points = [] + + @property + def count(self): + return len(self._points) + + def in_polygon(self, pos): + return _in_polygon(pos, self.polygon) + + def attach(self, canvas): + canvas.connect(self.on_mouse_press) + from .visuals import PolygonVisual + self.visual = PolygonVisual() + canvas.add_visual(self.visual) + self.visual.set_data(pos=self.polygon) + self.canvas = canvas + + def on_mouse_press(self, e): + if 'Control' in e.modifiers: + if e.button == 1: + pos = self.canvas.panzoom.get_mouse_pos(e.pos) + self.add(pos) + else: + self.clear() + self.visual.set_data(pos=self.polygon) + self.canvas.update() From 9283b2452519d3a85afeb1f91ce44d5781eebd84 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Sun, 7 Feb 2016 16:34:50 +0100 Subject: [PATCH 0991/1059] Flakify --- phy/plot/visuals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index e78117dc8..f880adaad 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -15,7 +15,7 @@ from vispy.gloo import Texture2D from .base import BaseVisual -from .transform import Range, NDC, pixels_to_ndc +from .transform import Range, NDC from .utils import (_tesselate_histogram, _get_texture, _get_array, From ee1bc799df20d1b2ccfa890823b8ae4e414cc9d0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 8 Feb 2016 12:06:06 +0100 Subject: [PATCH 0992/1059] WIP: update concat_per_cluster --- phy/cluster/manual/controller.py | 10 ++++------ phy/io/array.py | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 998db5060..48b94bf9c 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -76,9 +76,10 @@ def _init_context(self): self.context = Context(self.cache_dir) ctx = self.context - self.get_masks = ctx.cache(self.get_masks) - self.get_features = ctx.cache(self.get_features) - self.get_waveforms = ctx.cache(self.get_waveforms) + self.get_masks = concat_per_cluster(ctx.cache(self.get_masks)) + self.get_features = concat_per_cluster(ctx.cache(self.get_features)) + self.get_waveforms = concat_per_cluster(ctx.cache(self.get_waveforms)) + self.get_background_features = ctx.cache(self.get_background_features) self.get_mean_masks = ctx.memcache(self.get_mean_masks) @@ -118,7 +119,6 @@ def _data_lim(self, arr, n_max): # ------------------------------------------------------------------------- # Is cached in _init_context() - @concat_per_cluster def get_masks(self, cluster_id): return self._select_data(cluster_id, self.all_masks, @@ -132,7 +132,6 @@ def get_mean_masks(self, cluster_id): # ------------------------------------------------------------------------- # Is cached in _init_context() - @concat_per_cluster def get_waveforms(self, cluster_id): return [self._select_data(cluster_id, self.all_waveforms, @@ -168,7 +167,6 @@ def get_waveforms_amplitude(self, cluster_id): # ------------------------------------------------------------------------- # Is cached in _init_context() - @concat_per_cluster def get_features(self, cluster_id): return self._select_data(cluster_id, self.all_features, diff --git a/phy/io/array.py b/phy/io/array.py index 60b6ca8fc..81ad505fb 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -214,12 +214,12 @@ def concat_per_cluster(f): """Take a function accepting a single cluster, and return a function accepting multiple clusters.""" @wraps(f) - def wrapped(self, cluster_ids): + def wrapped(cluster_ids): # Single cluster. if not hasattr(cluster_ids, '__len__'): - return f(self, cluster_ids) + return f(cluster_ids) # Concatenate the result of multiple clusters. - l = [f(self, c) for c in cluster_ids] + l = [f(c) for c in cluster_ids] # Handle the case where every function returns a list of Bunch. if l and isinstance(l[0], list): # We assume that all items have the same length. From d0389649d4c494494224e9524a6b61836b7955f3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 12:10:05 +0100 Subject: [PATCH 0993/1059] Add polygon shaders --- phy/plot/glsl/polygon.frag | 5 +++++ phy/plot/glsl/polygon.vert | 7 +++++++ 2 files changed, 12 insertions(+) create mode 100644 phy/plot/glsl/polygon.frag create mode 100644 phy/plot/glsl/polygon.vert diff --git a/phy/plot/glsl/polygon.frag b/phy/plot/glsl/polygon.frag new file mode 100644 index 000000000..8df552c17 --- /dev/null +++ b/phy/plot/glsl/polygon.frag @@ -0,0 +1,5 @@ +uniform vec4 u_color; + +void main() { + gl_FragColor = u_color; +} diff --git a/phy/plot/glsl/polygon.vert b/phy/plot/glsl/polygon.vert new file mode 100644 index 000000000..a82e7671b --- /dev/null +++ b/phy/plot/glsl/polygon.vert @@ -0,0 +1,7 @@ +attribute vec2 a_position; +uniform vec4 u_color; + +void main() { + gl_Position = transform(a_position); + gl_Position.z = -.5; +} From 952cca84c86dc74e20eb5443c192d3effb285027 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 13:40:56 +0100 Subject: [PATCH 0994/1059] WIP: update transforms --- phy/plot/tests/test_transform.py | 7 ++---- phy/plot/transform.py | 41 +++++++++++++++++++------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py index c8ea810ae..67c507e83 100644 --- a/phy/plot/tests/test_transform.py +++ b/phy/plot/tests/test_transform.py @@ -42,11 +42,8 @@ def _check(transform, array, expected): expected = np.array(expected, dtype=np.float64) _check_forward(transform, array, expected) # Test the inverse transform if it is implemented. - try: - inv = transform.inverse() - _check_forward(inv, expected, array) - except NotImplementedError: - pass + inv = transform.inverse() + _check_forward(inv, expected, array) #------------------------------------------------------------------------------ diff --git a/phy/plot/transform.py b/phy/plot/transform.py index d7b94d2f7..de6f90fec 100644 --- a/phy/plot/transform.py +++ b/phy/plot/transform.py @@ -132,9 +132,10 @@ def inverse(self): class Translate(BaseTransform): - def apply(self, arr): + def apply(self, arr, value=None): assert isinstance(arr, np.ndarray) - return arr + np.asarray(self.value) + value = value if value is not None else self.value + return arr + np.asarray(value) def glsl(self, var): assert var @@ -149,8 +150,9 @@ def inverse(self): class Scale(BaseTransform): - def apply(self, arr): - return arr * np.asarray(self.value) + def apply(self, arr, value=None): + value = value if value is not None else self.value + return arr * np.asarray(value) def glsl(self, var): assert var @@ -169,14 +171,15 @@ def __init__(self, from_bounds=None, to_bounds=None): self.from_bounds = from_bounds if from_bounds is not None else NDC self.to_bounds = to_bounds if to_bounds is not None else NDC - def apply(self, arr): - self.from_bounds = np.asarray(self.from_bounds) - self.to_bounds = np.asarray(self.to_bounds) - - f0 = np.asarray(self.from_bounds[..., :2]) - f1 = np.asarray(self.from_bounds[..., 2:]) - t0 = np.asarray(self.to_bounds[..., :2]) - t1 = np.asarray(self.to_bounds[..., 2:]) + def apply(self, arr, from_bounds=None, to_bounds=None): + from_bounds = np.asarray(from_bounds if from_bounds is not None + else self.from_bounds) + to_bounds = np.asarray(to_bounds if to_bounds is not None + else self.to_bounds) + f0 = from_bounds[..., :2] + f1 = from_bounds[..., 2:] + t0 = to_bounds[..., :2] + t1 = to_bounds[..., 2:] return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) @@ -200,11 +203,12 @@ def __init__(self, bounds=None): super(Clip, self).__init__() self.bounds = bounds or NDC - def apply(self, arr): - index = ((arr[:, 0] >= self.bounds[0]) & - (arr[:, 1] >= self.bounds[1]) & - (arr[:, 0] <= self.bounds[2]) & - (arr[:, 1] <= self.bounds[3])) + def apply(self, arr, bounds=None): + bounds = bounds if bounds is not None else self.bounds + index = ((arr[:, 0] >= bounds[0]) & + (arr[:, 1] >= bounds[1]) & + (arr[:, 0] <= bounds[2]) & + (arr[:, 1] <= bounds[3])) return arr[index, ...] def glsl(self, var): @@ -220,6 +224,9 @@ def glsl(self, var): }} """.format(bounds=bounds, var=var) + def inverse(self): + return self + class Subplot(Range): """Assume that the from_bounds is [-1, -1, 1, 1].""" From c952bc4a3b0963f75981a01993d5d0b09f4254c8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 13:50:40 +0100 Subject: [PATCH 0995/1059] Add map/imap in PanZoom --- phy/plot/panzoom.py | 15 +++++++++++++-- phy/plot/tests/test_panzoom.py | 11 +++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py index 050429b44..252a12b86 100644 --- a/phy/plot/panzoom.py +++ b/phy/plot/panzoom.py @@ -99,6 +99,8 @@ def __init__(self, # Will be set when attached to a canvas. self.canvas = None + self._translate = Translate(self.pan_var_name) + self._scale = Scale(self.zoom_var_name) # Various properties # ------------------------------------------------------------------------- @@ -463,8 +465,7 @@ def attach(self, canvas): super(PanZoom, self).attach(canvas) canvas.panzoom = self - canvas.transforms.add_on_gpu([Translate(self.pan_var_name), - Scale(self.zoom_var_name)]) + canvas.transforms.add_on_gpu([self._translate, self._scale]) # Add the variable declarations. vs = ('uniform vec2 {};\n'.format(self.pan_var_name) + 'uniform vec2 {};\n'.format(self.zoom_var_name)) @@ -480,6 +481,16 @@ def attach(self, canvas): self._set_canvas_aspect() + def map(self, arr): + arr = Translate(self.pan).apply(arr) + arr = Scale(self.zoom).apply(arr) + return arr + + def imap(self, arr): + arr = Scale(self.zoom).inverse().apply(arr) + arr = Translate(self.pan).inverse().apply(arr) + return arr + def update_program(self, program): program[self.pan_var_name] = self._pan program[self.zoom_var_name] = self._zoom_aspect() diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py index 62479aa1d..94659c29d 100644 --- a/phy/plot/tests/test_panzoom.py +++ b/phy/plot/tests/test_panzoom.py @@ -112,6 +112,17 @@ def test_panzoom_basic_pan_zoom(): assert pz.zoom[1] > 3 * pz.zoom[0] +def test_panzoom_map(): + pz = PanZoom() + pz.pan = (1., -1.) + ac(pz.map([0., 0.]), [[1., -1.]]) + + pz.zoom = (2., .5) + ac(pz.map([0., 0.]), [[2., -.5]]) + + ac(pz.imap([2., -.5]), [[0., 0.]]) + + def test_panzoom_constraints_x(): pz = PanZoom() pz.xmin, pz.xmax = -2, 2 From 280f894794f76395e989dcaeb1ac9ced8c9c7c28 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 14:10:50 +0100 Subject: [PATCH 0996/1059] Add map and imap in interacts --- phy/plot/interact.py | 38 +++++++++++++++++++++++++++------ phy/plot/tests/test_interact.py | 22 +++++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index a439c9d17..9e439c6fd 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -43,21 +43,35 @@ def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): self.box_var = box_var or 'a_box_index' self.shape_var = shape_var self._shape = shape + ms = 1 - self.margin + mc = 1 - self.margin + self._transforms = [Scale((ms, ms)), + Clip([-mc, -mc, +mc, +mc]), + Subplot(self.shape_var, self.box_var), + ] def attach(self, canvas): super(Grid, self).attach(canvas) - ms = 1 - self.margin - mc = 1 - self.margin - canvas.transforms.add_on_gpu([Scale((ms, ms)), - Clip([-mc, -mc, +mc, +mc]), - Subplot(self.shape_var, self.box_var), - ]) + canvas.transforms.add_on_gpu(self._transforms) canvas.inserter.insert_vert(""" attribute vec2 {}; uniform vec2 {}; """.format(self.box_var, self.shape_var), 'header') + def map(self, arr, box=None): + assert box is not None + assert len(box) == 2 + arr = self._transforms[0].apply(arr) + arr = Subplot(self.shape, box).apply(arr) + return arr + + def imap(self, arr, box=None): + assert box is not None + arr = Subplot(self.shape, box).inverse().apply(arr) + arr = self._transforms[0].inverse().apply(arr) + return arr + def add_boxes(self, canvas, shape=None): shape = shape or self.shape assert isinstance(shape, tuple) @@ -157,10 +171,11 @@ def __init__(self, assert self._box_bounds.shape[1] == 4 self.n_boxes = len(self._box_bounds) + self._transforms = [Range(NDC, 'box_bounds')] def attach(self, canvas): super(Boxed, self).attach(canvas) - canvas.transforms.add_on_gpu([Range(NDC, 'box_bounds')]) + canvas.transforms.add_on_gpu(self._transforms) canvas.inserter.insert_vert(""" #include "utils.glsl" attribute float {}; @@ -174,6 +189,15 @@ def attach(self, canvas): box_bounds = (2 * box_bounds - 1); // See hack in Python. """.format(self.box_var), 'before_transforms') + def map(self, arr, box=None): + assert box is not None + assert 0 <= box < len(self.box_bounds) + return Range(NDC, self.box_bounds[box]).apply(arr) + + def imap(self, arr, box=None): + assert 0 <= box < len(self.box_bounds) + return Range(NDC, self.box_bounds[box]).inverse().apply(arr) + def update_program(self, program): # Signal bounds (positions). box_bounds = _get_texture(self._box_bounds, NDC, self.n_boxes, [-1, 1]) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index 0d72117ee..fe9a1e7ec 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -73,6 +73,15 @@ def _create_visual(qtbot, canvas, interact, box_index): # Test grid #------------------------------------------------------------------------------ +def test_grid_interact(): + grid = Grid((4, 8)) + ac(grid.map([0., 0.], (0, 0)), [[-0.875, 0.75]]) + ac(grid.map([0., 0.], (1, 3)), [[-0.125, 0.25]]) + ac(grid.map([0., 0.], (3, 7)), [[0.875, -0.75]]) + + ac(grid.imap([[0.875, -0.75]], (3, 7)), [[0., 0.]]) + + def test_grid_1(qtbot, canvas): n = 1000 @@ -145,6 +154,19 @@ def test_boxed_2(qtbot, canvas): # qtbot.stop() +def test_boxed_interact(): + + n = 8 + b = np.zeros((n, 4)) + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 1. / 4., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 1. / 4., 1., n) + + boxed = Boxed(box_bounds=b) + ac(boxed.map([0., 0.], 0), [[-.875, -.875]]) + ac(boxed.map([0., 0.], 7), [[.875, .875]]) + ac(boxed.imap([[.875, .875]], 7), [[0., 0.]]) + + def test_stacked_1(qtbot, canvas): n = 1000 From 509d70bf22b3ca0bc049a619749f2546529a9694 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 14:53:31 +0100 Subject: [PATCH 0997/1059] Add interact.get_closest_box() --- phy/plot/interact.py | 14 ++++++++++++ phy/plot/tests/test_interact.py | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index 9e439c6fd..b533b427f 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -103,6 +103,13 @@ def _remove_clip(tc): boxes.set_data(pos=pos) boxes.program['a_box_index'] = box_index.astype(np.float32) + def get_closest_box(self, pos): + x, y = pos + rows, cols = self.shape + j = np.clip(int(cols * (1. + x) / 2.), 0, cols - 1) + i = np.clip(int(rows * (1. - y) / 2.), 0, rows - 1) + return i, j + def update_program(self, program): program[self.shape_var] = self._shape # Only set the default box index if necessary. @@ -220,6 +227,13 @@ def box_bounds(self, val): self._box_bounds = val self.update() + def get_closest_box(self, pos): + pos = np.atleast_2d(pos) + d = np.sum((pos - self.box_pos) ** 2, axis=0) + i = np.argmin(d) + assert 0 <= i < self.n_boxes + return i + @property def box_pos(self): box_pos, _ = _get_box_pos_size(self._box_bounds) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py index fe9a1e7ec..7255bd82a 100644 --- a/phy/plot/tests/test_interact.py +++ b/phy/plot/tests/test_interact.py @@ -82,6 +82,15 @@ def test_grid_interact(): ac(grid.imap([[0.875, -0.75]], (3, 7)), [[0., 0.]]) +def test_grid_closest_box(): + grid = Grid((3, 7)) + ac(grid.get_closest_box((0., 0.)), (1, 3)) + ac(grid.get_closest_box((-1., +1.)), (0, 0)) + ac(grid.get_closest_box((+1., -1.)), (2, 6)) + ac(grid.get_closest_box((-1., -1.)), (2, 0)) + ac(grid.get_closest_box((+1., +1.)), (0, 6)) + + def test_grid_1(qtbot, canvas): n = 1000 @@ -112,6 +121,10 @@ def test_grid_2(qtbot, canvas): # qtbot.stop() +#------------------------------------------------------------------------------ +# Test boxed +#------------------------------------------------------------------------------ + def test_boxed_1(qtbot, canvas): n = 6 @@ -167,6 +180,21 @@ def test_boxed_interact(): ac(boxed.imap([[.875, .875]], 7), [[0., 0.]]) +def test_boxed_closest_box(): + b = np.array([[-.5, -.5, 0., 0.], + [0., 0., +.5, +.5]]) + boxed = Boxed(box_bounds=b) + + ac(boxed.get_closest_box((-1, -1)), 0) + ac(boxed.get_closest_box((-0.001, 0)), 0) + ac(boxed.get_closest_box((+0.001, 0)), 1) + ac(boxed.get_closest_box((-1, +1)), 0) + + +#------------------------------------------------------------------------------ +# Test stacked +#------------------------------------------------------------------------------ + def test_stacked_1(qtbot, canvas): n = 1000 @@ -176,3 +204,13 @@ def test_stacked_1(qtbot, canvas): _create_visual(qtbot, canvas, stacked, box_index) # qtbot.stop() + + +def test_stacked_closest_box(): + stacked = Stacked(n_boxes=4, origin='upper') + ac(stacked.get_closest_box((-.5, .9)), 0) + ac(stacked.get_closest_box((+.5, -.9)), 3) + + stacked = Stacked(n_boxes=4, origin='lower') + ac(stacked.get_closest_box((-.5, .9)), 3) + ac(stacked.get_closest_box((+.5, -.9)), 0) From c85d2e76fb29e0f9967c0c421916d8399fc8767e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 15:53:19 +0100 Subject: [PATCH 0998/1059] Integrate lasso tool in View --- phy/io/array.py | 2 + phy/plot/interact.py | 7 -- phy/plot/plot.py | 115 +++++++++++++++++++++++++++++++-- phy/plot/tests/test_plot.py | 78 ++++++++++++++++++++++ phy/plot/tests/test_visuals.py | 31 +-------- phy/plot/visuals.py | 60 ++--------------- 6 files changed, 195 insertions(+), 98 deletions(-) diff --git a/phy/io/array.py b/phy/io/array.py index 81ad505fb..7694fa1dc 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -188,6 +188,8 @@ def _in_polygon(points, polygon): polygon = _as_array(polygon) assert points.ndim == 2 assert polygon.ndim == 2 + if len(polygon): + polygon = np.vstack((polygon, polygon[0])) path = Path(polygon, closed=True) return path.contains_points(points) diff --git a/phy/plot/interact.py b/phy/plot/interact.py index b533b427f..064ecda6a 100644 --- a/phy/plot/interact.py +++ b/phy/plot/interact.py @@ -227,13 +227,6 @@ def box_bounds(self, val): self._box_bounds = val self.update() - def get_closest_box(self, pos): - pos = np.atleast_2d(pos) - d = np.sum((pos - self.box_pos) ** 2, axis=0) - i = np.argmin(d) - assert 0 <= i < self.n_boxes - return i - @property def box_pos(self): box_pos, _ = _get_box_pos_size(self._box_bounds) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 64b17760e..4588e05a1 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -12,14 +12,15 @@ import numpy as np -from phy.io.array import _accumulate +from phy.io.array import _accumulate, _in_polygon +from phy.utils._types import _as_tuple from .base import BaseCanvas from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom from .transform import NDC from .utils import _get_array from .visuals import (ScatterVisual, PlotVisual, HistogramVisual, - LineVisual, TextVisual) + LineVisual, TextVisual, PolygonVisual) #------------------------------------------------------------------------------ @@ -49,15 +50,19 @@ class View(BaseCanvas): _default_box_index = (0,) def __init__(self, layout=None, shape=None, n_plots=None, origin=None, - box_bounds=None, box_pos=None, box_size=None, **kwargs): + box_bounds=None, box_pos=None, box_size=None, + enable_lasso=False, + **kwargs): if not kwargs.get('keys', None): kwargs['keys'] = None super(View, self).__init__(**kwargs) + self.layout = layout if layout == 'grid': self._default_box_index = (0, 0) self.grid = Grid(shape) self.grid.attach(self) + self.interact = self.grid elif layout == 'boxed': self.n_plots = (len(box_bounds) @@ -66,15 +71,26 @@ def __init__(self, layout=None, shape=None, n_plots=None, origin=None, box_pos=box_pos, box_size=box_size) self.boxed.attach(self) + self.interact = self.boxed elif layout == 'stacked': self.n_plots = n_plots self.stacked = Stacked(n_plots, margin=.1, origin=origin) self.stacked.attach(self) + self.interact = self.stacked + + else: + self.interact = None self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) self.panzoom.attach(self) + if enable_lasso: + self.lasso = Lasso() + self.lasso.attach(self) + else: + self.lasso = None + self.clear() def clear(self): @@ -91,7 +107,7 @@ def _add_item(self, cls, *args, **kwargs): n = cls.vertex_count(**data) if not isinstance(box_index, np.ndarray): - k = len(box_index) if hasattr(box_index, '__len__') else 1 + k = len(self._default_box_index) box_index = _get_array(box_index, (n, k)) data['box_index'] = box_index @@ -123,7 +139,7 @@ def lines(self, *args, **kwargs): return self._add_item(LineVisual, *args, **kwargs) def __getitem__(self, box_index): - self._default_box_index = box_index + self._default_box_index = _as_tuple(box_index) return self def build(self): @@ -146,11 +162,100 @@ def build(self): # after the next VisPy release. if 'a_box_index' in visual.program._code_variables: visual.program['a_box_index'] = box_index.astype(np.float32) + # TODO: refactor this when there is the possibility to update existing + # visuals without recreating the whole scene. + if self.lasso: + self.lasso.create_visual() self.update() + def get_pos_from_mouse(self, pos, box): + # From window coordinates to NDC (pan & zoom taken into account). + pos = self.panzoom.get_mouse_pos(pos) + # From NDC to data coordinates. + pos = self.interact.imap(pos, box) if self.interact else pos + return pos + @contextmanager def building(self): """Context manager to specify the plots.""" self.clear() yield self.build() + + +#------------------------------------------------------------------------------ +# Interactive tools +#------------------------------------------------------------------------------ + +class Lasso(object): + def __init__(self): + self._points = [] + self.view = None + self.visual = None + self.box = None + + def add(self, pos): + self._points.append(pos) + + @property + def polygon(self): + l = self._points + # Close the polygon. + # l = l + l[0] if len(l) else l + out = np.array(l, dtype=np.float64) + out = np.reshape(out, (out.size // 2, 2)) + assert out.ndim == 2 + assert out.shape[1] == 2 + return out + + def clear(self): + self._points = [] + + @property + def count(self): + return len(self._points) + + def in_polygon(self, pos): + return _in_polygon(pos, self.polygon) + + def attach(self, view): + view.connect(self.on_mouse_press) + self.view = view + + def create_visual(self): + self.visual = PolygonVisual() + self.view.add_visual(self.visual) + self.update_visual() + + def update_visual(self): + if not self.visual: + return + # Update the polygon. + self.visual.set_data(pos=self.polygon) + # Set the box index for the polygon, depending on the box + # where the first point was clicked in. + box = (self.box if self.box is not None + else self.view._default_box_index) + k = len(self.view._default_box_index) + n = self.visual.vertex_count(pos=self.polygon) + box_index = _get_array(box, (n, k)).astype(np.float32) + self.visual.program['a_box_index'] = box_index + self.view.update() + + def on_mouse_press(self, e): + if 'Control' in e.modifiers: + if e.button == 1: + # Find the box. + ndc = self.view.panzoom.get_mouse_pos(e.pos) + # NOTE: we don't update the box after the second point. + # In other words, the first point determines the box for the + # lasso. + if self.box is None and self.view.interact: + self.box = self.view.interact.get_closest_box(ndc) + # Transform from window coordinates to NDC. + pos = self.view.get_pos_from_mouse(e.pos, self.box) + self.add(pos) + else: + self.clear() + self.box = None + self.update_visual() diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index e747e1e37..b298e9f5d 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -8,9 +8,12 @@ #------------------------------------------------------------------------------ import numpy as np +from numpy.testing import assert_array_equal as ae +from vispy.util import keys from ..panzoom import PanZoom from ..plot import View +from ..transform import NDC from ..utils import _get_linear_x @@ -57,6 +60,10 @@ def test_simple_view(qtbot): _show(qtbot, view) +#------------------------------------------------------------------------------ +# Test visuals in grid +#------------------------------------------------------------------------------ + def test_grid_scatter(qtbot): view = View(layout='grid', shape=(2, 3)) n = 100 @@ -140,6 +147,10 @@ def test_grid_complete(qtbot): _show(qtbot, view) +#------------------------------------------------------------------------------ +# Test other interact +#------------------------------------------------------------------------------ + def test_stacked_complete(qtbot): view = View(layout='stacked', n_plots=3) @@ -172,3 +183,70 @@ def test_boxed_complete(qtbot): color=np.random.uniform(.4, .9, size=(5, 4))) _show(qtbot, view) + + +#------------------------------------------------------------------------------ +# Test lasso +#------------------------------------------------------------------------------ + +def test_lasso_simple(qtbot): + view = View(enable_lasso=True, keys='interactive') + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + + view.scatter(x, y) + + l = view.lasso + ev = view.events + ev.mouse_press(pos=(0, 0), button=1, modifiers=(keys.CONTROL,)) + l.add((+1, -1)) + l.add((+1, +1)) + l.add((-1, +1)) + assert l.count == 4 + assert l.polygon.shape == (4, 2) + b = [[-1, -1], [+1, -1], [+1, +1], [-1, +1]] + ae(l.in_polygon(b), [False, False, True, True]) + + ev.mouse_press(pos=(0, 0), button=2, modifiers=(keys.CONTROL,)) + assert l.count == 0 + + _show(qtbot, view) + + +def test_lasso_grid(qtbot): + view = View(layout='grid', shape=(1, 2), + enable_lasso=True, keys='interactive') + x, y = np.meshgrid(np.linspace(-1., 1., 20), np.linspace(-1., 1., 20)) + x, y = x.ravel(), y.ravel() + view[0, 1].scatter(x, y, data_bounds=NDC) + + l = view.lasso + ev = view.events + + # Square selection in the left panel. + ev.mouse_press(pos=(100, 100), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 0) + ev.mouse_press(pos=(200, 100), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(200, 200), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(100, 200), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 0) + + # Clear. + ev.mouse_press(pos=(100, 200), button=2, modifiers=(keys.CONTROL,)) + assert l.box is None + + # Square selection in the right panel. + ev.mouse_press(pos=(500, 100), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 1) + ev.mouse_press(pos=(700, 100), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(700, 300), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(500, 300), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 1) + + ind = l.in_polygon(np.c_[x, y]) + view[0, 1].scatter(x[ind], y[ind], color=(1., 0., 0., 1.), + data_bounds=NDC) + + _show(qtbot, view) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 5cf0b77fa..2f60c71dd 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -8,12 +8,9 @@ #------------------------------------------------------------------------------ import numpy as np -from numpy.testing import assert_array_equal as ae -from vispy.util import keys from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, LineVisual, PolygonVisual, TextVisual, - Lasso, ) @@ -197,33 +194,6 @@ def test_polygon_0(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, PolygonVisual(), pos=pos) -def test_lasso(qtbot, canvas_pz): - v = TextVisual() - canvas_pz.add_visual(v) - v.set_data(text="Hello") - - l = Lasso() - l.attach(canvas_pz) - canvas_pz.show() - qtbot.waitForWindowShown(canvas_pz.native) - - ev = canvas_pz.events - ev.mouse_press(pos=(0, 0), button=1, modifiers=(keys.CONTROL,)) - l.add((+1, -1)) - l.add((+1, +1)) - l.add((-1, +1)) - assert l.count == 4 - assert l.polygon.shape == (4, 2) - b = [[-1, -1], [+1, -1], [+1, +1], [-1, +1]] - ae(l.in_polygon(b), [False, False, True, True]) - - ev.mouse_press(pos=(0, 0), button=2, modifiers=(keys.CONTROL,)) - assert l.count == 0 - - # qtbot.stop() - canvas_pz.close() - - #------------------------------------------------------------------------------ # Test text visual #------------------------------------------------------------------------------ @@ -231,6 +201,7 @@ def test_lasso(qtbot, canvas_pz): def test_text_empty(qtbot, canvas): pos = np.zeros((0, 2)) _test_visual(qtbot, canvas, TextVisual(), pos=pos, text=[]) + _test_visual(qtbot, canvas, TextVisual()) def test_text_0(qtbot, canvas_pz): diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index f880adaad..19614e78f 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -23,7 +23,6 @@ _get_pos, _get_index, ) -from phy.io.array import _in_polygon from phy.utils import Bunch @@ -350,9 +349,10 @@ def _get_glyph_indices(self, s): def validate(pos=None, text=None, anchor=None, data_bounds=None): + if text is None: + text = [] if isinstance(text, string_types): text = [text] - if pos is None: pos = np.zeros((len(text), 2)) @@ -516,11 +516,11 @@ def set_data(self, *args, **kwargs): class PolygonVisual(BaseVisual): """Polygon.""" - _default_color = (1., 1., 1., 1.) + _default_color = (.5, .5, .5, 1.) def __init__(self): super(PolygonVisual, self).__init__() - self.set_shader('simple') + self.set_shader('polygon') self.set_primitive_type('line_loop') self.data_range = Range(NDC) self.transforms.add_on_cpu(self.data_range) @@ -563,55 +563,3 @@ def set_data(self, *args, **kwargs): self.program['a_position'] = pos_tr.astype(np.float32) self.program['u_color'] = self._default_color - - -#------------------------------------------------------------------------------ -# Interactive tools -#------------------------------------------------------------------------------ - -class Lasso(object): - def __init__(self): - self._points = [] - - def add(self, pos): - self._points.append(pos) - - @property - def polygon(self): - l = self._points - # # Close the loop. - # if l: - # l.append(l[0]) - out = np.array(l, dtype=np.float64) - out = np.reshape(out, (out.size // 2, 2)) - assert out.ndim == 2 - assert out.shape[1] == 2 - return out - - def clear(self): - self._points = [] - - @property - def count(self): - return len(self._points) - - def in_polygon(self, pos): - return _in_polygon(pos, self.polygon) - - def attach(self, canvas): - canvas.connect(self.on_mouse_press) - from .visuals import PolygonVisual - self.visual = PolygonVisual() - canvas.add_visual(self.visual) - self.visual.set_data(pos=self.polygon) - self.canvas = canvas - - def on_mouse_press(self, e): - if 'Control' in e.modifiers: - if e.button == 1: - pos = self.canvas.panzoom.get_mouse_pos(e.pos) - self.add(pos) - else: - self.clear() - self.visual.set_data(pos=self.polygon) - self.canvas.update() From c08a670d7a061ccd17dcc9c93a7471567576581c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 16:29:50 +0100 Subject: [PATCH 0999/1059] WIP: implement split with lasso --- phy/cluster/manual/gui_component.py | 7 +++- .../manual/tests/test_gui_component.py | 41 ++++++++++++++++++- phy/cluster/manual/views.py | 30 ++++++++++++++ phy/plot/plot.py | 4 +- 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index bf1e4a080..92d190f8b 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -11,6 +11,8 @@ from functools import partial import logging +import numpy as np + from ._history import GlobalHistory from ._utils import create_cluster_meta from .clustering import Clustering @@ -448,11 +450,12 @@ def merge(self, cluster_ids=None): self.clustering.merge(cluster_ids) self._global_history.action(self.clustering) - def split(self, spike_ids): + def split(self, spike_ids=None): """Split the selected spikes (NOT IMPLEMENTED YET).""" + if spike_ids is None: + spike_ids = np.concatenate(self.gui.emit('request_split')) if len(spike_ids) == 0: return - # TODO: connect to request_split emitted by view self.clustering.split(spike_ids) self._global_history.action(self.clustering) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 0db27c19e..bfb202faf 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -9,12 +9,14 @@ from pytest import yield_fixture, fixture import numpy as np from numpy.testing import assert_array_equal as ae +from vispy.util import keys from ..gui_component import (ManualClustering, ) from phy.io.array import _spikes_in_clusters -from phy.gui import GUI +from phy.gui import GUI, create_gui from phy.utils import Bunch +from .conftest import MockController #------------------------------------------------------------------------------ @@ -147,6 +149,43 @@ def test_manual_clustering_split_2(gui, quality, similarity): assert mc.selected == [2, 3] +def test_manual_clustering_split_lasso(tempdir, qtbot): + gui = create_gui(config_dir=tempdir) + gui.controller = MockController(tempdir) + gui.controller.set_manual_clustering(gui) + mc = gui.controller.manual_clustering + + gui.controller = MockController(tempdir) + view = gui.controller.add_feature_view(gui) + + gui.show() + + # Select one cluster. + mc.select(0) + + # Simulate a lasso. + ev = view.events + ev.mouse_press(pos=(190, 10), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(200, 10), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(200, 30), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(190, 30), button=1, modifiers=(keys.CONTROL,)) + + ups = [] + + @mc.clustering.connect + def on_cluster(up): + ups.append(up) + + mc.split() + up = ups[0] + assert up.description == 'assign' + assert up.added == [4, 5] + assert up.deleted == [0] + + # qtbot.stop() + gui.close() + + def test_manual_clustering_move_1(manual_clustering): mc = manual_clustering diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index a6b2a9957..8b6c8cae0 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -16,6 +16,7 @@ from phy.io.array import _index_of, _get_padded, get_excerpts from phy.gui import Actions from phy.plot import View, _get_linear_x +from phy.plot.transform import Range from phy.plot.utils import _get_boxes from phy.stats import correlograms from phy.utils import Bunch @@ -983,6 +984,7 @@ def __init__(self, # Initialize the view. super(FeatureView, self).__init__(layout='grid', shape=self.shape, + enable_lasso=True, **kwargs) # Feature normalization. @@ -1199,6 +1201,7 @@ def attach(self, gui): self.actions.add(self.toggle_automatic_channel_selection) gui.connect_(self.on_channel_click) + gui.connect_(self.on_request_split) @property def state(self): @@ -1221,6 +1224,33 @@ def on_channel_click(self, channel_idx=None, key=None, button=None): self.fixed_channels = True self.on_select() + def on_request_split(self): + """Return the spikes enclosed by the lasso.""" + if self.lasso.count < 3: + return [] + tla = self.top_left_attribute + assert self.channels + x_dim, y_dim = _dimensions_matrix(self.channels, + n_cols=self.n_cols, + top_left_attribute=tla) + data = self.features(self.cluster_ids) + spike_ids = data.spike_ids + f = data.data + i, j = self.lasso.box + + # TODO: refactor and load all features. + x = self._get_feature(x_dim[i, j], spike_ids, f) + y = self._get_feature(y_dim[i, j], spike_ids, f) + pos = np.c_[x, y] + + # Retrieve the data bounds. + data_bounds = self._get_dim_bounds(x_dim[i, j], y_dim[i, j]) + pos = Range(from_bounds=data_bounds).apply(pos) + + ind = self.lasso.in_polygon(pos) + self.lasso.clear() + return spike_ids[ind] + def toggle_automatic_channel_selection(self): """Toggle the automatic selection of channels when the cluster selection changes.""" diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 4588e05a1..68f25cca8 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -196,6 +196,7 @@ def __init__(self): def add(self, pos): self._points.append(pos) + self.update_visual() @property def polygon(self): @@ -210,6 +211,8 @@ def polygon(self): def clear(self): self._points = [] + self.box = None + self.update_visual() @property def count(self): @@ -258,4 +261,3 @@ def on_mouse_press(self, e): else: self.clear() self.box = None - self.update_visual() From fddfbb028e73650e9f863efeb42eeb44dc8bb68e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 16:36:20 +0100 Subject: [PATCH 1000/1059] Minor update --- phy/cluster/manual/views.py | 1 + phy/gui/actions.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 8b6c8cae0..58859649c 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -938,6 +938,7 @@ class FeatureView(ManualClusteringView): default_shortcuts = { 'increase': 'ctrl++', 'decrease': 'ctrl+-', + 'toggle_automatic_channel_selection': 'c', } def __init__(self, diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 35a5ec7d3..545edf01b 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -100,7 +100,7 @@ def _show_shortcuts(shortcuts, name=None): for name in sorted(shortcuts): shortcut = _get_shortcut_string(shortcuts[name]) if not name.startswith('_'): - print('{0:<40}: {1:s}'.format(name, shortcut)) + print('- {0:<40}: {1:s}'.format(name, shortcut)) # ----------------------------------------------------------------------------- From d95d92355dea5303d195a7ed89dce4be4745f4c2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 17:01:43 +0100 Subject: [PATCH 1001/1059] Fixes for split --- phy/cluster/manual/controller.py | 4 ++-- phy/cluster/manual/views.py | 4 ++-- phy/io/array.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 48b94bf9c..530f3a4e7 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -167,10 +167,10 @@ def get_waveforms_amplitude(self, cluster_id): # ------------------------------------------------------------------------- # Is cached in _init_context() - def get_features(self, cluster_id): + def get_features(self, cluster_id, load_all=False): return self._select_data(cluster_id, self.all_features, - 1000, # TODO + 1000 if not load_all else None, # TODO ) def get_background_features(self): diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 58859649c..d592984d6 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1234,7 +1234,7 @@ def on_request_split(self): x_dim, y_dim = _dimensions_matrix(self.channels, n_cols=self.n_cols, top_left_attribute=tla) - data = self.features(self.cluster_ids) + data = self.features(self.cluster_ids, load_all=True) spike_ids = data.spike_ids f = data.data i, j = self.lasso.box @@ -1242,7 +1242,7 @@ def on_request_split(self): # TODO: refactor and load all features. x = self._get_feature(x_dim[i, j], spike_ids, f) y = self._get_feature(y_dim[i, j], spike_ids, f) - pos = np.c_[x, y] + pos = np.c_[x, y].astype(np.float64) # Retrieve the data bounds. data_bounds = self._get_dim_bounds(x_dim[i, j], y_dim[i, j]) diff --git a/phy/io/array.py b/phy/io/array.py index 7694fa1dc..b6439e6b4 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -216,12 +216,12 @@ def concat_per_cluster(f): """Take a function accepting a single cluster, and return a function accepting multiple clusters.""" @wraps(f) - def wrapped(cluster_ids): + def wrapped(cluster_ids, **kwargs): # Single cluster. if not hasattr(cluster_ids, '__len__'): - return f(cluster_ids) + return f(cluster_ids, **kwargs) # Concatenate the result of multiple clusters. - l = [f(c) for c in cluster_ids] + l = [f(c, **kwargs) for c in cluster_ids] # Handle the case where every function returns a list of Bunch. if l and isinstance(l[0], list): # We assume that all items have the same length. From 61431a7e4f33782c7337ba5402c731056f918694 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 9 Feb 2016 17:05:55 +0100 Subject: [PATCH 1002/1059] Improve selection after split --- phy/cluster/manual/gui_component.py | 8 +++++++- phy/cluster/manual/tests/test_gui_component.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 92d190f8b..4de209d9a 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -375,7 +375,13 @@ def on_cluster(self, up): self.cluster_view.select(clusters_0) self.similarity_view.select(clusters_1) elif up.added: - self.select(up.added) + if up.description == 'assign': + # NOTE: we reverse the order such that the last selected + # cluster (with a new color) is the split cluster. + added = up.added[::-1] + else: + added = up.added + self.select(added) if similar: self.similarity_view.next() elif up.metadata_changed: diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index bfb202faf..aebf0625d 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -146,7 +146,7 @@ def test_manual_clustering_split_2(gui, quality, similarity): mc.set_default_sort('quality', 'desc') mc.split([0]) - assert mc.selected == [2, 3] + assert mc.selected == [3, 2] def test_manual_clustering_split_lasso(tempdir, qtbot): From fe2b51a913180db07dec425fa68bda589546acc3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 10 Feb 2016 18:18:13 +0100 Subject: [PATCH 1003/1059] Fix subsample issue in trace view --- phy/cluster/manual/views.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d592984d6..6ce527c37 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -631,7 +631,7 @@ def __init__(self, # Sample rate. assert sample_rate > 0 - self.sample_rate = sample_rate + self.sample_rate = float(sample_rate) self.dt = 1. / self.sample_rate # Traces and spikes. @@ -716,6 +716,10 @@ def _plot_spike(self, waveforms=None, channels=None, masks=None, def _restrict_interval(self, interval): start, end = interval + # Round the times to full samples to avoid subsampling shifts + # in the traces. + start = int(round(start * self.sample_rate)) / self.sample_rate + end = int(round(end * self.sample_rate)) / self.sample_rate # Restrict the interval to the boundaries of the traces. if start < 0: end += (-start) @@ -1302,7 +1306,7 @@ def __init__(self, **kwargs): assert sample_rate > 0 - self.sample_rate = sample_rate + self.sample_rate = float(sample_rate) self.spike_times = np.asarray(spike_times) self.n_spikes, = self.spike_times.shape From fbd5b814c8cc3ba680f1a7452ad3e6a845427f81 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 11 Feb 2016 08:31:41 +0100 Subject: [PATCH 1004/1059] Update dimensions in feature view --- .../manual/tests/test_gui_component.py | 8 +- phy/cluster/manual/views.py | 78 ++++++++----------- 2 files changed, 37 insertions(+), 49 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index aebf0625d..424eedee5 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -165,10 +165,10 @@ def test_manual_clustering_split_lasso(tempdir, qtbot): # Simulate a lasso. ev = view.events - ev.mouse_press(pos=(190, 10), button=1, modifiers=(keys.CONTROL,)) - ev.mouse_press(pos=(200, 10), button=1, modifiers=(keys.CONTROL,)) - ev.mouse_press(pos=(200, 30), button=1, modifiers=(keys.CONTROL,)) - ev.mouse_press(pos=(190, 30), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(210, 10), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(280, 10), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(280, 30), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(210, 30), button=1, modifiers=(keys.CONTROL,)) ups = [] diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 6ce527c37..796077f4e 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -7,7 +7,10 @@ # Imports # ----------------------------------------------------------------------------- +import inspect +from itertools import product import logging +import re import numpy as np from matplotlib.colors import hsv_to_rgb, rgb_to_hsv @@ -883,41 +886,30 @@ def decrease(self): # ----------------------------------------------------------------------------- def _dimensions_matrix(channels, n_cols=None, top_left_attribute=None): - """Dimension matrix.""" - # time, attr time, (x, 0) time, (y, 0) time, (z, 0) - # time, (x, 1) (x, 0), (x, 1) (x, 0), (y, 0) (x, 0), (z, 0) - # time, (y, 1) (x, 1), (y, 1) (y, 0), (y, 1) (y, 0), (z, 0) - # time, (z, 1) (x, 1), (z, 1) (y, 1), (z, 1) (z, 0), (z, 1) - - assert n_cols > 0 - assert len(channels) >= n_cols - 1 - - y_dim = {} - x_dim = {} - x_dim[0, 0] = 'time' - y_dim[0, 0] = top_left_attribute or 'time' - - for i in range(1, n_cols): - # First line. - x_dim[0, i] = 'time' - y_dim[0, i] = (channels[i - 1], 0) - # First column. - x_dim[i, 0] = 'time' - y_dim[i, 0] = (channels[i - 1], 1) - # Diagonal. - x_dim[i, i] = (channels[i - 1], 0) - y_dim[i, i] = (channels[i - 1], 1) - - for i in range(1, n_cols): - for j in range(i + 1, n_cols): - assert j > i - # Above the diagonal. - x_dim[i, j] = (channels[i - 1], 0) - y_dim[i, j] = (channels[j - 1], 0) - # Below the diagonal. - x_dim[j, i] = (channels[i - 1], 1) - y_dim[j, i] = (channels[j - 1], 1) - + """ + time,x0 y0,x0 x1,x0 y1,x0 + x0,y0 time,y0 x1,y0 y1,y0 + x0,x1 y0,x1 time,x1 y1,x1 + x0,y1 y0,y1 x1,y1 time,y1 + """ + # Generate the dimensions matrix from the docstring. + ds = inspect.getdoc(_dimensions_matrix).strip() + x, y = channels[:2] + + def _get_dim(d): + if d == 'time': + return d + assert re.match(r'[xy][01]', d) + c = x if d[0] == 'x' else y + f = int(d[1]) + return c, f + + dims = [[_.split(',') for _ in re.split(r' +', line.strip())] + for line in ds.splitlines()] + x_dim = {(i, j): _get_dim(dims[i][j][0]) + for i, j in product(range(4), range(4))} + y_dim = {(i, j): _get_dim(dims[i][j][1]) + for i, j in product(range(4), range(4))} return x_dim, y_dim @@ -983,7 +975,7 @@ def __init__(self, assert spike_times.shape == (self.n_spikes,) assert self.n_spikes >= 0 - self.n_cols = self.n_features_per_channel + 1 + self.n_cols = 4 self.shape = (self.n_cols, self.n_cols) # Initialize the view. @@ -1085,7 +1077,7 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, def _get_channel_dims(self, cluster_ids): """Select the channels to show by default.""" - n = self.n_cols - 1 + n = 2 channels = self.best_channels(cluster_ids) channels = (channels if channels is not None else list(range(self.n_channels))) @@ -1214,18 +1206,14 @@ def state(self): def on_channel_click(self, channel_idx=None, key=None, button=None): """Respond to the click on a channel.""" - if key is None or not (1 <= key <= (self.n_cols - 1)): - return - # Get the axis from the pressed button (1, 2, etc.) - # axis = 'x' if button == 1 else 'y' - # Get the existing channels. channels = self.channels if channels is None: return - assert len(channels) == self.n_cols - 1 + assert len(channels) == 2 assert 0 <= channel_idx < self.n_channels - # Update the channel. - channels[key - 1] = channel_idx + # Get the axis from the pressed button (1, 2, etc.) + # axis = 'x' if button == 1 else 'y' + channels[0 if button == 1 else 1] = channel_idx self.fixed_channels = True self.on_select() From 78a2a687077a08b55b556b40bdd41df78a17e9d3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 11 Feb 2016 09:09:50 +0100 Subject: [PATCH 1005/1059] WIP: persist memcache on disk with pickle --- phy/io/context.py | 41 +++++++++++++++++++++++++++--------- phy/io/tests/test_context.py | 14 +++--------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index 603dab81e..6a2896761 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -14,7 +14,7 @@ from traitlets.config.configurable import Configurable import numpy as np -from six.moves.cPickle import dump +from six.moves.cPickle import dump, load from six import string_types try: from dask.array import Array @@ -156,6 +156,11 @@ def __init__(self, cache_dir, ipy_view=None, verbose=0): logger.debug("Create cache directory `%s`.", self.cache_dir) os.makedirs(self.cache_dir) + # Ensure the memcache directory exists. + path = op.join(self.cache_dir, 'memcache') + if not op.exists(path): + os.mkdir(path) + self._set_memory(self.cache_dir) self.ipy_view = ipy_view if ipy_view else None self._memcache = {} @@ -201,25 +206,41 @@ def cache(self, f): disk_cached = self._memory.cache(f, ignore=ignore) return disk_cached + def load_memcache(self, name): + # Load the memcache from disk, if it exists. + path = op.join(self.cache_dir, 'memcache', name + '.pkl') + if op.exists(path): + logger.debug("Load memcache for `%s`.", name) + with open(path, 'rb') as fd: + cache = load(fd) + else: + cache = {} + self._memcache[name] = cache + return cache + + def save_memcache(self): + for name, cache in self._memcache.items(): + path = op.join(self.cache_dir, 'memcache', name + '.pkl') + logger.debug("Save memcache for `%s`.", name) + with open(path, 'wb') as fd: + dump(cache, fd) + def memcache(self, f): from joblib import hash name = _fullname(f) - # Create the cache dictionary for the function. - if name not in self._memcache: - self._memcache[name] = {} - c = self._memcache[name] + cache = self.load_memcache(name) @wraps(f) def memcached(*args, **kwargs): """Cache the function in memory.""" h = hash((args, kwargs)) - if h in c: - # logger.debug("Get %s(%s) from memcache.", name, str(args)) - return c[h] + if h in cache: + logger.debug("Get %s(%s) from memcache.", name, str(args)) + return cache[h] else: - # logger.debug("Compute %s(%s).", name, str(args)) + logger.debug("Compute %s(%s).", name, str(args)) out = f(*args, **kwargs) - c[h] = out + cache[h] = out return out return memcached diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index c87bcaa43..53259d9a0 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -136,7 +136,6 @@ def test_context_memcache(tempdir, context): _res = [] @context.memcache - @context.cache def f(x): _res.append(x) return x ** 2 @@ -151,21 +150,14 @@ def f(x): assert len(_res) == 1 # We artificially clear the memory cache. - context._memcache[_fullname(f)].clear() + context.save_memcache() + del context._memcache[_fullname(f)] + context.load_memcache(_fullname(f)) # This time, the result is loaded from disk. ae(f(x), x ** 2) assert len(_res) == 1 - # Remove the cache directory. - assert context.cache_dir.replace('/private', '').startswith(tempdir) - shutil.rmtree(context.cache_dir) - context._memcache[_fullname(f)].clear() - - # Now, the result is re-computed. - ae(f(x), x ** 2) - assert len(_res) == 2 - def test_pickle_cache(tempdir, parallel_context): """Make sure the Context is picklable.""" From 7b072a7e8b82864d4ea4efff36b5168dd13c7c7f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 11 Feb 2016 09:13:25 +0100 Subject: [PATCH 1006/1059] WIP: update memcached methods --- phy/cluster/manual/controller.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 530f3a4e7..23d00f718 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -79,22 +79,19 @@ def _init_context(self): self.get_masks = concat_per_cluster(ctx.cache(self.get_masks)) self.get_features = concat_per_cluster(ctx.cache(self.get_features)) self.get_waveforms = concat_per_cluster(ctx.cache(self.get_waveforms)) - self.get_background_features = ctx.cache(self.get_background_features) self.get_mean_masks = ctx.memcache(self.get_mean_masks) self.get_mean_features = ctx.memcache(self.get_mean_features) self.get_mean_waveforms = ctx.memcache(self.get_mean_waveforms) - self.get_waveform_lims = ctx.cache(self.get_waveform_lims) - self.get_feature_lim = ctx.cache(self.get_feature_lim) + self.get_waveform_lims = ctx.memcache(self.get_waveform_lims) + self.get_feature_lim = ctx.memcache(self.get_feature_lim) - self.get_waveform_amplitude = ctx.memcache(ctx.cache( - self.get_waveforms_amplitude)) - self.get_best_channel_position = ctx.memcache( - self.get_best_channel_position) - self.get_close_clusters = ctx.memcache(ctx.cache( - self.get_close_clusters)) + self.get_close_clusters = ctx.memcache( + self.get_close_clusters) + self.get_probe_depth = ctx.memcache( + self.get_probe_depth) self.spikes_per_cluster = ctx.memcache(self.spikes_per_cluster) From 76a35f22077777394aa4a26fc87c744951bbf415 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 11 Feb 2016 09:14:51 +0100 Subject: [PATCH 1007/1059] Remove debug --- phy/io/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/io/context.py b/phy/io/context.py index 6a2896761..949147f02 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -235,10 +235,10 @@ def memcached(*args, **kwargs): """Cache the function in memory.""" h = hash((args, kwargs)) if h in cache: - logger.debug("Get %s(%s) from memcache.", name, str(args)) + # logger.debug("Get %s(%s) from memcache.", name, str(args)) return cache[h] else: - logger.debug("Compute %s(%s).", name, str(args)) + # logger.debug("Compute %s(%s).", name, str(args)) out = f(*args, **kwargs) cache[h] = out return out From 52bb9a7900181482ab6b5a6adbb776821ed944be Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 11 Feb 2016 09:15:11 +0100 Subject: [PATCH 1008/1059] Flakify --- phy/io/tests/test_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 53259d9a0..0964ee885 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -8,7 +8,6 @@ import os import os.path as op -import shutil import numpy as np from numpy.testing import assert_array_equal as ae From 6b792cf3c22d9726b6dc3814b871c5acc8e046b9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 12 Feb 2016 12:24:31 +0100 Subject: [PATCH 1009/1059] WIP: improve table widget sort implementation --- phy/gui/static/table.js | 60 ++++++++++++++++++++++++----------------- phy/gui/widgets.py | 26 +++++++++--------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index f9818f27c..7ed693e57 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -11,38 +11,35 @@ function isFloat(n) { return n === Number(n) && n % 1 !== 0; } +function clear(e) { + while (e.firstChild) { + e.removeChild(e.firstChild); + } +} -// Table class. +// Table class. var Table = function (el) { this.el = el; this.selected = []; this.headers = {}; // {name: th} mapping this.rows = {}; // {id: tr} mapping - this.tablesort = null; -}; -Table.prototype.setData = function(data) { - /* - data.cols: list of column names - data.items: list of rows (each row is an object {col: value}) - */ - // if (data.items.length == 0) return; + var thead = document.createElement("thead"); + this.el.appendChild(thead); - // Reinitialize the state. - this.selected = []; + var tbody = document.createElement("tbody"); + this.el.appendChild(tbody); +}; + +Table.prototype.setHeaders = function(data) { this.rows = {}; var that = this; var keys = data.cols; - // Clear the table. - while (this.el.firstChild) { - this.el.removeChild(this.el.firstChild); - } - - var thead = document.createElement("thead"); - var tbody = document.createElement("tbody"); + var thead = this.el.getElementsByTagName("thead")[0]; + clear(thead); // Header. var tr = document.createElement("tr"); @@ -54,6 +51,27 @@ Table.prototype.setData = function(data) { this.headers[key] = th; } thead.appendChild(tr); + + // Enable the tablesort plugin. + this.tablesort = new Tablesort(this.el); +} + +Table.prototype.setData = function(data) { + /* + data.cols: list of column names + data.items: list of rows (each row is an object {col: value}) + */ + // if (data.items.length == 0) return; + + // Reinitialize the state. + this.selected = []; + this.rows = {}; + var keys = data.cols; + var that = this; + + // Clear the table body. + var tbody = this.el.getElementsByTagName("tbody")[0]; + clear(tbody); this.nrows = data.items.length; // Data rows. @@ -114,12 +132,6 @@ Table.prototype.setData = function(data) { tbody.appendChild(tr); this.rows[data.items[i].id] = tr; } - - this.el.appendChild(thead); - this.el.appendChild(tbody); - - // Enable the tablesort plugin. - this.tablesort = new Tablesort(this.el); }; Table.prototype.rowId = function(i) { diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 51dbd4588..92d149bd5 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -11,6 +11,7 @@ import json import logging import os.path as op +from operator import itemgetter from six import text_type @@ -235,12 +236,17 @@ def add_column(self, func, name=None, show=True): assert func name = name or func.__name__ if name == '': - raise ValueError("Please provide a valid name for " + - name) + raise ValueError("Please provide a valid name for " + name) d = {'func': func, 'show': show, } self._columns[name] = d + + # Update the headers in the widget. + data = _create_json_dict(cols=self.column_names, + ) + self.eval_js('table.setHeaders({});'.format(data)) + return func @property @@ -257,26 +263,22 @@ def set_rows(self, ids): """Set the rows of the table.""" # NOTE: make sure we have integers and not np.generic objects. assert all(isinstance(i, int) for i in ids) - - # Determine the sort column and dir to set after the rows. - sort_col, sort_dir = self.current_sort default_sort_col, default_sort_dir = self.default_sort - sort_col = sort_col or default_sort_col - sort_dir = sort_dir or default_sort_dir or 'desc' - # Set the rows. logger.log(5, "Set %d rows in the table.", len(ids)) items = [self._get_row(id) for id in ids] + + # Optionally sort the rows before passing them to the widget. + if default_sort_col: + items = sorted(items, key=itemgetter(default_sort_col), + reverse=(default_sort_dir == 'desc')) + data = _create_json_dict(items=items, cols=self.column_names, ) self.eval_js('table.setData({});'.format(data)) - # Sort. - if sort_col: - self.sort_by(sort_col, sort_dir) - def sort_by(self, name, sort_dir='asc'): """Sort by a given variable.""" logger.log(5, "Sort by `%s` %s.", name, sort_dir) From a4d5657074c1011f66a424f76ed3d4a864b430e6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 12 Feb 2016 12:25:56 +0100 Subject: [PATCH 1010/1059] Fixing sort in cluster views --- phy/cluster/manual/controller.py | 2 +- phy/cluster/manual/gui_component.py | 12 +++++++++--- phy/cluster/manual/tests/test_gui_component.py | 1 + 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 23d00f718..d19fb1926 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -306,5 +306,5 @@ def set_manual_clustering(self, gui): cluster_groups=self.cluster_groups, ) self.manual_clustering = mc - mc.add_column(self.get_probe_depth) + mc.add_column(self.get_probe_depth, name='probe_depth') mc.attach(gui) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 4de209d9a..d9245cb18 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -417,9 +417,15 @@ def attach(self, gui): # Add the quality column in the cluster view. if self.quality: self.cluster_view.add_column(self.quality, - name=self.quality.__name__) - self.set_default_sort(self.quality.__name__ - if self.quality else 'n_spikes') + name=self.quality.__name__, + ) + + # Update the cluster view and sort by n_spikes at the beginning. + self._update_cluster_view() + # if not self.quality: + # self.cluster_view.sort_by('n_spikes', 'desc') + + # Add the similarity view if there is a similarity function. if self.similarity: gui.add_view(self.similarity_view, name='SimilarityView') diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 424eedee5..86714ab3b 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -50,6 +50,7 @@ def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, similarity=similarity, ) mc.attach(gui) + mc.set_default_sort(quality.__name__) return mc From 1c17f13501d33dfa1fffcd177c5dfad78209f437 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 12 Feb 2016 13:28:11 +0100 Subject: [PATCH 1011/1059] WIP: sort in table widget --- phy/gui/static/table.js | 9 ++++++--- phy/gui/widgets.py | 21 ++++++++++++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 7ed693e57..0a07927ab 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -24,6 +24,7 @@ var Table = function (el) { this.selected = []; this.headers = {}; // {name: th} mapping this.rows = {}; // {id: tr} mapping + this.cols = []; var thead = document.createElement("thead"); this.el.appendChild(thead); @@ -37,6 +38,7 @@ Table.prototype.setHeaders = function(data) { var that = this; var keys = data.cols; + this.cols = data.cols; var thead = this.el.getElementsByTagName("thead")[0]; clear(thead); @@ -148,9 +150,9 @@ Table.prototype.sortBy = function(header, dir) { throw "The column `" + header + "` doesn't exist." // Remove all sort classes. - for (var i = 0; i < this.headers.length; i++) { - this.headers[i].classList.remove('sort-up'); - this.headers[i].classList.remove('sort-down'); + for (var i = 0; i < this.cols.length; i++) { + var name = this.cols[i]; + this.headers[name].classList = ""; } // Add sort. @@ -158,6 +160,7 @@ Table.prototype.sortBy = function(header, dir) { if (dir == 'desc') { this.tablesort.sortTable(this.headers[header]); } + }; Table.prototype.currentSort = function() { diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 92d149bd5..4c675c580 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -11,7 +11,6 @@ import json import logging import os.path as op -from operator import itemgetter from six import text_type @@ -263,22 +262,30 @@ def set_rows(self, ids): """Set the rows of the table.""" # NOTE: make sure we have integers and not np.generic objects. assert all(isinstance(i, int) for i in ids) + + # Determine the sort column and dir to set after the rows. + sort_col, sort_dir = self.current_sort default_sort_col, default_sort_dir = self.default_sort + sort_col = sort_col or default_sort_col + sort_dir = sort_dir or default_sort_dir or 'desc' + # Set the rows. logger.log(5, "Set %d rows in the table.", len(ids)) items = [self._get_row(id) for id in ids] - - # Optionally sort the rows before passing them to the widget. - if default_sort_col: - items = sorted(items, key=itemgetter(default_sort_col), - reverse=(default_sort_dir == 'desc')) - + # Sort the rows before passing them to the widget. + # if sort_col: + # items = sorted(items, key=itemgetter(sort_col), + # reverse=(sort_dir == 'desc')) data = _create_json_dict(items=items, cols=self.column_names, ) self.eval_js('table.setData({});'.format(data)) + # Sort. + if sort_col: + self.sort_by(sort_col, sort_dir) + def sort_by(self, name, sort_dir='asc'): """Sort by a given variable.""" logger.log(5, "Sort by `%s` %s.", name, sort_dir) From 012ff85de3985d0ab5d89e001836220ed01a67f7 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 12 Feb 2016 13:42:08 +0100 Subject: [PATCH 1012/1059] Update tablesort to 4.0 --- phy/gui/static/table.js | 10 +++++----- phy/gui/static/tablesort.min.js | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js index 0a07927ab..f2654bdbb 100644 --- a/phy/gui/static/table.js +++ b/phy/gui/static/table.js @@ -152,15 +152,15 @@ Table.prototype.sortBy = function(header, dir) { // Remove all sort classes. for (var i = 0; i < this.cols.length; i++) { var name = this.cols[i]; - this.headers[name].classList = ""; + this.headers[name].classList.remove("sort-up"); + this.headers[name].classList.remove("sort-down"); } + var order = (dir == 'asc') ? "sort-up" : "sort-down"; + this.headers[header].classList.add(order); + // Add sort. this.tablesort.sortTable(this.headers[header]); - if (dir == 'desc') { - this.tablesort.sortTable(this.headers[header]); - } - }; Table.prototype.currentSort = function() { diff --git a/phy/gui/static/tablesort.min.js b/phy/gui/static/tablesort.min.js index 33ded4fbb..0d1fdb157 100644 --- a/phy/gui/static/tablesort.min.js +++ b/phy/gui/static/tablesort.min.js @@ -1,5 +1,5 @@ /*! - * tablesort v3.1.0 (2015-07-03) + * tablesort v4.0.0 (2015-12-17) * http://tristen.ca/tablesort/demo/ * Copyright (c) 2015 ; Licensed MIT -*/!function(){function a(b,c){if(!(this instanceof a))return new a(b,c);if(!b||"TABLE"!==b.tagName)throw new Error("Element must be a table");this.init(b,c||{})}var b=[],c=function(a){var b;return window.CustomEvent&&"function"==typeof window.CustomEvent?b=new CustomEvent(a):(b=document.createEvent("CustomEvent"),b.initCustomEvent(a,!1,!1,void 0)),b},d=function(a){return a.getAttribute("data-sort")||a.textContent||a.innerText||""},e=function(a,b){return a=a.toLowerCase(),b=b.toLowerCase(),a===b?0:b>a?1:-1},f=function(a,b){return function(c,d){var e=a(c.td,d.td);return 0===e?b?d.index-c.index:c.index-d.index:e}};a.extend=function(a,c,d){if("function"!=typeof c||"function"!=typeof d)throw new Error("Pattern and sort must be a function");b.push({name:a,pattern:c,sort:d})},a.prototype={init:function(a,b){var c,d,e,f,g=this;if(g.table=a,g.thead=!1,g.options=b,a.rows&&a.rows.length>0&&(a.tHead&&a.tHead.rows.length>0?(c=a.tHead.rows[a.tHead.rows.length-1],g.thead=!0):c=a.rows[0]),c){var h=function(){g.current&&g.current!==this&&(g.current.classList.remove("sort-up"),g.current.classList.remove("sort-down")),g.current=this,g.sortTable(this)};for(e=0;e0&&m.push(l),n++;if(!m)return}for(n=0;nn;n++)r[n]?(l=r[n],t++):l=q[n-t].tr,i.table.tBodies[0].appendChild(l);i.table.dispatchEvent(c("afterSort"))}},refresh:function(){void 0!==this.current&&this.sortTable(this.current,!0)}},"undefined"!=typeof module&&module.exports?module.exports=a:window.Tablesort=a}(); \ No newline at end of file +*/!function(){function a(b,c){if(!(this instanceof a))return new a(b,c);if(!b||"TABLE"!==b.tagName)throw new Error("Element must be a table");this.init(b,c||{})}var b=[],c=function(a){var b;return window.CustomEvent&&"function"==typeof window.CustomEvent?b=new CustomEvent(a):(b=document.createEvent("CustomEvent"),b.initCustomEvent(a,!1,!1,void 0)),b},d=function(a){return a.getAttribute("data-sort")||a.textContent||a.innerText||""},e=function(a,b){return a=a.toLowerCase(),b=b.toLowerCase(),a===b?0:b>a?1:-1},f=function(a,b){return function(c,d){var e=a(c.td,d.td);return 0===e?b?d.index-c.index:c.index-d.index:e}};a.extend=function(a,c,d){if("function"!=typeof c||"function"!=typeof d)throw new Error("Pattern and sort must be a function");b.push({name:a,pattern:c,sort:d})},a.prototype={init:function(a,b){var c,d,e,f,g=this;if(g.table=a,g.thead=!1,g.options=b,a.rows&&a.rows.length>0&&(a.tHead&&a.tHead.rows.length>0?(c=a.tHead.rows[a.tHead.rows.length-1],g.thead=!0):c=a.rows[0]),c){var h=function(){g.current&&g.current!==this&&(g.current.classList.remove("sort-up"),g.current.classList.remove("sort-down")),g.current=this,g.sortTable(this)};for(e=0;e0&&m.push(l),n++;if(!m)return}for(n=0;nn;n++)s[n]?(l=s[n],u++):l=r[n-u].tr,i.table.tBodies[0].appendChild(l);i.table.dispatchEvent(c("afterSort"))}},refresh:function(){void 0!==this.current&&this.sortTable(this.current,!0)}},"undefined"!=typeof module&&module.exports?module.exports=a:window.Tablesort=a}(); \ No newline at end of file From 43a1c8e76661ca3ed28bb2c7261695406b97751e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 17 Feb 2016 17:58:01 +0100 Subject: [PATCH 1013/1059] Update constrain bounds in plot --- phy/plot/plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 68f25cca8..426afcef1 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -82,7 +82,8 @@ def __init__(self, layout=None, shape=None, n_plots=None, origin=None, else: self.interact = None - self.panzoom = PanZoom(aspect=None, constrain_bounds=NDC) + self.panzoom = PanZoom(aspect=None, + constrain_bounds=[-2, -2, +2, +2]) self.panzoom.attach(self) if enable_lasso: From 5ac196a20d9504f0903171297027a936a642b964 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 17 Feb 2016 18:03:00 +0100 Subject: [PATCH 1014/1059] Flakify --- phy/plot/plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/phy/plot/plot.py b/phy/plot/plot.py index 426afcef1..48cbfaf4e 100644 --- a/phy/plot/plot.py +++ b/phy/plot/plot.py @@ -17,7 +17,6 @@ from .base import BaseCanvas from .interact import Grid, Boxed, Stacked from .panzoom import PanZoom -from .transform import NDC from .utils import _get_array from .visuals import (ScatterVisual, PlotVisual, HistogramVisual, LineVisual, TextVisual, PolygonVisual) From 0fddbe5e9faea0cb06a3e3e3c085725fa5f478b8 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Mar 2016 15:52:42 +0100 Subject: [PATCH 1015/1059] Save the new cluster id in the cache --- phy/cluster/manual/clustering.py | 10 +++++++--- phy/cluster/manual/controller.py | 18 +++++++++++++++++- phy/cluster/manual/gui_component.py | 4 +++- phy/io/context.py | 2 +- phy/io/tests/test_context.py | 2 +- 5 files changed, 29 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index 1948c0a8e..f95a96eea 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -146,14 +146,18 @@ class Clustering(EventEmitter): """ - def __init__(self, spike_clusters): + def __init__(self, spike_clusters, new_cluster_id=None): super(Clustering, self).__init__() self._undo_stack = History(base_item=(None, None, None)) # Spike -> cluster mapping. self._spike_clusters = _as_array(spike_clusters) self._n_spikes = len(self._spike_clusters) self._spike_ids = np.arange(self._n_spikes).astype(np.int64) - self._new_cluster_id = self._spike_clusters.max() + 1 + self._new_cluster_id_0 = (new_cluster_id or + self._spike_clusters.max() + 1) + self._new_cluster_id = self._new_cluster_id_0 + assert self._new_cluster_id >= 0 + assert np.all(self._spike_clusters < self._new_cluster_id) # Keep a copy of the original spike clusters assignment. self._spike_clusters_base = self._spike_clusters.copy() @@ -165,7 +169,7 @@ def reset(self): """ self._undo_stack.clear() self._spike_clusters = self._spike_clusters_base - self._new_cluster_id = self._spike_clusters.max() + 1 + self._new_cluster_id = self._new_cluster_id_0 @property def spike_clusters(self): diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index d19fb1926..fc7658d09 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -299,12 +299,28 @@ def add_correlogram_view(self, gui): # GUI methods # ------------------------------------------------------------------------- + def similarity(self, cluster_id): + return self.get_close_clusters(cluster_id) + def set_manual_clustering(self, gui): + # Load the new cluster id. + new_cluster_id = self.context.load('new_cluster_id'). \ + get('new_cluster_id', None) mc = ManualClustering(self.spike_clusters, self.spikes_per_cluster, - similarity=self.get_close_clusters, + similarity=self.similarity, cluster_groups=self.cluster_groups, + new_cluster_id=new_cluster_id, ) + + # Save the new cluster id on disk. + @mc.clustering.connect + def on_cluster(up): + new_cluster_id = mc.clustering.new_cluster_id() + logger.debug("Save the new cluster id: %d", new_cluster_id) + self.context.save('new_cluster_id', + dict(new_cluster_id=new_cluster_id)) + self.manual_clustering = mc mc.add_column(self.get_probe_depth, name='probe_depth') mc.attach(gui) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index d9245cb18..e93674556 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -128,6 +128,7 @@ def __init__(self, shortcuts=None, quality=None, similarity=None, + new_cluster_id=None, ): self.gui = None @@ -142,7 +143,8 @@ def __init__(self, self.shortcuts.update(shortcuts or {}) # Create Clustering and ClusterMeta. - self.clustering = Clustering(spike_clusters) + self.clustering = Clustering(spike_clusters, + new_cluster_id=new_cluster_id) self.cluster_meta = create_cluster_meta(cluster_groups) self._global_history = GlobalHistory(process_ups=_process_ups) self._register_logging() diff --git a/phy/io/context.py b/phy/io/context.py index 949147f02..fe4befa5f 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -330,7 +330,7 @@ def load(self, name, location='local'): path = self._get_path(name, location) if not op.exists(path): logger.debug("The file `%s` doesn't exist.", path) - return + return {} return _load_json(path) def __getstate__(self): diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index 0964ee885..f7236ad3a 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -96,7 +96,7 @@ def test_read_write(tempdir): def test_context_load_save(tempdir, context, temp_phy_user_dir): - assert context.load('unexisting') is None + assert not context.load('unexisting') context.save('a/hello', {'text': 'world'}) assert context.load('a/hello')['text'] == 'world' From 517e23e32b726dce40f611ecc3d0f6e70079f3f2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Mar 2016 16:07:12 +0100 Subject: [PATCH 1016/1059] WIP: fixing travis --- environment.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/environment.yml b/environment.yml index a2b8dec3a..be01624ca 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,8 @@ channels: - kwikteam dependencies: - python - - numpy=1.9 + - pip + - numpy - vispy - matplotlib - scipy @@ -15,6 +16,8 @@ dependencies: - six - ipyparallel - joblib - - dask - cython - click + - cloudpickle + - toolz + - dill From 306fb06f68d89d3c7bd1c14feb3cde2c670401d3 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Mar 2016 16:18:57 +0100 Subject: [PATCH 1017/1059] Add dask dependency --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index be01624ca..eb45a3e29 100644 --- a/environment.yml +++ b/environment.yml @@ -18,6 +18,7 @@ dependencies: - joblib - cython - click + - dask - cloudpickle - toolz - dill From 28c78a9b1790ca37c6239f11abff6a96d77f0712 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 1 Mar 2016 16:30:42 +0100 Subject: [PATCH 1018/1059] WIP --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index eb45a3e29..9adcdf1a6 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: dependencies: - python - pip - - numpy + - numpy=1.9 - vispy - matplotlib - scipy From 1394a8b7be65854d48c113e30dc181781bff6629 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 2 Mar 2016 15:10:16 +0100 Subject: [PATCH 1019/1059] WIP: refactor plugins --- phy/__init__.py | 2 +- phy/io/context.py | 8 ++------ phy/utils/__init__.py | 4 ++-- phy/utils/_misc.py | 5 +++++ phy/utils/cli.py | 18 +++++++++-------- phy/utils/config.py | 36 +++++++++++++++++++++++++++------- phy/utils/plugin.py | 31 ++++++----------------------- phy/utils/tests/test_cli.py | 2 +- phy/utils/tests/test_plugin.py | 35 ++------------------------------- 9 files changed, 58 insertions(+), 83 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index 73127b8aa..5086ab6ab 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -17,7 +17,7 @@ from .io.datasets import download_file, download_sample_data from .utils.config import load_master_config from .utils._misc import _git_version -from .utils.plugin import IPlugin, get_plugin, get_all_plugins +from .utils.plugin import IPlugin, get_plugin, discover_plugins #------------------------------------------------------------------------------ diff --git a/phy/io/context.py b/phy/io/context.py index fe4befa5f..a666cb0e3 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -25,7 +25,8 @@ "Install it with `conda install dask`.") from .array import read_array, write_array -from phy.utils import (Bunch, _save_json, _load_json, _ensure_dir_exists,) +from phy.utils import (Bunch, _save_json, _load_json, + _ensure_dir_exists, _fullname,) from phy.utils.config import phy_user_dir logger = logging.getLogger(__name__) @@ -141,11 +142,6 @@ def _ensure_cache_dirs_exist(cache_dir, name): os.makedirs(dirpath) -def _fullname(o): - """Return the fully-qualified name of a function.""" - return o.__module__ + "." + o.__name__ if o.__module__ else o.__name__ - - class Context(object): """Handle function cacheing and parallel map with ipyparallel.""" def __init__(self, cache_dir, ipy_view=None, verbose=0): diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 45ced54a5..c4648a542 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -3,10 +3,10 @@ """Utilities.""" -from ._misc import _load_json, _save_json +from ._misc import _load_json, _save_json, _fullname from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, _as_scalar, _as_scalars, Bunch, _is_list, _bunchify) from .event import EventEmitter, ProgressReporter -from .plugin import IPlugin, get_plugin, get_all_plugins +from .plugin import IPlugin, get_plugin from .config import _ensure_dir_exists, load_master_config, phy_user_dir diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index cb5c42ab9..1f11a6d43 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -105,6 +105,11 @@ def _save_json(path, data): # Various Python utility functions #------------------------------------------------------------------------------ +def _fullname(o): + """Return the fully-qualified name of a function.""" + return o.__module__ + "." + o.__name__ if o.__module__ else o.__name__ + + def _read_python(path): path = op.realpath(op.expanduser(path)) assert op.exists(path) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 7891df690..0033d53e0 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -17,7 +17,8 @@ import click from phy import (add_default_handler, DEBUG, _Formatter, _logger_fmt, - __version_git__) + __version_git__, discover_plugins) +from phy.utils import _fullname logger = logging.getLogger(__name__) @@ -73,21 +74,22 @@ def phy(ctx): # CLI plugins #------------------------------------------------------------------------------ -def load_cli_plugins(cli): +def load_cli_plugins(cli, user_dir=None): """Load all plugins and attach them to a CLI object.""" from .config import load_master_config - from .plugin import get_all_plugins - config = load_master_config() - plugins = get_all_plugins(config) + config = load_master_config(user_dir=user_dir) + plugins = discover_plugins(config.Plugins.dirs) - # TODO: try/except to avoid crashing if a plugin is broken. for plugin in plugins: if not hasattr(plugin, 'attach_to_cli'): # pragma: no cover continue - logger.debug("Attach plugin `%s` to CLI.", plugin.__name__) + logger.debug("Attach plugin `%s` to CLI.", _fullname(plugin)) # NOTE: plugin is a class, so we need to instantiate it. - plugin().attach_to_cli(cli) + try: + plugin().attach_to_cli(cli) + except Exception as e: + logger.error("Error when loading plugin `%s`: %s", plugin, e) # Load all plugins when importing this module. diff --git a/phy/utils/config.py b/phy/utils/config.py index 0e312cf7f..81b4f5d79 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -9,6 +9,7 @@ import logging import os import os.path as op +from textwrap import dedent from traitlets.config import (Config, PyFileConfigLoader, @@ -41,6 +42,7 @@ def _load_config(path): path = op.realpath(path) dirpath, filename = op.split(path) file_ext = op.splitext(path)[1] + logger.debug("Load config file `%s`.", path) if file_ext == '.py': config = PyFileConfigLoader(filename, dirpath, log=logger).load_config() @@ -50,15 +52,35 @@ def _load_config(path): return config +def _default_config(user_dir=None): + path = op.join(user_dir or '~/.phy/', 'plugins/') + return dedent(""" + # You can also put your plugins in ~/.phy/plugins/. + + from phy import IPlugin + + class MyPlugin(IPlugin): + def attach_to_cli(self, cli): + pass + + + c = get_config() + c.Plugins.dirs = ['{}'] + """.format(path)) + + def load_master_config(user_dir=None): - """Load a master Config file from `~/.phy/phy_config.py|json`.""" + """Load a master Config file from `~/.phy/phy_config.py`.""" user_dir = user_dir or phy_user_dir() - c = Config() - paths = [op.join(user_dir, 'phy_config.json'), - op.join(user_dir, 'phy_config.py')] - for path in paths: - c.update(_load_config(path)) - return c + path = op.join(user_dir, 'phy_config.py') + # Create a default config file if necessary. + if not op.exists(path): + _ensure_dir_exists(op.dirname(path)) + logger.debug("Creating default phy config file at `%s`.", path) + with open(path, 'w') as f: + f.write(_default_config(user_dir=user_dir)) + assert op.exists(path) + return _load_config(path) def save_config(path, config): diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py index 0da72720a..b38214d7d 100644 --- a/phy/utils/plugin.py +++ b/phy/utils/plugin.py @@ -18,7 +18,7 @@ from six import with_metaclass -from . import config +from ._misc import _fullname logger = logging.getLogger(__name__) @@ -32,10 +32,10 @@ class IPluginRegistry(type): def __init__(cls, name, bases, attrs): if name != 'IPlugin': - logger.debug("Register plugin `%s`.", name) - plugin_tuple = (cls,) - if plugin_tuple not in IPluginRegistry.plugins: - IPluginRegistry.plugins.append(plugin_tuple) + logger.debug("Register plugin `%s`.", _fullname(cls)) + if _fullname(cls) not in (_fullname(_) + for _ in IPluginRegistry.plugins): + IPluginRegistry.plugins.append(cls) class IPlugin(with_metaclass(IPluginRegistry)): @@ -49,7 +49,7 @@ class IPlugin(with_metaclass(IPluginRegistry)): def get_plugin(name): """Get a plugin class from its name.""" - for (plugin,) in IPluginRegistry.plugins: + for plugin in IPluginRegistry.plugins: if name in plugin.__name__: return plugin raise ValueError("The plugin %s cannot be found." % name) @@ -110,22 +110,3 @@ def discover_plugins(dirs): finally: file.close() return IPluginRegistry.plugins - - -def _builtin_plugins_dir(): - return op.realpath(op.join(op.dirname(__file__), '../plugins/')) - - -def _user_plugins_dir(): - return op.expanduser(op.join(config.phy_user_dir(), 'plugins/')) - - -def get_all_plugins(config=None): - """Load all builtin and user plugins.""" - # By default, builtin and default user plugin. - dirs = [_builtin_plugins_dir(), _user_plugins_dir()] - # Add Plugins.dirs from the optionally-passed config object. - if config and isinstance(config.Plugins.dirs, list): - dirs += config.Plugins.dirs - logger.debug("Discovering plugins in: %s.", ', '.join(dirs)) - return [plugin for (plugin,) in discover_plugins(dirs)] diff --git a/phy/utils/tests/test_cli.py b/phy/utils/tests/test_cli.py index a648f286a..99aed19e9 100644 --- a/phy/utils/tests/test_cli.py +++ b/phy/utils/tests/test_cli.py @@ -63,7 +63,7 @@ def hello(): # NOTE: make the import after the temp_user_dir fixture, to avoid # loading any user plugin affecting the CLI. from ..cli import phy, load_cli_plugins - load_cli_plugins(phy) + load_cli_plugins(phy, user_dir=temp_user_dir) # The plugin should have added a new command. result = runner.invoke(phy, ['--help']) diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index bd53fc8d5..a8119f84a 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -15,10 +15,8 @@ IPlugin, get_plugin, discover_plugins, - get_all_plugins, ) from .._misc import _write_text -from ..config import load_master_config #------------------------------------------------------------------------------ @@ -66,7 +64,7 @@ def test_plugin_1(no_native_plugins): class MyPlugin(IPlugin): pass - assert IPluginRegistry.plugins == [(MyPlugin,)] + assert IPluginRegistry.plugins == [MyPlugin] assert get_plugin('MyPlugin').__name__ == 'MyPlugin' with raises(ValueError): @@ -80,33 +78,4 @@ def test_discover_plugins(tempdir, no_native_plugins): plugins = discover_plugins([tempdir]) assert plugins - assert plugins[0][0].__name__ == 'MyPlugin' - - -def test_get_all_plugins(plugin): - temp_user_dir, in_default_dir, path = plugin - n_builtin_plugins = 0 - - plugins = get_all_plugins() - - def _assert_loaded(): - assert len(plugins) == n_builtin_plugins + 1 - p = plugins[-1] - assert p.__name__ == 'MyPlugin' - - if in_default_dir: - # Create a plugin in the default plugins directory: it will be - # discovered and automatically loaded by get_all_plugins(). - _assert_loaded() - else: - assert len(plugins) == n_builtin_plugins - - # This time, we write the custom plugins path in the config file. - _write_my_plugins_dir_in_config(temp_user_dir) - - # We reload all plugins with the master config object. - config = load_master_config() - plugins = get_all_plugins(config) - - # This time, the plugin should be found. - _assert_loaded() + assert plugins[0].__name__ == 'MyPlugin' From a0ea3fe4b8aa1104693ed899d034baa5e4cfba1c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 2 Mar 2016 15:30:30 +0100 Subject: [PATCH 1020/1059] Fix bug with conflicting plugin names between tests --- phy/gui/tests/test_gui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 071154721..e5c57965b 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -185,11 +185,11 @@ def test_create_gui_1(qapp, tempdir): _tmp = [] - class MyPlugin(IPlugin): + class MyGUIPlugin(IPlugin): def attach_to_gui(self, gui): _tmp.append(gui.state.hello) - gui = create_gui(plugins=['MyPlugin'], config_dir=tempdir) + gui = create_gui(plugins=['MyGUIPlugin'], config_dir=tempdir) assert gui assert _tmp == ['world'] gui.state.hello = 'dolly' From a04154a47f1c396305373d89c60ba7ae22166176 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 2 Mar 2016 16:19:54 +0100 Subject: [PATCH 1021/1059] WIP: remove create_gui() --- phy/gui/__init__.py | 2 +- phy/gui/gui.py | 49 +++++++++----------------------------- phy/gui/tests/conftest.py | 4 ++-- phy/gui/tests/test_gui.py | 50 ++++++++++----------------------------- 4 files changed, 26 insertions(+), 79 deletions(-) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index afbbbc979..bb3202e71 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -4,6 +4,6 @@ """GUI routines.""" from .qt import require_qt, create_app, run_app -from .gui import GUI, GUIState, create_gui +from .gui import GUI, GUIState from .actions import Actions from .widgets import HTMLWidget, Table diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 9d8639758..1a845f445 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -15,10 +15,10 @@ Qt, QSize, QMetaObject) from .actions import Actions, Snippets from phy.utils.event import EventEmitter -from phy.utils import (load_master_config, Bunch, _bunchify, +from phy.utils import (Bunch, _bunchify, _load_json, _save_json, _ensure_dir_exists, phy_user_dir,) -from phy.utils.plugin import get_plugin, IPlugin +from phy.utils.plugin import IPlugin logger = logging.getLogger(__name__) @@ -124,6 +124,7 @@ def __init__(self, size=None, name=None, subtitle=None, + **kwargs ): # HACK to ensure that closeEvent is called only twice (seems like a # Qt bug). @@ -163,6 +164,14 @@ def __init__(self, # Create and attach snippets. self.snippets = Snippets(self) + # Create the state. + self.state = GUIState(self.name, **kwargs) + + # Save the state to disk when closing the GUI. + @self.connect_ + def on_close(): + self.state.save() + def _set_name(self, name, subtitle): if name is None: name = self.__class__.__name__ @@ -408,39 +417,3 @@ def on_close(): def on_show(): gs = state.get('geometry_state', None) gui.restore_geometry_state(gs) - - -def create_gui(name=None, subtitle=None, plugins=None, **state_kwargs): - """Create a GUI with a list of plugins. - - By default, the list of plugins is taken from the `c.TheGUI.plugins` - parameter, where `TheGUI` is the name of the GUI class. - - """ - gui = GUI(name=name, subtitle=subtitle) - name = gui.name - plugins = plugins or [] - - # Create the state. - state = GUIState(gui.name, **state_kwargs) - - # Make the state. - gui.state = state - - # If no plugins are specified, load the master config and - # get the list of user plugins to attach to the GUI. - plugins_conf = load_master_config()[name].plugins - plugins_conf = plugins_conf if isinstance(plugins_conf, list) else [] - plugins.extend(plugins_conf) - - # Attach the plugins to the GUI. - for plugin in plugins: - logger.debug("Attach plugin `%s` to %s.", plugin, name) - get_plugin(plugin)().attach_to_gui(gui) - - # Save the state to disk. - @gui.connect_ - def on_close(): - state.save() - - return gui diff --git a/phy/gui/tests/conftest.py b/phy/gui/tests/conftest.py index e2a2e6bdf..0ad5a9d9f 100644 --- a/phy/gui/tests/conftest.py +++ b/phy/gui/tests/conftest.py @@ -17,8 +17,8 @@ #------------------------------------------------------------------------------ @yield_fixture -def gui(qapp): - gui = GUI(position=(200, 100), size=(100, 100)) +def gui(tempdir, qapp): + gui = GUI(position=(200, 100), size=(100, 100), config_dir=tempdir) yield gui gui.close() diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index e5c57965b..3967e4801 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -6,18 +6,15 @@ # Imports #------------------------------------------------------------------------------ -import os.path as op - from pytest import raises from ..qt import Qt, QApplication, QWidget from ..gui import (GUI, GUIState, - create_gui, _try_get_matplotlib_canvas, _try_get_vispy_canvas, SaveGeometryStatePlugin, ) -from phy.utils import IPlugin, Bunch, _save_json, _ensure_dir_exists +from phy.utils import Bunch from phy.utils._color import _random_color @@ -56,15 +53,15 @@ def test_matplotlib_view(): # Test GUI #------------------------------------------------------------------------------ -def test_gui_noapp(): +def test_gui_noapp(tempdir): if not QApplication.instance(): with raises(RuntimeError): # pragma: no cover - GUI() + GUI(config_dir=tempdir) -def test_gui_1(qtbot): +def test_gui_1(tempdir, qtbot): - gui = GUI(position=(200, 100), size=(100, 100)) + gui = GUI(position=(200, 100), size=(100, 100), config_dir=tempdir) qtbot.addWidget(gui) assert gui.name == 'GUI' @@ -115,9 +112,9 @@ def test_gui_status_message(gui): assert gui.status_message == '' -def test_gui_geometry_state(qtbot): +def test_gui_geometry_state(tempdir, qtbot): _gs = [] - gui = GUI(size=(100, 100)) + gui = GUI(size=(100, 100), config_dir=tempdir) qtbot.addWidget(gui) gui.add_view(_create_canvas(), 'view1') @@ -140,7 +137,7 @@ def on_close(): gui.close() # Recreate the GUI with the saved state. - gui = GUI() + gui = GUI(config_dir=tempdir) gui.add_view(_create_canvas(), 'view1') gui.add_view(_create_canvas(), 'view2') @@ -168,40 +165,17 @@ def on_show(): # Test GUI plugin #------------------------------------------------------------------------------ -def test_gui_state_view(): +def test_gui_state_view(tempdir): view = Bunch(name='MyView0') - state = GUIState() + state = GUIState(config_dir=tempdir) state.update_view_state(view, dict(hello='world')) assert not state.get_view_state(Bunch(name='MyView')) assert not state.get_view_state(Bunch(name='MyView1')) assert state.get_view_state(view) == Bunch(hello='world') -def test_create_gui_1(qapp, tempdir): - - _ensure_dir_exists(op.join(tempdir, 'GUI/')) - path = op.join(tempdir, 'GUI/state.json') - _save_json(path, {'hello': 'world'}) - - _tmp = [] - - class MyGUIPlugin(IPlugin): - def attach_to_gui(self, gui): - _tmp.append(gui.state.hello) - - gui = create_gui(plugins=['MyGUIPlugin'], config_dir=tempdir) - assert gui - assert _tmp == ['world'] - gui.state.hello = 'dolly' - gui.state.save() - - assert GUIState(config_dir=tempdir).hello == 'dolly' - - gui.close() - - -def test_save_geometry_state(gui): - gui.state = Bunch() +def test_save_geometry_state(tempdir, gui): + gui.state = GUIState(config_dir=tempdir) SaveGeometryStatePlugin().attach_to_gui(gui) gui.close() From 71cd6d46cde2c5984b5581ed24e5aa584ef4c693 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 2 Mar 2016 16:22:33 +0100 Subject: [PATCH 1022/1059] WIP: remove create_gui() --- phy/cluster/manual/tests/test_gui_component.py | 10 ++++------ phy/cluster/manual/tests/test_views.py | 4 ++-- phy/gui/tests/test_gui.py | 3 +-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 86714ab3b..fcac0648c 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -14,8 +14,7 @@ from ..gui_component import (ManualClustering, ) from phy.io.array import _spikes_in_clusters -from phy.gui import GUI, create_gui -from phy.utils import Bunch +from phy.gui import GUI from .conftest import MockController @@ -24,9 +23,8 @@ #------------------------------------------------------------------------------ @yield_fixture -def gui(qtbot): - gui = GUI(position=(200, 100), size=(500, 500)) - gui.state = Bunch() +def gui(tempdir, qtbot): + gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir) gui.show() qtbot.waitForWindowShown(gui) yield gui @@ -151,7 +149,7 @@ def test_manual_clustering_split_2(gui, quality, similarity): def test_manual_clustering_split_lasso(tempdir, qtbot): - gui = create_gui(config_dir=tempdir) + gui = GUI(config_dir=tempdir) gui.controller = MockController(tempdir) gui.controller.set_manual_clustering(gui) mc = gui.controller.manual_clustering diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index eddea5c3e..f591094be 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -12,7 +12,7 @@ from vispy.util import keys from pytest import fixture -from phy.gui import create_gui +from phy.gui import GUI from phy.utils import Bunch from .conftest import MockController from ..views import (ScatterView, @@ -39,7 +39,7 @@ def state(tempdir): @fixture def gui(tempdir, state): - gui = create_gui(config_dir=tempdir, **state) + gui = GUI(config_dir=tempdir, **state) gui.controller = MockController(tempdir) gui.controller.set_manual_clustering(gui) return gui diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index 3967e4801..e7bf36734 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -162,7 +162,7 @@ def on_show(): #------------------------------------------------------------------------------ -# Test GUI plugin +# Test GUI state #------------------------------------------------------------------------------ def test_gui_state_view(tempdir): @@ -175,7 +175,6 @@ def test_gui_state_view(tempdir): def test_save_geometry_state(tempdir, gui): - gui.state = GUIState(config_dir=tempdir) SaveGeometryStatePlugin().attach_to_gui(gui) gui.close() From e88444003e55dc8d362f0ab40350b349eb5e87b5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 2 Mar 2016 16:44:17 +0100 Subject: [PATCH 1023/1059] Add controller.create_gui() --- phy/cluster/manual/controller.py | 21 +++++++++++++++++++++ phy/cluster/manual/tests/test_controller.py | 15 +++++++++++---- phy/gui/gui.py | 15 ++++++++++++++- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index fc7658d09..62584a88a 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -19,6 +19,7 @@ select_traces, extract_spikes, ) +from phy.gui import GUI from phy.io.array import _get_data_lim, concat_per_cluster from phy.io import Context, Selector from phy.stats.clusters import (mean, @@ -324,3 +325,23 @@ def on_cluster(up): self.manual_clustering = mc mc.add_column(self.get_probe_depth, name='probe_depth') mc.attach(gui) + + def create_gui(self, name=None, subtitle=None, + plugins=None, config_dir=None): + """Create a manual clustering GUI.""" + gui = GUI(name=name, subtitle=subtitle, config_dir=config_dir) + self.set_manual_clustering(gui) + + # Add views. + self.add_correlogram_view(gui) + if self.all_features is not None: + self.add_feature_view(gui) + if self.all_waveforms is not None: + self.add_waveform_view(gui) + if self.all_traces is not None: + self.add_trace_view(gui) + + # Attach the specified plugins. + gui.attach_plugins(plugins) + + return gui diff --git a/phy/cluster/manual/tests/test_controller.py b/phy/cluster/manual/tests/test_controller.py index 2ecc0e430..c3a4d3894 100644 --- a/phy/cluster/manual/tests/test_controller.py +++ b/phy/cluster/manual/tests/test_controller.py @@ -6,12 +6,19 @@ # Imports #------------------------------------------------------------------------------ +from .conftest import MockController + #------------------------------------------------------------------------------ -# Fixtures +# Test controller #------------------------------------------------------------------------------ +def test_controller_1(qtbot, tempdir): + controller = MockController(tempdir) + gui = controller.create_gui('MyGUI', config_dir=tempdir) + gui.show() -#------------------------------------------------------------------------------ -# Utils -#------------------------------------------------------------------------------ + controller.manual_clustering.select([2, 3]) + + # qtbot.stop() + gui.close() diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 1a845f445..ea57bb791 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -18,7 +18,7 @@ from phy.utils import (Bunch, _bunchify, _load_json, _save_json, _ensure_dir_exists, phy_user_dir,) -from phy.utils.plugin import IPlugin +from phy.utils.plugin import IPlugin, get_plugin logger = logging.getLogger(__name__) @@ -100,6 +100,9 @@ def _get_dock_position(position): }[position or 'right'] +_DEFAULT_PLUGINS = ('SaveGeometryStatePlugin',) + + class GUI(QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. @@ -354,6 +357,16 @@ def restore_geometry_state(self, gs): if gs.get('state', None): self.restoreState((gs['state'])) + # Plugins + # ------------------------------------------------------------------------- + + def attach_plugins(self, plugins=None): + """Attach specified plugins.""" + plugins = list(_DEFAULT_PLUGINS) + (plugins or []) + for plugin in plugins: + get_plugin(plugin)().attach_to_gui(self) + return self + # ----------------------------------------------------------------------------- # GUI state, creator, plugins From 95c9ff5b433aed782f6cf789bbc4df2d5d4c2967 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 2 Mar 2016 17:04:26 +0100 Subject: [PATCH 1024/1059] Add debug --- phy/gui/gui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index ea57bb791..5309bf3e1 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -423,10 +423,12 @@ def attach_to_gui(self, gui): @gui.connect_ def on_close(): + logger.debug("Save geometry state.") gs = gui.save_geometry_state() state['geometry_state'] = gs @gui.connect_ def on_show(): + logger.debug("Load geometry state.") gs = state.get('geometry_state', None) gui.restore_geometry_state(gs) From 2eff6a52c5125e0a5bdaa2db118e9e7e43fd47a2 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 16:36:22 +0100 Subject: [PATCH 1025/1059] Rename _load_config() to load_config() --- phy/utils/__init__.py | 6 +++++- phy/utils/config.py | 4 ++-- phy/utils/tests/test_config.py | 6 +++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index c4648a542..12c8eab5f 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -9,4 +9,8 @@ Bunch, _is_list, _bunchify) from .event import EventEmitter, ProgressReporter from .plugin import IPlugin, get_plugin -from .config import _ensure_dir_exists, load_master_config, phy_user_dir +from .config import( _ensure_dir_exists, + load_master_config, + phy_user_dir, + load_config, + ) diff --git a/phy/utils/config.py b/phy/utils/config.py index 81b4f5d79..ce08dcab2 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -35,7 +35,7 @@ def _ensure_dir_exists(path): assert op.exists(path) and op.isdir(path) -def _load_config(path): +def load_config(path): """Load a Python or JSON config file.""" if not op.exists(path): return Config() @@ -80,7 +80,7 @@ def load_master_config(user_dir=None): with open(path, 'w') as f: f.write(_default_config(user_dir=user_dir)) assert op.exists(path) - return _load_config(path) + return load_config(path) def save_config(path, config): diff --git a/phy/utils/tests/test_config.py b/phy/utils/tests/test_config.py index 6d8626938..e5c171f2a 100644 --- a/phy/utils/tests/test_config.py +++ b/phy/utils/tests/test_config.py @@ -16,7 +16,7 @@ from .. import config as _config from .._misc import _write_text from ..config import (_ensure_dir_exists, - _load_config, + load_config, load_master_config, save_config, ) @@ -86,7 +86,7 @@ class MyConfigurable(Configurable): assert MyConfigurable().my_var == 0.0 - c = _load_config(config) + c = load_config(config) assert c.MyConfigurable.my_var == 1.0 # Create a new MyConfigurable instance. @@ -117,5 +117,5 @@ def test_save_config(tempdir): path = op.join(tempdir, 'config.json') save_config(path, c) - c1 = _load_config(path) + c1 = load_config(path) assert c1.A.b == 3. From 1b174dcc6d4558d6afef634c571f74bc8d7ca936 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 17:19:23 +0100 Subject: [PATCH 1026/1059] Remove SaveGeometryStatePlugin --- phy/gui/gui.py | 41 ++++++++------------------------------- phy/gui/tests/test_gui.py | 16 +++++---------- 2 files changed, 13 insertions(+), 44 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 5309bf3e1..54d1f7423 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -18,7 +18,6 @@ from phy.utils import (Bunch, _bunchify, _load_json, _save_json, _ensure_dir_exists, phy_user_dir,) -from phy.utils.plugin import IPlugin, get_plugin logger = logging.getLogger(__name__) @@ -100,9 +99,6 @@ def _get_dock_position(position): }[position or 'right'] -_DEFAULT_PLUGINS = ('SaveGeometryStatePlugin',) - - class GUI(QMainWindow): """A Qt main window holding docking Qt or VisPy widgets. @@ -169,10 +165,15 @@ def __init__(self, # Create the state. self.state = GUIState(self.name, **kwargs) + gs = self.state.get('geometry_state', None) + self.restore_geometry_state(gs) - # Save the state to disk when closing the GUI. @self.connect_ def on_close(): + logger.debug("Save geometry state.") + gs = self.save_geometry_state() + self.state['geometry_state'] = gs + # Save the state to disk when closing the GUI. self.state.save() def _set_name(self, name, subtitle): @@ -352,24 +353,15 @@ def restore_geometry_state(self, gs): """ if not gs: return + logger.debug("Load geometry state.") if gs.get('geometry', None): self.restoreGeometry((gs['geometry'])) if gs.get('state', None): self.restoreState((gs['state'])) - # Plugins - # ------------------------------------------------------------------------- - - def attach_plugins(self, plugins=None): - """Attach specified plugins.""" - plugins = list(_DEFAULT_PLUGINS) + (plugins or []) - for plugin in plugins: - get_plugin(plugin)().attach_to_gui(self) - return self - # ----------------------------------------------------------------------------- -# GUI state, creator, plugins +# GUI state, creator # ----------------------------------------------------------------------------- class GUIState(Bunch): @@ -415,20 +407,3 @@ def save(self): logger.debug("Save the GUI state to `%s`.", self.path) _save_json(self.path, {k: v for k, v in self.items() if k not in ('config_dir', 'name')}) - - -class SaveGeometryStatePlugin(IPlugin): - def attach_to_gui(self, gui): - state = gui.state - - @gui.connect_ - def on_close(): - logger.debug("Save geometry state.") - gs = gui.save_geometry_state() - state['geometry_state'] = gs - - @gui.connect_ - def on_show(): - logger.debug("Load geometry state.") - gs = state.get('geometry_state', None) - gui.restore_geometry_state(gs) diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py index e7bf36734..f41b8773f 100644 --- a/phy/gui/tests/test_gui.py +++ b/phy/gui/tests/test_gui.py @@ -12,7 +12,6 @@ from ..gui import (GUI, GUIState, _try_get_matplotlib_canvas, _try_get_vispy_canvas, - SaveGeometryStatePlugin, ) from phy.utils import Bunch from phy.utils._color import _random_color @@ -96,6 +95,11 @@ def on_close_view(view): view.close() assert _close == [1, 0] + gui.close() + + assert gui.state.geometry_state['geometry'] + assert gui.state.geometry_state['state'] + gui.default_actions.exit() @@ -172,13 +176,3 @@ def test_gui_state_view(tempdir): assert not state.get_view_state(Bunch(name='MyView')) assert not state.get_view_state(Bunch(name='MyView1')) assert state.get_view_state(view) == Bunch(hello='world') - - -def test_save_geometry_state(tempdir, gui): - SaveGeometryStatePlugin().attach_to_gui(gui) - gui.close() - - assert gui.state.geometry_state['geometry'] - assert gui.state.geometry_state['state'] - - gui.show() From 8f65479f00f65d73c2c9ee6474116fc60002c271 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 17:40:17 +0100 Subject: [PATCH 1027/1059] Rename user_dir to config_dir --- phy/cluster/manual/controller.py | 40 +++++++++++++++---- phy/cluster/manual/tests/conftest.py | 6 +-- phy/cluster/manual/tests/test_controller.py | 25 +++++++++++- .../manual/tests/test_gui_component.py | 4 +- phy/cluster/manual/tests/test_views.py | 7 +--- phy/gui/gui.py | 4 +- phy/io/context.py | 4 +- phy/io/datasets.py | 8 ++-- phy/io/tests/test_context.py | 10 ++--- phy/utils/__init__.py | 2 +- phy/utils/cli.py | 6 +-- phy/utils/config.py | 14 +++---- phy/utils/tests/conftest.py | 12 +++--- phy/utils/tests/test_cli.py | 12 +++--- phy/utils/tests/test_config.py | 12 +++--- phy/utils/tests/test_plugin.py | 24 ----------- 16 files changed, 101 insertions(+), 89 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 62584a88a..04876ff1b 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -25,7 +25,7 @@ from phy.stats.clusters import (mean, get_waveform_amplitude, ) -from phy.utils import Bunch +from phy.utils import Bunch, load_master_config, get_plugin, EventEmitter logger = logging.getLogger(__name__) @@ -34,16 +34,38 @@ # Kwik GUI #------------------------------------------------------------------------------ -class Controller(object): - """Take data out of the model and feeds it to views.""" +class Controller(EventEmitter): + """Take data out of the model and feeds it to views. + + Events + ------ + + init() + create_gui(gui) + + """ # responsible for the cache - def __init__(self): + def __init__(self, plugins=None, config_dir=None): + super(Controller, self).__init__() + self.config_dir = config_dir self._init_data() self._init_selector() self._init_context() self.n_spikes = len(self.spike_times) + # Attach the plugins. + plugins = plugins or [] + config = load_master_config(config_dir=config_dir) + c = config.get(self.__class__.__name__) + default_plugins = c.plugins if c else [] + if len(default_plugins): + plugins = default_plugins + plugins + for plugin in plugins: + get_plugin(plugin)().attach_to_controller(self) + + self.emit('init') + # Internal methods # ------------------------------------------------------------------------- @@ -327,10 +349,13 @@ def on_cluster(up): mc.attach(gui) def create_gui(self, name=None, subtitle=None, - plugins=None, config_dir=None): + plugins=None, config_dir=None, **kwargs): """Create a manual clustering GUI.""" - gui = GUI(name=name, subtitle=subtitle, config_dir=config_dir) + config_dir = config_dir or self.config_dir + gui = GUI(name=name, subtitle=subtitle, + config_dir=config_dir, **kwargs) self.set_manual_clustering(gui) + gui.controller = self # Add views. self.add_correlogram_view(gui) @@ -341,7 +366,6 @@ def create_gui(self, name=None, subtitle=None, if self.all_traces is not None: self.add_trace_view(gui) - # Attach the specified plugins. - gui.attach_plugins(plugins) + self.emit('create_gui', gui) return gui diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py index a22d6bd9e..01eca9764 100644 --- a/phy/cluster/manual/tests/conftest.py +++ b/phy/cluster/manual/tests/conftest.py @@ -54,12 +54,8 @@ def similarity(c): class MockController(Controller): - def __init__(self, tempdir): - self.tempdir = tempdir - super(MockController, self).__init__() - def _init_data(self): - self.cache_dir = self.tempdir + self.cache_dir = self.config_dir self.n_samples_waveforms = 31 self.n_samples_t = 20000 self.n_channels = 11 diff --git a/phy/cluster/manual/tests/test_controller.py b/phy/cluster/manual/tests/test_controller.py index c3a4d3894..888b59121 100644 --- a/phy/cluster/manual/tests/test_controller.py +++ b/phy/cluster/manual/tests/test_controller.py @@ -6,6 +6,9 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op +from textwrap import dedent + from .conftest import MockController @@ -14,10 +17,28 @@ #------------------------------------------------------------------------------ def test_controller_1(qtbot, tempdir): - controller = MockController(tempdir) - gui = controller.create_gui('MyGUI', config_dir=tempdir) + + plugin = dedent(''' + from phy import IPlugin + + class MockControllerPlugin(IPlugin): + def attach_to_controller(self, controller): + controller.hello = 'world' + + c = get_config() + c.MockController.plugins = ['MockControllerPlugin'] + + ''') + with open(op.join(tempdir, 'phy_config.py'), 'w') as f: + f.write(plugin) + + controller = MockController(config_dir=tempdir) + gui = controller.create_gui() gui.show() + # Ensure that the plugin has been loaded. + assert controller.hello == 'world' + controller.manual_clustering.select([2, 3]) # qtbot.stop() diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index fcac0648c..c24a8f3b3 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -150,11 +150,9 @@ def test_manual_clustering_split_2(gui, quality, similarity): def test_manual_clustering_split_lasso(tempdir, qtbot): gui = GUI(config_dir=tempdir) - gui.controller = MockController(tempdir) + gui.controller = MockController(config_dir=tempdir) gui.controller.set_manual_clustering(gui) mc = gui.controller.manual_clustering - - gui.controller = MockController(tempdir) view = gui.controller.add_feature_view(gui) gui.show() diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index f591094be..c1d7ee157 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -12,7 +12,6 @@ from vispy.util import keys from pytest import fixture -from phy.gui import GUI from phy.utils import Bunch from .conftest import MockController from ..views import (ScatterView, @@ -39,10 +38,8 @@ def state(tempdir): @fixture def gui(tempdir, state): - gui = GUI(config_dir=tempdir, **state) - gui.controller = MockController(tempdir) - gui.controller.set_manual_clustering(gui) - return gui + controller = MockController(config_dir=tempdir) + return controller.create_gui(**state) def _select_clusters(gui): diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 54d1f7423..5ad17b8fb 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -17,7 +17,7 @@ from phy.utils.event import EventEmitter from phy.utils import (Bunch, _bunchify, _load_json, _save_json, - _ensure_dir_exists, phy_user_dir,) + _ensure_dir_exists, phy_config_dir,) logger = logging.getLogger(__name__) @@ -374,7 +374,7 @@ class GUIState(Bunch): def __init__(self, name='GUI', config_dir=None, **kwargs): super(GUIState, self).__init__(**kwargs) self.name = name - self.config_dir = config_dir or phy_user_dir() + self.config_dir = config_dir or phy_config_dir() _ensure_dir_exists(op.join(self.config_dir, self.name)) self.load() diff --git a/phy/io/context.py b/phy/io/context.py index a666cb0e3..8fcb6ac6e 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -27,7 +27,7 @@ from .array import read_array, write_array from phy.utils import (Bunch, _save_json, _load_json, _ensure_dir_exists, _fullname,) -from phy.utils.config import phy_user_dir +from phy.utils.config import phy_config_dir logger = logging.getLogger(__name__) @@ -312,7 +312,7 @@ def _get_path(self, name, location): if location == 'local': return op.join(self.cache_dir, name + '.json') elif location == 'global': - return op.join(phy_user_dir(), name + '.json') + return op.join(phy_config_dir(), name + '.json') def save(self, name, data, location='local'): """Save a dictionary in a JSON file within the cache directory.""" diff --git a/phy/io/datasets.py b/phy/io/datasets.py index 4b9e8f36b..9092b92f1 100644 --- a/phy/io/datasets.py +++ b/phy/io/datasets.py @@ -12,7 +12,7 @@ import os.path as op from phy.utils.event import ProgressReporter -from phy.utils.config import phy_user_dir, _ensure_dir_exists +from phy.utils.config import phy_config_dir, _ensure_dir_exists logger = logging.getLogger(__name__) @@ -147,10 +147,10 @@ def download_file(url, output_path): return -def download_test_data(name, user_dir=None, force=False): +def download_test_data(name, config_dir=None, force=False): """Download a test file.""" - user_dir = user_dir or phy_user_dir() - dir = op.join(user_dir, 'test_data') + config_dir = config_dir or phy_config_dir() + dir = op.join(config_dir, 'test_data') _ensure_dir_exists(dir) path = op.join(dir, name) if not force and op.exists(path): diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index f7236ad3a..d8e32596c 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -60,13 +60,13 @@ def parallel_context(tempdir, ipy_client, request): @yield_fixture -def temp_phy_user_dir(tempdir): +def temp_phy_config_dir(tempdir): """Use a temporary phy user directory.""" import phy.io.context - f = phy.io.context.phy_user_dir - phy.io.context.phy_user_dir = lambda: tempdir + f = phy.io.context.phy_config_dir + phy.io.context.phy_config_dir = lambda: tempdir yield - phy.io.context.phy_user_dir = f + phy.io.context.phy_config_dir = f #------------------------------------------------------------------------------ @@ -95,7 +95,7 @@ def test_read_write(tempdir): ae(read_array(op.join(tempdir, 'test.npy')), x) -def test_context_load_save(tempdir, context, temp_phy_user_dir): +def test_context_load_save(tempdir, context, temp_phy_config_dir): assert not context.load('unexisting') context.save('a/hello', {'text': 'world'}) diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 12c8eab5f..2a49c9a45 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -11,6 +11,6 @@ from .plugin import IPlugin, get_plugin from .config import( _ensure_dir_exists, load_master_config, - phy_user_dir, + phy_config_dir, load_config, ) diff --git a/phy/utils/cli.py b/phy/utils/cli.py index 0033d53e0..3841b8a1d 100644 --- a/phy/utils/cli.py +++ b/phy/utils/cli.py @@ -74,11 +74,11 @@ def phy(ctx): # CLI plugins #------------------------------------------------------------------------------ -def load_cli_plugins(cli, user_dir=None): +def load_cli_plugins(cli, config_dir=None): """Load all plugins and attach them to a CLI object.""" from .config import load_master_config - config = load_master_config(user_dir=user_dir) + config = load_master_config(config_dir=config_dir) plugins = discover_plugins(config.Plugins.dirs) for plugin in plugins: @@ -88,7 +88,7 @@ def load_cli_plugins(cli, user_dir=None): # NOTE: plugin is a class, so we need to instantiate it. try: plugin().attach_to_cli(cli) - except Exception as e: + except Exception as e: # pragma: no cover logger.error("Error when loading plugin `%s`: %s", plugin, e) diff --git a/phy/utils/config.py b/phy/utils/config.py index ce08dcab2..56da6bc80 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -23,7 +23,7 @@ # Config #------------------------------------------------------------------------------ -def phy_user_dir(): +def phy_config_dir(): """Return the absolute path to the phy user directory.""" return op.expanduser('~/.phy/') @@ -52,8 +52,8 @@ def load_config(path): return config -def _default_config(user_dir=None): - path = op.join(user_dir or '~/.phy/', 'plugins/') +def _default_config(config_dir=None): + path = op.join(config_dir or '~/.phy/', 'plugins/') return dedent(""" # You can also put your plugins in ~/.phy/plugins/. @@ -69,16 +69,16 @@ def attach_to_cli(self, cli): """.format(path)) -def load_master_config(user_dir=None): +def load_master_config(config_dir=None): """Load a master Config file from `~/.phy/phy_config.py`.""" - user_dir = user_dir or phy_user_dir() - path = op.join(user_dir, 'phy_config.py') + config_dir = config_dir or phy_config_dir() + path = op.join(config_dir, 'phy_config.py') # Create a default config file if necessary. if not op.exists(path): _ensure_dir_exists(op.dirname(path)) logger.debug("Creating default phy config file at `%s`.", path) with open(path, 'w') as f: - f.write(_default_config(user_dir=user_dir)) + f.write(_default_config(config_dir=config_dir)) assert op.exists(path) return load_config(path) diff --git a/phy/utils/tests/conftest.py b/phy/utils/tests/conftest.py index 7026bc466..25f21eac2 100644 --- a/phy/utils/tests/conftest.py +++ b/phy/utils/tests/conftest.py @@ -14,18 +14,18 @@ #------------------------------------------------------------------------------ @yield_fixture -def temp_user_dir(tempdir): +def temp_config_dir(tempdir): """NOTE: the user directory should be loaded with: ```python from .. import config - config.phy_user_dir() + config.phy_config_dir() ``` and not: ```python - from config import phy_user_dir + from config import phy_config_dir ``` Otherwise, the monkey patching hack in tests won't work. @@ -33,7 +33,7 @@ def temp_user_dir(tempdir): """ from phy.utils import config - user_dir = config.phy_user_dir - config.phy_user_dir = lambda: tempdir + config_dir = config.phy_config_dir + config.phy_config_dir = lambda: tempdir yield tempdir - config.phy_user_dir = user_dir + config.phy_config_dir = config_dir diff --git a/phy/utils/tests/test_cli.py b/phy/utils/tests/test_cli.py index 99aed19e9..e8f9f22bf 100644 --- a/phy/utils/tests/test_cli.py +++ b/phy/utils/tests/test_cli.py @@ -25,9 +25,9 @@ def runner(): yield CliRunner() -def test_cli_empty(temp_user_dir, runner): +def test_cli_empty(temp_config_dir, runner): - # NOTE: make the import after the temp_user_dir fixture, to avoid + # NOTE: make the import after the temp_config_dir fixture, to avoid # loading any user plugin affecting the CLI. from ..cli import phy, load_cli_plugins load_cli_plugins(phy) @@ -44,7 +44,7 @@ def test_cli_empty(temp_user_dir, runner): assert result.output.startswith('Usage: phy') -def test_cli_plugins(temp_user_dir, runner): +def test_cli_plugins(temp_config_dir, runner): # Write a CLI plugin. cli_plugin = """ @@ -57,13 +57,13 @@ def attach_to_cli(self, cli): def hello(): click.echo("hello world") """ - path = op.join(temp_user_dir, 'plugins/hello.py') + path = op.join(temp_config_dir, 'plugins/hello.py') _write_text(path, cli_plugin) - # NOTE: make the import after the temp_user_dir fixture, to avoid + # NOTE: make the import after the temp_config_dir fixture, to avoid # loading any user plugin affecting the CLI. from ..cli import phy, load_cli_plugins - load_cli_plugins(phy, user_dir=temp_user_dir) + load_cli_plugins(phy, config_dir=temp_config_dir) # The plugin should have added a new command. result = runner.invoke(phy, ['--help']) diff --git a/phy/utils/tests/test_config.py b/phy/utils/tests/test_config.py index e5c171f2a..a13b2d3e9 100644 --- a/phy/utils/tests/test_config.py +++ b/phy/utils/tests/test_config.py @@ -26,8 +26,8 @@ # Test config #------------------------------------------------------------------------------ -def test_phy_user_dir(): - assert _config.phy_user_dir().endswith('.phy/') +def test_phy_config_dir(): + assert _config.phy_config_dir().endswith('.phy/') def test_ensure_dir_exists(tempdir): @@ -36,8 +36,8 @@ def test_ensure_dir_exists(tempdir): assert op.isdir(path) -def test_temp_user_dir(temp_user_dir): - assert _config.phy_user_dir() == temp_user_dir +def test_temp_config_dir(temp_config_dir): + assert _config.phy_config_dir() == temp_config_dir #------------------------------------------------------------------------------ @@ -98,13 +98,13 @@ class MyConfigurable(Configurable): assert configurable.my_var == 1.0 -def test_load_master_config(temp_user_dir): +def test_load_master_config(temp_config_dir): # Create a config file in the temporary user directory. config_contents = dedent(""" c = get_config() c.MyConfigurable.my_var = 1.0 """) - with open(op.join(temp_user_dir, 'phy_config.py'), 'w') as f: + with open(op.join(temp_config_dir, 'phy_config.py'), 'w') as f: f.write(config_contents) # Load the master config file. diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py index a8119f84a..b1f3994d6 100644 --- a/phy/utils/tests/test_plugin.py +++ b/phy/utils/tests/test_plugin.py @@ -32,30 +32,6 @@ def no_native_plugins(): IPluginRegistry.plugins = plugins -@yield_fixture(params=[(False, 'my_plugins/plugin.py'), - (True, 'plugins/plugin.py'), - ]) -def plugin(no_native_plugins, temp_user_dir, request): - path = op.join(temp_user_dir, request.param[1]) - contents = """ - from phy import IPlugin - class MyPlugin(IPlugin): - pass - """ - _write_text(path, contents) - yield temp_user_dir, request.param[0], request.param[1] - - -def _write_my_plugins_dir_in_config(temp_user_dir): - # Now, we specify the path to the plugin in the phy config file. - config_contents = """ - c = get_config() - c.Plugins.dirs = [r'%s'] - """ - _write_text(op.join(temp_user_dir, 'phy_config.py'), - config_contents % op.join(temp_user_dir, 'my_plugins/')) - - #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ From 770dc5b224af284afb61a0c9908389b88118b8cd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 17:44:52 +0100 Subject: [PATCH 1028/1059] Refactor _add_view() in controller --- phy/cluster/manual/controller.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 04876ff1b..572423b2e 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -280,14 +280,18 @@ def spikes_per_cluster(self, cluster_id): # View methods # ------------------------------------------------------------------------- + def _add_view(self, gui, view): + view.attach(gui) + self.emit('add_view', view) + return view + def add_waveform_view(self, gui): v = WaveformView(waveforms=self.get_waveforms, channel_positions=self.channel_positions, waveform_lims=self.get_waveform_lims(), best_channels=self.get_best_channels, ) - v.attach(gui) - return v + return self._add_view(gui, v) def add_trace_view(self, gui): v = TraceView(traces=self.get_traces, @@ -296,8 +300,7 @@ def add_trace_view(self, gui): duration=self.duration, n_channels=self.n_channels, ) - v.attach(gui) - return v + return self._add_view(gui, v) def add_feature_view(self, gui): v = FeatureView(features=self.get_features, @@ -308,16 +311,14 @@ def add_feature_view(self, gui): feature_lim=self.get_feature_lim(), best_channels=self.get_channels_by_amplitude, ) - v.attach(gui) - return v + return self._add_view(gui, v) def add_correlogram_view(self, gui): v = CorrelogramView(spike_times=self.spike_times, spike_clusters=self.spike_clusters, sample_rate=self.sample_rate, ) - v.attach(gui) - return v + return self._add_view(gui, v) # GUI methods # ------------------------------------------------------------------------- From 63e483fb75c525e634d232138234a7135fe71a19 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 17:45:50 +0100 Subject: [PATCH 1029/1059] Refactor _add_view() in controller --- phy/cluster/manual/controller.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 572423b2e..4b64dbcd6 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -42,6 +42,7 @@ class Controller(EventEmitter): init() create_gui(gui) + add_view(gui, view) """ # responsible for the cache @@ -282,7 +283,7 @@ def spikes_per_cluster(self, cluster_id): def _add_view(self, gui, view): view.attach(gui) - self.emit('add_view', view) + self.emit('add_view', gui, view) return view def add_waveform_view(self, gui): From eeaed96c7a672f1f9bcb2286615cb7c73bcd81ac Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 17:50:37 +0100 Subject: [PATCH 1030/1059] Fix geometry state in GUI --- phy/gui/gui.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/phy/gui/gui.py b/phy/gui/gui.py index 5ad17b8fb..b154ebc2b 100644 --- a/phy/gui/gui.py +++ b/phy/gui/gui.py @@ -165,12 +165,16 @@ def __init__(self, # Create the state. self.state = GUIState(self.name, **kwargs) - gs = self.state.get('geometry_state', None) - self.restore_geometry_state(gs) + + @self.connect_ + def on_show(): + logger.debug("Load the geometry state.") + gs = self.state.get('geometry_state', None) + self.restore_geometry_state(gs) @self.connect_ def on_close(): - logger.debug("Save geometry state.") + logger.debug("Save the geometry state.") gs = self.save_geometry_state() self.state['geometry_state'] = gs # Save the state to disk when closing the GUI. @@ -353,7 +357,6 @@ def restore_geometry_state(self, gs): """ if not gs: return - logger.debug("Load geometry state.") if gs.get('geometry', None): self.restoreGeometry((gs['geometry'])) if gs.get('state', None): From d16389207a7654eedd41737c7fd1a59773c72cd0 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 18:07:12 +0100 Subject: [PATCH 1031/1059] Refactor set_manual_clustering() --- phy/cluster/manual/controller.py | 50 ++++++++++--------- phy/cluster/manual/gui_component.py | 3 +- .../manual/tests/test_gui_component.py | 9 ++-- 3 files changed, 32 insertions(+), 30 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 4b64dbcd6..f4a843ffc 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -52,6 +52,7 @@ def __init__(self, plugins=None, config_dir=None): self._init_data() self._init_selector() self._init_context() + self._set_manual_clustering() self.n_spikes = len(self.spike_times) @@ -119,6 +120,28 @@ def _init_context(self): self.spikes_per_cluster = ctx.memcache(self.spikes_per_cluster) + def _set_manual_clustering(self): + # Load the new cluster id. + new_cluster_id = self.context.load('new_cluster_id'). \ + get('new_cluster_id', None) + mc = ManualClustering(self.spike_clusters, + self.spikes_per_cluster, + similarity=self.similarity, + cluster_groups=self.cluster_groups, + new_cluster_id=new_cluster_id, + ) + + # Save the new cluster id on disk. + @mc.clustering.connect + def on_cluster(up): + new_cluster_id = mc.clustering.new_cluster_id() + logger.debug("Save the new cluster id: %d", new_cluster_id) + self.context.save('new_cluster_id', + dict(new_cluster_id=new_cluster_id)) + + self.manual_clustering = mc + mc.add_column(self.get_probe_depth, name='probe_depth') + def _select_spikes(self, cluster_id, n_max=None): assert isinstance(cluster_id, int) assert cluster_id >= 0 @@ -327,38 +350,17 @@ def add_correlogram_view(self, gui): def similarity(self, cluster_id): return self.get_close_clusters(cluster_id) - def set_manual_clustering(self, gui): - # Load the new cluster id. - new_cluster_id = self.context.load('new_cluster_id'). \ - get('new_cluster_id', None) - mc = ManualClustering(self.spike_clusters, - self.spikes_per_cluster, - similarity=self.similarity, - cluster_groups=self.cluster_groups, - new_cluster_id=new_cluster_id, - ) - - # Save the new cluster id on disk. - @mc.clustering.connect - def on_cluster(up): - new_cluster_id = mc.clustering.new_cluster_id() - logger.debug("Save the new cluster id: %d", new_cluster_id) - self.context.save('new_cluster_id', - dict(new_cluster_id=new_cluster_id)) - - self.manual_clustering = mc - mc.add_column(self.get_probe_depth, name='probe_depth') - mc.attach(gui) - def create_gui(self, name=None, subtitle=None, plugins=None, config_dir=None, **kwargs): """Create a manual clustering GUI.""" config_dir = config_dir or self.config_dir gui = GUI(name=name, subtitle=subtitle, config_dir=config_dir, **kwargs) - self.set_manual_clustering(gui) gui.controller = self + # Attach the ManualClustering component to the GUI. + self.manual_clustering.attach(gui) + # Add views. self.add_correlogram_view(gui) if self.all_features is not None: diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index e93674556..d463cb8be 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -467,7 +467,8 @@ def merge(self, cluster_ids=None): def split(self, spike_ids=None): """Split the selected spikes (NOT IMPLEMENTED YET).""" if spike_ids is None: - spike_ids = np.concatenate(self.gui.emit('request_split')) + spike_ids = self.gui.emit('request_split') + spike_ids = np.concatenate(spike_ids).astype(np.int64) if len(spike_ids) == 0: return self.clustering.split(spike_ids) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index c24a8f3b3..624c10be1 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -149,11 +149,10 @@ def test_manual_clustering_split_2(gui, quality, similarity): def test_manual_clustering_split_lasso(tempdir, qtbot): - gui = GUI(config_dir=tempdir) - gui.controller = MockController(config_dir=tempdir) - gui.controller.set_manual_clustering(gui) - mc = gui.controller.manual_clustering - view = gui.controller.add_feature_view(gui) + controller = MockController(config_dir=tempdir) + gui = controller.create_gui() + mc = controller.manual_clustering + view = gui.list_views('FeatureView', is_visible=False)[0] gui.show() From 2c7c91ec4a3b449e9717d3589bdd713c53fd4c79 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 3 Mar 2016 18:08:51 +0100 Subject: [PATCH 1032/1059] Update comment --- phy/cluster/manual/gui_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index d463cb8be..221f33d87 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -465,7 +465,7 @@ def merge(self, cluster_ids=None): self._global_history.action(self.clustering) def split(self, spike_ids=None): - """Split the selected spikes (NOT IMPLEMENTED YET).""" + """Split the selected spikes.""" if spike_ids is None: spike_ids = self.gui.emit('request_split') spike_ids = np.concatenate(spike_ids).astype(np.int64) From 4d8cf13b7362042341ec753619c24f8a5e0d659b Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 14:30:40 +0100 Subject: [PATCH 1033/1059] 2D anchor in text visual --- phy/plot/glsl/text.vert | 4 ++-- phy/plot/tests/test_plot.py | 2 +- phy/plot/tests/test_visuals.py | 23 +++++++++++++++++++++++ phy/plot/visuals.py | 17 ++++++++++------- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/phy/plot/glsl/text.vert b/phy/plot/glsl/text.vert index 1a78c28c7..f564df0f2 100644 --- a/phy/plot/glsl/text.vert +++ b/phy/plot/glsl/text.vert @@ -4,7 +4,7 @@ attribute float a_glyph_index; // glyph index in the text attribute float a_quad_index; // quad index in the glyph attribute float a_char_index; // index of the glyph in the texture attribute float a_lengths; -attribute float a_anchor; +attribute vec2 a_anchor; uniform vec2 u_glyph_size; // (w, h) @@ -26,7 +26,7 @@ void main() { // Position of the glyph. gl_Position = transform(a_position); gl_Position.xy = gl_Position.xy + vec2(a_glyph_index * w + dx * w, dy * h); - gl_Position.x += (a_anchor - .5) * a_lengths * w; + gl_Position.xy += (a_anchor - .5) * vec2(a_lengths * w, h); // Index in the texture float i = floor(a_char_index / cols); diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py index b298e9f5d..d1e57584e 100644 --- a/phy/plot/tests/test_plot.py +++ b/phy/plot/tests/test_plot.py @@ -128,7 +128,7 @@ def test_grid_lines(qtbot): def test_grid_text(qtbot): view = View(layout='grid', shape=(2, 1)) - view[0, 0].text(pos=(0, 0), text='Hello world!', anchor=0.) + view[0, 0].text(pos=(0, 0), text='Hello world!', anchor=(0., 0.)) view[1, 0].text(pos=[[-.5, 0], [+.5, 0]], text=['|', ':)']) _show(qtbot, view) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py index 2f60c71dd..bee2dccb2 100644 --- a/phy/plot/tests/test_visuals.py +++ b/phy/plot/tests/test_visuals.py @@ -9,6 +9,7 @@ import numpy as np +from ..transform import NDC from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, LineVisual, PolygonVisual, TextVisual, ) @@ -212,3 +213,25 @@ def test_text_0(qtbot, canvas_pz): _test_visual(qtbot, canvas_pz, TextVisual(), pos=pos, text=text) + + +def test_text_1(qtbot, canvas_pz): + c = canvas_pz + + text = ['--x--'] * 5 + pos = [[0, 0], [-.5, +.5], [+.5, +.5], [-.5, -.5], [+.5, -.5]] + anchor = [[0, 0], [-1, +1], [+1, +1], [-1, -1], [+1, -1]] + + v = TextVisual() + c.add_visual(v) + v.set_data(pos=pos, text=text, anchor=anchor, data_bounds=NDC) + + v = ScatterVisual() + c.add_visual(v) + v.set_data(pos=pos, data_bounds=NDC) + + c.show() + qtbot.waitForWindowShown(c.native) + + # qtbot.stop() + c.close() diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 19614e78f..ee8d0e7f2 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -94,6 +94,7 @@ def validate(x=None, if pos is None: x, y = _get_pos(x, y) pos = np.c_[x, y] + pos = np.asarray(pos) assert pos.ndim == 2 assert pos.shape[1] == 2 n = pos.shape[0] @@ -363,10 +364,12 @@ def validate(pos=None, text=None, anchor=None, n_text = pos.shape[0] assert len(text) == n_text - anchor = anchor if anchor is not None else 0. - if not hasattr(anchor, '__len__'): - anchor = [anchor] * n_text - assert len(anchor) == n_text + anchor = anchor if anchor is not None else (0., 0.) + anchor = np.atleast_2d(anchor) + if anchor.shape[0] == 1: + anchor = np.repeat(anchor, n_text, axis=0) + assert anchor.ndim == 2 + assert anchor.shape == (n_text, 2) # By default, we assume that the coordinates are in NDC. if data_bounds is None: @@ -419,8 +422,8 @@ def set_data(self, *args, **kwargs): a_quad_index = np.tile(a_quad_index, n_glyphs) a_char_index = np.repeat(a_char_index, 6) - a_anchor = np.repeat(a_anchor, lengths) - a_anchor = np.repeat(a_anchor, 6) + a_anchor = np.repeat(a_anchor, lengths, axis=0) + a_anchor = np.repeat(a_anchor, 6, axis=0) a_lengths = np.repeat(lengths, lengths) a_lengths = np.repeat(a_lengths, 6) @@ -429,7 +432,7 @@ def set_data(self, *args, **kwargs): assert a_position.shape == (n_vertices, 2) assert a_glyph_index.shape == (n_vertices,) assert a_quad_index.shape == (n_vertices,) - assert a_anchor.shape == (n_vertices,) + assert a_anchor.shape == (n_vertices, 2) assert a_lengths.shape == (n_vertices,) # Transform the positions. From efb3d9bd24fe66cca8d0f3d43c65ea3c2d07dfbd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 14:41:14 +0100 Subject: [PATCH 1034/1059] Add channel labels in waveform view --- phy/cluster/manual/controller.py | 19 ++++++++++-------- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 3 +++ .../static/SourceCodePro-Regular-32.npy.gz | Bin 0 -> 12030 bytes phy/plot/visuals.py | 2 +- 5 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 phy/plot/static/SourceCodePro-Regular-32.npy.gz diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index f4a843ffc..26a0f0eae 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -351,7 +351,9 @@ def similarity(self, cluster_id): return self.get_close_clusters(cluster_id) def create_gui(self, name=None, subtitle=None, - plugins=None, config_dir=None, **kwargs): + plugins=None, config_dir=None, + add_default_views=True, + **kwargs): """Create a manual clustering GUI.""" config_dir = config_dir or self.config_dir gui = GUI(name=name, subtitle=subtitle, @@ -362,13 +364,14 @@ def create_gui(self, name=None, subtitle=None, self.manual_clustering.attach(gui) # Add views. - self.add_correlogram_view(gui) - if self.all_features is not None: - self.add_feature_view(gui) - if self.all_waveforms is not None: - self.add_waveform_view(gui) - if self.all_traces is not None: - self.add_trace_view(gui) + if add_default_views: + self.add_correlogram_view(gui) + if self.all_features is not None: + self.add_feature_view(gui) + if self.all_waveforms is not None: + self.add_waveform_view(gui) + if self.all_traces is not None: + self.add_trace_view(gui) self.emit('create_gui', gui) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index c1d7ee157..92a87478f 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -39,7 +39,7 @@ def state(tempdir): @fixture def gui(tempdir, state): controller = MockController(config_dir=tempdir) - return controller.create_gui(**state) + return controller.create_gui(add_default_views=False, **state) def _select_clusters(gui): diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 796077f4e..7389fe493 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -374,6 +374,9 @@ def on_select(self, cluster_ids=None): depth=depth, data_bounds=self.data_bounds, ) + # Add channel labels. + self[ch].text(pos=[[t[0, 0], 0.]], text=str(ch), + anchor=[-1., -.25]) # Zoom on the best channels when selecting clusters. channels = self.best_channels(cluster_ids) diff --git a/phy/plot/static/SourceCodePro-Regular-32.npy.gz b/phy/plot/static/SourceCodePro-Regular-32.npy.gz new file mode 100644 index 0000000000000000000000000000000000000000..46e0019dfb15a8599e367edde79d64085d00e8ac GIT binary patch literal 12030 zcmVwKQXJAxC z*S?zufdHX5k=6hY;cs-k=f3P=}G1QdiWv|U|HiOEV~3FYhArAW&@`;& z1I-&X4Vw}sQ3>I$q9-ltRoQlbXFdJ`Ni>BUB)tQHR;Y41Q(j#^qm~6Jjjj z;SH(zF>}k9%jMx`Sp!>V{q_T3h?n~RnbZxAYTjgg#BWWm_?i17o0!J0%z3!30}*m> z8~XdbGJCKSBAqmL!hHPR2A-$ym%`TC+Y|hGvw#~h zfsDN)g5}#;I9*FAP+|k_U9!1&lV7Ysg>8(xG*JGL<-w=wPqU@F!7d z9-j3oWW+BSNzAv(-bOycF-qsmbU>z<4$^%T#P6qYOuXrf8(|W62GCa!8UU*ef`qEz)jD{htxK|+~ekXB)`IaddUz5*%=s{(UvrktlUT1VX z&IO1Jt22;%hTkB*Mcagq;(7mKtiOu{J`B#3ws z=+4g5xYhG9v+@zd)-{$(=daLfKsR#*=K=NgyFssr|B{4XjWPeB#0nihUz9(=jNfRb zrHy7<+yv@Sg&fFmr_=KLjE0hBgZ@EPC* z{q7VTJ4aSSRQ-uZfenh4K#P{iQ4eyZKGKq+(TJskE5c{>DB)smi5F2)IxPibg z48yVu8zb0vQYifKP0B9R=Tjv#UXL@b>xe8`3~^|d!(O+x{$dSF^0LVk=DS)Su5xXNwL)ESGgCM zhlPw-;o|2P@00Q;2=(@6@2Tfi>IHNspDXfpy{_K%iHi2{3NTsx4p@qF<;06yq83*M zx5#A-vgC!o=d9$_?dqkJNx&m!?ku91P`BmIu9v=pj_nPg#W?tdZOM*#&P$mh)!CiB z%b-Pc^-cK`yGh`HSE+Z|YQ)u$e%gnAg3urP$mM6i zPyQmW;t_ETu#M+ZngQ>dxsixt?t>K_yMjN%=HRzS)DFTtmnqU-BDb4m{otGOCkWXj zZMIjbuOK$88{2hGx$j+Sk9b2~;cRoRk^IwGB=+r>gt6qZQW9`xV0f2V>A(&rT)#c+ zv6u_E>SyFNF?4?gpd7Jw)4a=>F({quoAM|BN4vdBZNEW2N=R#18nsxQD%LM{MtOyy zYB9x6i~Q4)_9TwZ5~aK#bUY8d;<1=lfv1hUFJQ0f{V}){DHi;_cR4b3ky?vRLwr>J zNZa3&y}OQAsc3QEqlw(lQFL!QDs!YOeV>u>eI@z2 zU$Z42o~%xloAdW!%f7WQ)qd;1b6H{`M#Al>2>u_%4hzu`_5lv9SKhT;(nQe6Bx$0?lb*UeJ8;G@avN{zP6 zYMCzE4kB;dZ6iU*?@bRsj^`(5~}vl0^)=>F_oCrU6?_uhmoH8pwb)9VbR3u7V;(lQPq6oCnwP z-r>F00lDfw_aUjom^K~8C&0AwQvj;F-_s-s`%RW5Lgh51&FS8%?`puR)@24TNBk~% z7|?c}JCoKojm9rvKfaV2uHO&TO2~X=b;q(4x?zw=gOI+R#c#f`F4av(amddPFE}WG%>_U&ial7jwmXSFg*N3x_GtxB2>AFH*^HD(~y;?5L z%k8$F1HUuE%==ebH8-wc!*P}Ar?&VSz$z+b^E|w`j>Yd?V^zL`xSH>Vo`FrpB?XD{ zMTtz+isO4gIlmBQ#bzoQXTQQ7-Ma(WUE6~|?LEMDcl~R1ZTWBjFkE47go9c1bh80` z6j>%-3E(IFF7z_F+sfi+Ja16Zs_ek=s7|*2I%jHfhhUR0;&D;^^*gv~=;=876>czn zIRiPbd7eUBUduto%lzLeXq_GDsMHD!2l~)LZt&ZOOcSpUp!ejK4Y*R)h+iK=iot-} z9Weg`_*J_(fH>tjPpF2e`Ox-TdNDRwygeHprd59h)4g%WEz%Q=Lfh$R2wwd=h^BYJ#8 zOY!mR0>zMCrmAiKGie_g!FA=VvjhJpnDCo*#09v{O;#a1fmrg7Ex&3@`t>=3rcf@G z`5!04HNQOh-+t4!+m67dF5Qm;mg}PC{R$4I6@LN_H)Leo+y6Brqv{&tA3Yr|t+i!~ z-Af|GPMUOLYiKbpewG8sp!1t1{$q!`=#1(~SRj6Hnc95P{=1nWfKSFijhT8-xcK7~`?Ep%X;ik}~8 z?X(;hzY%m#Cn)$i_Y79I`~wf*fV51<;I}RP!|+mpHC{Xm#LEpm2Pco?(6DwL#>KK@ zw@@M#v+4VwacQ@RE{`MHI}>`HbZG-xb)I|$s|^3S4@AH($l{;>H-`Q9S3z5gWnwRB zzt_d~TM}SS5%KVQz>1wrbF-AJ+HM@_XCv>C?>|GLyu{7m=chE=V!x@a-wx!wC|#9} z0%g4~ZEeu+OSPleTB$^zRe5kOPz*zamWlF*Hl(T;0IlxjGaMb7ObiUK^o3XV|M+4-rNd# zqAD9$^~LS@OJ;c3V6bqv3>a`9b-M^qweXPuV>fTYuiOb_hV!db4p?cf2?xJjkj@*5 z$7j@Gh_xeqf;H8M6ZHe7=9h8BnB7-^*@E9Y4NOPQ7+rnl*V!^A2k8U7$po!Y?%uv$Crl4He?o}2>wb7^X6)9ZT-Xr z0uEaxU=R2GG}=WLBLnpIUaf~-rcR^C;}DV4$D?|&$nM7uQ9(Zf8Q=Er5dQ3PqKz00 z8z}t9TavoCe}`86S3BWSTWuhjwHz7Ly3_0&*g8__64~it$umSZo!;%ihhIGeB*$;QJ=ywsQUmL| z;hXIu{RVDNvc~?PcTzn3_(N2w|63Q8mY=j^d?iykjQ&@w&+Wzq<-WR76EoME}d=SEXTjmQ%^Wp39XdzPa(gPh@xGTYYlkrZw)DXZ&cny!id^ zH#a3_C(qmoHRlf?3o9>45xZt zhn~vUpK~TN^Ykos2Hg2aaaAU26uqf30_=9WUc20(I=9+-m3yp#LC&?dA}6sI{3`sQ zo_XdgC zr?Xm`bU5EwY+q>JSVAt6d}PcYI@ySo3t0>t4>BfZDsAGGI19{_(^S~>v`LKJ5_Rj> zz>HrHCw2swf2LCGlnunjAy&sPO??5ik~+8$C0XuPKM$9B;(IBqiXq+;POD~`F$xET zm74zyrr2{xQ?BMxe}ly&w+wj8eVuu{k2t%zg@9kPELFvcN;d!L#9)$E_cc~^@<#TX~an-=eg{C7ez-^s5 zys}w1W16boIewL6k=~H^GsRkPC^uLTO4-&@eeC3=iamzP3)da8i1uSGrHdU{q2=?K zX3SL8iN->id}tnn^MfF-4mC5vm*FF}Ng%rW8xf#ZJ+3`4KQLtWnkL z3CULkzs?-qN-z1{9)2spdWYiTx2dIgF#j9ZJrsTmh1Ig@i^|&$yJ~?W6nDsea7WRu zx%{-_)%zv;=;27$8n856KG1R44mi}++zW{DRC{q_{BtU?OjS67p7Ya^9=IL+CWGB0 z=Hs_jxEd!|t{{z(q*(g&d6c}*x*QYb`{La{*b)Q9iDY$`WT9QvrCW4p z%GE~A>qROKd8)loT2rAY-E5jELyf2A7b3$|+bD+Hz^{*}PhJY*{>-pGMdSO)ct_#4 zUi4ABnktFe8ayERp%PA^llWCXpW*5R#eUAbRmN5?fPigG z>$T3lHU4&oJhv+4XVZbK$jtrf0ELVlJVDU-A`4ozhV^VD}HN$@}{5f5I;!) z*bYY%&d)6CF`%glKT@aF!eWKD0WR&7L5NG=hV_u#SP4N2en+sOp-c;qRoqnU zE?A{+qW_u9HfXsYq#C)YGlw@47-;&r!1x90u7#j<_|2V)(|tE}kyrt#7X&kYOs16MUsjWzJnw=s( z&jZ(9Zu4@~@jMB?@o=hotp6=t1yPHwW><}L=2A9WmG3pk`EAsiudnmpMVi80aQw!p zX9`B?<_22Y>Nz+cU}SQuV#2S9%RW!#?S$(=`RYNJVnj4kM8ID@#~WEib|QX^((6_s zJD4q_%uWp#as6ecJY&5-`tm)DT<9ubpT4N-6Mki4;8aef)hQr;qhXUzdsk1?%?4?6ga9~rUcpT3o!xbNMXf=SHt`n~_N>#o|;&n}=mnNw-~&%oZm zU1HDpHFi=xMcB;KnVF|&hG~pSsp?VcBDyUiJ=}c7r$+N4Lz~;6DL~wNPdwOh*k;#M z@6ttNqL*$dy}@tX!`4OQjVg@WpJa<)>CKIG5fxDp6;Tlt(Z7a_^OO}FAF$I@8Saee z-%Y@!D2bGNo)6+tvV4l@ z?4oCY{({iRa0Yg!od5d^V=Ok&DgggwY0h9G(8)WpCk7^*iHEL0~2D- zO(Jn=?xF5)p93zdG49?1Ql@t5IOQg7LM6(zeC+k3%_?FGG&99viN$q&JLZ3U|;( zvjPu4w=tj$wSEe5e&WC9IXE4a=$zscXak&DrxNT`cB6H<;4GYyuOv5PMhEse6jZ0fNl{5{;ouu0w| ze$_mfNsjB9WBw7*w0JsIi9rBmWh~OYR(wo`kx0X<(ZQ*W%zD{oKbLM|FtW9tq>QEu zS+8{FM53w|{?z-|v~aLIJ2zTCJ)!H_LZh!lP207B_NnC?e%)lJsKL3~xOz}D7Mp+p zi^MK>8P$)G&d=D<%|lLC7oK35Ay;t0sLt^4a}xuUskJ3qIdKyXn_Ew#0)xvEWmwOq zov7?r-0f;o5wEaSEcaFn;pV?c^L)WCC|yG62|af5ccQPbov8hy=o51<A;9;53j^WU_HdD>8#fV2AVXO-Px`8^#2#j=tn z|GsAd2i-sUf}gMgi5Bri(t677UdQ-=d=iia~r?f{$|l_lUS%~Y3{X|C=iZMW|HsIO#ocu}3qNcQ^0 zx{IpkWpuaLwb!ylJb6)9(>{4cDhs3aYiI&am^Wa#`0-K+oBYnO;*B9zi&*`%o*Hi+6-jhU#YV{V*-bb=^;-InycHBHl~1kI23wgc0ht2Q;m4CF7aSn5o}fRc=1LNifnduXqJYG=r|>;wg#ss@~l` z4MuBY%eV!GXfeL_d)TZ379M^=@u}%zHLP<>bL4jD4of?H{6}Jb`pE?kddsx+(&*g0 zE6dq1)74z;p2x#bX5PV8KyqMMCK)F*#CO5TyjXM@yln&J2R?qo%qn1^?!N{0bm4f2 zce#H$xx+8itO6GD{^d#!Av*3JavLXOl}QEkh9B#uytzkNLifbMwYW{$lI}c#-W}u? zziuWK&>Mbk`v^BwFF}&K>SIm+nPl8Q*U0~xrUJU%KR0GoF#LqFv&EoqX}1xn`JY4X zA9t3f0!~w93g3RCQsY`)Hb9H0h>EC)?pm^JG_FnZsNkKe+XK!N1rOVrXv>J&IKu5F zwfl1WT)*}o&pb}rX@9~143`1K;0dJaPA zsodvoBpRba8hCEg7B+sC%^c5DcUq)({62SUigzdF`jGZ-uyp=2@Y?{C^deG(-yKQZ zwfNN?zb(d^7BOA3)E0{@JLP1~t$l-;uCnH|JvVGA$0q9$R4U9kmwR(#7Z#oHFE8X| zp8Z50CirvlYO~Mgv6E5hyUnVrwi`dUF{PzY;Pcd*C#=|7jQ4|l*RjU>^u)IX;rTOp zQO~H%?t;*DeY`F4Cn{=Iz@W`H@^5S&fmDYAEp%-g9MG781#IuyYS z)&E6jWhwouT~wT!m5P*&@~^?cuX{Sh)9?$LFEbqdjA?zvue%l7>D6WdbM--bH5UGs;pW_nk(a@$J&jKqfF;Z z>?iEGJzI5|5{F#bq_(9MSo5D&{Wnvvl!^rQ4QlrYT5;ax)xbK7p^ZJd1SE* zM=Aen`+mpv{;mG=3y>P3r!{IJJf3<(;M=4B=N$Y}V!BTSPTLlBmULYSGI@fx}a7x><1`#J!YQL`Tu; zN8HQuPf`?crd;{iEJxv#TX%o*51WeSQSofm=NNE8*=g*f+26MTB8ssIp3!}TYUmKL z#R{)AXmeO2em4|bPfILL>knsDya0F6;=gG>`Gnu&RvanqL5kdoCnfd(;`^~zv)7}_ zm5mCY)+?}3wUzWn5vpB_USv?N>R;zeqfLeMs;LpMG<3d%Zi5{!#>*9q5(-(>$Y6UE z4_i#R<(MG#g0J{Jo`(pdg{%?3t$JnDI&sQf0jQRbb=<~8Q^~L8s6{(3Wk=_?Uv~W|X1Ldw(yIHT zEjWIbC2%h2+VwLB@-<7eJHsy_4=4LNO*Wsmi=Sr3DC}0xaPa6Q(&81hg@<1}F31;` z%?>DUXvbpoq|u0^woNHpbJP5zdYp`Lj&7B3XZVp(ue>lrwzPJe(pUUql-RrSGnrx$ z>vR%{s#OmMzq-H@`Dz*PkRjZ z(rPkNPUY+boyI$UE8$Yd!F!pQ*+$CK8hWt^Gk){56H?uYa;wc`-2D%(!kb)@#1v*BkgmVSVN?Zum81%Kz7c z$D!kwJ4m|d>A+ca2E2o4>q>Oq7fwhQHS&=v`ROYCvp)aW1cy59s7%K^HR6|z^QPs% zI~+gci$Mm<@;!_dx3BnFBGoD6b57II6Po6dY1^9}+YQ>}%G;^q=fsg$+dTe_6E@A- zsuG!tq<+-np)PM6SKqN|WY~VUL;LaZxU8X5mP|@cPI$Cshas!;E>aUIBYqoz^}WMM zAK9dawnkAX_z6G4*@==G3M7td*L~V1O3XWcLVYqlU)3|bCz-#cd$MW?Q9}anXxvA~ z?=@u+Ae}nTRAH?;tY(ap{}8{0@_qKe(K0%DPZW{bnE5sI!PvZY9?KUM^-20otwqR`fdK{ zAF1{myT<889$nTZ+OEv2DV|Tz#hQJ3JS{(I&pV9CjuF37AGt2yR2Out_q%@%{=JDu z-E?qb6V>7ZUysPnPgwgxh%V*}enQ~OTaxmVw!Hk0bW~UQ6z0n&;H4i}O-0-(0Qtk) zdFTIOT8*ok1$_nGi*rG!s_~%VzlpM_BbWE%w%lU2H=XnP7f~#n>d@5S9IOSh+my#^ z%Q`LO&A*T)lgaU_7u!A1VelF|94~rLK+)7gqaF)?c01^BF#k64|3Y77IQ58kd5QuO zm2_vgC-eU;)NIbdWLy5tBMUlvGlwuzgPD<8`ANIR6!R`xL`C#B)8Al!aQ9NpImc4- z*_n<3ziDvg#%>v2)b@GwKnCDUxWKH88Sonw6`e-Lo+JN|K=ZYyP5Zm6K+zUpmSOvn z*M1xA=0($?DF7_pQ<$}+=^0L-G5Zx%a5}i6@vVPCb=#Yu`|+n+P1XD>W~Aewgq%y^J_O5Q0bUOm*a3RRcpk z{Kbh3d(y$F^>iVIjI~uSn}TvL)IiF41)S1qX0cG&)@P-(6MCN&JC$|Mk%}EpE85pW zd7I+K&0MSptlZ$&!_1_J;1@bhayqjGGX)*LF#{E=yItDWM1v@qlQ&``i(l)j@(;4{ zOyDxL!p@&+szN^DC+x!&Z8`WOq!yR_Lr8WO3l$6cOZZm&8mmUHtevff0^5g+TG}zl zD{Cs4F|8;7ex5xgiueUZqmR%5DP<`9;$gJw3%dy%9MrnYlz2E9{GQ3j@z>fv*lU4o zkV87U^J6-;>^yp_3;&gcShGHp`&jMrd1!9of^JINf)6P>P{(DT!J#hMRM4^&R&kfY zZ!Z#*J%XOl1G67jJ1w=uu9--0bbI*O-63bu@Duvsy#IbN%9@U!*{*LtFkMPhY$5aP z3-fHR(%W|NxT&~}X!t(W_>J2dJp6)^mAF)b@T}}2G%!!nA3_`vcu9~pjax`RSGkE( zL4xoaTx#%Vzf0JxTsAj4!3x?BJL9h}h_|B8#m~W7wA5akc72AVQP(*uHEan#;Q&_S zauAK5$*!*(HIg&W4kv5$mdy(Q>>GZ)Q#4AGmczs-Sf9&%wTxeD$ZidE{M4*8 zyS}BmOF096)7TQ1!}U7HlG-QyYB@1px~St)%sW4ZD^!f`8kQuwfBQ*a%>NSh@J%?Y z&#p?ODqsesX#S~XdP_GEAY>`oY)U53Mlcc=zfX`lp=n9AeI3G2kGy;xId8R@?vhWC z%lK8y0U8NB{DerwWo9(u_|JaHux%Ie8;~^jG2^t+C;WuXiWSmKq{tKZen1zJIK#5d zRZbADDt^ldm54$*@>@_jzK@I$IY~Zc4lD>goY+WpB&2_aZ7*k3<>D7SJ3)U%I^z&Q zvhp@*aFnf?Wc)ryq7-?RgP&^Gm!sS8dD?yzM5~QRaj$#pFvdrNZ}|0)yRLvl)f$+n zR1*ci^gY6dcuL%Vdrr%sJ!o}@P!*eOD7~4rNNt*l6SSeTh~Jfk%xL2FQ3&<>*{0za zzF#ugvJAhNvrjRYw4^ys^E z9`7J8RJ1M9{;ai#5r(3dP?fy@Z|m-NMSW$1!QQE<)Yiu{O9MgskZUxj6TOY=9RcsD zbh>1;5;TyA-=oMVCxj%UKL@{>z*6aA8=TtP+=-K~qtLYAN{2Zzs~i=sdTI?i5epU5fQBTpR#OLJkzo5cg5? zb3C5~CHR#)LAn@Gmx?K7#_tm~*!f$!m$Z95;-~lB!L+%+Y_P|PFCkM=kv?Ve_z_2i zak!~DU?bh!BWFeP^Y790q&69~Y)5=)p+*`G(L?=X$gkQ2PO7=SNxK%L$anoYg_m3> zg5MjsDh`%`m++>JU+zHZ;wcBt%dMe84)WwavvqH3aQ>4NQe!zt>}BQV^iO))0OJ6Z2ur~j|p&_jwU+hG_de)81)bl@oBD~K^XR0Us$mGPjc2u|U@P5mj$PCl? zey2#51KYHIdi-`t&Fw_Ah*_Q54n2j18<{-L5nbL@Prr}ANoh6UejL|A$+xg6*56;l zMGqD)At9cMEzu+s(cxU*RE3G3HeWah9lzb*$g}PYy?SpeMq~$Td2GKu)R?H-pN6~qVA^Ui-rYohmkbvZBP_vb9$#*el$4_~w8#;kjZ)uGEwK77lO zd3yOP%Ah=8cbfI`*G)yol5`$rh%Qe%ksCl1AxT9ie5H&}C z_`e1}noRl&kN&wkr9I*D7TZZ8U)S%8Q5sVL=&r?YD;w^|EoE>5Q1G6rvF8(6?&xmC zuPtcLy&mXy-~*!{P0mx*-=Qt26^QOq{Hnw<4yQd(IGJ8L;#mKGrv39vMA67>Cfo+y zh4|HZq~n_@$mNF)GE?nOB5&Ef*>7Z>TU1ZD9l8tg`&b Date: Fri, 4 Mar 2016 15:03:27 +0100 Subject: [PATCH 1035/1059] Fix minor anchor bug in text visual --- phy/plot/glsl/text.vert | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/phy/plot/glsl/text.vert b/phy/plot/glsl/text.vert index f564df0f2..bce79fd8b 100644 --- a/phy/plot/glsl/text.vert +++ b/phy/plot/glsl/text.vert @@ -26,7 +26,7 @@ void main() { // Position of the glyph. gl_Position = transform(a_position); gl_Position.xy = gl_Position.xy + vec2(a_glyph_index * w + dx * w, dy * h); - gl_Position.xy += (a_anchor - .5) * vec2(a_lengths * w, h); + gl_Position.xy += (a_anchor - 1.) * .5 * vec2(a_lengths * w, h); // Index in the texture float i = floor(a_char_index / cols); From 821340c5e4ed5f50f8dc978c92b9fa4b1ebf050f Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 15:04:02 +0100 Subject: [PATCH 1036/1059] Add channel labels in trace view --- phy/cluster/manual/views.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 7389fe493..ef7223ae2 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -376,7 +376,9 @@ def on_select(self, cluster_ids=None): ) # Add channel labels. self[ch].text(pos=[[t[0, 0], 0.]], text=str(ch), - anchor=[-1., -.25]) + anchor=[-1.5, -.25], + data_bounds=self.data_bounds, + ) # Zoom on the best channels when selecting clusters. channels = self.best_channels(cluster_ids) @@ -677,14 +679,17 @@ def __init__(self, def _plot_traces(self, traces=None, color=None): assert traces.shape[1] == self.n_channels t = self.interval[0] + np.arange(traces.shape[0]) * self.dt - t = np.tile(t, (self.n_channels, 1)) color = color or self.default_trace_color channels = np.arange(self.n_channels) for ch in channels: - self[ch].plot(t[ch, :], traces[:, ch], + self[ch].plot(t, traces[:, ch], color=color, data_bounds=self.data_bounds, ) + # Add channel labels. + self[ch].text(pos=[-1, 0.], text=str(ch), + anchor=[+1., -.1], + ) def _plot_spike(self, waveforms=None, channels=None, masks=None, spike_time=None, spike_cluster=None, offset_samples=0, From 5aa75a4c4dd73eb5626537dec767a8049ac12034 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 16:46:14 +0100 Subject: [PATCH 1037/1059] Fix labels in trace view --- phy/cluster/manual/views.py | 3 ++- phy/plot/visuals.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index ef7223ae2..bcd0281cd 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -687,8 +687,9 @@ def _plot_traces(self, traces=None, color=None): data_bounds=self.data_bounds, ) # Add channel labels. - self[ch].text(pos=[-1, 0.], text=str(ch), + self[ch].text(pos=[t[0], traces[0, ch]], text=str(ch), anchor=[+1., -.1], + data_bounds=self.data_bounds, ) def _plot_spike(self, waveforms=None, channels=None, masks=None, diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index 1d9859b14..a1de2c2ca 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -444,7 +444,7 @@ def set_data(self, *args, **kwargs): pos_tr = self.transforms.apply(a_position) assert pos_tr.shape == (n_vertices, 2) - self.program['a_position'] = a_position.astype(np.float32) + self.program['a_position'] = pos_tr.astype(np.float32) self.program['a_glyph_index'] = a_glyph_index.astype(np.float32) self.program['a_quad_index'] = a_quad_index.astype(np.float32) self.program['a_char_index'] = a_char_index.astype(np.float32) From 6383cf2e77c55d1044b0aecb28dd28845c5e095e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 17:07:48 +0100 Subject: [PATCH 1038/1059] WIP: add dimension labels in feature view --- phy/cluster/manual/views.py | 15 ++++++++++++--- phy/plot/glsl/text.vert | 3 +++ phy/plot/visuals.py | 2 +- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index bcd0281cd..1506d25ff 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -376,7 +376,7 @@ def on_select(self, cluster_ids=None): ) # Add channel labels. self[ch].text(pos=[[t[0, 0], 0.]], text=str(ch), - anchor=[-1.5, -.25], + anchor=[-1.01, -.25], data_bounds=self.data_bounds, ) @@ -1076,13 +1076,22 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, # The marker size is smaller for background spikes. ms = (self._default_marker_size if spike_clusters_rel is not None else 1.) - self[i, j].scatter(x=x, - y=y, + self[i, j].scatter(x=x, y=y, color=color, depth=d, data_bounds=data_bounds, size=ms * np.ones(n_spikes), ) + if i == (self.n_cols - 1): + self[i, j].text(pos=[0., -1.], + text=str(x_dim[i, j]), + anchor=[0., -1.04], + ) + if j == 0: + self[i, j].text(pos=[-1., 0.], + text=str(y_dim[i, j]), + anchor=[-1.03, 0.], + ) def _get_channel_dims(self, cluster_ids): """Select the channels to show by default.""" diff --git a/phy/plot/glsl/text.vert b/phy/plot/glsl/text.vert index bce79fd8b..f5abc380d 100644 --- a/phy/plot/glsl/text.vert +++ b/phy/plot/glsl/text.vert @@ -26,7 +26,10 @@ void main() { // Position of the glyph. gl_Position = transform(a_position); gl_Position.xy = gl_Position.xy + vec2(a_glyph_index * w + dx * w, dy * h); + // Anchor: the part in [-1, 1] is relative to the text size. gl_Position.xy += (a_anchor - 1.) * .5 * vec2(a_lengths * w, h); + // NOTE: The part beyond [-1, 1] is absolute, so that texts stay aligned. + gl_Position.xy += (a_anchor - clamp(a_anchor, -1., 1.)); // Index in the texture float i = floor(a_char_index / cols); diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py index a1de2c2ca..57dd0f90b 100644 --- a/phy/plot/visuals.py +++ b/phy/plot/visuals.py @@ -390,7 +390,7 @@ def vertex_count(pos=None, **kwargs): def set_data(self, *args, **kwargs): data = self.validate(*args, **kwargs) - pos = data.pos + pos = data.pos.astype(np.float64) assert pos.ndim == 2 assert pos.shape[1] == 2 assert pos.dtype == np.float64 From 1aa50422ab3717af9c37f70cc33bfcb2b29f4d3e Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 17:14:15 +0100 Subject: [PATCH 1039/1059] Fix dimension labels in feature view --- phy/cluster/manual/views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 1506d25ff..fad6f5085 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1083,8 +1083,9 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, size=ms * np.ones(n_spikes), ) if i == (self.n_cols - 1): + dim = x_dim[i, j] if j < (self.n_cols - 1) else x_dim[i, 0] self[i, j].text(pos=[0., -1.], - text=str(x_dim[i, j]), + text=str(dim), anchor=[0., -1.04], ) if j == 0: From b8895eef5d28834727e2a9f2e7a8b59ef5489313 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 17:17:23 +0100 Subject: [PATCH 1040/1059] Add cluster labels in correlogram view --- phy/cluster/manual/views.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index fad6f5085..d3d85924f 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -1396,6 +1396,12 @@ def on_select(self, cluster_ids=None): color=color, ylim=ylim, ) + # Cluster labels. + if i == (n_clusters - 1): + self[i, j].text(pos=[0., -1.], + text=str(cluster_ids[j]), + anchor=[0., -1.04], + ) def toggle_normalization(self): """Change the normalization of the correlograms.""" From 394f881aebdd0045b0cbe897ac5c4731e90d147c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 4 Mar 2016 17:46:23 +0100 Subject: [PATCH 1041/1059] Minor performance improvements --- phy/cluster/manual/controller.py | 2 +- phy/cluster/manual/tests/test_views.py | 6 +++--- phy/cluster/manual/views.py | 24 ++++++++++++++---------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 26a0f0eae..1d87d11c5 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -179,7 +179,7 @@ def get_mean_masks(self, cluster_id): def get_waveforms(self, cluster_id): return [self._select_data(cluster_id, self.all_waveforms, - 100, # TODO + 50, # TODO )] def get_mean_waveforms(self, cluster_id): diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index 92a87478f..b201a8c55 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -173,13 +173,13 @@ def test_trace_view(qtbot, gui): assert v.time == .25 v.go_to(-.5) - assert v.time == .25 + assert v.time == .125 v.go_left() - assert v.time == .25 + assert v.time == .125 v.go_right() - assert v.time == .35 + assert v.time == .175 # Change interval size. v.set_interval((.25, .75)) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index d3d85924f..f948f835b 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -375,7 +375,8 @@ def on_select(self, cluster_ids=None): data_bounds=self.data_bounds, ) # Add channel labels. - self[ch].text(pos=[[t[0, 0], 0.]], text=str(ch), + self[ch].text(pos=[[t[0, 0], 0.]], + text=str(ch), anchor=[-1.01, -.25], data_bounds=self.data_bounds, ) @@ -613,7 +614,7 @@ def extract_spikes(traces, interval, sample_rate=None, class TraceView(ManualClusteringView): - interval_duration = .5 # default duration of the interval + interval_duration = .25 # default duration of the interval shift_amount = .1 scaling_coeff = 1.1 default_trace_color = (.3, .3, .3, 1.) @@ -676,7 +677,7 @@ def __init__(self, # Internal methods # ------------------------------------------------------------------------- - def _plot_traces(self, traces=None, color=None): + def _plot_traces(self, traces=None, color=None, show_labels=True): assert traces.shape[1] == self.n_channels t = self.interval[0] + np.arange(traces.shape[0]) * self.dt color = color or self.default_trace_color @@ -686,11 +687,13 @@ def _plot_traces(self, traces=None, color=None): color=color, data_bounds=self.data_bounds, ) - # Add channel labels. - self[ch].text(pos=[t[0], traces[0, ch]], text=str(ch), - anchor=[+1., -.1], - data_bounds=self.data_bounds, - ) + if show_labels: + # Add channel labels. + self[ch].text(pos=[t[0], traces[0, ch]], + text=str(ch), + anchor=[+1., -.1], + data_bounds=self.data_bounds, + ) def _plot_spike(self, waveforms=None, channels=None, masks=None, spike_time=None, spike_cluster=None, offset_samples=0, @@ -765,8 +768,9 @@ def set_interval(self, interval, change_status=True): self.data_bounds = np.array([start, m, end, M]) # Plot the traces. - for traces in all_traces: - self._plot_traces(**traces) + for i, traces in enumerate(all_traces): + # Only show labels for the first set of traces. + self._plot_traces(show_labels=(i == 0), **traces) # Plot the spikes. spikes = self.spikes(interval, all_traces) From ad9b8c27ccaf6d8edb21e5d43d32303ca762e5ac Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 7 Mar 2016 11:23:46 +0100 Subject: [PATCH 1042/1059] WIP: color all spikes in trace view --- phy/cluster/manual/views.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index f948f835b..72ee15203 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -55,7 +55,10 @@ def _selected_clusters_colors(n_clusters=None): colors = np.tile(_COLORMAP, (1 + n_clusters // _COLORMAP.shape[0], 1)) else: colors = _COLORMAP - return colors[:n_clusters, ...] / 255. + out = colors[:n_clusters, ...] / 255. + if n_clusters is not None: + assert out.shape == (n_clusters, 3) + return out def _extract_wave(traces, start, mask, wave_len=None, mask_threshold=.5): @@ -98,6 +101,11 @@ def _get_color(masks, spike_clusters_rel=None, n_clusters=None, alpha=.5): alpha = alpha if spike_clusters_rel is not None else .25 assert masks.shape == (n_spikes,) # Generate the colors. + if n_clusters is None: + if spike_clusters_rel is None: + n_clusters = _COLORMAP.shape[0] + else: + n_clusters = spike_clusters_rel.max() + 1 colors = _selected_clusters_colors(n_clusters) # Color as a function of the mask. if spike_clusters_rel is not None: @@ -708,17 +716,18 @@ def _plot_spike(self, waveforms=None, channels=None, masks=None, # Determine the color as a function of the spike's cluster. if color is None: + # TODO: improve all of this. clu = spike_cluster if self.cluster_ids is None or clu not in self.cluster_ids: - sc = None - n_clusters = None + k = len(self.cluster_ids) if self.cluster_ids else 0 + clu_rel = k + (clu % max(_COLORMAP.shape[0] - k, 0)) + sc = clu_rel * np.ones(n_channels, dtype=np.int32) else: clu_rel = self.cluster_ids.index(clu) sc = clu_rel * np.ones(n_channels, dtype=np.int32) - n_clusters = len(self.cluster_ids) color = _get_color(masks, spike_clusters_rel=sc, - n_clusters=n_clusters) + ) # Generate the x coordinates of the waveform. t = t0 + self.dt * np.arange(n_samples) From 42d7fce6b41c096e7c7da22b6df516c300a9cb41 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 14:23:18 +0100 Subject: [PATCH 1043/1059] More color functions --- phy/utils/_color.py | 78 +++++++++++++++++++++++++++++------ phy/utils/tests/test_color.py | 23 ++++++++--- phy/utils/tests/test_types.py | 7 ++++ 3 files changed, 90 insertions(+), 18 deletions(-) diff --git a/phy/utils/_color.py b/phy/utils/_color.py index da2dd1e39..33b4d36fe 100644 --- a/phy/utils/_color.py +++ b/phy/utils/_color.py @@ -6,20 +6,21 @@ # Imports #------------------------------------------------------------------------------ -import numpy as np +from colorsys import hsv_to_rgb +import numpy as np from numpy.random import uniform -from colorsys import hsv_to_rgb +from matplotlib.colors import hsv_to_rgb, rgb_to_hsv #------------------------------------------------------------------------------ -# Colors +# Random colors #------------------------------------------------------------------------------ def _random_color(): """Generate a random RGB color.""" h, s, v = uniform(0., 1.), uniform(.5, 1.), uniform(.5, 1.) - r, g, b = hsv_to_rgb(h, s, v) + r, g, b = hsv_to_rgb((h, s, v)) return r, g, b @@ -39,7 +40,7 @@ def _random_bright_color(): #------------------------------------------------------------------------------ -# Default colormap +# Colormap #------------------------------------------------------------------------------ # Default color map for the selected clusters. @@ -49,14 +50,65 @@ def _random_bright_color(): [228, 31, 228], [2, 217, 2], [255, 147, 2], + + [212, 150, 70], + [205, 131, 201], + [201, 172, 36], + [150, 179, 62], + [95, 188, 122], + [129, 173, 190], + [231, 107, 119], ]) -def _selected_clusters_colors(n_clusters=None): - if n_clusters is None: - n_clusters = _COLORMAP.shape[0] - if n_clusters > _COLORMAP.shape[0]: - colors = np.tile(_COLORMAP, (1 + n_clusters // _COLORMAP.shape[0], 1)) - else: - colors = _COLORMAP - return colors[:n_clusters, ...] / 255. +def _apply_color_masks(color, masks=None, alpha=None): + alpha = alpha or .5 + color = np.atleast_2d(color) + hsv = rgb_to_hsv(color[:, :3]) + # Change the saturation and value as a function of the mask. + if masks is not None: + hsv[:, 1] *= masks + hsv[:, 2] *= .5 * (1. + masks) + color = hsv_to_rgb(hsv) + n = color.shape[0] + color = np.c_[color, alpha * np.ones((n, 1))] + return color + + +def _colormap(i): + n = len(_COLORMAP) + return _COLORMAP[i % n] / 255. + + +def _spike_colors(spike_clusters, masks=None, alpha=None): + n = len(_COLORMAP) + c = _COLORMAP[np.mod(spike_clusters, n), :] / 255. + c = _apply_color_masks(c, masks=masks, alpha=alpha) + return c + + +class ColorSelector(object): + """Return the color of a cluster. + + If the cluster belongs to the selection, returns the colormap color. + + Otherwise, return a random color and remember this color. + + """ + def __init__(self): + self._colors = {} + + def get(self, clu, cluster_ids=None, alpha=None): + alpha = alpha or .5 + if cluster_ids and clu in cluster_ids: + i = cluster_ids.index(clu) + color = _colormap(i) + color = tuple(color) + (alpha,) + else: + if clu in self._colors: + return self._colors[clu] + color = _random_color() + color = tuple(color) + (alpha,) + self._colors[clu] = color + assert len(color) == 4 + return color diff --git a/phy/utils/tests/test_color.py b/phy/utils/tests/test_color.py index 781616d72..2974e8b09 100644 --- a/phy/utils/tests/test_color.py +++ b/phy/utils/tests/test_color.py @@ -6,8 +6,10 @@ # Imports #------------------------------------------------------------------------------ +import numpy as np + from .._color import (_random_color, _is_bright, _random_bright_color, - _selected_clusters_colors, + _colormap, _spike_colors, ColorSelector, ) from ..testing import show_colored_canvas @@ -24,7 +26,18 @@ def test_random_color(): assert _is_bright(_random_bright_color()) -def test_selected_clusters_colors(): - assert _selected_clusters_colors().ndim == 2 - assert len(_selected_clusters_colors(3)) == 3 - assert len(_selected_clusters_colors(10)) == 10 +def test_colormap(): + assert len(_colormap(0)) == 3 + assert len(_colormap(1000)) == 3 + + assert _spike_colors([0, 1, 10, 1000]).shape == (4, 4) + assert _spike_colors([0, 1, 10, 1000], + alpha=1.).shape == (4, 4) + assert _spike_colors([0, 1, 10, 1000], + masks=np.linspace(0., 1., 4)).shape == (4, 4) + + +def test_color_selector(): + sel = ColorSelector() + assert len(sel.get(0)) == 4 + assert len(sel.get(0, [1, 0])) == 4 diff --git a/phy/utils/tests/test_types.py b/phy/utils/tests/test_types.py index 94ee65c05..a7d0006a9 100644 --- a/phy/utils/tests/test_types.py +++ b/phy/utils/tests/test_types.py @@ -11,6 +11,7 @@ from .._types import (Bunch, _bunchify, _is_integer, _is_list, _is_float, _as_list, _is_array_like, _as_array, _as_tuple, + _as_scalar, ) @@ -70,6 +71,12 @@ def test_as_tuple(): assert _as_tuple([3, 4]) == ([3, 4], ) +def test_as_scalar(): + assert _as_scalar(1) == 1 + assert _as_scalar(np.ones(1)[0]) == 1. + assert type(_as_scalar(np.ones(1)[0])) == float + + def test_array(): def _check(arr): assert isinstance(arr, np.ndarray) From d5058af1823831217b374b7bb522c8a488516a25 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 14:30:39 +0100 Subject: [PATCH 1044/1059] Update cluster colors in views --- phy/cluster/manual/tests/test_views.py | 8 -- phy/cluster/manual/views.py | 104 ++++--------------------- phy/utils/_color.py | 10 +-- 3 files changed, 20 insertions(+), 102 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index b201a8c55..cbdccd659 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -16,7 +16,6 @@ from .conftest import MockController from ..views import (ScatterView, _extract_wave, - _selected_clusters_colors, _extend, ) @@ -83,13 +82,6 @@ def test_extract_wave(): [[16, 17], [21, 22], [0, 0], [0, 0]]) -def test_selected_clusters_colors(): - assert _selected_clusters_colors().shape[0] > 10 - assert _selected_clusters_colors(0).shape[0] == 0 - assert _selected_clusters_colors(1).shape[0] == 1 - assert _selected_clusters_colors(100).shape[0] == 100 - - #------------------------------------------------------------------------------ # Test waveform view #------------------------------------------------------------------------------ diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index 72ee15203..ad349f072 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -13,7 +13,6 @@ import re import numpy as np -from matplotlib.colors import hsv_to_rgb, rgb_to_hsv from vispy.util.event import Event from phy.io.array import _index_of, _get_padded, get_excerpts @@ -23,6 +22,7 @@ from phy.plot.utils import _get_boxes from phy.stats import correlograms from phy.utils import Bunch +from phy.utils._color import _spike_colors, ColorSelector logger = logging.getLogger(__name__) @@ -31,36 +31,6 @@ # Utils # ----------------------------------------------------------------------------- -# Default color map for the selected clusters. -_COLORMAP = np.array([[8, 146, 252], - [255, 2, 2], - [240, 253, 2], - [228, 31, 228], - [2, 217, 2], - [255, 147, 2], - [212, 150, 70], - [205, 131, 201], - [201, 172, 36], - [150, 179, 62], - [95, 188, 122], - [129, 173, 190], - [231, 107, 119], - ]) - - -def _selected_clusters_colors(n_clusters=None): - if n_clusters is None: - n_clusters = _COLORMAP.shape[0] - if n_clusters > _COLORMAP.shape[0]: - colors = np.tile(_COLORMAP, (1 + n_clusters // _COLORMAP.shape[0], 1)) - else: - colors = _COLORMAP - out = colors[:n_clusters, ...] / 255. - if n_clusters is not None: - assert out.shape == (n_clusters, 3) - return out - - def _extract_wave(traces, start, mask, wave_len=None, mask_threshold=.5): n_samples, n_channels = traces.shape assert mask.shape == (n_channels,) @@ -92,37 +62,6 @@ def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): return depth -def _get_color(masks, spike_clusters_rel=None, n_clusters=None, alpha=.5): - """Return the color of vertices as a function of the mask and - cluster index.""" - n_spikes = masks.shape[0] - # The transparency depends on whether the spike clusters are specified. - # For background spikes, we use a smaller alpha. - alpha = alpha if spike_clusters_rel is not None else .25 - assert masks.shape == (n_spikes,) - # Generate the colors. - if n_clusters is None: - if spike_clusters_rel is None: - n_clusters = _COLORMAP.shape[0] - else: - n_clusters = spike_clusters_rel.max() + 1 - colors = _selected_clusters_colors(n_clusters) - # Color as a function of the mask. - if spike_clusters_rel is not None: - assert spike_clusters_rel.shape == (n_spikes,) - color = colors[spike_clusters_rel] - else: - # Fixed color when the spike clusters are not specified. - color = np.ones((n_spikes, 3)) - hsv = rgb_to_hsv(color[:, :3]) - # Change the saturation and value as a function of the mask. - hsv[:, 1] *= masks - hsv[:, 2] *= .5 * (1. + masks) - color = hsv_to_rgb(hsv) - color = np.c_[color, alpha * np.ones((n_spikes, 1))] - return color - - def _extend(channels, n=None): channels = list(channels) if n is None: @@ -372,11 +311,10 @@ def on_select(self, cluster_ids=None): depth = _get_depth(m, spike_clusters_rel=spike_clusters_rel, n_clusters=n_clusters) - color = _get_color(m, - spike_clusters_rel=spike_clusters_rel, - n_clusters=n_clusters, - alpha=alpha, - ) + color = _spike_colors(spike_clusters_rel, + masks=m, + alpha=alpha, + ) self[ch].plot(x=t, y=w[:, :, ch], color=color, depth=depth, @@ -667,6 +605,8 @@ def __init__(self, self._scaling = 1. self._origin = None + self._color_selector = ColorSelector() + # Initialize the view. super(TraceView, self).__init__(layout='stacked', origin=self.origin, @@ -714,21 +654,6 @@ def _plot_spike(self, waveforms=None, channels=None, masks=None, t0 = spike_time - offset_samples / sr - # Determine the color as a function of the spike's cluster. - if color is None: - # TODO: improve all of this. - clu = spike_cluster - if self.cluster_ids is None or clu not in self.cluster_ids: - k = len(self.cluster_ids) if self.cluster_ids else 0 - clu_rel = k + (clu % max(_COLORMAP.shape[0] - k, 0)) - sc = clu_rel * np.ones(n_channels, dtype=np.int32) - else: - clu_rel = self.cluster_ids.index(clu) - sc = clu_rel * np.ones(n_channels, dtype=np.int32) - color = _get_color(masks, - spike_clusters_rel=sc, - ) - # Generate the x coordinates of the waveform. t = t0 + self.dt * np.arange(n_samples) t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) @@ -784,8 +709,11 @@ def set_interval(self, interval, change_status=True): # Plot the spikes. spikes = self.spikes(interval, all_traces) assert isinstance(spikes, (tuple, list)) + for spike in spikes: - self._plot_spike(**spike) + color = self._color_selector.get(spike.spike_cluster, + self.cluster_ids) + self._plot_spike(color=color, **spike) self.build() self.update() @@ -1082,7 +1010,7 @@ def _plot_features(self, i, j, x_dim, y_dim, x, y, m = np.maximum(mx, my) # Get the color of the markers. - color = _get_color(m, spike_clusters_rel=sc, n_clusters=n_clusters) + color = _spike_colors(sc, masks=m) assert color.shape == (n_spikes, 4) # Create the scatter plot for the current subplot. @@ -1263,7 +1191,6 @@ def on_request_split(self): f = data.data i, j = self.lasso.box - # TODO: refactor and load all features. x = self._get_feature(x_dim[i, j], spike_ids, f) y = self._get_feature(y_dim[i, j], spike_ids, f) pos = np.c_[x, y].astype(np.float64) @@ -1396,15 +1323,14 @@ def on_select(self, cluster_ids=None): ccg = self._compute_correlograms(cluster_ids) ylim = [ccg.max()] if not self.uniform_normalization else None - colors = _selected_clusters_colors(n_clusters) + colors = _spike_colors(np.arange(n_clusters), alpha=1.) self.grid.shape = (n_clusters, n_clusters) with self.building(): for i in range(n_clusters): for j in range(n_clusters): hist = ccg[i, j, :] - color = colors[i] if i == j else np.ones(3) - color = np.hstack((color, [1])) + color = colors[i] if i == j else np.ones(4) self[i, j].hist(hist, color=color, ylim=ylim, @@ -1498,7 +1424,7 @@ def on_select(self, cluster_ids=None): with self.building(): m = np.ones(n_spikes) # Get the color of the markers. - color = _get_color(m, spike_clusters_rel=sc, n_clusters=n_clusters) + color = _spike_colors(sc, masks=m) assert color.shape == (n_spikes, 4) ms = (self._default_marker_size if sc is not None else 1.) diff --git a/phy/utils/_color.py b/phy/utils/_color.py index 33b4d36fe..974779b57 100644 --- a/phy/utils/_color.py +++ b/phy/utils/_color.py @@ -6,11 +6,9 @@ # Imports #------------------------------------------------------------------------------ -from colorsys import hsv_to_rgb - import numpy as np from numpy.random import uniform -from matplotlib.colors import hsv_to_rgb, rgb_to_hsv +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb #------------------------------------------------------------------------------ @@ -63,7 +61,6 @@ def _random_bright_color(): def _apply_color_masks(color, masks=None, alpha=None): alpha = alpha or .5 - color = np.atleast_2d(color) hsv = rgb_to_hsv(color[:, :3]) # Change the saturation and value as a function of the mask. if masks is not None: @@ -82,7 +79,10 @@ def _colormap(i): def _spike_colors(spike_clusters, masks=None, alpha=None): n = len(_COLORMAP) - c = _COLORMAP[np.mod(spike_clusters, n), :] / 255. + if spike_clusters is not None: + c = _COLORMAP[np.mod(spike_clusters, n), :] / 255. + else: + c = np.ones((masks.shape[0], 3)) c = _apply_color_masks(c, masks=masks, alpha=alpha) return c From 3ac73cf14fb7ea46657046bd24e036dd7447d70c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 14:43:53 +0100 Subject: [PATCH 1045/1059] Save interval in trace view's state --- phy/cluster/manual/tests/test_views.py | 2 +- phy/cluster/manual/views.py | 27 +++++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py index cbdccd659..10de92878 100644 --- a/phy/cluster/manual/tests/test_views.py +++ b/phy/cluster/manual/tests/test_views.py @@ -174,7 +174,7 @@ def test_trace_view(qtbot, gui): assert v.time == .175 # Change interval size. - v.set_interval((.25, .75)) + v.interval = (.25, .75) ac(v.interval, (.25, .75)) v.widen() ac(v.interval, (.225, .775)) diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py index ad349f072..dfa59efde 100644 --- a/phy/cluster/manual/views.py +++ b/phy/cluster/manual/views.py @@ -619,7 +619,7 @@ def __init__(self, self._update_boxes() # Initial interval. - self.interval = None + self._interval = None self.go_to(duration / 2.) # Internal methods @@ -627,7 +627,7 @@ def __init__(self, def _plot_traces(self, traces=None, color=None, show_labels=True): assert traces.shape[1] == self.n_channels - t = self.interval[0] + np.arange(traces.shape[0]) * self.dt + t = self._interval[0] + np.arange(traces.shape[0]) * self.dt color = color or self.default_trace_color channels = np.arange(self.n_channels) for ch in channels: @@ -688,7 +688,7 @@ def set_interval(self, interval, change_status=True): """Display the traces and spikes in a given interval.""" self.clear() interval = self._restrict_interval(interval) - self.interval = interval + self._interval = interval start, end = interval # Set the status message. if change_status: @@ -720,7 +720,7 @@ def set_interval(self, interval, change_status=True): def on_select(self, cluster_ids=None): super(TraceView, self).on_select(cluster_ids) - self.set_interval(self.interval, change_status=False) + self.set_interval(self._interval, change_status=False) def attach(self, gui): """Attach the view to the GUI.""" @@ -738,6 +738,7 @@ def attach(self, gui): def state(self): return Bunch(scaling=self.scaling, origin=self.origin, + interval=self._interval, ) # Scaling @@ -770,13 +771,21 @@ def origin(self, value): @property def time(self): """Time at the center of the window.""" - return sum(self.interval) * .5 + return sum(self._interval) * .5 + + @property + def interval(self): + return self._interval + + @interval.setter + def interval(self, value): + self.set_interval(value) @property def half_duration(self): """Half of the duration of the current interval.""" - if self.interval is not None: - a, b = self.interval + if self._interval is not None: + a, b = self._interval return (b - a) * .5 else: return self.interval_duration * .5 @@ -792,13 +801,13 @@ def shift(self, delay): def go_right(self): """Go to right.""" - start, end = self.interval + start, end = self._interval delay = (end - start) * .2 self.shift(delay) def go_left(self): """Go to left.""" - start, end = self.interval + start, end = self._interval delay = (end - start) * .2 self.shift(-delay) From 0184dd44e30fabf5942ec27584cc09af08154d56 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 15:19:26 +0100 Subject: [PATCH 1046/1059] Save current sort in cluster view --- phy/cluster/manual/gui_component.py | 21 +++++++++++++++++++ .../manual/tests/test_gui_component.py | 10 +++++++++ phy/gui/qt.py | 2 +- phy/gui/widgets.py | 21 +++++++++++++++---- 4 files changed, 49 insertions(+), 5 deletions(-) diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py index 221f33d87..f9fdee400 100644 --- a/phy/cluster/manual/gui_component.py +++ b/phy/cluster/manual/gui_component.py @@ -55,6 +55,15 @@ def __init__(self): } ''') + @property + def state(self): + return {'sort_by': self.current_sort} + + def set_state(self, state): + sort_by = state.get('sort_by', None) + if sort_by: + self.sort_by(*sort_by) + class ManualClustering(object): """Component that brings manual clustering facilities to a GUI: @@ -431,6 +440,18 @@ def attach(self, gui): if self.similarity: gui.add_view(self.similarity_view, name='SimilarityView') + # Set the view state. + cv = self.cluster_view + cv.set_state(gui.state.get_view_state(cv)) + + # Save the view state in the GUI state. + @gui.connect_ + def on_close(): + gui.state.update_view_state(cv, cv.state) + # NOTE: create_gui() already saves the state, but the event + # is registered *before* we add all views. + gui.state.save() + # Update the cluster views and selection when a cluster event occurs. self.gui.connect_(self.on_cluster) return self diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index 624c10be1..fea74e020 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -148,6 +148,16 @@ def test_manual_clustering_split_2(gui, quality, similarity): assert mc.selected == [3, 2] +def test_manual_clustering_state(tempdir, qtbot, gui, manual_clustering): + mc = manual_clustering + cv = mc.cluster_view + cv.sort_by('id') + gui.close() + assert cv.state['sort_by'] == ('id', 'asc') + cv.set_state(cv.state) + assert cv.state['sort_by'] == ('id', 'asc') + + def test_manual_clustering_split_lasso(tempdir, qtbot): controller = MockController(config_dir=tempdir) gui = controller.create_gui() diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 1a0a66ca4..3ae8f7ed8 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -19,7 +19,7 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - QVariant, QEventLoop, QTimer, + QVariant, QPyNullVariant, QEventLoop, QTimer, pyqtSignal, pyqtSlot, QSize, QUrl) from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 4c675c580..b48a030d9 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -14,9 +14,9 @@ from six import text_type -from .qt import (QWebView, QWebPage, QUrl, QWebSettings, QVariant, - pyqtSlot, - _wait_signal, +from .qt import (QWebView, QWebPage, QUrl, QWebSettings, + QVariant, QPyNullVariant, + pyqtSlot, _wait_signal, ) from phy.utils import EventEmitter from phy.utils._misc import _CustomEncoder @@ -62,6 +62,19 @@ def javaScriptConsoleMessage(self, msg, line, source): logger.debug("[%d] %s", line, msg) # pragma: no cover +def _to_py(obj): + if isinstance(obj, QVariant): + return obj.toPyObject() + elif isinstance(obj, QPyNullVariant): + return None + elif isinstance(obj, list): + return [_to_py(_) for _ in obj] + elif isinstance(obj, tuple): + return tuple(_to_py(_) for _ in obj) + else: + return obj + + class HTMLWidget(QWebView): """An HTML widget that is displayed with Qt. @@ -174,7 +187,7 @@ def eval_js(self, expr): return logger.log(5, "Evaluate Javascript: `%s`.", expr) out = self.page().mainFrame().evaluateJavaScript(expr) - return out.toPyObject() if isinstance(out, QVariant) else out + return _to_py(out) @pyqtSlot(str, str) def _emit_from_js(self, name, arg_json): From 4357fc616bdcc864abb2a06d2409bcadbb12380d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 15:31:02 +0100 Subject: [PATCH 1047/1059] Fix Python 2 --- phy/gui/qt.py | 6 +++++- phy/gui/widgets.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 3ae8f7ed8..275b7e116 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -19,8 +19,12 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - QVariant, QPyNullVariant, QEventLoop, QTimer, + QVariant, QEventLoop, QTimer, pyqtSignal, pyqtSlot, QSize, QUrl) +try: + from PyQt4.QtCore import QPyNullVariant # noqa +except: # pragma: no cover + QPyNullVariant = None from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, QMessageBox, QApplication, QMenuBar, diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index b48a030d9..f6663abf8 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -62,7 +62,7 @@ def javaScriptConsoleMessage(self, msg, line, source): logger.debug("[%d] %s", line, msg) # pragma: no cover -def _to_py(obj): +def _to_py(obj): # pragma: no cover if isinstance(obj, QVariant): return obj.toPyObject() elif isinstance(obj, QPyNullVariant): From 66204dc3a205ff834b6133208c7442131551b874 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 15:45:27 +0100 Subject: [PATCH 1048/1059] Fix Python 2 --- phy/gui/qt.py | 2 +- phy/gui/widgets.py | 4 +++- phy/utils/_misc.py | 3 +++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 275b7e116..1c43dfc97 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -19,7 +19,7 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - QVariant, QEventLoop, QTimer, + QVariant, QEventLoop, QTimer, QString, pyqtSignal, pyqtSlot, QSize, QUrl) try: from PyQt4.QtCore import QPyNullVariant # noqa diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index f6663abf8..5adbd2f6d 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -15,7 +15,7 @@ from six import text_type from .qt import (QWebView, QWebPage, QUrl, QWebSettings, - QVariant, QPyNullVariant, + QVariant, QPyNullVariant, QString, pyqtSlot, _wait_signal, ) from phy.utils import EventEmitter @@ -65,6 +65,8 @@ def javaScriptConsoleMessage(self, msg, line, source): def _to_py(obj): # pragma: no cover if isinstance(obj, QVariant): return obj.toPyObject() + elif isinstance(obj, QString): + return unicode(obj) elif isinstance(obj, QPyNullVariant): return None elif isinstance(obj, list): diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index 1f11a6d43..a6971e997 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -39,6 +39,7 @@ def _decode_qbytearray(data_b64): class _CustomEncoder(json.JSONEncoder): def default(self, obj): + from phy.gui.qt import QVariant, QString if isinstance(obj, np.ndarray): obj_contiguous = np.ascontiguousarray(obj) data_b64 = base64.b64encode(obj_contiguous.data).decode('utf8') @@ -49,6 +50,8 @@ def default(self, obj): return {'__qbytearray__': _encode_qbytearray(obj)} elif isinstance(obj, np.generic): return np.asscalar(obj) + elif isinstance(obj, QString): + return unicode(obj) return super(_CustomEncoder, self).default(obj) # pragma: no cover From ef1c7b39e97687dc19bb4d7395cb434c640dedd9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 8 Mar 2016 15:53:14 +0100 Subject: [PATCH 1049/1059] WIP: fixes --- phy/gui/qt.py | 6 +++++- phy/gui/widgets.py | 4 ++-- phy/utils/_misc.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index 1c43dfc97..6fd3045b7 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -19,12 +19,16 @@ # ----------------------------------------------------------------------------- from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa - QVariant, QEventLoop, QTimer, QString, + QVariant, QEventLoop, QTimer, pyqtSignal, pyqtSlot, QSize, QUrl) try: from PyQt4.QtCore import QPyNullVariant # noqa except: # pragma: no cover QPyNullVariant = None +try: + from PyQt4.QtCore import QString # noqa +except: # pragma: no cover + QString = None from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa QMainWindow, QDockWidget, QWidget, QMessageBox, QApplication, QMenuBar, diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py index 5adbd2f6d..0379ce1cc 100644 --- a/phy/gui/widgets.py +++ b/phy/gui/widgets.py @@ -65,8 +65,8 @@ def javaScriptConsoleMessage(self, msg, line, source): def _to_py(obj): # pragma: no cover if isinstance(obj, QVariant): return obj.toPyObject() - elif isinstance(obj, QString): - return unicode(obj) + elif QString and isinstance(obj, QString): + return text_type(obj) elif isinstance(obj, QPyNullVariant): return None elif isinstance(obj, list): diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index a6971e997..d10848689 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -15,7 +15,7 @@ from textwrap import dedent import numpy as np -from six import string_types, exec_ +from six import string_types, text_type, exec_ from ._types import _is_integer @@ -39,7 +39,7 @@ def _decode_qbytearray(data_b64): class _CustomEncoder(json.JSONEncoder): def default(self, obj): - from phy.gui.qt import QVariant, QString + from phy.gui.qt import QString if isinstance(obj, np.ndarray): obj_contiguous = np.ascontiguousarray(obj) data_b64 = base64.b64encode(obj_contiguous.data).decode('utf8') @@ -50,8 +50,8 @@ def default(self, obj): return {'__qbytearray__': _encode_qbytearray(obj)} elif isinstance(obj, np.generic): return np.asscalar(obj) - elif isinstance(obj, QString): - return unicode(obj) + elif isinstance(obj, QString): # pragma: no cover + return text_type(obj) return super(_CustomEncoder, self).default(obj) # pragma: no cover From bfd770c27eed56ab18082835160c71df855fc96d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Mar 2016 22:43:21 +0100 Subject: [PATCH 1050/1059] Fix bug with Windows path in config --- phy/utils/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/phy/utils/config.py b/phy/utils/config.py index 56da6bc80..c4b744433 100644 --- a/phy/utils/config.py +++ b/phy/utils/config.py @@ -53,7 +53,7 @@ def load_config(path): def _default_config(config_dir=None): - path = op.join(config_dir or '~/.phy/', 'plugins/') + path = op.join(config_dir or op.join('~', '.phy'), 'plugins/') return dedent(""" # You can also put your plugins in ~/.phy/plugins/. @@ -65,7 +65,7 @@ def attach_to_cli(self, cli): c = get_config() - c.Plugins.dirs = ['{}'] + c.Plugins.dirs = [r'{}'] """.format(path)) From 0c39f696f2a310da80d4a0d1208e383b7ce13316 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Mar 2016 23:09:04 +0100 Subject: [PATCH 1051/1059] Show action alias in status bar --- phy/gui/actions.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/phy/gui/actions.py b/phy/gui/actions.py index 545edf01b..55b56898e 100644 --- a/phy/gui/actions.py +++ b/phy/gui/actions.py @@ -114,7 +114,7 @@ def _alias(name): @require_qt -def _create_qaction(gui, name, callback, shortcut, docstring=None): +def _create_qaction(gui, name, callback, shortcut, docstring=None, alias=''): # Create the QAction instance. action = QAction(name.capitalize().replace('_', ' '), gui) @@ -127,6 +127,7 @@ def wrapped(checked, *args, **kwargs): # pragma: no cover sequence = [sequence] action.setShortcuts(sequence) assert docstring + docstring += ' (alias: {})'.format(alias) action.setStatusTip(docstring) action.setWhatsThis(docstring) return action @@ -176,8 +177,11 @@ def add(self, callback=None, name=None, shortcut=None, alias=None, docstring = re.sub(r'[\s]{2,}', ' ', docstring) # Create and register the action. - action = _create_qaction(self.gui, name, callback, shortcut, - docstring=docstring) + action = _create_qaction(self.gui, name, callback, + shortcut, + docstring=docstring, + alias=alias, + ) action_obj = Bunch(qaction=action, name=name, alias=alias, shortcut=shortcut, callback=callback, menu=menu) if verbose and not name.startswith('_'): From e76fd55f807b8a7ce02d48b714930b783bdf4421 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Thu, 10 Mar 2016 23:38:34 +0100 Subject: [PATCH 1052/1059] Fix randomly failing test with split --- phy/cluster/manual/tests/test_gui_component.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py index fea74e020..4f186a469 100644 --- a/phy/cluster/manual/tests/test_gui_component.py +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -171,9 +171,9 @@ def test_manual_clustering_split_lasso(tempdir, qtbot): # Simulate a lasso. ev = view.events - ev.mouse_press(pos=(210, 10), button=1, modifiers=(keys.CONTROL,)) - ev.mouse_press(pos=(280, 10), button=1, modifiers=(keys.CONTROL,)) - ev.mouse_press(pos=(280, 30), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(210, 1), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(320, 1), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(320, 30), button=1, modifiers=(keys.CONTROL,)) ev.mouse_press(pos=(210, 30), button=1, modifiers=(keys.CONTROL,)) ups = [] From c6bdb606d892604dcf9afdfef90d347411453927 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Fri, 11 Mar 2016 18:32:32 +0100 Subject: [PATCH 1053/1059] Parameters for n_spikes in views --- phy/cluster/manual/controller.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py index 1d87d11c5..bb97b43e6 100644 --- a/phy/cluster/manual/controller.py +++ b/phy/cluster/manual/controller.py @@ -45,6 +45,15 @@ class Controller(EventEmitter): add_view(gui, view) """ + + n_spikes_waveforms = 100 + n_spikes_waveforms_lim = 100 + n_spikes_masks = 100 + n_spikes_features = 5000 + n_spikes_background_features = 5000 + n_spikes_features_lim = 100 + n_spikes_close_clusters = 100 + # responsible for the cache def __init__(self, plugins=None, config_dir=None): super(Controller, self).__init__() @@ -166,7 +175,7 @@ def _data_lim(self, arr, n_max): def get_masks(self, cluster_id): return self._select_data(cluster_id, self.all_masks, - 100, # TODO + self.n_spikes_masks, ) def get_mean_masks(self, cluster_id): @@ -179,14 +188,14 @@ def get_mean_masks(self, cluster_id): def get_waveforms(self, cluster_id): return [self._select_data(cluster_id, self.all_waveforms, - 50, # TODO + self.n_spikes_waveforms, )] def get_mean_waveforms(self, cluster_id): return mean(self.get_waveforms(cluster_id)[0].data) def get_waveform_lims(self): - n_spikes = 100 # TODO + n_spikes = self.n_spikes_waveforms_lim arr = self.all_waveforms n = arr.shape[0] k = max(1, n // n_spikes) @@ -214,11 +223,12 @@ def get_waveforms_amplitude(self, cluster_id): def get_features(self, cluster_id, load_all=False): return self._select_data(cluster_id, self.all_features, - 1000 if not load_all else None, # TODO + (self.n_spikes_features + if not load_all else None), ) def get_background_features(self): - k = max(1, int(self.n_spikes // 1000)) + k = max(1, int(self.n_spikes // self.n_spikes_background_features)) spike_ids = slice(None, None, k) b = Bunch() b.data = self.all_features[spike_ids] @@ -231,7 +241,7 @@ def get_mean_features(self, cluster_id): return mean(self.get_features(cluster_id).data) def get_feature_lim(self): - return self._data_lim(self.all_features, 100) # TODO + return self._data_lim(self.all_features, self.n_spikes_features_lim) # Traces # ------------------------------------------------------------------------- @@ -295,7 +305,7 @@ def get_close_clusters(self, cluster_id): assert dist.shape == (len(clusters),) # Closest clusters. ind = np.argsort(dist) - ind = ind[:100] # TODO + ind = ind[:self.n_spikes_close_clusters] return [(int(clusters[i]), float(dist[i])) for i in ind] def spikes_per_cluster(self, cluster_id): From 0440ff910d2c7615aae73f8de053e07645d57be5 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Mar 2016 14:36:21 +0100 Subject: [PATCH 1054/1059] Add virtual concatenation code --- phy/io/array.py | 121 +++++++++++++++++++++++++++++++++++++ phy/io/tests/test_array.py | 31 ++++++++++ 2 files changed, 152 insertions(+) diff --git a/phy/io/array.py b/phy/io/array.py index b6439e6b4..6661f1962 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -267,6 +267,127 @@ def write_array(path, arr): "is not currently supported.") +# ----------------------------------------------------------------------------- +# Virtual concatenation +# ----------------------------------------------------------------------------- + +def _start_stop(item): + """Find the start and stop indices of a __getitem__ item. + + This is used only by ConcatenatedArrays. + + Only two cases are supported currently: + + * Single integer. + * Contiguous slice in the first dimension only. + + """ + if isinstance(item, tuple): + item = item[0] + if isinstance(item, slice): + # Slice. + if item.step not in (None, 1): + return NotImplementedError() + return item.start, item.stop + elif isinstance(item, (list, np.ndarray)): + # List or array of indices. + return np.min(item), np.max(item) + else: + # Integer. + return item, item + 1 + + +def _fill_index(arr, item): + if isinstance(item, tuple): + item = (slice(None, None, None),) + item[1:] + return arr[item] + else: + return arr + + +class ConcatenatedArrays(object): + """This object represents a concatenation of several memory-mapped + arrays.""" + def __init__(self, arrs, cols=None): + assert isinstance(arrs, list) + self.arrs = arrs + # Reordering of the columns. + self.cols = cols + self.offsets = np.concatenate([[0], np.cumsum([arr.shape[0] + for arr in arrs])], + axis=0) + self.dtype = arrs[0].dtype if arrs else None + + @property + def shape(self): + if self.arrs[0].ndim == 1: + return (self.offsets[-1],) + ncols = (len(self.cols) if self.cols is not None + else self.arrs[0].shape[1]) + return (self.offsets[-1], ncols) + + def _get_recording(self, index): + """Return the recording that contains a given index.""" + assert index >= 0 + recs = np.nonzero((index - self.offsets[:-1]) >= 0)[0] + if len(recs) == 0: + # If the index is greater than the total size, + # return the last recording. + return len(self.arrs) - 1 + # Return the last recording such that the index is greater than + # its offset. + return recs[-1] + + def __getitem__(self, item): + cols = self.cols if self.cols is not None else slice(None, None, None) + # Get the start and stop indices of the requested item. + start, stop = _start_stop(item) + # Return the concatenation of all arrays. + if start is None and stop is None: + return np.concatenate(self.arrs, axis=0)[..., cols] + if start is None: + start = 0 + if stop is None: + stop = self.offsets[-1] + if stop < 0: + stop = self.offsets[-1] + stop + # Get the recording indices of the first and last item. + rec_start = self._get_recording(start) + rec_stop = self._get_recording(stop) + assert 0 <= rec_start <= rec_stop < len(self.arrs) + # Find the start and stop relative to the arrays. + start_rel = start - self.offsets[rec_start] + stop_rel = stop - self.offsets[rec_stop] + # Single array case. + if rec_start == rec_stop: + # Apply the rest of the index. + return _fill_index(self.arrs[rec_start][start_rel:stop_rel], + item)[..., cols] + chunk_start = self.arrs[rec_start][start_rel:] + chunk_stop = self.arrs[rec_stop][:stop_rel] + # Concatenate all chunks. + l = [chunk_start] + if rec_stop - rec_start >= 2: + logger.warn("Loading a full virtual array: this might be slow " + "and something might be wrong.") + l += [self.arrs[r][...] for r in range(rec_start + 1, + rec_stop)] + l += [chunk_stop] + # Apply the rest of the index. + return _fill_index(np.concatenate(l, axis=0), item)[..., cols] + + def __len__(self): + return self.shape[0] + + +def _concatenate_virtual_arrays(arrs, cols=None): + """Return a virtual concatenate of several NumPy arrays.""" + n = len(arrs) + if n == 0: + return None + return ConcatenatedArrays(arrs, cols) + + # ----------------------------------------------------------------------------- # Chunking functions # ----------------------------------------------------------------------------- diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 81a3c53c0..1e8aab621 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -27,6 +27,7 @@ data_chunk, grouped_mean, get_excerpts, + _concatenate_virtual_arrays, _range_from_slice, _pad, _get_padded, @@ -228,6 +229,36 @@ def test_read_write_dask(tempdir): ae(read_array(path, mmap_mode='r'), arr) +#------------------------------------------------------------------------------ +# Test virtual concatenation +#------------------------------------------------------------------------------ + +def test_concatenate_virtual_arrays_1(): + arrs = [np.arange(5), np.arange(10, 12), np.array([0])] + c = _concatenate_virtual_arrays(arrs) + assert c.shape == (8,) + assert c._get_recording(3) == 0 + assert c._get_recording(5) == 1 + + ae(c[:], [0, 1, 2, 3, 4, 10, 11, 0]) + ae(c[0], [0]) + ae(c[4], [4]) + ae(c[5], [10]) + ae(c[6], [11]) + + ae(c[4:6], [4, 10]) + + ae(c[:6], [0, 1, 2, 3, 4, 10]) + ae(c[4:], [4, 10, 11, 0]) + ae(c[4:-1], [4, 10, 11]) + + +def test_concatenate_virtual_arrays_2(): + arrs = [np.zeros((2, 2)), np.ones((3, 2))] + c = _concatenate_virtual_arrays(arrs) + assert c.shape == (5, 2) + + #------------------------------------------------------------------------------ # Test chunking #------------------------------------------------------------------------------ From 3f8855d5ef350610ec8632815cd47598c15b8ba6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Mar 2016 14:41:10 +0100 Subject: [PATCH 1055/1059] Increase coverage --- phy/io/tests/test_array.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 1e8aab621..2b69a43e1 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -257,6 +257,8 @@ def test_concatenate_virtual_arrays_2(): arrs = [np.zeros((2, 2)), np.ones((3, 2))] c = _concatenate_virtual_arrays(arrs) assert c.shape == (5, 2) + ae(c[:, :], np.vstack((np.zeros((2, 2)), np.ones((3, 2))))) + ae(c[0:4, 0], [0, 0, 1, 1]) #------------------------------------------------------------------------------ From 5a37147f3667c37f7bf78d1895ec9c8f033f886c Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Mar 2016 15:23:31 +0100 Subject: [PATCH 1056/1059] NOTE: remove detection and KK2 code, which are now in the klusta package --- .travis.yml | 2 - environment.yml | 8 - phy/cluster/__init__.py | 1 - phy/cluster/algorithms/__init__.py | 2 - phy/cluster/algorithms/klustakwik.py | 148 -------- phy/cluster/algorithms/tests/__init__.py | 0 .../algorithms/tests/test_klustakwik.py | 28 -- phy/io/context.py | 215 +---------- phy/io/tests/test_context.py | 123 +------ phy/traces/__init__.py | 2 - phy/traces/detect.py | 344 ------------------ phy/traces/pca.py | 151 -------- phy/traces/spike_detect.py | 254 ------------- phy/traces/tests/test_detect.py | 293 --------------- phy/traces/tests/test_pca.py | 95 ----- phy/traces/tests/test_spike_detect.py | 177 --------- 16 files changed, 8 insertions(+), 1835 deletions(-) delete mode 100644 phy/cluster/algorithms/__init__.py delete mode 100644 phy/cluster/algorithms/klustakwik.py delete mode 100644 phy/cluster/algorithms/tests/__init__.py delete mode 100644 phy/cluster/algorithms/tests/test_klustakwik.py delete mode 100644 phy/traces/detect.py delete mode 100644 phy/traces/pca.py delete mode 100644 phy/traces/spike_detect.py delete mode 100644 phy/traces/tests/test_detect.py delete mode 100644 phy/traces/tests/test_pca.py delete mode 100644 phy/traces/tests/test_spike_detect.py diff --git a/.travis.yml b/.travis.yml index 5ed964668..633218bd2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ language: python sudo: false python: - "2.7" - - "3.4" - "3.5" before_install: - pip install codecov @@ -21,7 +20,6 @@ install: # Create the environment. - conda env create python=$TRAVIS_PYTHON_VERSION - source activate phy - - pip install klustakwik2 # Dev requirements - pip install -r requirements-dev.txt - pip install -e . diff --git a/environment.yml b/environment.yml index 9adcdf1a6..8e7d1f829 100644 --- a/environment.yml +++ b/environment.yml @@ -3,22 +3,14 @@ channels: - kwikteam dependencies: - python - - pip - numpy=1.9 - vispy - matplotlib - scipy - h5py - pyqt - - ipython - requests - traitlets - six - - ipyparallel - joblib - - cython - click - - dask - - cloudpickle - - toolz - - dill diff --git a/phy/cluster/__init__.py b/phy/cluster/__init__.py index 78194577b..2df738347 100644 --- a/phy/cluster/__init__.py +++ b/phy/cluster/__init__.py @@ -3,5 +3,4 @@ """Automatic and manual clustering facilities.""" -from . import algorithms from . import manual diff --git a/phy/cluster/algorithms/__init__.py b/phy/cluster/algorithms/__init__.py deleted file mode 100644 index 442e1ed9b..000000000 --- a/phy/cluster/algorithms/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- coding: utf-8 -*- -"""Automatic clustering algorithms.""" diff --git a/phy/cluster/algorithms/klustakwik.py b/phy/cluster/algorithms/klustakwik.py deleted file mode 100644 index b640f0c23..000000000 --- a/phy/cluster/algorithms/klustakwik.py +++ /dev/null @@ -1,148 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Wrapper to KlustaKwik2 implementation.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -import six - -from phy.io.array import chunk_bounds -from phy.utils.event import EventEmitter - - -#------------------------------------------------------------------------------ -# Sparse structures -#------------------------------------------------------------------------------ - -def sparsify_features_masks(features, masks, chunk_size=10000): - from klustakwik2 import RawSparseData - - assert features.ndim == 2 - assert masks.ndim == 2 - assert features.shape == masks.shape - - n_spikes, num_features = features.shape - - # Stage 1: read min/max of fet values for normalization - # and count total number of unmasked features. - vmin = np.ones(num_features) * np.inf - vmax = np.ones(num_features) * (-np.inf) - total_unmasked_features = 0 - for _, _, i, j in chunk_bounds(n_spikes, chunk_size): - f, m = features[i:j], masks[i:j] - inds = m > 0 - # Replace the masked values by NaN. - vmin = np.minimum(np.min(f, axis=0), vmin) - vmax = np.maximum(np.max(f, axis=0), vmax) - total_unmasked_features += inds.sum() - # Stage 2: read data line by line, normalising - vdiff = vmax - vmin - vdiff[vdiff == 0] = 1 - fetsum = np.zeros(num_features) - fet2sum = np.zeros(num_features) - nsum = np.zeros(num_features) - all_features = np.zeros(total_unmasked_features) - all_fmasks = np.zeros(total_unmasked_features) - all_unmasked = np.zeros(total_unmasked_features, dtype=int) - offsets = np.zeros(n_spikes + 1, dtype=int) - curoff = 0 - for i in six.moves.range(n_spikes): - fetvals, fmaskvals = (features[i] - vmin) / vdiff, masks[i] - inds = (fmaskvals > 0).nonzero()[0] - masked_inds = (fmaskvals == 0).nonzero()[0] - all_features[curoff:curoff + len(inds)] = fetvals[inds] - all_fmasks[curoff:curoff + len(inds)] = fmaskvals[inds] - all_unmasked[curoff:curoff + len(inds)] = inds - offsets[i] = curoff - curoff += len(inds) - fetsum[masked_inds] += fetvals[masked_inds] - fet2sum[masked_inds] += fetvals[masked_inds] ** 2 - nsum[masked_inds] += 1 - offsets[-1] = curoff - - nsum[nsum == 0] = 1 - noise_mean = fetsum / nsum - noise_variance = fet2sum / nsum - noise_mean ** 2 - - return RawSparseData(noise_mean, - noise_variance, - all_features, - all_fmasks, - all_unmasked, - offsets, - ) - - -#------------------------------------------------------------------------------ -# Clustering class -#------------------------------------------------------------------------------ - -class KlustaKwik(EventEmitter): - """KlustaKwik automatic clustering algorithm.""" - def __init__(self, **kwargs): - super(KlustaKwik, self).__init__() - self._kwargs = kwargs - self.__dict__.update(kwargs) - # Set the version. - from klustakwik2 import __version__ - self.version = __version__ - - def cluster(self, - spike_ids=None, - features=None, - masks=None, - ): - """Run the clustering algorithm on the model, or on any features - and masks. - - Return the `spike_clusters` assignments. - - Emit the `iter` event at every KlustaKwik iteration. - - """ - # Select some spikes if needed. - if spike_ids is not None: - features = features[spike_ids] - masks = masks[spike_ids] - # Convert the features and masks to the sparse structure used - # by KK. - data = sparsify_features_masks(features, masks) - data = data.to_sparse_data() - # Run KK. - from klustakwik2 import KK - kk = KK(data, **self._kwargs) - - @kk.register_callback - def f(_): - # Skip split iterations. - if _.name != '': - return - self.emit('iter', kk.clusters) - - self.params = kk.all_params - kk.cluster_mask_starts() - spike_clusters = kk.clusters - return spike_clusters - - -def cluster(features=None, masks=None, algorithm='klustakwik', - spike_ids=None, **kwargs): - """Launch an automatic clustering algorithm on the model. - - Parameters - ---------- - - features : ndarray - masks : ndarray - algorithm : str - Only 'klustakwik' is supported currently. - **kwargs - Parameters for KK. - - """ - assert algorithm == 'klustakwik' - kk = KlustaKwik(**kwargs) - return kk.cluster(features=features, masks=masks, spike_ids=spike_ids) diff --git a/phy/cluster/algorithms/tests/__init__.py b/phy/cluster/algorithms/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/cluster/algorithms/tests/test_klustakwik.py b/phy/cluster/algorithms/tests/test_klustakwik.py deleted file mode 100644 index 276d87309..000000000 --- a/phy/cluster/algorithms/tests/test_klustakwik.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of clustering algorithms.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from phy.io.mock import artificial_features, artificial_masks -from ..klustakwik import cluster - - -#------------------------------------------------------------------------------ -# Tests clustering -#------------------------------------------------------------------------------ - -def test_cluster(tempdir): - n_channels = 4 - n_spikes = 100 - features = artificial_features(n_spikes, n_channels * 3) - masks = artificial_masks(n_spikes, n_channels * 3) - - spike_clusters = cluster(features, masks, num_starting_clusters=10) - assert len(spike_clusters) == n_spikes - - spike_clusters = cluster(features, masks, num_starting_clusters=10, - spike_ids=range(100)) - assert len(spike_clusters) == 100 diff --git a/phy/io/context.py b/phy/io/context.py index 8fcb6ac6e..d12814e21 100644 --- a/phy/io/context.py +++ b/phy/io/context.py @@ -12,136 +12,19 @@ import os import os.path as op -from traitlets.config.configurable import Configurable -import numpy as np from six.moves.cPickle import dump, load -from six import string_types -try: - from dask.array import Array - from dask.async import get_sync as get - from dask.core import flatten -except ImportError: # pragma: no cover - raise Exception("dask is not installed. " - "Install it with `conda install dask`.") -from .array import read_array, write_array -from phy.utils import (Bunch, _save_json, _load_json, +from phy.utils import (_save_json, _load_json, _ensure_dir_exists, _fullname,) from phy.utils.config import phy_config_dir logger = logging.getLogger(__name__) -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _iter_chunks_dask(da): - for chunk in flatten(da._keys()): - yield chunk - - #------------------------------------------------------------------------------ # Context #------------------------------------------------------------------------------ -def _mapped(i, chunk, dask, func, args, cache_dir, name): - """Top-level function to map. - - This function needs to be a top-level function for ipyparallel to work. - - """ - # Load the array's chunk. - arr = get(dask, chunk) - - # Execute the function on the chunk. - # logger.debug("Run %s on chunk %d", name, i) - res = func(arr, *args) - - # Save the result, and return the information about what we saved. - return _save_stack_chunk(i, res, cache_dir, name) - - -def _save_stack_chunk(i, arr, cache_dir, name): - """Save an output chunk array to a npy file, and return information about - it.""" - # Handle the case where several output arrays are returned. - if isinstance(arr, tuple): - # The name is a tuple of names for the different arrays returned. - assert isinstance(name, tuple) - assert len(arr) == len(name) - - return tuple(_save_stack_chunk(i, arr_, cache_dir, name_) - for arr_, name_ in zip(arr, name)) - - assert isinstance(name, string_types) - assert isinstance(arr, np.ndarray) - - dirpath = op.join(cache_dir, name) - path = op.join(dirpath, '{}.npy'.format(i)) - write_array(path, arr) - - # Return information about what we just saved. - return Bunch(dask_tuple=(read_array, path), - shape=arr.shape, - dtype=arr.dtype, - name=name, - dirpath=dirpath, - ) - - -def _save_stack_info(outputs): - """Save the npy stack info, and return one or several dask arrays from - saved npy stacks. - - The argument is a list of objects returned by `_save_stack_chunk()`. - - """ - # Handle the case where several arrays are returned, i.e. outputs is a list - # of tuples of Bunch objects. - assert len(outputs) - if isinstance(outputs[0], tuple): - return tuple(_save_stack_info(output) for output in zip(*outputs)) - - # Get metadata fields common to all chunks. - assert len(outputs) - assert isinstance(outputs[0], Bunch) - name = outputs[0].name - dirpath = outputs[0].dirpath - dtype = outputs[0].dtype - trail_shape = outputs[0].shape[1:] - trail_ndim = len(trail_shape) - - # Ensure the consistency of all chunks metadata. - assert all(output.name == name for output in outputs) - assert all(output.dirpath == dirpath for output in outputs) - assert all(output.dtype == dtype for output in outputs) - assert all(output.shape[1:] == trail_shape for output in outputs) - - # Compute the output dask array chunks and shape. - chunks = (tuple(output.shape[0] for output in outputs),) + trail_shape - n = sum(output.shape[0] for output in outputs) - shape = (n,) + trail_shape - - # Save the info object for dask npy stack. - with open(op.join(dirpath, 'info'), 'wb') as f: - dump({'chunks': chunks, 'dtype': dtype, 'axis': 0}, f) - - # Return the result as a dask array. - dask_tuples = tuple(output.dask_tuple for output in outputs) - dask = {((name, i) + (0,) * trail_ndim): chunk - for i, chunk in enumerate(dask_tuples)} - return Array(dask, name, chunks, dtype=dtype, shape=shape) - - -def _ensure_cache_dirs_exist(cache_dir, name): - if isinstance(name, tuple): - return [_ensure_cache_dirs_exist(cache_dir, name_) for name_ in name] - dirpath = op.join(cache_dir, name) - if not op.exists(dirpath): - os.makedirs(dirpath) - - class Context(object): """Handle function cacheing and parallel map with ipyparallel.""" def __init__(self, cache_dir, ipy_view=None, verbose=0): @@ -176,18 +59,6 @@ def _set_memory(self, cache_dir): "Install it with `conda install joblib`.") self._memory = None - @property - def ipy_view(self): - """ipyparallel view to parallel computing resources.""" - return self._ipy_view - - @ipy_view.setter - def ipy_view(self, value): - self._ipy_view = value - if hasattr(value, 'use_dill'): - # Dill is necessary because we need to serialize closures. - value.use_dill() - def cache(self, f): """Cache a function using the context's cache directory.""" if self._memory is None: # pragma: no cover @@ -195,7 +66,7 @@ def cache(self, f): return f assert f # NOTE: discard self in instance methods. - if 'self' in inspect.getargspec(f).args: + if 'self' in inspect.getargspec(f).args: # noqa ignore = ['self'] else: ignore = None @@ -240,74 +111,6 @@ def memcached(*args, **kwargs): return out return memcached - def map_dask_array(self, func, da, *args, **kwargs): - """Map a function on the chunks of a dask array, and return a - new dask array. - - This function works in parallel if an `ipy_view` has been set. - - Every task loads one chunk, applies the function, and saves the - result into a `.npy` file in a cache subdirectory with the specified - name (the function's name by default). The result is a new dask array - that reads data from the npy stack in the cache subdirectory. - - The mapped function can return several arrays as a tuple. In this case, - `name` must also be a tuple, and the output of this function is a - tuple of dask arrays. - - """ - assert isinstance(da, Array) - - name = kwargs.get('name', None) or func.__name__ - assert name != da.name - dask = da.dask - - # Ensure the directories exist. - _ensure_cache_dirs_exist(self.cache_dir, name) - - args_0 = list(_iter_chunks_dask(da)) - n = len(args_0) - output = self.map(_mapped, range(n), args_0, [dask] * n, - [func] * n, [args] * n, - [self.cache_dir] * n, [name] * n) - - # output contains information about the output arrays. We use this - # information to reconstruct the final dask array. - return _save_stack_info(output) - - def _map_serial(self, f, *args): - return [f(*arg) for arg in zip(*args)] - - def _map_ipy(self, f, *args, **kwargs): - if kwargs.get('sync', True): - name = 'map_sync' - else: - name = 'map_async' - return getattr(self._ipy_view, name)(f, *args) - - def map_async(self, f, *args): - """Map a function asynchronously. - - Use the ipyparallel resources if available. - - """ - if self._ipy_view: - return self._map_ipy(f, *args, sync=False) - else: - raise RuntimeError("Asynchronous execution requires an " - "ipyparallel context.") - - def map(self, f, *args): - """Map a function synchronously. - - Use the ipyparallel resources if available. - - """ - if self._ipy_view: - return self._map_ipy(f, *args, sync=True) - else: - return self._map_serial(f, *args) - def _get_path(self, name, location): if location == 'local': return op.join(self.cache_dir, name + '.json') @@ -333,7 +136,6 @@ def __getstate__(self): """Make sure that this class is picklable.""" state = self.__dict__.copy() state['_memory'] = None - state['_ipy_view'] = None return state def __setstate__(self, state): @@ -341,16 +143,3 @@ def __setstate__(self, state): self.__dict__ = state # Recreate the joblib Memory instance. self._set_memory(state['cache_dir']) - - -#------------------------------------------------------------------------------ -# Task -#------------------------------------------------------------------------------ - -class Task(Configurable): - def __init__(self, ctx=None): - super(Task, self).__init__() - self.set_context(ctx) - - def set_context(self, ctx): - self.ctx = ctx diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py index d8e32596c..2d54e2be2 100644 --- a/phy/io/tests/test_context.py +++ b/phy/io/tests/test_context.py @@ -6,59 +6,27 @@ # Imports #------------------------------------------------------------------------------ -import os import os.path as op import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import yield_fixture, mark, raises +from pytest import yield_fixture from six.moves import cPickle -from ..context import (Context, Task, - _iter_chunks_dask, write_array, read_array, - _fullname, - ) +from ..array import write_array, read_array +from ..context import Context, _fullname #------------------------------------------------------------------------------ # Fixtures #------------------------------------------------------------------------------ -@yield_fixture(scope='module') -def ipy_client(): - - def iptest_stdstreams_fileno(): - return os.open(os.devnull, os.O_WRONLY) - - # OMG-THIS-IS-UGLY-HACK: monkey-patch this global object to avoid - # using the nose iptest extension (we're using pytest). - # See https://github.com/ipython/ipython/blob/master/IPython/testing/iptest.py#L317-L319 # noqa - from ipyparallel import Client - import ipyparallel.tests - ipyparallel.tests.nose.iptest_stdstreams_fileno = iptest_stdstreams_fileno - - # Start two engines engine (one is launched by setup()). - ipyparallel.tests.setup() - ipyparallel.tests.add_engines(1) - yield Client(profile='iptest') - ipyparallel.tests.teardown() - - @yield_fixture(scope='function') def context(tempdir): ctx = Context('{}/cache/'.format(tempdir), verbose=1) yield ctx -@yield_fixture(scope='function', params=[False, True]) -def parallel_context(tempdir, ipy_client, request): - """Parallel and non-parallel context.""" - ctx = Context('{}/cache/'.format(tempdir)) - if request.param: - ctx.ipy_view = ipy_client[:] - yield ctx - - @yield_fixture def temp_phy_config_dir(tempdir): """Use a temporary phy user directory.""" @@ -69,15 +37,6 @@ def temp_phy_config_dir(tempdir): phy.io.context.phy_config_dir = f -#------------------------------------------------------------------------------ -# ipyparallel tests -#------------------------------------------------------------------------------ - -def test_client_1(ipy_client): - assert ipy_client.ids == [0, 1] - assert ipy_client[:].map_sync(lambda x: x * x, [1, 2, 3]) == [1, 4, 9] - - #------------------------------------------------------------------------------ # Test utils and cache #------------------------------------------------------------------------------ @@ -158,81 +117,11 @@ def f(x): assert len(_res) == 1 -def test_pickle_cache(tempdir, parallel_context): +def test_pickle_cache(tempdir, context): """Make sure the Context is picklable.""" with open(op.join(tempdir, 'test.pkl'), 'wb') as f: - cPickle.dump(parallel_context, f) + cPickle.dump(context, f) with open(op.join(tempdir, 'test.pkl'), 'rb') as f: ctx = cPickle.load(f) assert isinstance(ctx, Context) - assert ctx.cache_dir == parallel_context.cache_dir - - -#------------------------------------------------------------------------------ -# Test map -#------------------------------------------------------------------------------ - -def test_context_map(parallel_context): - - def square(x): - return x * x - - assert parallel_context.map(square, [1, 2, 3]) == [1, 4, 9] - if not parallel_context.ipy_view: - with raises(RuntimeError): - parallel_context.map_async(square, [1, 2, 3]) - else: - assert parallel_context.map_async(square, [1, 2, 3]).get() == [1, 4, 9] - - -def test_task(): - task = Task(ctx=None) - assert task - - -#------------------------------------------------------------------------------ -# Test context dask -#------------------------------------------------------------------------------ - -def test_iter_chunks_dask(): - from dask.array import from_array - - x = np.arange(10) - da = from_array(x, chunks=(3,)) - assert len(list(_iter_chunks_dask(da))) == 4 - - -@mark.parametrize('multiple_outputs', [True, False]) -def test_context_dask(parallel_context, multiple_outputs): - from dask.array import from_array, from_npy_stack - context = parallel_context - - if not multiple_outputs: - def f4(x, onset): - return x * x * x * x - name = None - else: - def f4(x, onset): - return x * x * x * x + onset, x + 1 - name = ('power_four', 'plus_one') - - x = np.arange(10) - da = from_array(x, chunks=(3,)) - res = context.map_dask_array(f4, da, 0, name=name) - - # Check that we can load the dumped dask array from disk. - # The location is in the context cache dir, in a subdirectory with the - # name of the function by default. - if not multiple_outputs: - ae(res.compute(), x ** 4) - - y = from_npy_stack(op.join(context.cache_dir, 'f4')) - ae(y.compute(), x ** 4) - else: - ae(res[0].compute(), x ** 4) - ae(res[1].compute(), x + 1) - y = from_npy_stack(op.join(context.cache_dir, 'power_four')) - ae(y.compute(), x ** 4) - - y = from_npy_stack(op.join(context.cache_dir, 'plus_one')) - ae(y.compute(), x + 1) + assert ctx.cache_dir == context.cache_dir diff --git a/phy/traces/__init__.py b/phy/traces/__init__.py index 039bc84aa..e51e25895 100644 --- a/phy/traces/__init__.py +++ b/phy/traces/__init__.py @@ -3,7 +3,5 @@ """Spike detection, waveform extraction.""" -from .detect import Thresholder, FloodFillDetector, compute_threshold from .filter import Filter, Whitening -from .pca import PCA from .waveform import WaveformLoader, WaveformExtractor, SpikeLoader diff --git a/phy/traces/detect.py b/phy/traces/detect.py deleted file mode 100644 index 00483316f..000000000 --- a/phy/traces/detect.py +++ /dev/null @@ -1,344 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Spike detection.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from six import string_types -from six.moves import range, zip - -from phy.io.array import _as_array - - -#------------------------------------------------------------------------------ -# Thresholder -#------------------------------------------------------------------------------ - -def compute_threshold(arr, single_threshold=True, std_factor=None): - """Compute the threshold(s) of filtered traces. - - Parameters - ---------- - - arr : ndarray - Filtered traces, shape `(n_samples, n_channels)`. - single_threshold : bool - Whether there should be a unique threshold for all channels, or - one threshold per channel. - std_factor : float or 2-tuple - The threshold in unit of signal std. Two values can be specified - for multiple thresholds (weak and strong). - - Returns - ------- - - thresholds : ndarray - A `(2,)` or `(2, n_channels)` array with the thresholds. - - """ - assert arr.ndim == 2 - ns, nc = arr.shape - - assert std_factor is not None - if isinstance(std_factor, (int, float)): - std_factor = (std_factor, std_factor) - assert isinstance(std_factor, (tuple, list)) - assert len(std_factor) == 2 - std_factor = np.array(std_factor) - - if not single_threshold: - std_factor = std_factor[:, None] - - # Get the median of all samples in all excerpts, on all channels. - if single_threshold: - median = np.median(np.abs(arr)) - # Or independently for each channel. - else: - median = np.median(np.abs(arr), axis=0) - - # Compute the threshold from the median. - std = median / .6745 - threshold = std_factor * std - assert isinstance(threshold, np.ndarray) - - if single_threshold: - assert threshold.ndim == 1 - assert len(threshold) == 2 - else: - assert threshold.ndim == 2 - assert threshold.shape == (2, nc) - return threshold - - -class Thresholder(object): - """Threshold traces to detect spikes. - - Parameters - ---------- - - mode : str - `'positive'`, `'negative'`, or `'both'`. - thresholds : dict - A `{str: float}` mapping for multiple thresholds (e.g. `weak` - and `strong`). - - Example - ------- - - ```python - thres = Thresholder('positive', thresholds=(.1, .2)) - crossings = thres(traces) - ``` - - """ - def __init__(self, - mode=None, - thresholds=None, - ): - assert mode in ('positive', 'negative', 'both') - if isinstance(thresholds, (float, int, np.ndarray)): - thresholds = {'default': thresholds} - thresholds = thresholds if thresholds is not None else {} - assert isinstance(thresholds, dict) - self._mode = mode - self._thresholds = thresholds - - def transform(self, data): - """Return `data`, `-data`, or `abs(data)` depending on the mode.""" - if self._mode == 'positive': - return data - elif self._mode == 'negative': - return -data - elif self._mode == 'both': - return np.abs(data) - - def detect(self, data_t, threshold=None): - """Perform the thresholding operation.""" - # Accept dictionary of thresholds. - if isinstance(threshold, (list, tuple)): - return {name: self(data_t, threshold=name) - for name in threshold} - # Use the only threshold by default (if there is only one). - if threshold is None: - assert len(self._thresholds) == 1 - threshold = list(self._thresholds.keys())[0] - # Fetch the threshold from its name. - if isinstance(threshold, string_types): - assert threshold in self._thresholds - threshold = self._thresholds[threshold] - # threshold = float(threshold) - # Threshold the data. - return data_t > threshold - - def __call__(self, data, threshold=None): - # Transform the data according to the mode. - data_t = self.transform(data) - return self.detect(data_t, threshold=threshold) - - -# ----------------------------------------------------------------------------- -# Connected components -# ----------------------------------------------------------------------------- - -def connected_components(weak_crossings=None, - strong_crossings=None, - probe_adjacency_list=None, - join_size=None): - """Find all connected components in binary arrays of threshold crossings. - - Parameters - ---------- - - weak_crossings : array - `(n_samples, n_channels)` array with weak threshold crossings - strong_crossings : array - `(n_samples, n_channels)` array with strong threshold crossings - probe_adjacency_list : dict - A dict `{channel: [neighbors]}` - join_size : int - The number of samples defining the tolerance in time for - finding connected components - - Returns - ------- - - A list of lists of pairs `(samp, chan)` of the connected components in - the 2D array `weak_crossings`, where a pair is adjacent if the samples are - within `join_size` of each other, and the channels are adjacent in - `probe_adjacency_list`, the channel graph. - - Note - ---- - - The channel mapping assumes that column #i in the data array is channel #i - in the probe adjacency graph. - - """ - - probe_adjacency_list = probe_adjacency_list or {} - - # Make sure the values are sets. - probe_adjacency_list = {c: set(cs) - for c, cs in probe_adjacency_list.items()} - - if strong_crossings is None: - strong_crossings = weak_crossings - - assert weak_crossings.shape == strong_crossings.shape - - # Set of connected component labels which contain at least one strong - # node. - strong_nodes = set() - - n_s, n_ch = weak_crossings.shape - join_size = int(join_size or 0) - - # An array with the component label for each node in the array - label_buffer = np.zeros((n_s, n_ch), dtype=np.int32) - - # Component indices, a dictionary with keys the label of the component - # and values a list of pairs (sample, channel) belonging to that component - comp_inds = {} - - # mgraph is the channel graph, but with edge node connected to itself - # because we want to include ourself in the adjacency. Each key of the - # channel graph (a dictionary) is a node, and the value is a set of nodes - # which are connected to it by an edge - mgraph = {} - for source, targets in probe_adjacency_list.items(): - # we add self connections - mgraph[source] = targets.union([source]) - - # Label of the next component - c_label = 1 - - # For all pairs sample, channel which are nonzero (note that numpy .nonzero - # returns (all_i_s, all_i_ch), a pair of lists whose values at the - # corresponding place are the sample, channel pair which is nonzero. The - # lists are also returned in sorted order, so that i_s is always increasing - # and i_ch is always increasing for a given value of i_s. izip is an - # iterator version of the Python zip function, i.e. does the same as zip - # but quicker. zip(A,B) is a list of all pairs (a,b) with a in A and b in B - # in order (i.e. (A[0], B[0]), (A[1], B[1]), .... In conclusion, the next - # line loops through all the samples i_s, and for each sample it loops - # through all the channels. - for i_s, i_ch in zip(*weak_crossings.nonzero()): - # The next two lines iterate through all the neighbours of i_s, i_ch - # in the graph defined by graph in the case of edges, and - # j_s from i_s-join_size to i_s. - for j_s in range(i_s - join_size, i_s + 1): - # Allow us to leave out a channel from the graph to exclude bad - # channels - if i_ch not in mgraph: # pragma: no cover - continue - for j_ch in mgraph[i_ch]: - # Label of the adjacent element. - adjlabel = label_buffer[j_s, j_ch] - # If the adjacent element is nonzero we need to do something. - if adjlabel: - curlabel = label_buffer[i_s, i_ch] - if curlabel == 0: - # If current element is still zero, we just assign - # the label of the adjacent element to the current one. - label_buffer[i_s, i_ch] = adjlabel - # And add it to the list for the labelled component. - comp_inds[adjlabel].append((i_s, i_ch)) - - elif curlabel != adjlabel: - # If the current element is unequal to the adjacent - # one, we merge them by reassigning the elements of the - # adjacent component to the current one. - # samps_chans is an array of pairs sample, channel - # currently assigned to component adjlabel. - samps_chans = np.array(comp_inds[adjlabel], - dtype=np.int32) - - # samps_chans[:, 0] is the sample indices, so this - # gives only the samp,chan pairs that are within - # join_size of the current point. - # TODO: is this the right behaviour? If a component can - # have a width bigger than join_size I think it isn't! - samps_chans = samps_chans[i_s - samps_chans[:, 0] <= - join_size] - - # Relabel the adjacent samp,chan points with current - # label. - samps, chans = samps_chans[:, 0], samps_chans[:, 1] - label_buffer[samps, chans] = curlabel - - # Add them to the current label list, and remove the - # adjacent component entirely. - comp_inds[curlabel].extend(comp_inds.pop(adjlabel)) - - # Did not deal with merge condition, now fixed it - # seems... - # WARNING: might this "in" incur a performance hit - # here...? - if adjlabel in strong_nodes: - strong_nodes.add(curlabel) - strong_nodes.remove(adjlabel) - - # NEW: add the current component label to the set of all - # strong nodes, if the current node is strong. - if curlabel > 0 and strong_crossings[i_s, i_ch]: - strong_nodes.add(curlabel) - - if label_buffer[i_s, i_ch] == 0: - # If nothing is adjacent, we have the beginnings of a new - # component, # so we label it, create a new list for the new - # component which is given label c_label, - # then increase c_label for the next new component afterwards. - label_buffer[i_s, i_ch] = c_label - comp_inds[c_label] = [(i_s, i_ch)] - if strong_crossings[i_s, i_ch]: - strong_nodes.add(c_label) - c_label += 1 - - # Only return the values, because we don't actually need the labels. - comps = [comp_inds[key] for key in comp_inds.keys() if key in strong_nodes] - return comps - - -class FloodFillDetector(object): - """Detect spikes in weak and strong threshold crossings. - - Parameters - ---------- - - probe_adjacency_list : dict - A dict `{channel: [neighbors]}`. - join_size : int - The number of samples defining the tolerance in time for - finding connected components - - Example - ------- - - ```python - det = FloodFillDetector(probe_adjacency_list=..., - join_size=...) - components = det(weak_crossings, strong_crossings) - ``` - - `components` is a list of `(n, 2)` int arrays with the sample and channel - for every sample in the component. - - """ - def __init__(self, probe_adjacency_list=None, join_size=None): - self._adjacency_list = probe_adjacency_list - self._join_size = join_size - - def __call__(self, weak_crossings=None, strong_crossings=None): - weak_crossings = _as_array(weak_crossings, np.bool) - strong_crossings = _as_array(strong_crossings, np.bool) - - cc = connected_components(weak_crossings=weak_crossings, - strong_crossings=strong_crossings, - probe_adjacency_list=self._adjacency_list, - join_size=self._join_size, - ) - # cc is a list of list of pairs (sample, channel) - return [np.array(c) for c in cc] diff --git a/phy/traces/pca.py b/phy/traces/pca.py deleted file mode 100644 index 0d4e10960..000000000 --- a/phy/traces/pca.py +++ /dev/null @@ -1,151 +0,0 @@ -# -*- coding: utf-8 -*- - -"""PCA for features.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from traitlets import Int - -from phy.io.context import Task -from ..utils._types import _as_array - - -#------------------------------------------------------------------------------ -# PCA -#------------------------------------------------------------------------------ - -def _compute_pcs(x, n_pcs=None, masks=None): - """Compute the PCs of waveforms.""" - - assert x.ndim == 3 - x = _as_array(x, np.float64) - - n_spikes, n_samples, n_channels = x.shape - - if masks is not None: - assert isinstance(masks, np.ndarray) - assert masks.shape == (n_spikes, n_channels) - - # Compute regularization cov matrix. - cov_reg = np.eye(n_samples) - if masks is not None: - unmasked = masks > 0 - # The last dimension is now time. The second dimension is channel. - x_swapped = np.swapaxes(x, 1, 2) - # This is the list of all unmasked spikes on all channels. - # shape: (n_unmasked_spikes, n_samples) - unmasked_all = x_swapped[unmasked, :] - # Let's compute the regularization cov matrix of this beast. - # shape: (n_samples, n_samples) - cov_reg_ = np.cov(unmasked_all, rowvar=0) - # Make sure the covariance matrix is valid. - if cov_reg_.ndim == 2: - cov_reg = cov_reg - assert cov_reg.shape == (n_samples, n_samples) - - pcs_list = [] - # Loop over channels - for channel in range(n_channels): - x_channel = x[:, :, channel] - # Compute cov matrix for the channel - if masks is not None: - # Unmasked waveforms on that channel - # shape: (n_unmasked, n_samples) - x_channel = np.compress(masks[:, channel] > 0, - x_channel, axis=0) - assert x_channel.ndim == 2 - # Don't compute the cov matrix if there are no unmasked spikes - # on that channel. - alpha = 1. / n_spikes - if x_channel.shape[0] <= 1: # pragma: no cover - cov = alpha * cov_reg - else: - cov_channel = np.cov(x_channel, rowvar=0) - assert cov_channel.shape == (n_samples, n_samples) - cov = alpha * cov_reg + cov_channel - # Compute the eigenelements - vals, vecs = np.linalg.eigh(cov) - pcs = vecs.T.astype(np.float32)[np.argsort(vals)[::-1]] - # Take the first n_pcs components. - if n_pcs is not None: - pcs = pcs[:n_pcs, ...] - pcs_list.append(pcs[:n_pcs, ...]) - - pcs = np.dstack(pcs_list) - return pcs - - -def _project_pcs(x, pcs): - """Project data points onto principal components. - - Parameters - ---------- - - x : array - The waveforms - pcs : array - The PCs returned by `_compute_pcs()`. - """ - # x: (n, ns, nc) - # pcs: (nf, ns, nc) - # out: (n, nc, nf) - assert pcs.ndim == 3 - assert x.ndim == 3 - n, ns, nc = x.shape - nf, ns_, nc_ = pcs.shape - assert ns == ns_ - assert nc == nc_ - - x_proj = np.einsum('ijk,...jk->...ki', pcs, x) - assert x_proj.shape == (n, nc, nf) - return x_proj - - -class PCA(Task): - """Apply PCA to waveforms.""" - n_features_per_channel = Int(3) - - def fit(self, waveforms, masks=None): - """Compute the PCs of waveforms. - - Parameters - ---------- - - waveforms : ndarray - Shape: `(n_spikes, n_samples, n_channels)` - masks : ndarray - Shape: `(n_spikes, n_channels)` - - """ - self._pcs = _compute_pcs(waveforms, - n_pcs=self.n_features_per_channel, - masks=masks, - ) - return self._pcs - - def _project(self, waveforms): - return _project_pcs(waveforms, self._pcs) - - def transform(self, waveforms, pcs=None): - """Project waveforms on the PCs. - - Parameters - ---------- - - waveforms : ndarray - Shape: `(n_spikes, n_samples, n_channels)` - - """ - self._pcs = self._pcs if pcs is None else pcs - assert self._pcs is not None - if not self.ctx: - return self._project(waveforms) - else: - import dask.array as da - assert isinstance(waveforms, da.Array) - - return self.ctx.map_dask_array(self._project, waveforms, - name='features') diff --git a/phy/traces/spike_detect.py b/phy/traces/spike_detect.py deleted file mode 100644 index 45763aa2c..000000000 --- a/phy/traces/spike_detect.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Spike detection.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import logging - -import numpy as np -from traitlets import Int, Float, Unicode, Bool - -from phy.electrode.mea import MEA, _adjacency_subset, _remap_adjacency -from phy.io.array import get_excerpts -from phy.io.context import Task -from .detect import FloodFillDetector, Thresholder, compute_threshold -from .filter import Filter -from .waveform import WaveformExtractor - -logger = logging.getLogger(__name__) - - -#------------------------------------------------------------------------------ -# Chunking-related utility functions -#------------------------------------------------------------------------------ - -def _spikes_to_keep(spikes, trace_chunks, depth): - """Find the indices of the spikes to keep given a chunked trace array.""" - - # Find where to trim the spikes in the overlapping bands. - def _find_bounds(x, block_id=None): - n = trace_chunks[0][block_id[0]] - i = np.searchsorted(x, depth) - j = np.searchsorted(x, n + depth) - return np.array([i, j]) - - # Trim the arrays. - ij = spikes.map_blocks(_find_bounds, chunks=(2,)).compute() - return ij[::2], ij[1::2] - - -def _trim_spikes(arr, indices): - onsets, offsets = indices - - def _trim(x, block_id=None): - i = block_id[0] - on = onsets[i] - off = offsets[i] - return x[on:off, ...] - - # Compute the trimmed chunks. - chunks = (tuple(offsets - onsets),) + arr.chunks[1:] - return arr.map_blocks(_trim, chunks=chunks) - - -def _add_chunk_offset(arr, trace_chunks, depth): - - # Add the spike sample offsets. - def _add_offset(x, block_id=None): - i = block_id[0] - return x + sum(trace_chunks[0][:i]) - depth - - return arr.map_blocks(_add_offset) - - -def _concat_spikes(s, m, w, trace_chunks=None, depth=None): - indices = _spikes_to_keep(s, trace_chunks, depth) - s = _trim_spikes(s, indices) - m = _trim_spikes(m, indices) - w = _trim_spikes(w, indices) - - s = _add_chunk_offset(s, trace_chunks, depth) - return s, m, w - - -#------------------------------------------------------------------------------ -# SpikeDetector -#------------------------------------------------------------------------------ - -class SpikeDetector(Task): - do_filter = Bool(True) - filter_low = Float(500.) - filter_butter_order = Int(3) - chunk_size_seconds = Float(1) - chunk_overlap_seconds = Float(.015) - n_excerpts = Int(50) - excerpt_size_seconds = Float(1.) - use_single_threshold = Bool(True) - threshold_strong_std_factor = Float(4.5) - threshold_weak_std_factor = Float(2) - detect_spikes = Unicode('negative') - connected_component_join_size = Int(1) - extract_s_before = Int(10) - extract_s_after = Int(10) - weight_power = Float(2) - - def set_metadata(self, probe, site_label_to_traces_row=None, - sample_rate=None): - assert isinstance(probe, MEA) - self.probe = probe - - assert sample_rate > 0 - self.sample_rate = sample_rate - - # Channel mapping. - if site_label_to_traces_row is None: - site_label_to_traces_row = {c: c for c in probe.channels} - # Remove channels mapped to None or a negative value: they are dead. - site_label_to_traces_row = {k: v for (k, v) in - site_label_to_traces_row.items() - if v is not None and v >= 0} - # channel mappings is {trace_col: channel_id}. - # Trace columns and channel ids to keep. - self.trace_cols = sorted(site_label_to_traces_row.keys()) - self.channel_ids = sorted(site_label_to_traces_row.values()) - # The key is the col in traces, the val is the channel id. - adj = self.probe.adjacency # Numbers are all channel ids. - # First, we subset the adjacency list with the kept channel ids. - adj = _adjacency_subset(adj, self.channel_ids) - # Then, we remap to convert from channel ids to trace columns. - # We need to inverse the mapping. - site_label_to_traces_row_inv = {v: c for (c, v) in - site_label_to_traces_row.items()} - # Now, the adjacency list contains trace column numbers. - adj = _remap_adjacency(adj, site_label_to_traces_row_inv) - assert set(adj) <= set(self.trace_cols) - # Finally, we need to remap with relative column indices. - rel_mapping = {c: i for (i, c) in enumerate(self.trace_cols)} - adj = _remap_adjacency(adj, rel_mapping) - self._adjacency = adj - - # Array of channel idx to consider. - self.n_channels = len(self.channel_ids) - self.n_samples_waveforms = self.extract_s_before + self.extract_s_after - - def subset_traces(self, traces): - return traces[:, self.trace_cols] - - def find_thresholds(self, traces): - """Find weak and strong thresholds in filtered traces.""" - excerpt_size = int(self.excerpt_size_seconds * self.sample_rate) - single_threshold = self.use_single_threshold - std_factor = (self.threshold_weak_std_factor, - self.threshold_strong_std_factor) - - logger.info("Extracting some data for finding the thresholds...") - excerpt = get_excerpts(traces, n_excerpts=self.n_excerpts, - excerpt_size=excerpt_size) - - logger.info("Filtering the excerpts...") - excerpt_f = self.filter(excerpt) - - logger.info("Computing the thresholds...") - thresholds = compute_threshold(excerpt_f, - single_threshold=single_threshold, - std_factor=std_factor) - - thresholds = {'weak': thresholds[0], 'strong': thresholds[1]} - # logger.info("Thresholds found: {}.".format(thresholds)) - return thresholds - - def filter(self, traces): - if not self.do_filter: # pragma: no cover - return traces - f = Filter(rate=self.sample_rate, - low=self.filter_low, - high=0.95 * .5 * self.sample_rate, - order=self.filter_butter_order, - ) - logger.info("Filtering %d samples...", traces.shape[0]) - return f(traces).astype(np.float32) - - def extract_spikes(self, traces_subset, thresholds=None): - thresholds = thresholds or self._thresholds - assert thresholds is not None - self._thresholder = Thresholder(mode=self.detect_spikes, - thresholds=thresholds) - - # Filter the traces. - traces_f = self.filter(traces_subset) - - # Transform the filtered data according to the detection mode. - traces_t = self._thresholder.transform(traces_f) - - # Compute the threshold crossings. - weak = self._thresholder.detect(traces_t, 'weak') - strong = self._thresholder.detect(traces_t, 'strong') - - # Run the detection. - logger.info("Detecting connected components...") - join_size = self.connected_component_join_size - detector = FloodFillDetector(probe_adjacency_list=self._adjacency, - join_size=join_size) - components = detector(weak_crossings=weak, - strong_crossings=strong) - - # Extract all waveforms. - extractor = WaveformExtractor(extract_before=self.extract_s_before, - extract_after=self.extract_s_after, - weight_power=self.weight_power, - thresholds=thresholds, - ) - - logger.info("Extracting %d spikes...", len(components)) - s, m, w = zip(*(extractor(component, data=traces_f, data_t=traces_t) - for component in components)) - s = np.array(s, dtype=np.int64) - m = np.array(m, dtype=np.float32) - w = np.array(w, dtype=np.float32) - return s, m, w - - def detect(self, traces, thresholds=None): - - # Only keep the selected channels (given shank, no dead channels, etc.) - traces = self.subset_traces(traces) - assert traces.ndim == 2 - assert traces.shape[1] == self.n_channels - - # Find the thresholds. - if thresholds is None: - thresholds = self.find_thresholds(traces) - - # Extract the spikes, masks, waveforms. - if not self.ctx: - return self.extract_spikes(traces, thresholds=thresholds) - else: # pragma: no cover # skipped for now in the test suite - import dask.array as da - - # Chunking parameters. - chunk_size = int(self.chunk_size_seconds * self.sample_rate) - depth = int(self.chunk_overlap_seconds * self.sample_rate) - trace_chunks = (chunk_size, traces.shape[1]) - - # Chunk the data. traces is now a dask Array. - traces = da.from_array(traces, chunks=trace_chunks) - trace_chunks = traces.chunks - - # Add overlapping band in traces. - traces = da.ghost.ghost(traces, - depth={0: depth}, boundary={0: 0}) - - names = ('spike_samples', 'masks', 'waveforms') - self._thresholds = thresholds - - # Run the spike extraction procedure in parallel. - s, m, w = self.ctx.map_dask_array(self.extract_spikes, - traces, name=names) - - # Return the concatenated spike samples, masks, waveforms, as - # dask arrays reading from the cached .npy files. - return _concat_spikes(s, m, w, - trace_chunks=trace_chunks, - depth=depth) diff --git a/phy/traces/tests/test_detect.py b/phy/traces/tests/test_detect.py deleted file mode 100644 index 10c4780eb..000000000 --- a/phy/traces/tests/test_detect.py +++ /dev/null @@ -1,293 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of spike detection routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac - -from ..detect import (compute_threshold, - Thresholder, - connected_components, - FloodFillDetector, - ) -from ...io.mock import artificial_traces - - -#------------------------------------------------------------------------------ -# Test thresholder -#------------------------------------------------------------------------------ - -def test_compute_threshold(): - n_samples, n_channels = 100, 10 - data = artificial_traces(n_samples, n_channels) - - # Single threshold. - threshold = compute_threshold(data, std_factor=1.) - assert threshold.shape == (2,) - assert threshold[0] > 0 - assert threshold[0] == threshold[1] - - threshold = compute_threshold(data, std_factor=[1., 2.]) - assert threshold.shape == (2,) - assert threshold[1] == 2 * threshold[0] - - # Multiple threshold. - threshold = compute_threshold(data, single_threshold=False, std_factor=2.) - assert threshold.shape == (2, n_channels) - - threshold = compute_threshold(data, - single_threshold=False, - std_factor=(1., 2.)) - assert threshold.shape == (2, n_channels) - ac(threshold[1], 2 * threshold[0]) - - -def test_thresholder(): - n_samples, n_channels = 100, 12 - strong, weak = .1, .2 - - data = artificial_traces(n_samples, n_channels) - - # Positive and strong. - thresholder = Thresholder(mode='positive', - thresholds=strong) - ae(thresholder(data), data > strong) - - # Negative and weak. - thresholder = Thresholder(mode='negative', - thresholds={'weak': weak}) - ae(thresholder(data), data < -weak) - - # Both and strong+weak. - thresholder = Thresholder(mode='both', - thresholds={'weak': weak, - 'strong': strong, - }) - ae(thresholder(data, 'weak'), np.abs(data) > weak) - ae(thresholder(data, threshold='strong'), np.abs(data) > strong) - - # Multiple thresholds. - t = thresholder(data, ('weak', 'strong')) - ae(t['weak'], np.abs(data) > weak) - ae(t['strong'], np.abs(data) > strong) - - # Array threshold. - thre = np.linspace(weak - .05, strong + .05, n_channels) - thresholder = Thresholder(mode='positive', thresholds=thre) - t = thresholder(data) - assert t.shape == data.shape - ae(t, data > thre) - - -#------------------------------------------------------------------------------ -# Test connected components -#------------------------------------------------------------------------------ - -def _as_set(c): - if isinstance(c, np.ndarray): - c = c.tolist() - c = [tuple(_) for _ in c] - return set(c) - - -def _assert_components_equal(cc1, cc2): - assert len(cc1) == len(cc2) - for c1, c2 in zip(cc1, cc2): - assert _as_set(c1) == _as_set(c2) - - -def _test_components(chunk=None, components=None, **kwargs): - - def _clip(x, m, M): - return [_ for _ in x if m <= _ < M] - - n = 5 - probe_adjacency_list = {i: set(_clip([i - 1, i + 1], 0, n)) - for i in range(n)} - - if chunk is None: - chunk = [[0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 1, 1, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 1, 1], - ] - - if components is None: - components = [] - - if not isinstance(chunk, np.ndarray): - chunk = np.array(chunk) - strong_crossings = kwargs.pop('strong_crossings', None) - if (strong_crossings is not None and - not isinstance(strong_crossings, np.ndarray)): - strong_crossings = np.array(strong_crossings) - - comp = connected_components(chunk, - probe_adjacency_list=probe_adjacency_list, - strong_crossings=strong_crossings, - **kwargs) - _assert_components_equal(comp, components) - - -def test_components(): - # 1 time step, 1 element - _test_components([[0, 0, 0, 0, 0]], []) - - _test_components([[1, 0, 0, 0, 0]], [[(0, 0)]]) - - _test_components([[0, 1, 0, 0, 0]], [[(0, 1)]]) - - _test_components([[0, 0, 0, 1, 0]], [[(0, 3)]]) - - _test_components([[0, 0, 0, 0, 1]], [[(0, 4)]]) - - # 1 time step, 2 elements - _test_components([[1, 1, 0, 0, 0]], [[(0, 0), (0, 1)]]) - - _test_components([[1, 0, 1, 0, 0]], [[(0, 0)], [(0, 2)]]) - - _test_components([[1, 0, 0, 0, 1]], [[(0, 0)], [(0, 4)]]) - - _test_components([[0, 1, 0, 1, 0]], [[(0, 1)], [(0, 3)]]) - - # 1 time step, 3 elements - _test_components([[1, 1, 1, 0, 0]], [[(0, 0), (0, 1), (0, 2)]]) - - _test_components([[1, 1, 0, 1, 0]], [[(0, 0), (0, 1)], [(0, 3)]]) - - _test_components([[1, 0, 1, 1, 0]], [[(0, 0)], [(0, 2), (0, 3)]]) - - _test_components([[0, 1, 1, 1, 0]], [[(0, 1), (0, 2), (0, 3)]]) - - _test_components([[0, 1, 1, 0, 1]], [[(0, 1), (0, 2)], [(0, 4)]]) - - # 5 time steps, varying join_size - _test_components([ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 1, 1, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 1, 1], - ], [[(1, 2)], - [(2, 0)], - [(2, 2), (2, 3)], - [(3, 0)], - [(3, 3)], - [(4, 1)], - [(4, 3), (4, 4)], - ]) - - _test_components([ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 1, 1, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 1, 1], - ], [[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4)], - [(2, 0), (3, 0), (4, 1)]], join_size=1) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4), - (2, 0), (3, 0), (4, 1)]], join_size=2) - - # 5 time steps, strong != weak - _test_components(join_size=0, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]) - - _test_components(components=[[(1, 2)]], - join_size=0, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4)]], - join_size=1, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1]]) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4), - (2, 0), (3, 0), (4, 1)]], - join_size=2, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4), - (2, 0), (3, 0), (4, 1)]], - join_size=2, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 1, 0, 0, 0]]) - - -def test_flood_fill(): - - graph = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: []} - - ff = FloodFillDetector(probe_adjacency_list=graph, - join_size=1, - ) - - weak = [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 0], - [0, 0, 0, 1], - [0, 0, 0, 0], - ] - - # Weak - weak - comps = [[(1, 1), (1, 2)], - [(3, 2), (4, 2)], - [(4, 3)], - [(6, 3)], - ] - cc = ff(weak, weak) - _assert_components_equal(cc, comps) - - # Weak and strong - strong = [[0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 0, 0], - [0, 0, 0, 1], - [0, 0, 0, 0], - ] - - comps = [[(3, 2), (4, 2)], - [(4, 3)], - [(6, 3)], - ] - cc = ff(weak, strong) - _assert_components_equal(cc, comps) diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py deleted file mode 100644 index ea9c71ce4..000000000 --- a/phy/traces/tests/test_pca.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- - -"""PCA tests.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae - -from phy.io.tests.test_context import (ipy_client, context, # noqa - parallel_context) -from ...io.mock import artificial_waveforms, artificial_masks -from ..pca import PCA, _compute_pcs, _project_pcs - - -#------------------------------------------------------------------------------ -# Test PCA -#------------------------------------------------------------------------------ - -def test_compute_pcs(): - """Test PCA on a 2D array.""" - # Horizontal ellipsoid. - x = np.random.randn(20000, 2) * np.array([[10., 1.]]) - # Rotate the points by pi/4. - a = 1. / np.sqrt(2.) - rot = np.array([[a, -a], [a, a]]) - x = np.dot(x, rot) - # Compute the PCs. - pcs = _compute_pcs(x[..., None]) - assert pcs.ndim == 3 - assert (np.abs(pcs) - a).max() < 1e-2 - - -def test_compute_pcs_3d(): - """Test PCA on a 3D array.""" - x1 = np.random.randn(20000, 2) * np.array([[10., 1.]]) - x2 = np.random.randn(20000, 2) * np.array([[1., 10.]]) - x = np.dstack((x1, x2)) - # Compute the PCs. - pcs = _compute_pcs(x) - assert pcs.ndim == 3 - assert np.linalg.norm(pcs[0, :, 0] - np.array([-1., 0.])) < 1e-2 - assert np.linalg.norm(pcs[1, :, 0] - np.array([0., -1.])) < 1e-2 - assert np.linalg.norm(pcs[0, :, 1] - np.array([0, 1.])) < 1e-2 - assert np.linalg.norm(pcs[1, :, 1] - np.array([-1., 0.])) < 1e-2 - - -def test_project_pcs(): - n, ns, nc = 1000, 50, 100 - nf = 3 - arr = np.random.randn(n, ns, nc) - pcs = np.random.randn(nf, ns, nc) - - y1 = _project_pcs(arr, pcs) - assert y1.shape == (n, nc, nf) - - -class TestPCA(object): - def setup(self): - self.n_spikes = 100 - self.n_samples = 40 - self.n_channels = 12 - self.waveforms = artificial_waveforms(self.n_spikes, - self.n_samples, - self.n_channels) - self.masks = artificial_masks(self.n_spikes, self.n_channels) - - def _get_features(self): - pca = PCA() - pcs = pca.fit(self.waveforms, self.masks) - assert pcs.shape == (3, self.n_samples, self.n_channels) - return pca.transform(self.waveforms) - - def test_serial(self): - fet = self._get_features() - assert fet.shape == (self.n_spikes, self.n_channels, 3) - - def test_parallel(self, parallel_context): # noqa - - # Chunk the waveforms array. - from dask.array import from_array - chunks = (10, self.n_samples, self.n_channels) - waveforms = from_array(self.waveforms, chunks) - - # Compute the PCs in parallel. - pca = PCA(parallel_context) - pcs = pca.fit(waveforms, self.masks) - assert pcs.shape == (3, self.n_samples, self.n_channels) - fet = pca.transform(waveforms) - assert fet.shape == (self.n_spikes, self.n_channels, 3) - - # Check that the computed features are identical. - ae(fet, self._get_features()) diff --git a/phy/traces/tests/test_spike_detect.py b/phy/traces/tests/test_spike_detect.py deleted file mode 100644 index 526575de7..000000000 --- a/phy/traces/tests/test_spike_detect.py +++ /dev/null @@ -1,177 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of spike detection routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import fixture - -from phy.io.datasets import download_test_data -from phy.io.tests.test_context import (ipy_client, context, # noqa - parallel_context) -from phy.electrode import load_probe -from ..spike_detect import (SpikeDetector, - _spikes_to_keep, - _trim_spikes, - _add_chunk_offset, - _concat_spikes, - ) - - -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -@fixture -def traces(): - path = download_test_data('test-32ch-10s.dat') - traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) - traces = traces[:20000] - - return traces - - -@fixture(params=[(True,), (False,)]) -def spike_detector(request): - remap = request.param[0] - - probe = load_probe('1x32_buzsaki') - site_label_to_traces_row = ({i: i for i in range(1, 21, 2)} - if remap else None) - - sd = SpikeDetector() - sd.use_single_threshold = False - sample_rate = 20000 - sd.set_metadata(probe, - site_label_to_traces_row=site_label_to_traces_row, - sample_rate=sample_rate) - - return sd - - -#------------------------------------------------------------------------------ -# Test spike detection -#------------------------------------------------------------------------------ - -def _plot(sd, traces, spike_samples, masks): # pragma: no cover - from vispy.app import run - from phy.plot import plot_traces - plot_traces(sd.subset_traces(traces), - spike_samples=spike_samples, - masks=masks, - n_samples_per_spike=40) - run() - - -class TestConcat(object): - # [ * * 0 1 2 3 4 * * | * * 5 6 7 8 9 * * | * * 10 11 ] - # [ ! ! ! ! ! ] - # spike_samples: 1, 4, 5 - - def setup(self): - from dask.array import Array, from_array - - self.trace_chunks = ((5, 5, 2), (3,)) - self.depth = 2 - - # Create the chunked spike_samples array. - dask = {('spike_samples', 0): np.array([0, 3, 6]), - ('spike_samples', 1): np.array([2, 7]), - ('spike_samples', 2): np.array([]), - } - spikes_chunks = ((3, 2, 0),) - s = Array(dask, 'spike_samples', spikes_chunks, - shape=(5,), dtype=np.int32) - self.spike_samples = s - # Indices of the spikes that are kept (outside of overlapping bands). - self.spike_indices = np.array([1, 2, 3]) - - assert len(self.spike_samples.compute()) == 5 - - self.masks = from_array(np.arange(5 * 3).reshape((5, 3)), - spikes_chunks + (3,)) - self.waveforms = from_array(np.arange(5 * 3 * 2).reshape((5, 3, 2)), - spikes_chunks + (3, 2)) - - def test_spikes_to_keep(self): - indices = _spikes_to_keep(self.spike_samples, - self.trace_chunks, - self.depth) - onsets, offsets = indices - assert list(zip(onsets, offsets)) == [(1, 3), (0, 1), (0, 0)] - - def test_trim_spikes(self): - indices = _spikes_to_keep(self.spike_samples, - self.trace_chunks, - self.depth) - - # Trim the spikes. - spikes_trimmed = _trim_spikes(self.spike_samples, indices) - ae(spikes_trimmed.compute(), [3, 6, 2]) - - def test_add_chunk_offset(self): - indices = _spikes_to_keep(self.spike_samples, - self.trace_chunks, - self.depth) - spikes_trimmed = _trim_spikes(self.spike_samples, indices) - - # Add the chunk offsets to the spike samples. - self.spikes = _add_chunk_offset(spikes_trimmed, - self.trace_chunks, self.depth) - ae(self.spikes, [1, 4, 5]) - - def test_concat(self): - sc, mc, wc = _concat_spikes(self.spike_samples, - self.masks, - self.waveforms, - trace_chunks=self.trace_chunks, - depth=self.depth, - ) - sc = sc.compute() - mc = mc.compute() - wc = wc.compute() - - ae(sc, [1, 4, 5]) - ae(mc, self.masks.compute()[self.spike_indices]) - ae(wc, self.waveforms.compute()[self.spike_indices]) - - -def test_detect_simple(spike_detector, traces): - sd = spike_detector - - spike_samples, masks, _ = sd.detect(traces) - - n_channels = sd.n_channels - n_spikes = len(spike_samples) - - assert spike_samples.dtype == np.int64 - assert spike_samples.ndim == 1 - - assert masks.dtype == np.float32 - assert masks.ndim == 2 - assert masks.shape == (n_spikes, n_channels) - - # _plot(sd, traces, spike_samples, masks) - - -# # NOTE: skip for now to accelerate the test suite... -# def _test_detect_context(spike_detector, traces, parallel_context): # noqa -# sd = spike_detector -# sd.set_context(parallel_context) - -# spike_samples, masks, _ = sd.detect(traces) - -# n_channels = sd.n_channels -# n_spikes = len(spike_samples) - -# assert spike_samples.dtype == np.int64 -# assert spike_samples.ndim == 1 - -# assert masks.dtype == np.float32 -# assert masks.ndim == 2 -# assert masks.shape == (n_spikes, n_channels) -# # _plot(sd, traces, spike_samples.compute(), masks.compute()) From 25381c708abd2d5692740797a774e2ae169121f9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Mar 2016 15:30:01 +0100 Subject: [PATCH 1057/1059] Remove dask --- phy/__init__.py | 10 ---------- phy/io/array.py | 12 ------------ phy/io/tests/test_array.py | 13 ------------- 3 files changed, 35 deletions(-) diff --git a/phy/__init__.py b/phy/__init__.py index 5086ab6ab..70d279d89 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -66,16 +66,6 @@ def add_default_handler(level='INFO'): sys.argv.remove('--debug') -# Force dask to use the synchronous scheduler: we'll use ipyparallel -# manually for parallel processing. -try: - import dask.async - from dask import set_options - set_options(get=dask.async.get_sync) -except ImportError: # pragma: no cover - logger.debug("dask is not available.") - - def test(): # pragma: no cover """Run the full testing suite of phy.""" import pytest diff --git a/phy/io/array.py b/phy/io/array.py index 6661f1962..c79ba1561 100644 --- a/phy/io/array.py +++ b/phy/io/array.py @@ -250,18 +250,6 @@ def write_array(path, arr): """Write an array to a .npy file.""" file_ext = op.splitext(path)[1] if file_ext == '.npy': - try: - # Save a dask array into a .npy file chunk-by-chunk. - from dask.array import Array, store - if isinstance(arr, Array): - f = np.memmap(path, mode='w+', - dtype=arr.dtype, shape=arr.shape) - store(arr, f) - del f - except ImportError: # pragma: no cover - # We'll save the dask array normally: it works but it is less - # efficient since we need to load everything in memory. - pass return np.save(path, arr) raise NotImplementedError("The file extension `{}` ".format(file_ext) + "is not currently supported.") diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py index 2b69a43e1..66c767996 100644 --- a/phy/io/tests/test_array.py +++ b/phy/io/tests/test_array.py @@ -216,19 +216,6 @@ def test_read_write(tempdir): ae(read_array(path, mmap_mode='r'), arr) -def test_read_write_dask(tempdir): - from dask.array import from_array - arr = np.arange(10).astype(np.float32) - - arr_da = from_array(arr, ((5, 5),)) - - path = op.join(tempdir, 'test.npy') - - write_array(path, arr_da) - ae(read_array(path), arr) - ae(read_array(path, mmap_mode='r'), arr) - - #------------------------------------------------------------------------------ # Test virtual concatenation #------------------------------------------------------------------------------ From 8803542263fd66d64287df3da00714587122b0dd Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 22 Mar 2016 17:20:00 +0100 Subject: [PATCH 1058/1059] Update README --- README.md | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8ec57faa0..04ddc550d 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,40 @@ [![GitHub release](https://img.shields.io/github/release/kwikteam/phy.svg)](https://github.com/kwikteam/phy/releases/latest) [![Join the chat at https://gitter.im/kwikteam/phy](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/kwikteam/phy?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[**phy**](https://github.com/kwikteam/phy) is an open source electrophysiological data analysis package in Python for neuronal recordings made with high-density multielectrode arrays containing up to thousands of channels. +[**phy**](https://github.com/kwikteam/phy) is an open source neurophysiological data analysis package in Python. It provides features for sorting, analyzing, and visualizing extracellular recordings made with high-density multielectrode arrays containing hundreds to thousands of recording sites. -* [Documentation](http://phy.cortexlab.net) + +## Overview + +**phy** contains the following subpackages: + +* **phy.cluster.manual**: an API for manual sorting, used to create graphical interfaces for neurophysiological data +* **phy.gui**: a generic API for creating desktop applications with PyQt. +* **phy.plot**: a generic API for creating high-performance plots with VisPy (using the graphics processor via OpenGL) + +phy doesn't provide any I/O code. It only provides Python routines to process and visualize data. + + +## phy-contrib + +The [phy-contrib](https://github.com/kwikteam/phy-contrib) repo contains a set of plugins with integrated GUIs that work with dedicated automatic clustering software. Currently it provides: + +* **KwikGUI**: a manual sorting GUI that works with data processed with **klusta**, an automatic clustering package. +* **TemplateGUI**: a manual sorting GUI that works with data processed with **Spyking Circus** and **KiloSort** (not released yet), which are template-matching-based spike sorting algorithms. + + +## Getting started + +You will find installation instructions and a quick start guide in the [documentation](http://phy.readthedocs.org/en/latest/) (work in progress). + + +## Links + +* [Documentation](http://phy.readthedocs.org/en/latest/) (work in progress) +* [Mailing list](https://groups.google.com/forum/#!forum/phy-users) +* [Sample data repository](http://phy.cortexlab.net/data/) (work in progress) + + +## Credits + +**phy** is developed by [Cyrille Rossant](http://cyrille.rossant.net), [Shabnam Kadir](https://iris.ucl.ac.uk/iris/browse/profile?upi=SKADI56), [Dan Goodman](http://thesamovar.net/), [Max Hunter](https://iris.ucl.ac.uk/iris/browse/profile?upi=MLDHU99), and [Kenneth Harris](https://iris.ucl.ac.uk/iris/browse/profile?upi=KDHAR02), in the [Cortexlab](https://www.ucl.ac.uk/cortexlab), University College London. From 424a6960bf549091f88935864b09b1e8b3d498fb Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Wed, 23 Mar 2016 12:59:10 +0100 Subject: [PATCH 1059/1059] Remove appveyor for the time being --- appveyor.yml | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 appveyor.yml diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 6521ec916..000000000 --- a/appveyor.yml +++ /dev/null @@ -1,23 +0,0 @@ -# CI on Windows via appveyor -# This file was based on Olivier Grisel's python-appveyor-demo - -environment: - matrix: - - PYTHON: "C:\\Python34-conda64" - PYTHON_VERSION: "3.4" - PYTHON_ARCH: "64" -install: - - ps: Start-FileDownload 'http://repo.continuum.io/miniconda/Miniconda-latest-Windows-x86_64.exe' - - ps: .\Miniconda-latest-Windows-x86_64.exe /RegisterPython=1 /S /D="$Home\miniconda3" | Out-Null - - ps: $env:Path += ";$Home\miniconda3\;$Home\miniconda3\Scripts" - - ps: conda config --set ssl_verify false - - ps: conda env create python=3.4 - - ps: activate phy - - ps: conda install -c kwikteam klustakwik2 -y - - ps: conda config --set ssl_verify true - - ps: pip install -r requirements-dev.txt - - ps: pip install -e . -build: false # Not a C# project, build stuff at the test step instead. -test_script: - # Run the project tests - - py.test phy