diff --git a/.coveragerc b/.coveragerc index 22e6810ed..e7222f912 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,12 @@ source = phy omit = */phy/ext/* */phy/utils/tempdir.py + */default_settings.py + +[report] +exclude_lines = + pragma: no cover + raise AssertionError + raise NotImplementedError + pass + return diff --git a/.gitignore b/.gitignore index b7aba4db4..2e500629a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,24 @@ +contrib data doc +phy-doc docker experimental htmlcov format wiki +.cache .idea .ipynb_checkpoints .*fuse* *.orig +*.log* .eggs +.profile __pycache__ - +_old *.py[cod] +*~ .coverage* *credentials diff --git a/.travis.yml b/.travis.yml index a473fb4b8..633218bd2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,7 +2,6 @@ language: python sudo: false python: - "2.7" - - "3.4" - "3.5" before_install: - pip install codecov @@ -19,12 +18,8 @@ install: - conda update -q conda - conda info -a # Create the environment. - - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION - - source activate testenv - - conda install pip numpy vispy matplotlib scipy h5py pyqt pyzmq ipython requests six - # NOTE: cython is only temporarily needed for building KK2. - # Once we have KK2 binary builds on binstar we can remove this dependency. - - conda install cython && pip install klustakwik2 + - conda env create python=$TRAVIS_PYTHON_VERSION + - source activate phy # Dev requirements - pip install -r requirements-dev.txt - pip install -e . diff --git a/CONTRIBUTE.md b/CONTRIBUTE.md index d717d3c79..c4446d498 100644 --- a/CONTRIBUTE.md +++ b/CONTRIBUTE.md @@ -7,37 +7,35 @@ Please read this entire document before contributing any code. On your development computer: * Use the very latest Anaconda release (`conda update conda anaconda`). -* Use a special `phy` conda environment based on the latest Python 3.x (3.4 at the time of writing). +* Use a special `phy` conda environment based on Python 3.5. * Have another `phy2` clone environment based on Python 2.7. -* phy only supports Python 2.7, and Python 3.4+. -* Use `six` for writing compatible code (see [the documentation here](http://pythonhosted.org/six/)) -* You need the following dependencies for development (not required for using phy): pytest, pip, flake8, coverage, coveralls. -* For IPython, use the IPython git `master` branch (or version 3.0 when it will be released in early 2015). +* Install the dev dependencies: `pip install -r requirements-dev.txt` +* phy only supports Python 2.7 and Python 3.4+. +* Use the `six` library for writing compatible code (see [the documentation here](http://pythonhosted.org/six/)) A few rules: * Every module `phy/mypackage/mymodule.py` must come with a `phy/mypackage/tests/test_mymodule.py` test module that contains a bunch of `test_*()` functions. * Never import test modules in the main code. -* Do not import packages from `phy/__init__.py`. Every subpackage `phy.stuff` will need to be imported explicitly by the user. Dependencies required by this subpackage will only be loaded when the subpackage is loaded. This ensures that users can use `phy.subpackageA` without having to install the dependencies required by `phy.subpackageB`. -* phy's required dependencies are: numpy. Every subpackage can come with further dependencies. For example, `phy.io.kwik` depends on h5py. +* In general, do not import packages from `phy/__init__.py`. Every subpackage `phy.stuff` will need to be imported explicitly by the user. Dependencies required by this subpackage will only be loaded when the subpackage is loaded. This ensures that users can use `phy.subpackageA` without having to install the dependencies required by `phy.subpackageB`. +* phy's required dependencies are: pip, traitlets, click, numpy. Every subpackage may come with further dependencies. * You can experiment with ideas and prototypes in the `kwikteam/experimental` repo. Use a different folder for every experiment. -* `kwikteam/phy` will only contain a `master` branch and release tags. There should be no experimental/debugging code in the entire repository. +* `kwikteam/phy` will only contain a `master` branch, release tags, and possibly one refactoring branch. There should be no experimental/debugging code in the entire repository. ### GitHub flow -* Work through PRs from `yourfork/specialbranch` against `phy/master` exclusively. +* Work through PRs from `yourfork/specialbranch` against `phy/master` or `phy/kill-kwik` exclusively. * Set `upstream` to `kwikteam/phy` and `origin` to your fork. * When master and your PR's branch are out of sync, [rebase your branch in your fork](https://groups.google.com/forum/#!msg/vispy-dev/q-UNjxburGA/wYNkZRXiySwJ). * Two-pairs-of-eyes rule: every line of code needs to be reviewed by 2 people, including the author. * Never merge your own PR to the main phy repository, unless in exceptional circumstances. * A PR is assumed to be **not ready for merge** unless explicitly stated otherwise. * Always run `make test` before stating that a PR is ready to merge (and ideally before pushing on your PR's branch). -* We try to have a code coverage close to 100%: always test all features you implement, and verify through code coverage that all lines are covered by your tests. +* We try to have a code coverage close to 100%: always test all features you implement, and verify through code coverage that all lines are covered by your tests. Use `#pragma: no cover` comments for lines that don't absolutely need to be covered (for example, rare exception-raising code). * Always wait for Travis to be green before merging. * `phy/master` should always be stable and deployable. * Use [semantic versioning](http://www.semver.org) for stable releases. -* Do not make too many releases until the software is mature enough. Early adopters can work directly off `master`. * We follow almost all [PEP8 rules](https://www.python.org/dev/peps/pep-0008/), except [for a few exceptions](https://github.com/kwikteam/phy/blob/master/Makefile#L24). @@ -55,10 +53,10 @@ A few rules: Make sure your text editor is configured to: -* automatically clean blank lines (i.e. no lines containing only whitespace) -* use four spaces per indent level, and **never** tab indents -* enforce an empty blank line at the end of every text file -* display a vertical ruler at 79 characters (length limit of every line) +* Automatically clean blank lines (i.e. no lines containing only whitespace) +* Use four spaces per indent level, and **never** tab indents +* Enforce an empty blank line at the end of every text file +* Display a vertical ruler at 79 characters (length limit of every line) Below is a settings script for the popular text editor [Sublime](http://www.sublimetext.com) which you can put into your ```Preferences -> Settings (User)```: diff --git a/MANIFEST.in b/MANIFEST.in index 2531a5bd8..35bff6167 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,7 +4,8 @@ include README.md recursive-include tests * recursive-include phy/electrode/probes *.prb recursive-include phy/plot/glsl *.vert *.frag *.glsl +recursive-include phy/plot/static *.npy *.gz *.txt recursive-include phy/cluster/manual/static *.html *.css -recursive-include phy/gui/static *.html *.css +recursive-include phy/gui/static *.html *.css *.js recursive-exclude * __pycache__ recursive-exclude * *.py[co] diff --git a/Makefile b/Makefile index fb4d079dd..1fa8cea14 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,3 @@ -help: - @echo "clean - remove all build, test, coverage and Python artifacts" - @echo "clean-build - remove build artifacts" - @echo "clean-pyc - remove Python file artifacts" - @echo "lint - check style with flake8" - @echo "test - run tests quickly with the default Python" - @echo "release - package and upload a release" - @echo "apidoc - build API doc" - -clean: clean-build clean-pyc - clean-build: rm -fr build/ rm -fr dist/ @@ -20,21 +9,22 @@ clean-pyc: find . -name '*~' -exec rm -f {} + find . -name '__pycache__' -exec rm -fr {} + +clean: clean-build clean-pyc + lint: flake8 phy test: lint - py.test + py.test --cov-report term-missing --cov=phy phy coverage: coverage --html -test-quick: lint - python setup.py test -a "-m \"not long\" phy" - -unit-tests: lint - python setup.py test -a phy +apidoc: + python tools/api.py -integration-tests: lint - python setup.py test -a tests +build: + python setup.py sdist --formats=zip +upload: + python setup.py sdist --formats=zip upload diff --git a/README.md b/README.md index 85482706f..04ddc550d 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,47 @@ # phy project -[![Build Status](https://travis-ci.org/kwikteam/phy.svg?branch=master)](https://travis-ci.org/kwikteam/phy) -[![Build Status](https://ci.appveyor.com/api/projects/status/fuoyuo113domjplr/branch/master?svg=true)](https://ci.appveyor.com/project/kwikteam/phy/) -[![codecov.io](https://img.shields.io/codecov/c/github/kwikteam/phy.svg?)](http://codecov.io/github/kwikteam/phy?branch=master) +[![Build Status](https://img.shields.io/travis/kwikteam/phy.svg)](https://travis-ci.org/kwikteam/phy) +[![Build Status](https://img.shields.io/appveyor/ci/kwikteam/phy.svg)](https://ci.appveyor.com/project/kwikteam/phy/) +[![codecov.io](https://img.shields.io/codecov/c/github/kwikteam/phy.svg)](http://codecov.io/github/kwikteam/phy?branch=master) [![Documentation Status](https://readthedocs.org/projects/phy/badge/?version=latest)](https://readthedocs.org/projects/phy/?badge=latest) [![PyPI release](https://img.shields.io/pypi/v/phy.svg)](https://pypi.python.org/pypi/phy) +[![GitHub release](https://img.shields.io/github/release/kwikteam/phy.svg)](https://github.com/kwikteam/phy/releases/latest) [![Join the chat at https://gitter.im/kwikteam/phy](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/kwikteam/phy?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[**phy**](https://github.com/kwikteam/phy) is an open source electrophysiological data analysis package in Python for neuronal recordings made with high-density multielectrode arrays containing up to thousands of channels. +[**phy**](https://github.com/kwikteam/phy) is an open source neurophysiological data analysis package in Python. It provides features for sorting, analyzing, and visualizing extracellular recordings made with high-density multielectrode arrays containing hundreds to thousands of recording sites. -* [Documentation](http://phy.cortexlab.net) + +## Overview + +**phy** contains the following subpackages: + +* **phy.cluster.manual**: an API for manual sorting, used to create graphical interfaces for neurophysiological data +* **phy.gui**: a generic API for creating desktop applications with PyQt. +* **phy.plot**: a generic API for creating high-performance plots with VisPy (using the graphics processor via OpenGL) + +phy doesn't provide any I/O code. It only provides Python routines to process and visualize data. + + +## phy-contrib + +The [phy-contrib](https://github.com/kwikteam/phy-contrib) repo contains a set of plugins with integrated GUIs that work with dedicated automatic clustering software. Currently it provides: + +* **KwikGUI**: a manual sorting GUI that works with data processed with **klusta**, an automatic clustering package. +* **TemplateGUI**: a manual sorting GUI that works with data processed with **Spyking Circus** and **KiloSort** (not released yet), which are template-matching-based spike sorting algorithms. + + +## Getting started + +You will find installation instructions and a quick start guide in the [documentation](http://phy.readthedocs.org/en/latest/) (work in progress). + + +## Links + +* [Documentation](http://phy.readthedocs.org/en/latest/) (work in progress) +* [Mailing list](https://groups.google.com/forum/#!forum/phy-users) +* [Sample data repository](http://phy.cortexlab.net/data/) (work in progress) + + +## Credits + +**phy** is developed by [Cyrille Rossant](http://cyrille.rossant.net), [Shabnam Kadir](https://iris.ucl.ac.uk/iris/browse/profile?upi=SKADI56), [Dan Goodman](http://thesamovar.net/), [Max Hunter](https://iris.ucl.ac.uk/iris/browse/profile?upi=MLDHU99), and [Kenneth Harris](https://iris.ucl.ac.uk/iris/browse/profile?upi=KDHAR02), in the [Cortexlab](https://www.ucl.ac.uk/cortexlab), University College London. diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index a999f35e9..000000000 --- a/appveyor.yml +++ /dev/null @@ -1,24 +0,0 @@ -# CI on Windows via appveyor -# This file was based on Olivier Grisel's python-appveyor-demo - -environment: - - matrix: - - PYTHON: "C:\\Python34-conda64" - PYTHON_VERSION: "3.4" - PYTHON_ARCH: "64" - -install: - # Install app from latest script - - "powershell iex ((new-object net.webclient).DownloadString('http://phy.cortexlab.net/install/latest.ps1'))" - - "SET PATH=%HOMEPATH%\\miniconda3;%HOMEPATH%\\miniconda3\\Scripts;%PATH%" - - # We want to use the git version of phy, so let's uninstall the pip version - - "pip uninstall phy --yes" - - "pip install -r requirements-dev.txt" - - "python setup.py develop" -build: false # Not a C# project, build stuff at the test step instead. - -test_script: - # Run the project tests - - py.test phy -m "not long" diff --git a/conftest.py b/conftest.py deleted file mode 100644 index dc101ae01..000000000 --- a/conftest.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- - -"""py.test utilities.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os - -import numpy as np -from pytest import yield_fixture - -from phy.electrode.mea import load_probe -from phy.io.mock import artificial_traces -from phy.utils._types import Bunch -from phy.utils.tempdir import TemporaryDirectory -from phy.utils.settings import _load_default_settings -from phy.utils.datasets import download_test_data - - -#------------------------------------------------------------------------------ -# Common fixtures -#------------------------------------------------------------------------------ - -@yield_fixture -def tempdir(): - with TemporaryDirectory() as tempdir: - yield tempdir - - -@yield_fixture -def chdir_tempdir(): - curdir = os.getcwd() - with TemporaryDirectory() as tempdir: - os.chdir(tempdir) - yield tempdir - os.chdir(curdir) - - -@yield_fixture -def tempdir_bis(): - with TemporaryDirectory() as tempdir: - yield tempdir - - -@yield_fixture(params=['null', 'artificial', 'real']) -def raw_dataset(request): - sample_rate = 20000 - params = _load_default_settings()['spikedetekt'] - data_type = request.param - if data_type == 'real': - path = download_test_data('test-32ch-10s.dat') - traces = np.fromfile(path, dtype=np.int16).reshape((200000, 32)) - traces = traces[:45000] - n_samples, n_channels = traces.shape - params['use_single_threshold'] = False - probe = load_probe('1x32_buzsaki') - else: - probe = {'channel_groups': { - 0: {'channels': [0, 1, 2], - 'graph': [[0, 1], [0, 2], [1, 2]], - }, - 1: {'channels': [3], - 'graph': [], - 'geometry': {3: [0., 0.]}, - } - }} - if data_type == 'null': - n_samples, n_channels = 25000, 4 - traces = np.zeros((n_samples, n_channels)) - elif data_type == 'artificial': - n_samples, n_channels = 25000, 4 - traces = artificial_traces(n_samples, n_channels) - traces[5000:5010, 1] *= 5 - traces[15000:15010, 3] *= 5 - n_samples_w = params['extract_s_before'] + params['extract_s_after'] - yield Bunch(n_channels=n_channels, - n_samples=n_samples, - sample_rate=sample_rate, - n_samples_waveforms=n_samples_w, - traces=traces, - params=params, - probe=probe, - ) diff --git a/phy/cluster/algorithms/tests/__init__.py b/docs/analysis.md similarity index 100% rename from phy/cluster/algorithms/tests/__init__.py rename to docs/analysis.md diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 000000000..0de7ede1d --- /dev/null +++ b/docs/api.md @@ -0,0 +1,3 @@ +# API reference + +TODO. In the meantime, see the code directly on GitHub. diff --git a/docs/cli.md b/docs/cli.md new file mode 100644 index 000000000..2b42a372e --- /dev/null +++ b/docs/cli.md @@ -0,0 +1,63 @@ +# CLI + +When you install phy, a command-line tool named `phy` is installed: + +```bash +$ phy +Usage: phy [OPTIONS] COMMAND [ARGS]... + + By default, the `phy` command does nothing. Add subcommands with plugins + using `attach_to_cli()` and the `click` library. + +Options: + --version Show the version and exit. + -h, --help Show this message and exit. +``` + +This command doesn't do anything by default, but it serves as an entry-point for your applications. + +## Adding a subcommand + +A subcommand is called with `phy subcommand ...` from the command-line. To create a subcommand, create a new plugin, and implement the `attach_to_cli(cli)` method. This uses the [click](http://click.pocoo.org/5/) library. + +Here is an example. Create a file in `~/.phy/plugins/hello.py` and write the following: + +``` +from phy import IPlugin +import click + + +class MyPlugin(IPlugin): + def attach_to_cli(self, cli): + @cli.command('hello') + @click.argument('name') + def hello(name): + print("Hello %s!" % name) +``` + +Then, type the following in a system shell: + +```bash +$ phy +Usage: phy [OPTIONS] COMMAND [ARGS]... + + By default, the `phy` command does nothing. Add subcommands with plugins + using `attach_to_cli()` and the `click` library. + +Options: + --version Show the version and exit. + -h, --help Show this message and exit. + +Commands: + hello + +$ phy hello +Usage: phy hello [OPTIONS] NAME + +Error: Missing argument "name". + +$ phy hello world +Hello world! +``` + +When the `phy` CLI is created, the `attach_to_cli(cli)` method of all discovered plugins are called. Refer to the click documentation to create subcommands with phy. diff --git a/docs/cluster-manual.md b/docs/cluster-manual.md new file mode 100644 index 000000000..e7cdf46d7 --- /dev/null +++ b/docs/cluster-manual.md @@ -0,0 +1,287 @@ +# Manual clustering + +The `phy.cluster.manual` package provides manual clustering routines. The components can be used independently in a modular way. + +## Clustering + +The `Clustering` class implements the logic of assigning clusters to spikes as a succession of undoable merge and split operations. Also, it provides efficient methods to retrieve the set of spikes belonging to one or several clusters. + +Create an instance with `clustering = Clustering(spike_clusters)` where `spike_clusters` is an `n_spikes`-long array containing the cluster number of every spike. + +Notable properties are: + +* `clustering.spikes_per_cluster`: a dictionary `{cluster_id: spike_ids}`. +* `clustering.cluster_ids`: array of all non-empty clusters +* `clustering.spike_counts`: dictionary with the number of spikes in each cluster + +Notable methods are: + +* `clustering.new_cluster_id()`: generate a new unique cluster id +* `clustering.spikes_in_clusters(cluster_ids)`: return the array of spike ids belonging to a set of clusters. +* `clustering.merge(cluster_ids)`: merge some clusters. +* `clustering.split(spike_ids)`: create a new cluster from a set of spikes. **Note**: this will change the cluster ids of all affected clusters. For example, splitting a single spike belonging to cluster 10 containing 100 spikes leads to the deletion of cluster 10, and the creation of clusters N (99 spikes) and N+1 (1 spike). This is to ensure that cluster ids are always unique. +* `clustering.undo()`: undo the last operation. +* `clustering.redo()`: redo the last operation. + +### UpdateInfo + +Every clustering action returns an `UpdateInfo` object and emits a `cluster` event. + +An `UpdateInfo` object is a `Bunch` instance (dictionary with dotted attribute access) with several keys, including: + +* `description`: can be `merge` or `assign` +* `history`: can be `None` (default), `'undo'`, or `'redo'` +* `added`: list of new clusters +* `deleted`: list of removed clusters + +A `Clustering` object emits the `cluster` event after every clustering action (including undo and redo). To register a callback function to this event, use the `connect()` method. + +Here is a complete example: + +```python +>>> import numpy as np +>>> from phy.cluster.manual import Clustering +``` + +```python +>>> clustering = Clustering(np.arange(5)) +``` + +```python +>>> @clustering.connect +... def on_cluster(up): +... print("A %s just occurred." % up.description) +``` + +```python +>>> up = clustering.merge([0, 1, 2]) +A merge just occurred. +``` + +```python +>>> clustering.cluster_ids +array([3, 4, 5]) +``` + +```python +>>> clustering.spike_counts +{3: 1, 4: 1, 5: 3} +``` + +```python +>>> for key, val in up.items(): +... print(key, "=", val) +undo_state = None +description = merge +metadata_value = None +spike_ids = [0 1 2] +old_spikes_per_cluster = {0: array([0]), 1: array([1]), 2: array([2])} +descendants = [(0, 5), (1, 5), (2, 5)] +added = [5] +new_spikes_per_cluster = {5: array([0, 1, 2])} +history = None +metadata_changed = [] +deleted = [0, 1, 2] +``` + +## Cluster metadata + +The `ClusterMeta` class implement the logic of assigning metadata to every cluster (for example, a cluster group) as a succession of undoable operations. + +Here is an example. + +```python +>>> from phy.cluster.manual import ClusterMeta +>>> cm = ClusterMeta() +``` + +```python +>>> cm.add_field('group', default_value='unsorted') +``` + +```python +>>> cm.get('group', 3) +'unsorted' +``` + +```python +>>> cm.set('group', 3, 'good') + good> +``` + +```python +>>> cm.set('group', 3, 'bad') + bad> +``` + +```python +>>> cm.get('group', 3) +'bad' +``` + +```python +>>> cm.undo() + bad> +``` + +```python +>>> cm.get('group', 3) +'good' +``` + +You can import and export data from a dictionary using the `to_dict()` and `from_dict()` methods. + +```python +>>> cm.to_dict('group') +{3: 'good'} +``` + +## Views + +There are several views typically associated with manual clustering operations. + +### Waveform view + +The waveform view displays action potentials across all channels, following the probe geometry. + +### Feature view + +The feature view shows the principal components of spikes across multiple dimensions. + +### Trace view + +The trace view shows the continuous traces from multiple channels with spikes superimposed. The spikes are in white except those belonging to the selected clusters, which are in the colors of the clusters. + +### Correlogram view + +The correlogram view computes and shows all pairwise correlograms of a set of clusters. + +## Manual clustering GUI component + +The `ManualClustering` component encapsulates all the logic for a manual clustering GUI: + +* cluster views +* selection of clusters +* navigation with a wizard +* clustering actions: merge, split, undo stack +* moving clusters to groups + +Create an object with `mc = ManualClustering(spike_clusters)`. Then you can attach it to a GUI to bring manual clustering facilities to the GUI: `mc.attach(gui)`. This adds the manual clustering actions and the two tables to the GUI: the cluster view and the similarity view. + +The main objects are the following: + +`mc.clustering`: a `Clustering` instance +`mc.cluster_meta`: a `ClusterMeta` instance +`mc.cluster_view`: the cluster view (derives from `Table`) +`mc.similarity_view`: the similarity view (derives from `Table`) +`mc.actions`: the clustering actions (instance of `Actions`) + +Use `gui.request('manual_clustering')` to get the `ManualClustering` instance inside the `attach_to_gui(gui, model=None, state=None)` method of a GUI plugin. + +### Cluster and similarity view + +The cluster view shows the list of all clusters with their ids, while the similarity view shows the list of all clusters sorted by decreasing similarity wrt the currently-selected clusters in the cluster view. + +You can add a new column in both views as follows: + +```python +>>> @mc.add_column +... def n_spikes(cluster_id): +... return mc.clustering.spike_counts[cluster_id] +``` + +The similarity view has an additional column compared to the cluster view: `similarity` with respect to the currently-selected clusters in the cluster view. + +See also the following methods: + +* `mc.set_default_sort(name)`: set a column as default sort in the quality cluster view +* `mc.set_similarity_func(func)`: set a similarity function for the similarity view + +### Cluster selection + +The `ManualClustering` instance is responsible for the selection of the clusters. + +* `mc.select(cluster_ids)`: select some clusters +* `mc.selected`: list of currently-selected clusters + +When the selection changes, the attached GUI raises the `select(cluster_ids, spike_ids)` event. + +Other events are `cluster(up)` when a clustering action occurs, and `request_save(spike_clusters, cluster_groups)` when the user wants to save the results of the manual clustering session. + +## Cluster store + +The **cluster store** contains a library of functions computing data and statistics for every cluster. These functions are cached on disk and possibly in memory. A `ClusterStore` instance is initialized with a `Context` which provides the caching facilities. You can add new functions with the `add(f)` method/decorator. + +The `create_cluster_store()` function creates a cluster store with a built-in library of functions that take the data from a model (for example, the `KwikModel` that works with the Kwik format). + +Use `gui.request('cluster_store')` to get the cluster store instance inside the `attach_to_gui(gui, model=None, state=None)` method of a GUI plugin. + +## GUI plugins + +You can create plugins to customize the manual clustering GUI. Here is a complete example showing how to change the quality and similarity measures. Put the following in `~/.phy/phy_config.py`: + +```python + +import numpy as np +from phy import IPlugin +from phy.cluster.manual import get_closest_clusters + + +# We write the plugin directly in the config file here, for simplicity. +# When dealing with more plugins it is a better practice to put them +# in separate files in ~/.phy/plugins/ or in your own repo that you can +# refer to in c.Plugins.dirs = ['/path/to/myplugindir']. +class MyPlugin(IPlugin): + def attach_to_gui(self, gui, model=None, state=None): + + # We can get GUI components with `gui.request(name)`. + # These are the two main components. There is also `context` to + # deal with the cache and parallel computing context. + mc = gui.request('manual_clustering') + cs = gui.request('cluster_store') + + # We add a column in the cluster view and set it as the default. + @mc.add_column(default=True) + @cs.add(cache='memory') + def mymeasure(cluster_id): + # This function takes a cluster id as input and returns a scalar. + + # We retrieve the spike_ids and waveforms for that cluster. + # spike_ids is a (n_spikes,) array. + # waveforms is a (n_spikes, n_samples, n_channels) array. + spike_ids, waveforms = cs.waveforms(cluster_id) + return waveforms.max() + + def mysim(cluster_0, cluster_1): + # This function returns a score for every pair of clusters. + + # Here we compute a distance between the mean masks. + m0 = cs.mean_masks(cluster_0) # (n_channels,) array + m1 = cs.mean_masks(cluster_1) # (n_channels,) array + distance = np.sum((m1 - m0) ** 2) + + # We need to convert the distance to a score: higher = better + # similarity. + score = -distance + return score + + # We set the similarity function. + @mc.set_similarity_func + @cs.add(cache='memory') + def myclosest(cluster_id): + """This function returns the list of closest clusters as + a list of `(cluster, sim)` pairs. + + By default, the 20 closest clusters are kept. + + """ + return get_closest_clusters(cluster_id, model.cluster_ids, mysim) + + +# Now we set the config object. +c = get_config() + +# Here we say that we always want to load our plugin in the KwikGUI. +c.KwikGUI.plugins = ['MyPlugin'] + +``` diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 000000000..09ed3b89e --- /dev/null +++ b/docs/config.md @@ -0,0 +1,31 @@ +# Configuration and plugin system + +phy uses part of the **traitlets** package for its config system. **This is still a work in progress**. + +## Configuration file + +It is in `~/.phy/phy_config.py`. It is a regular Python file. It should begin with `c = get_config()` with no import (this function is injected in the namespace automatically by the config system). + +Then, you can set configuration options as follows: + +`c.SomeClass.some_param = some_value` + +## Plugin system + +A plugin is a Python class deriving from `phy.IPlugin`. To ensure that phy knows about your plugin, just make sure that your class is imported in the Python namespace. + +Here are two common methods: + +* Implement your plugin in a Python file and put this file in `~/.phy/plugins/`: it will be automatically discovered by phy. +* Edit `c.Plugins.dirs = ['/path/to/folder']` in your `phy_config.py` file: all Python scripts there will be automatically imported. + +Here is a minimal plugin template: + +```python +from phy import IPlugin + +class MyPlugin(IPlugin): + pass +``` + +We'll see in the next sections what methods you can implement in your plugins, and how to use them in your applications. diff --git a/docs/faq.md b/docs/faq.md new file mode 100644 index 000000000..8b72fdadc --- /dev/null +++ b/docs/faq.md @@ -0,0 +1,46 @@ +# Frequently Asked Questions + +Read this section to understand what phy is and is not. + +## What file formats does phy support? + +None, but [phycontrib](https://github.com/kwikteam/phy-contrib) contains a set of plugins supporting file formats like the Kwik format. + +If a file format becomes a widely-used standard in the future, we might support it directly in phy. + +## Are there ready-to-use scripts and GUIs for spike sorting in phy? + +No, but there are some in [phycontrib](https://github.com/kwikteam/phy-contrib). + +## Can you add feature X? + +No. + +In principle, you should be able to implement it by writing a plugin, and if not, we'd be happy to help. + +If the majority of our users are desperately asking one particular feature, we might consider implementing it. + +Here is a nice explanation of why *No* should be the default answer here (from the *Getting Real* book by 37signals): + +> **Each time you say yes to a feature, you're adopting a child**. You have to take your baby through a whole chain of events (e.g. design, implementation, testing, etc.). And once that feature's out there, you're stuck with it. Just try to take a released feature away from customers and see how pissed off they get. + +> Make each feature work hard to be implemented. Make each feature prove itself and show that it's a survivor. **It's like "Fight Club." You should only consider features if they're willing to stand on the porch for three days waiting to be let in.** + +> That's why you start with no. Every new feature request that comes to us – or from us – meets a no. We listen but don't act. **The initial response is "not now."** If a request for a feature keeps coming back, that's when we know it's time to take a deeper look. Then, and only then, do we start considering the feature for real. + +> And what do you say to people who complain when you won't adopt their feature idea? Remind them why they like the app in the first place. "You like it because we say no. You like it because it doesn't do 100 other things. You like it because it doesn't try to please everyone all the time." + +References: + +* [Why you should write buggy software with as few features as possible](https://astrocompute.wordpress.com/2013/07/11/why-you-should-write-buggy-software-with-as-few-features-as-possible-no-really/) +* [Getting Real](https://basecamp.com/books/Getting%20Real.pdf) + +## Where can I get some help? + +[On the mailing list](https://groups.google.com/forum/#!forum/phy-users). + +## Who is developing phy? + +Cyrille Rossant, Max Hunter, aided by Shabnam Kadir, Nick Steinmetz, from the Cortexlab (University College London), led by Kenneth Harris and Matteo Carandini. + +## How do I cite phy? diff --git a/docs/gui.md b/docs/gui.md new file mode 100644 index 000000000..1a28d5aa1 --- /dev/null +++ b/docs/gui.md @@ -0,0 +1,289 @@ +# GUI + +`phy.gui` provides generic Qt-based GUI components. You don't need to know Qt to use `phy.gui`, although it might help. + +## Creating a Qt application + +You need to create a Qt application before creating and using GUIs. There is a single Qt application object in a Python interpreter. + +In IPython (console or notebook), you can just use the following magic command before doing anything Qt-related: + +```python +>>> %gui qt +``` + +In other situations, like in regular Python scripts, you need to: + +* Call `phy.gui.create_app()` once, before you create a GUI. +* Call `phy.gui.run_app()` to launch your application. This blocks the Python interpreter and runs the Qt event loop. Generally, when this call returns, the application exits. + +For interactive use and explorative work, it is highly recommended to use IPython, for example with a Jupyter Notebook. + +## Creating a GUI + +phy provides a **GUI**, a main window with dockable widgets (`QMainWindow`). By default, a GUI is empty, but you can add views. A view is any Qt widget or a matplotlib or VisPy canvas. + +Let's create an empty GUI: + +```python +>>> from phy.gui import GUI +>>> gui = GUI(position=(400, 200), size=(600, 400)) +>>> gui.show() +``` + +## Adding a visualization + +We can add any Qt widget with `gui.add_view(widget)`, as well as visualizations with VisPy or matplotlib (which are fully compatible with Qt and phy). + +### With VisPy + +The `gui.add_view()` method accepts any VisPy canvas. For example, here we add an empty VisPy window: + +```python +>>> from vispy.app import Canvas +>>> from vispy import gloo +... +>>> c = Canvas() +... +>>> @c.connect +... def on_draw(e): +... gloo.clear('purple') +... +>>> gui.add_view(c) + +``` + +We can now dock and undock our widget from the GUI. This is particularly convenient when there are many widgets. + +### With matplotlib + +Here we add a matplotlib figure to our GUI: + +```python +>>> import numpy as np +>>> import matplotlib.pyplot as plt +... +>>> f = plt.figure() +>>> ax = f.add_subplot(111) +>>> t = np.linspace(-10., 10., 1000) +>>> ax.plot(t, np.sin(t)) +>>> gui.add_view(f) + +``` + +## Adding an HTML widget + +phy provides an `HTMLWidget` component which allows you to create widgets in HTML. This is just a `QWebView` with some user-friendly facilities. + +First, let's create a standalone HTML widget: + +```python +>>> from phy.gui import HTMLWidget +>>> widget = HTMLWidget() +>>> widget.set_body("Hello world!") +>>> widget.show() +``` + +Now that our widget is created, let's add it to the GUI: + +```python +>>> gui.add_view(widget) + +``` + +You'll find in the API reference other methods to edit the styles, scripts, header, and body of the HTML widget. + +### Table + +phy also provides a `Table` widget written in HTML and Javascript (using the [tablesort](https://github.com/tristen/tablesort) Javascript library). This widget shows a table of items, where every item (row) has an id, and every column is defined as a function `id => string`, the string being the contents of a row's cell in the table. The table can be sorted by every column. + +One or several items can be selected by the user. The `select` event is raised when rows are selected. Here is a complete example: + +```python +>>> from phy.gui import Table +>>> table = Table() +... +>>> # We add a column in the table. +... @table.add_column +... def name(id): +... # This function takes an id as input and returns a string. +... return "My id is %d" % id +... +>>> # Now we add some rows. +... table.set_rows([2, 3, 5, 7]) +... +>>> # We print something when items are selected. +... @table.connect_ +... def on_select(ids): +... # NOTE: we use `connect_` and not `connect`, because `connect` is +... # a Qt method associated to every Qt widget, and `Table` is a subclass +... # of `QWidget`. Using `connect_` ensures that we're using phy's event +... # system, not Qt's. +... print("The items %s have been selected." % ids) +... +>>> table.show() +The items [3] have been selected. +The items [3, 5] have been selected. +``` + +### Interactivity with Javascript + +We can use Javascript in an HTML widget, and we can make Python and Javascript communicate. + +```python +>>> from phy.gui import HTMLWidget +>>> widget = HTMLWidget() +>>> widget.set_body('
') +>>> # We can execute Javascript code from Python. +... widget.eval_js("document.getElementById('mydiv').innerHTML='hello'") +>>> widget.show() +>>> gui.add_view(widget) + +``` + +You can use `widget.eval_js()` to evaluate Javascript code from Python. Conversely, you can use `widget.some_method()` from Javascript, where `some_method()` is a method implemented in your widget (which should be a subclass of `HTMLWidget`). + +## Other GUI methods + +Let's display the list of views in the GUI: + +```python +>>> gui.list_views() +[, + , + , + ] +``` + +The following method allows you to check how many views of each class there are: + +```python +>>> gui.view_count() +``` + +Use the following property to change the status bar: + +```python +>>> gui.status_message = "Hello world" +``` + +Finally, the following methods allow you to save/restore the state of the GUI and the widgets: + +```python +>>> gs = gui.save_geometry_state() +``` + +```python +>>> gui.restore_geometry_state(gs) +``` + +The object `gs` is a JSON-serializable Python dictionary. + +## Adding actions + +An **action** is a Python function that the user can run from the menu bar or with a keyboard shortcut. You can create an `Actions` object to specify a list of actions attached to a GUI. + +```python +>>> from phy.gui import Actions +>>> actions = Actions(gui) +... +>>> @actions.add(shortcut='ctrl+h') +... def hello(): +... print("Hello world!") +``` + +Now, if you press *Ctrl+H* in the GUI, you'll see `Hello world!` printed in the console. + +Once an action is added, you can call it with `actions.hello()` where `hello` is the name of the action. By default, this is the name of the associated function, but you can also specify the name explicitly with the `name=...` keyword argument in `actions.add()`. + +You'll find more details about `actions.add()` in the API reference. For example, use the `menu='MenuName'` keyword argument to add the action to a menu in the menu bar. + +Every GUI comes with a `default_actions` property which implements actions always available in GUIs: + +```python +>>> gui.default_actions + +``` + +For example, the following action shows the shortcuts of all actions attached to the GUI: + +```python +>>> gui.default_actions.show_shortcuts() + +Keyboard shortcuts for GUI +enable_snippet_mode : : +exit : ctrl+q +hello : ctrl+h +show_shortcuts : f1, h +``` + +You can create multiple `Actions` instance for a single GUI, which allows you to separate between different sets of actions. + +## Snippets + +The GUI provides a convenient system to quickly execute actions without leaving one's keyboard. Inspired by console-based text editors like *vim*, it is enabled by pressing `:` on the keyboard. Once this mode is enabled, what you type is displayed in the status bar. Then, you can call a function by typing its name or its alias. You can also use arguments to the actions, using a special syntax. Here is an example. + +```python +>>> @actions.add(alias='c') +... def select(ids, obj): +... print("Select %s with %s" % (ids, obj)) +``` + +Now, pressing `:c 3-6 hello` followed by the `Enter` keystroke displays `Select [3, 4, 5, 6] with hello` in the console. + +By convention, multiple arguments are separated by spaces, sequences of numbers are given either with `2,3,5,7` or `3-6` for consecutive numbers. If an alias is not specified when adding the action, you can always use the full action's name. + +## GUI plugins + +To create a specific GUI, you implement functionality in components and create GUI plugins to allow users to activate these components in their GUI. You can also specify the list of default plugins. + +You create a GUI with the function `gui = create_gui(name, model=model, plugins=plugins)`. The model provides the data while the list of plugins defines the functionality of the GUI. + +Plugins create views and actions and have full control on the GUI. Also, they can register objects or functions with `gui.register(obj, name='my_object')`. Other plugins can then retrieve these objects with `gui.request(name)`. This is how plugins can communicate. This function returns `None` if the object is not available. + +Note that the order of how plugins are attached matters, since you can only request components that have been registered in the previously-attached plugins. + +To create a GUI plugin, just define a class deriving from `IPlugin` and implementing `attach_to_gui(gui, model=None, state=None)`. + +## GUI state + +The **GUI state** is a special Python dictionary that holds info and parameters about a particular GUI session, like its position and size, the positions of the widgets, and other user preferences. This state is automatically persisted to disk (in JSON) in the config directory (passed as a parameter in the `create_gui()` function). By default, this is `~/.phy/gui_name/state.json`. + +The GUI state is a `Bunch` instance, which derives from `dict` to support the additional `bunch.name` syntax. + +Plugins can simply add fields to the GUI state and it will be persisted. There are special methods for GUI parameters: `state.save_gui_params()` and `state.load_gui_params()`. + + +## Example + +In this example we'll create a GUI plugin and show how to activate it. + +```python +>>> from phy import IPlugin +>>> from phy.gui import GUI, HTMLWidget, create_app, run_app, create_gui +>>> from phy.utils import Bunch +``` + +```python +>>> class MyComponent(IPlugin): +... def attach_to_gui(self, gui, model=None, state=None): +... # We create a widget. +... view = HTMLWidget() +... view.set_body("Hello %s!" % model.name) +... view.show() +... gui.add_view(view) +DEBUG:phy.utils.plugin:Register plugin `MyComponent`. +``` + +```python +>>> gui = create_gui('MyGUI', model=Bunch(name='world'), plugins=['MyComponent']) +DEBUG:phy.gui.gui:The GUI state file `/Users/cyrille/.phy/MyGUI/state.json` doesn't exist. +DEBUG:phy.gui.gui:Attach plugin `MyComponent` to MyGUI. +``` + +```python +>>> gui.show() +DEBUG:phy.gui.gui:Save the GUI state to `/Users/cyrille/.phy/MyGUI/state.json`. +``` + +This opens a GUI showing `Hello world!`. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..b8a67892d --- /dev/null +++ b/docs/index.md @@ -0,0 +1,15 @@ +# phy documentation + +phy is an **electrophysiology data analysis library**. It provides spike sorting and analysis routines for extracellular multielectrode recordings, as well as generic components to build command-line and graphical applications. Overall, phy lets you build custom data analysis applications, for example, a manual clustering GUI. + +phy is entirely agnostic to file formats and processing workflows. As such, it can be easily integrated with any existing system. + +## List of components + +* GUI +* Plotting +* Configuration and plugin system +* Command-line interface +* Manual clustering routines +* Analysis functions +* Utilities diff --git a/docs/install.md b/docs/install.md new file mode 100644 index 000000000..25b1adb1a --- /dev/null +++ b/docs/install.md @@ -0,0 +1,5 @@ +# Installation + +* Install Anaconda +* To install the development version: `pip install git+https://github.com/kwikteam/phy.git@kill-kwik` +* (not done yet) to install the stable version: `conda install -c kwikteam/phy` diff --git a/docs/io.md b/docs/io.md new file mode 100644 index 000000000..1bd8d3065 --- /dev/null +++ b/docs/io.md @@ -0,0 +1,29 @@ +## Data utilities + +The `phy.io` package contains utilities related to array manipulation, mock datasets, caching, and parallel computing context. + +### Array + +The `array` module contains functions to select subsets of large data arrays, and to obtain the spikes belonging to a set of clusters (notably the `Selector` class) + +### Context + +The `Context` provides facilities to accelerate computations through caching (with **joblib**) and parallel computing (with **ipyparallel**). + +A `Context` is initialized with a cache directory (typically a subdirectory `.phy` within the directory containing the data). You can also provide an `ipy_view` instance to use parallel computing with the ipyparallel package. + +#### Cache + +Use `f = context.cache(f)` to cache a function. By default, the decorated function will be cached on disk in the cache directory, using joblib. NumPy arrays are fully and efficiently supported. + +With the `memcache=True` argument, you can *also* use memory caching. This is interesting for example when caching functions returning a scalar for every cluster. This is the case with the functions computing the quality and similarity of clusters. These functions are called a lot during a manual clustering session. + +#### Parallel computing + +Use the `map()` and `map_async()` to call a function on multiple arguments in sequence or in parallel if an ipyparallel context is available (`ipy_view` keyword in the `Context`'s constructor). + +There is also an experimental `map_dask_array(f, da)` method to map in parallel a function that processes a single chunk of a **dask Array**. The result of every computation unit is saved in a `.npy` file in the cache directory, and the result is a new dask Array that is dynamically memory-mapped from the stack of `.npy` files. **The cache directory needs to be available from all computing units for this method to work** (using a network file system). Doing it this way should mitigate the performance issues with transferring large amounts of data over the network. + +#### Store + +You can store JSON-serializable Python dictionaries with `context.save()` and `context.load()`. The files are saved in the cache directory. NumPy array and Qt buffers are fully supported. You can save the GUI state and geometry there for example. diff --git a/docs/plot.md b/docs/plot.md new file mode 100644 index 000000000..bd11dfa5b --- /dev/null +++ b/docs/plot.md @@ -0,0 +1,121 @@ +# Plotting with VisPy + +phy provides a simple and fast plotting system based on VisPy's low-level **gloo** interface. This plotting system is entirely generic. Currently, it privileges speed and scalability over quality. In other words, you can display millions of points at very high speed, but the plotting quality is not as good as matplotlib, for example. While this sytem uses the GPU extensively, knowledge of GPU or OpenGL is not required for most purposes. + +First, we need to activate the Qt event loop in IPython, or create and run the Qt application in a script. + +```python +>>> %gui qt +``` + +## Simple view + +Let's create a simple view with a scatter plot. + +```python +>>> import numpy as np +>>> from phy.plot import View +``` + +```python +>>> view = View() +... +>>> n = 1000 +>>> x, y = np.random.randn(2, n) +>>> c = np.random.uniform(.3, .7, (n, 4)) +>>> s = np.random.uniform(5, 30, n) +... +>>> # NOTE: currently, the building process needs to be explicit. +... # All commands that construct the view should be enclosed in this +... # context manager, or at least one should ensure that +... # `view.clear()` and `view.build()` are called before and after +... # the building commands. +... with view.building(): +... view.scatter(x, y, color=c, size=s, marker='disc') +... +>>> view.show() +``` + +Note that you can pan and zoom with the mouse and keyboard. + +The other plotting commands currently supported are `plot()` and `hist()`. We're planning to add support for text in the near future. + +Several layouts are supported for subplots. + +## Grid view + +The Grid view lets you create multiple subplots arranged in a grid (like in matplotlib). Subplots are all individually clipped, which means that their viewports never overlap across the grid boundaries. Here is an example: + +```python +>>> view = View(layout='grid', shape=(1, 2)) # the shape is `(n_rows, n_cols)` +... +>>> x = np.linspace(-10., 10., 1000) +... +>>> with view.building(): +... view[0, 0].plot(x, np.sin(x)) +... view[0, 1].plot(x, np.cos(x), color=(1, 0, 0, 1)) +... +>>> view.show() +``` + +Subplots are created with the `view[i, j]` syntax. The indexing scheme works like mathematical matrices (origin at the upper left). + +Note that there are no axes at this point, but we'll be working on it. Also, independent per-subplot panning and zooming is not supported and this is unlikely to change in the foreseable future. + +## Stacked view + +The stacked view lets you stack several subplots vertically with no clipping. An example is a trace view showing a multichannel time-dependent signal. + +```python +>>> view = View(layout='stacked', n_plots=50) +... +>>> with view.building(): +... for i in range(view.n_plots): +... view[i].plot(y=np.random.randn(2000), +... color=np.random.uniform(.5, .9, 4)) +... +>>> view.show() +``` + +## Boxed view + +The boxed view lets you put subplots at arbitrary locations. You can dynamically change the positions and the sizes of the boxes. An example is the waveform view, where line plots are positioned at the recording sites on a multielectrode array. + +```python +>>> # Generate box positions along a circle. +... dt = np.pi / 10 +>>> t = np.arange(0, 2 * np.pi, dt) +>>> x = np.cos(t) +>>> y = np.sin(t) +>>> box_pos = np.c_[x, y] +... +>>> view = View(layout='boxed', box_pos=box_pos) +... +>>> with view.building(): +... for i in range(view.n_plots): +... # Create the subplots. +... view[i].plot(y=np.random.randn(10, 100), +... color=np.random.uniform(.5, .9, 4)) +... +>>> view.show() +``` + +You can use `ctrl+arrows` and `shift+arrows` to change the scaling of the positions and boxes. + +## Data normalization + +Data normalization is supported via the `data_bounds` keyword. This is a 4-tuple `(xmin, ymin, xmax, ymax)` with the coordinates of the viewport in the data coordinate system. By default, this is obtained with the min and max of the data. Here is an example: + +```python +>>> view = View(layout='stacked', n_plots=2) +... +>>> n = 100 +>>> x = np.linspace(0., 1., n) +>>> y = np.random.rand(n) +... +>>> with view.building(): +... view[0].plot(x, y, data_bounds=(0, 0, 1, 1)) +... view[1].plot(x, y, data_bounds=(0, -10, 1, 10)) +... +>>> view.show() +``` diff --git a/environment.yml b/environment.yml index 0b61d5aae..8e7d1f829 100644 --- a/environment.yml +++ b/environment.yml @@ -9,14 +9,8 @@ dependencies: - scipy - h5py - pyqt - - ipython - requests - traitlets - six - - ipyparallel - joblib - - dask - - cython - click - - pip: - - klustakwik2 diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000..88ae426fa --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,14 @@ +site_name: phy +pages: +- Home: 'index.md' +- FAQ: 'faq.md' +- Installation: 'install.md' +- GUI: 'gui.md' +- Plotting: 'plot.md' +- Configuration and plugin system: 'config.md' +- Command-line interface: 'cli.md' +- Data utilities: 'io.md' +- Manual clustering: 'cluster-manual.md' +- Analysis functions: 'analysis.md' +- API reference: 'api.md' +theme: readthedocs diff --git a/phy/__init__.py b/phy/__init__.py index 8846b78d8..70d279d89 100644 --- a/phy/__init__.py +++ b/phy/__init__.py @@ -8,12 +8,16 @@ # Imports #------------------------------------------------------------------------------ +import logging import os.path as op -from pkg_resources import get_distribution, DistributionNotFound +import sys -from .utils.logging import _default_logger, set_level -from .utils.datasets import download_sample_data +from six import StringIO + +from .io.datasets import download_file, download_sample_data +from .utils.config import load_master_config from .utils._misc import _git_version +from .utils.plugin import IPlugin, get_plugin, discover_plugins #------------------------------------------------------------------------------ @@ -22,26 +26,47 @@ __author__ = 'Kwik team' __email__ = 'cyrille.rossant at gmail.com' -__version__ = '0.2.2' +__version__ = '0.3.0.dev0' __version_git__ = __version__ + _git_version() -__all__ = ['debug', 'set_level'] +# Set a null handler on the root logger +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.NullHandler()) + + +_logger_fmt = '%(asctime)s [%(levelname)s] %(caller)s %(message)s' +_logger_date_fmt = '%H:%M:%S' + + +class _Formatter(logging.Formatter): + def format(self, record): + # Only keep the first character in the level name. + record.levelname = record.levelname[0] + filename = op.splitext(op.basename(record.pathname))[0] + record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(20) + return super(_Formatter, self).format(record) + + +def add_default_handler(level='INFO'): + handler = logging.StreamHandler() + handler.setLevel(level) + formatter = _Formatter(fmt=_logger_fmt, + datefmt=_logger_date_fmt) + handler.setFormatter(formatter) -# Set up the default logger. -_default_logger() + logger.addHandler(handler) -def debug(enable=True): - """Enable debug logging mode.""" - if enable: - set_level('debug') - else: - set_level('info') +DEBUG = False +if '--debug' in sys.argv: # pragma: no cover + DEBUG = True + sys.argv.remove('--debug') -def test(): +def test(): # pragma: no cover """Run the full testing suite of phy.""" import pytest pytest.main() diff --git a/phy/cluster/__init__.py b/phy/cluster/__init__.py index ff6b165f7..2df738347 100644 --- a/phy/cluster/__init__.py +++ b/phy/cluster/__init__.py @@ -2,3 +2,5 @@ # flake8: noqa """Automatic and manual clustering facilities.""" + +from . import manual diff --git a/phy/cluster/algorithms/__init__.py b/phy/cluster/algorithms/__init__.py deleted file mode 100644 index 442e1ed9b..000000000 --- a/phy/cluster/algorithms/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- coding: utf-8 -*- -"""Automatic clustering algorithms.""" diff --git a/phy/cluster/algorithms/default_settings.py b/phy/cluster/algorithms/default_settings.py deleted file mode 100644 index c65f06ee0..000000000 --- a/phy/cluster/algorithms/default_settings.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for clustering.""" - - -# ----------------------------------------------------------------------------- -# Clustering -# ----------------------------------------------------------------------------- - -# NOTE: the default parameters are in klustakwik2's repository. -klustakwik2 = { -} diff --git a/phy/cluster/algorithms/klustakwik.py b/phy/cluster/algorithms/klustakwik.py deleted file mode 100644 index 46144f431..000000000 --- a/phy/cluster/algorithms/klustakwik.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Wrapper to KlustaKwik2 implementation.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ...utils.array import PartialArray -from ...utils.event import EventEmitter -from ...io.kwik.sparse_kk2 import sparsify_features_masks - - -#------------------------------------------------------------------------------ -# Clustering class -#------------------------------------------------------------------------------ - -class KlustaKwik(EventEmitter): - """KlustaKwik automatic clustering algorithm.""" - def __init__(self, **kwargs): - super(KlustaKwik, self).__init__() - self._kwargs = kwargs - self.__dict__.update(kwargs) - # Set the version. - from klustakwik2 import __version__ - self.version = __version__ - - def cluster(self, - model=None, - spike_ids=None, - features=None, - masks=None, - ): - """Run the clustering algorithm on the model, or on any features - and masks. - - Return the `spike_clusters` assignements. - - Emit the `iter` event at every KlustaKwik iteration. - - """ - # Get the features and masks. - if model is not None: - if features is None: - features = PartialArray(model.features_masks, 0) - if masks is None: - masks = PartialArray(model.features_masks, 1) - # Select some spikes if needed. - if spike_ids is not None: - features = features[spike_ids] - masks = masks[spike_ids] - # Convert the features and masks to the sparse structure used - # by KK. - data = sparsify_features_masks(features, masks) - data = data.to_sparse_data() - # Run KK. - from klustakwik2 import KK - kk = KK(data, **self._kwargs) - - @kk.register_callback - def f(_): - # Skip split iterations. - if _.name != '': - return - self.emit('iter', kk.clusters) - - self.params = kk.all_params - kk.cluster_mask_starts() - spike_clusters = kk.clusters - return spike_clusters - - -def cluster(model, algorithm='klustakwik', spike_ids=None, **kwargs): - """Launch an automatic clustering algorithm on the model. - - Parameters - ---------- - - model : BaseModel - A model. - algorithm : str - Only 'klustakwik' is supported currently. - **kwargs - Parameters for KK. - - """ - assert algorithm == 'klustakwik' - kk = KlustaKwik(**kwargs) - return kk.cluster(model=model, spike_ids=spike_ids) diff --git a/phy/cluster/algorithms/tests/test_klustakwik.py b/phy/cluster/algorithms/tests/test_klustakwik.py deleted file mode 100644 index c0315b6b0..000000000 --- a/phy/cluster/algorithms/tests/test_klustakwik.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of clustering algorithms.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ....utils.logging import set_level -from ....io.kwik import KwikModel -from ....io.kwik.mock import create_mock_kwik -from ..klustakwik import cluster - - -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -def setup(): - set_level('info') - - -def teardown(): - set_level('info') - - -sample_rate = 10000 -n_samples = 25000 -n_channels = 4 - - -#------------------------------------------------------------------------------ -# Tests clustering -#------------------------------------------------------------------------------ - -def test_cluster(tempdir): - n_spikes = 100 - filename = create_mock_kwik(tempdir, - n_clusters=1, - n_spikes=n_spikes, - n_channels=8, - n_features_per_channel=3, - n_samples_traces=5000) - model = KwikModel(filename) - - spike_clusters = cluster(model, num_starting_clusters=10) - assert len(spike_clusters) == n_spikes - - spike_clusters = cluster(model, num_starting_clusters=10, - spike_ids=range(100)) - assert len(spike_clusters) == 100 diff --git a/phy/cluster/manual/__init__.py b/phy/cluster/manual/__init__.py index 8d7154907..d6b871021 100644 --- a/phy/cluster/manual/__init__.py +++ b/phy/cluster/manual/__init__.py @@ -3,14 +3,7 @@ """Manual clustering facilities.""" -from .view_models import (BaseClusterViewModel, - HTMLClusterViewModel, - StatsViewModel, - FeatureViewModel, - WaveformViewModel, - TraceViewModel, - CorrelogramViewModel, - ) +from ._utils import ClusterMeta from .clustering import Clustering -from .wizard import Wizard -from .gui import ClusterManualGUI +from .gui_component import ManualClustering +from .views import WaveformView, TraceView, FeatureView, CorrelogramView diff --git a/phy/cluster/manual/_history.py b/phy/cluster/manual/_history.py index 81a0a9fd0..813bf957a 100644 --- a/phy/cluster/manual/_history.py +++ b/phy/cluster/manual/_history.py @@ -94,7 +94,7 @@ def add(self, item): def back(self): """Go back in history if possible. - Return the current item after going back. + Return the undone item. """ if self._index <= 0: diff --git a/phy/cluster/manual/_utils.py b/phy/cluster/manual/_utils.py index 95887a607..b1ec99c86 100644 --- a/phy/cluster/manual/_utils.py +++ b/phy/cluster/manual/_utils.py @@ -7,9 +7,13 @@ #------------------------------------------------------------------------------ from copy import deepcopy +from collections import defaultdict +import logging from ._history import History -from ...utils import Bunch, _as_list +from phy.utils import Bunch, _as_list, _is_list, EventEmitter + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ @@ -28,6 +32,18 @@ def _join(clusters): return '[{}]'.format(', '.join(map(str, clusters))) +def create_cluster_meta(cluster_groups): + """Return a ClusterMeta instance with cluster group support.""" + meta = ClusterMeta() + meta.add_field('group') + + cluster_groups = cluster_groups or {} + data = {c: {'group': v} for c, v in cluster_groups.items()} + meta.from_dict(data) + + return meta + + #------------------------------------------------------------------------------ # UpdateInfo class #------------------------------------------------------------------------------ @@ -45,9 +61,8 @@ def __init__(self, **kwargs): descendants=[], # pairs of (old_cluster, new_cluster) metadata_changed=[], # clusters with changed metadata metadata_value=None, # new metadata value - old_spikes_per_cluster={}, # only for the affected clusters - new_spikes_per_cluster={}, # only for the affected clusters - selection=[], # clusters selected before the action + undo_state=None, # returned during an undo: it contains + # information about the undone action ) d.update(kwargs) super(UpdateInfo, self).__init__(d) @@ -79,41 +94,99 @@ def __repr__(self): # ClusterMetadataUpdater class #------------------------------------------------------------------------------ -class ClusterMetadataUpdater(object): +class ClusterMeta(EventEmitter): """Handle cluster metadata changes.""" - def __init__(self, cluster_metadata): - self._cluster_metadata = cluster_metadata - # Keep a deep copy of the original structure for the undo stack. - self._data_base = deepcopy(cluster_metadata.data) - # The stack contains (clusters, field, value, update_info) tuples. - self._undo_stack = History((None, None, None, None)) - - for field, func in self._cluster_metadata._fields.items(): - - # Create self.(clusters). - def _make_get(field): - def f(clusters): - return self._cluster_metadata._get(clusters, field) - return f - setattr(self, field, _make_get(field)) - - # Create self.set_(clusters, value). - def _make_set(field): - def f(clusters, value): - return self._set(clusters, field, value) - return f - setattr(self, 'set_{0:s}'.format(field), _make_set(field)) - - def _set(self, clusters, field, value, add_to_stack=True): - self._cluster_metadata._set(clusters, field, value) + def __init__(self): + super(ClusterMeta, self).__init__() + self._fields = {} + self._reset_data() + + def _reset_data(self): + self._data = {} + self._data_base = {} + # The stack contains (clusters, field, value, update_info, undo_state) + # tuples. + self._undo_stack = History((None, None, None, None, None)) + + @property + def fields(self): + """List of fields.""" + return sorted(self._fields.keys()) + + def add_field(self, name, default_value=None): + """Add a field with an optional default value.""" + self._fields[name] = default_value + + def func(cluster): + return self.get(name, cluster) + + setattr(self, name, func) + + def from_dict(self, dic): + """Import data from a {cluster: {field: value}} dictionary.""" + self._reset_data() + for cluster, vals in dic.items(): + for field, value in vals.items(): + self.set(field, [cluster], value, add_to_stack=False) + self._data_base = deepcopy(self._data) + + def to_dict(self, field): + """Export data to a {cluster: value} dictionary, for a particular + field.""" + assert field in self._fields, "This field doesn't exist" + return {cluster: self.get(field, cluster) + for cluster in self._data.keys()} + + def set(self, field, clusters, value, add_to_stack=True): + """Set the value of one of several clusters.""" + assert field in self._fields + clusters = _as_list(clusters) - info = UpdateInfo(description='metadata_' + field, - metadata_changed=clusters, - metadata_value=value, - ) + for cluster in clusters: + if cluster not in self._data: + self._data[cluster] = {} + self._data[cluster][field] = value + + up = UpdateInfo(description='metadata_' + field, + metadata_changed=clusters, + metadata_value=value, + ) + undo_state = self.emit('request_undo_state', up) + if add_to_stack: - self._undo_stack.add((clusters, field, value, info)) - return info + self._undo_stack.add((clusters, field, value, up, undo_state)) + self.emit('cluster', up) + + return up + + def get(self, field, cluster): + """Retrieve the value of one cluster.""" + if _is_list(cluster): + return [self.get(field, c) for c in cluster] + assert field in self._fields + default = self._fields[field] + return self._data.get(cluster, {}).get(field, default) + + def set_from_descendants(self, descendants): + """Update metadata of some clusters given the metadata of their + ascendants.""" + for field in self.fields: + + # This gives a set of metadata values of all the parents + # of any new cluster. + candidates = defaultdict(set) + for old, new in descendants: + candidates[new].add(self.get(field, old)) + + # Loop over all new clusters. + for new, vals in candidates.items(): + vals = list(vals) + default = self._fields[field] + # If all the parents have the same value, assign it to + # the new cluster if it is not the default. + if len(vals) == 1 and vals[0] != default: + self.set(field, new, vals[0]) + # Otherwise, the default is assumed. def undo(self): """Undo the last metadata change. @@ -127,14 +200,18 @@ def undo(self): args = self._undo_stack.back() if args is None: return - self._cluster_metadata._data = deepcopy(self._data_base) - for clusters, field, value, _ in self._undo_stack: + self._data = deepcopy(self._data_base) + for clusters, field, value, up, undo_state in self._undo_stack: if clusters is not None: - self._set(clusters, field, value, add_to_stack=False) + self.set(field, clusters, value, add_to_stack=False) + # Return the UpdateInfo instance of the undo action. - info = args[-1] - info.history = 'undo' - return info + up, undo_state = args[-2:] + up.history = 'undo' + up.undo_state = undo_state + + self.emit('cluster', up) + return up def redo(self): """Redo the next metadata change. @@ -147,8 +224,11 @@ def redo(self): args = self._undo_stack.forward() if args is None: return - clusters, field, value, info = args - self._set(clusters, field, value, add_to_stack=False) + clusters, field, value, up, undo_state = args + self.set(field, clusters, value, add_to_stack=False) + # Return the UpdateInfo instance of the redo action. - info.history = 'redo' - return info + up.history = 'redo' + + self.emit('cluster', up) + return up diff --git a/phy/cluster/manual/clustering.py b/phy/cluster/manual/clustering.py index f7c261a38..f95a96eea 100644 --- a/phy/cluster/manual/clustering.py +++ b/phy/cluster/manual/clustering.py @@ -8,13 +8,13 @@ import numpy as np -from ...utils._types import _as_array, _is_array_like -from ...utils.array import (_unique, - _spikes_in_clusters, - _spikes_per_cluster, - ) +from phy.utils._types import _as_array, _is_array_like +from phy.io.array import (_unique, + _spikes_in_clusters, + ) from ._utils import UpdateInfo from ._history import History +from phy.utils.event import EventEmitter #------------------------------------------------------------------------------ @@ -25,7 +25,7 @@ def _extend_spikes(spike_ids, spike_clusters): """Return all spikes belonging to the clusters containing the specified spikes.""" # We find the spikes belonging to modified clusters. - # What are the old clusters that are modified by the assignement? + # What are the old clusters that are modified by the assignment? old_spike_clusters = spike_clusters[spike_ids] unique_clusters = _unique(old_spike_clusters) # Now we take all spikes from these clusters. @@ -46,7 +46,11 @@ def _concatenate_spike_clusters(*pairs): return concat[:, 0].astype(np.int64), concat[:, 1].astype(np.int64) -def _extend_assignement(spike_ids, old_spike_clusters, spike_clusters_rel): +def _extend_assignment(spike_ids, + old_spike_clusters, + spike_clusters_rel, + new_cluster_id, + ): # 1. Add spikes that belong to modified clusters. # 2. Find new cluster ids for all changed clusters. @@ -58,7 +62,6 @@ def _extend_assignement(spike_ids, old_spike_clusters, spike_clusters_rel): assert spike_clusters_rel.min() >= 0 # We renumber the new cluster indices. - new_cluster_id = old_spike_clusters.max() + 1 new_spike_clusters = (spike_clusters_rel + (new_cluster_id - spike_clusters_rel.min())) @@ -80,9 +83,7 @@ def _extend_assignement(spike_ids, old_spike_clusters, spike_clusters_rel): extended_spike_clusters)) -def _assign_update_info(spike_ids, - old_spike_clusters, old_spikes_per_cluster, - new_spike_clusters, new_spikes_per_cluster): +def _assign_update_info(spike_ids, old_spike_clusters, new_spike_clusters): old_clusters = _unique(old_spike_clusters) new_clusters = _unique(new_spike_clusters) descendants = list(set(zip(old_spike_clusters, @@ -92,13 +93,11 @@ def _assign_update_info(spike_ids, added=list(new_clusters), deleted=list(old_clusters), descendants=descendants, - old_spikes_per_cluster=old_spikes_per_cluster, - new_spikes_per_cluster=new_spikes_per_cluster, ) return update_info -class Clustering(object): +class Clustering(EventEmitter): """Handle cluster changes in a set of spikes. Features @@ -144,24 +143,22 @@ class Clustering(object): information about the clusters. metadata_changed : list List of clusters with changed metadata (cluster group changes) - old_spikes_per_cluster : dict - Dictionary of `{cluster: spikes}` for the old clusters and - old clustering. - new_spikes_per_cluster : dict - Dictionary of `{cluster: spikes}` for the new clusters and - new clustering. """ - def __init__(self, spike_clusters): - self._undo_stack = History(base_item=(None, None)) + def __init__(self, spike_clusters, new_cluster_id=None): + super(Clustering, self).__init__() + self._undo_stack = History(base_item=(None, None, None)) # Spike -> cluster mapping. self._spike_clusters = _as_array(spike_clusters) self._n_spikes = len(self._spike_clusters) self._spike_ids = np.arange(self._n_spikes).astype(np.int64) - # Create the spikes per cluster structure. - self._update_all_spikes_per_cluster() - # Keep a copy of the original spike clusters assignement. + self._new_cluster_id_0 = (new_cluster_id or + self._spike_clusters.max() + 1) + self._new_cluster_id = self._new_cluster_id_0 + assert self._new_cluster_id >= 0 + assert np.all(self._spike_clusters < self._new_cluster_id) + # Keep a copy of the original spike clusters assignment. self._spike_clusters_base = self._spike_clusters.copy() def reset(self): @@ -172,36 +169,28 @@ def reset(self): """ self._undo_stack.clear() self._spike_clusters = self._spike_clusters_base - self._update_all_spikes_per_cluster() + self._new_cluster_id = self._new_cluster_id_0 @property def spike_clusters(self): """A n_spikes-long vector containing the cluster ids of all spikes.""" return self._spike_clusters - @property - def spikes_per_cluster(self): - """A dictionary `{cluster: spikes}`.""" - return self._spikes_per_cluster - @property def cluster_ids(self): """Ordered list of ids of all non-empty clusters.""" - return np.array(sorted(self._spikes_per_cluster)) - - @property - def cluster_counts(self): - """Dictionary with the number of spikes in each cluster.""" - return {cluster: len(self._spikes_per_cluster[cluster]) - for cluster in self.cluster_ids} + return np.unique(self._spike_clusters) def new_cluster_id(self): """Generate a brand new cluster id. - This is `maximum cluster id + 1`. + NOTE: This new id strictly increases after an undo + new action, + meaning that old cluster ids are *not* reused. This ensures that + any cluster_id-based cache will always be valid even after undo + operations (i.e. no need for explicit cache invalidation in this case). """ - return int(np.max(self.cluster_ids)) + 1 + return self._new_cluster_id @property def n_clusters(self): @@ -225,6 +214,61 @@ def spikes_in_clusters(self, clusters): # Actions #-------------------------------------------------------------------------- + def _do_assign(self, spike_ids, new_spike_clusters): + """Make spike-cluster assignments after the spike selection has + been extended to full clusters.""" + + # Ensure spike_clusters has the right shape. + spike_ids = _as_array(spike_ids) + if len(new_spike_clusters) == 1 and len(spike_ids) > 1: + new_spike_clusters = (np.ones(len(spike_ids), dtype=np.int64) * + new_spike_clusters[0]) + old_spike_clusters = self._spike_clusters[spike_ids] + + assert len(spike_ids) == len(old_spike_clusters) + assert len(new_spike_clusters) == len(spike_ids) + + # Update the spikes per cluster structure. + old_clusters = _unique(old_spike_clusters) + + # NOTE: shortcut to a merge if this assignment is effectively a merge + # i.e. if all spikes are assigned to a single cluster. + # The fact that spike selection has been previously extended to + # whole clusters is critical here. + new_clusters = _unique(new_spike_clusters) + if len(new_clusters) == 1: + return self._do_merge(spike_ids, old_clusters, new_clusters[0]) + + # We return the UpdateInfo structure. + up = _assign_update_info(spike_ids, + old_spike_clusters, + new_spike_clusters) + + # We update the new cluster id (strictly increasing during a session). + self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1) + + # We make the assignments. + self._spike_clusters[spike_ids] = new_spike_clusters + return up + + def _do_merge(self, spike_ids, cluster_ids, to): + + # Create the UpdateInfo instance here. + descendants = [(cluster, to) for cluster in cluster_ids] + up = UpdateInfo(description='merge', + spike_ids=spike_ids, + added=[to], + deleted=list(cluster_ids), + descendants=descendants, + ) + + # We update the new cluster id (strictly increasing during a session). + self._new_cluster_id = max(max(up.added) + 1, self._new_cluster_id) + + # Assign the clusters. + self.spike_clusters[spike_ids] = to + return up + def merge(self, cluster_ids, to=None): """Merge several clusters to a new cluster. @@ -266,73 +310,17 @@ def merge(self, cluster_ids, to=None): # Find all spikes in the specified clusters. spike_ids = _spikes_in_clusters(self.spike_clusters, cluster_ids) - # Create the UpdateInfo instance here. - descendants = [(cluster, to) for cluster in cluster_ids] - old_spc = {k: self._spikes_per_cluster[k] for k in cluster_ids} - new_spc = {to: spike_ids} - up = UpdateInfo(description='merge', - spike_ids=spike_ids, - added=[to], - deleted=cluster_ids, - descendants=descendants, - old_spikes_per_cluster=old_spc, - new_spikes_per_cluster=new_spc, - ) - - # Update the spikes_per_cluster structure directly. - self._spikes_per_cluster[to] = spike_ids - for cluster in cluster_ids: - del self._spikes_per_cluster[cluster] - - # Assign the clusters. - self.spike_clusters[spike_ids] = to + up = self._do_merge(spike_ids, cluster_ids, to) + undo_state = self.emit('request_undo_state', up) # Add to stack. - self._undo_stack.add((spike_ids, [to])) - - return up - - def _update_all_spikes_per_cluster(self): - self._spikes_per_cluster = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) - - def _do_assign(self, spike_ids, new_spike_clusters): - """Make spike-cluster assignements after the spike selection has - been extended to full clusters.""" - - # Ensure spike_clusters has the right shape. - spike_ids = _as_array(spike_ids) - if len(new_spike_clusters) == 1 and len(spike_ids) > 1: - new_spike_clusters = (np.ones(len(spike_ids), dtype=np.int64) * - new_spike_clusters[0]) - old_spike_clusters = self._spike_clusters[spike_ids] - - assert len(spike_ids) == len(old_spike_clusters) - assert len(new_spike_clusters) == len(spike_ids) - - # Update the spikes per cluster structure. - clusters = _unique(old_spike_clusters) - old_spikes_per_cluster = {cluster: self._spikes_per_cluster[cluster] - for cluster in clusters} - new_spikes_per_cluster = _spikes_per_cluster(spike_ids, - new_spike_clusters) - self._spikes_per_cluster.update(new_spikes_per_cluster) - # All old clusters are deleted. - for cluster in clusters: - del self._spikes_per_cluster[cluster] - - # We return the UpdateInfo structure. - up = _assign_update_info(spike_ids, - old_spike_clusters, old_spikes_per_cluster, - new_spike_clusters, new_spikes_per_cluster) - - # We make the assignements. - self._spike_clusters[spike_ids] = new_spike_clusters + self._undo_stack.add((spike_ids, [to], undo_state)) + self.emit('cluster', up) return up def assign(self, spike_ids, spike_clusters_rel=0): - """Make new spike cluster assignements. + """Make new spike cluster assignments. Parameters ---------- @@ -387,22 +375,26 @@ def assign(self, spike_ids, spike_clusters_rel=0): return UpdateInfo() assert len(spike_ids) == len(spike_clusters_rel) assert spike_ids.min() >= 0 - assert spike_ids.max() < self._n_spikes + assert spike_ids.max() < self._n_spikes, "Some spikes don't exist." - # Normalize the spike-cluster assignement such that + # Normalize the spike-cluster assignment such that # there are only new or dead clusters, not modified clusters. - # This implies that spikes not explicitely selected, but that + # This implies that spikes not explicitly selected, but that # belong to clusters affected by the operation, will be assigned # to brand new clusters. - spike_ids, cluster_ids = _extend_assignement(spike_ids, - self._spike_clusters, - spike_clusters_rel) + spike_ids, cluster_ids = _extend_assignment(spike_ids, + self._spike_clusters, + spike_clusters_rel, + self.new_cluster_id(), + ) up = self._do_assign(spike_ids, cluster_ids) + undo_state = self.emit('request_undo_state', up) - # Add the assignement to the undo stack. - self._undo_stack.add((spike_ids, cluster_ids)) + # Add the assignment to the undo stack. + self._undo_stack.add((spike_ids, cluster_ids, undo_state)) + self.emit('cluster', up) return up def split(self, spike_ids): @@ -433,7 +425,7 @@ def split(self, spike_ids): return self.assign(spike_ids, 0) def undo(self): - """Undo the last cluster assignement operation. + """Undo the last cluster assignment operation. Returns ------- @@ -441,13 +433,13 @@ def undo(self): up : UpdateInfo instance of the changes done by this operation. """ - self._undo_stack.back() + _, _, undo_state = self._undo_stack.back() # Retrieve the initial spike_cluster structure. spike_clusters_new = self._spike_clusters_base.copy() # Loop over the history (except the last item because we undo). - for spike_ids, cluster_ids in self._undo_stack: + for spike_ids, cluster_ids, _ in self._undo_stack: # We update the spike clusters accordingly. if spike_ids is not None: spike_clusters_new[spike_ids] = cluster_ids @@ -457,13 +449,16 @@ def undo(self): spike_clusters_new)[0] clusters_changed = spike_clusters_new[changed] - up = self._do_assign(changed, - clusters_changed) + up = self._do_assign(changed, clusters_changed) up.history = 'undo' + # Add the undo_state object from the undone object. + up.undo_state = undo_state + + self.emit('cluster', up) return up def redo(self): - """Redo the last cluster assignement operation. + """Redo the last cluster assignment operation. Returns ------- @@ -471,17 +466,22 @@ def redo(self): up : UpdateInfo instance of the changes done by this operation. """ - # Go forward in the stack, and retrieve the new assignement. + # Go forward in the stack, and retrieve the new assignment. item = self._undo_stack.forward() if item is None: # No redo has been performed: abort. return - spike_ids, cluster_ids = item + # NOTE: the undo_state object is only returned when undoing. + # It represents data associated to the state + # *before* the action. What might be more useful would be the + # undo_state object of the next item in the list (if it exists). + spike_ids, cluster_ids, undo_state = item assert spike_ids is not None - # We apply the new assignement. - up = self._do_assign(spike_ids, - cluster_ids) + # We apply the new assignment. + up = self._do_assign(spike_ids, cluster_ids) up.history = 'redo' + + self.emit('cluster', up) return up diff --git a/phy/cluster/manual/controller.py b/phy/cluster/manual/controller.py new file mode 100644 index 000000000..bb97b43e6 --- /dev/null +++ b/phy/cluster/manual/controller.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- + +"""Controller: model -> views.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging + +import numpy as np + +from phy.cluster.manual.gui_component import ManualClustering +from phy.cluster.manual.views import (WaveformView, + TraceView, + FeatureView, + CorrelogramView, + select_traces, + extract_spikes, + ) +from phy.gui import GUI +from phy.io.array import _get_data_lim, concat_per_cluster +from phy.io import Context, Selector +from phy.stats.clusters import (mean, + get_waveform_amplitude, + ) +from phy.utils import Bunch, load_master_config, get_plugin, EventEmitter + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Kwik GUI +#------------------------------------------------------------------------------ + +class Controller(EventEmitter): + """Take data out of the model and feeds it to views. + + Events + ------ + + init() + create_gui(gui) + add_view(gui, view) + + """ + + n_spikes_waveforms = 100 + n_spikes_waveforms_lim = 100 + n_spikes_masks = 100 + n_spikes_features = 5000 + n_spikes_background_features = 5000 + n_spikes_features_lim = 100 + n_spikes_close_clusters = 100 + + # responsible for the cache + def __init__(self, plugins=None, config_dir=None): + super(Controller, self).__init__() + self.config_dir = config_dir + self._init_data() + self._init_selector() + self._init_context() + self._set_manual_clustering() + + self.n_spikes = len(self.spike_times) + + # Attach the plugins. + plugins = plugins or [] + config = load_master_config(config_dir=config_dir) + c = config.get(self.__class__.__name__) + default_plugins = c.plugins if c else [] + if len(default_plugins): + plugins = default_plugins + plugins + for plugin in plugins: + get_plugin(plugin)().attach_to_controller(self) + + self.emit('init') + + # Internal methods + # ------------------------------------------------------------------------- + + def _init_data(self): # pragma: no cover + self.cache_dir = None + # Child classes must set these variables. + self.spike_times = None # (n_spikes,) array + + # TODO: make sure these structures are updated during a session + self.spike_clusters = None # (n_spikes,) array + self.cluster_groups = None # dict {cluster_id: None/'noise'/'mua'} + self.cluster_ids = None + + self.channel_positions = None # (n_channels, 2) array + self.n_samples_waveforms = None # int > 0 + self.n_channels = None # int > 0 + self.n_features_per_channel = None # int > 0 + self.sample_rate = None # float + self.duration = None # float + + self.all_masks = None # (n_spikes, n_channels) + self.all_waveforms = None # (n_spikes, n_samples, n_channels) + self.all_features = None # (n_spikes, n_channels, n_features) + self.all_traces = None # (n_samples_traces, n_channels) + + def _init_selector(self): + self.selector = Selector(self.spikes_per_cluster) + + def _init_context(self): + assert self.cache_dir + self.context = Context(self.cache_dir) + ctx = self.context + + self.get_masks = concat_per_cluster(ctx.cache(self.get_masks)) + self.get_features = concat_per_cluster(ctx.cache(self.get_features)) + self.get_waveforms = concat_per_cluster(ctx.cache(self.get_waveforms)) + self.get_background_features = ctx.cache(self.get_background_features) + + self.get_mean_masks = ctx.memcache(self.get_mean_masks) + self.get_mean_features = ctx.memcache(self.get_mean_features) + self.get_mean_waveforms = ctx.memcache(self.get_mean_waveforms) + + self.get_waveform_lims = ctx.memcache(self.get_waveform_lims) + self.get_feature_lim = ctx.memcache(self.get_feature_lim) + + self.get_close_clusters = ctx.memcache( + self.get_close_clusters) + self.get_probe_depth = ctx.memcache( + self.get_probe_depth) + + self.spikes_per_cluster = ctx.memcache(self.spikes_per_cluster) + + def _set_manual_clustering(self): + # Load the new cluster id. + new_cluster_id = self.context.load('new_cluster_id'). \ + get('new_cluster_id', None) + mc = ManualClustering(self.spike_clusters, + self.spikes_per_cluster, + similarity=self.similarity, + cluster_groups=self.cluster_groups, + new_cluster_id=new_cluster_id, + ) + + # Save the new cluster id on disk. + @mc.clustering.connect + def on_cluster(up): + new_cluster_id = mc.clustering.new_cluster_id() + logger.debug("Save the new cluster id: %d", new_cluster_id) + self.context.save('new_cluster_id', + dict(new_cluster_id=new_cluster_id)) + + self.manual_clustering = mc + mc.add_column(self.get_probe_depth, name='probe_depth') + + def _select_spikes(self, cluster_id, n_max=None): + assert isinstance(cluster_id, int) + assert cluster_id >= 0 + return self.selector.select_spikes([cluster_id], n_max) + + def _select_data(self, cluster_id, arr, n_max=None): + spike_ids = self._select_spikes(cluster_id, n_max) + b = Bunch() + b.data = arr[spike_ids] + b.spike_ids = spike_ids + b.spike_clusters = self.spike_clusters[spike_ids] + b.masks = self.all_masks[spike_ids] + return b + + def _data_lim(self, arr, n_max): + return _get_data_lim(arr, n_spikes=n_max) + + # Masks + # ------------------------------------------------------------------------- + + # Is cached in _init_context() + def get_masks(self, cluster_id): + return self._select_data(cluster_id, + self.all_masks, + self.n_spikes_masks, + ) + + def get_mean_masks(self, cluster_id): + return mean(self.get_masks(cluster_id).data) + + # Waveforms + # ------------------------------------------------------------------------- + + # Is cached in _init_context() + def get_waveforms(self, cluster_id): + return [self._select_data(cluster_id, + self.all_waveforms, + self.n_spikes_waveforms, + )] + + def get_mean_waveforms(self, cluster_id): + return mean(self.get_waveforms(cluster_id)[0].data) + + def get_waveform_lims(self): + n_spikes = self.n_spikes_waveforms_lim + arr = self.all_waveforms + n = arr.shape[0] + k = max(1, n // n_spikes) + # Extract waveforms. + arr = arr[::k] + # Take the corresponding masks. + masks = self.all_masks[::k].copy() + arr = arr * masks[:, np.newaxis, :] + # NOTE: on some datasets, there are a few outliers that screw up + # the normalization. These parameters should be customizable. + m = np.percentile(arr, .05) + M = np.percentile(arr, 99.95) + return m, M + + def get_waveforms_amplitude(self, cluster_id): + mm = self.get_mean_masks(cluster_id) + mw = self.get_mean_waveforms(cluster_id) + assert mw.ndim == 2 + return get_waveform_amplitude(mm, mw) + + # Features + # ------------------------------------------------------------------------- + + # Is cached in _init_context() + def get_features(self, cluster_id, load_all=False): + return self._select_data(cluster_id, + self.all_features, + (self.n_spikes_features + if not load_all else None), + ) + + def get_background_features(self): + k = max(1, int(self.n_spikes // self.n_spikes_background_features)) + spike_ids = slice(None, None, k) + b = Bunch() + b.data = self.all_features[spike_ids] + b.spike_ids = spike_ids + b.spike_clusters = self.spike_clusters[spike_ids] + b.masks = self.all_masks[spike_ids] + return b + + def get_mean_features(self, cluster_id): + return mean(self.get_features(cluster_id).data) + + def get_feature_lim(self): + return self._data_lim(self.all_features, self.n_spikes_features_lim) + + # Traces + # ------------------------------------------------------------------------- + + def get_traces(self, interval): + tr = select_traces(self.all_traces, interval, + sample_rate=self.sample_rate, + ) + return [Bunch(traces=tr)] + + def get_spikes_traces(self, interval, traces): + # NOTE: we extract the spikes from the first traces array. + traces = traces[0].traces + b = extract_spikes(traces, interval, + sample_rate=self.sample_rate, + spike_times=self.spike_times, + spike_clusters=self.spike_clusters, + all_masks=self.all_masks, + n_samples_waveforms=self.n_samples_waveforms, + ) + return b + + # Cluster statistics + # ------------------------------------------------------------------------- + + def get_best_channel(self, cluster_id): + wa = self.get_waveforms_amplitude(cluster_id) + return int(wa.argmax()) + + def get_best_channels(self, cluster_ids): + channels = [self.get_best_channel(cluster_id) + for cluster_id in cluster_ids] + return list(set(channels)) + + def get_channels_by_amplitude(self, cluster_ids): + wa = self.get_waveforms_amplitude(cluster_ids[0]) + return np.argsort(wa)[::-1].tolist() + + def get_best_channel_position(self, cluster_id): + cha = self.get_best_channel(cluster_id) + return tuple(self.channel_positions[cha]) + + def get_probe_depth(self, cluster_id): + return self.get_best_channel_position(cluster_id)[1] + + def get_close_clusters(self, cluster_id): + assert isinstance(cluster_id, int) + # Position of the cluster's best channel. + pos0 = self.get_best_channel_position(cluster_id) + n = len(pos0) + assert n in (2, 3) + # Positions of all clusters' best channels. + clusters = self.cluster_ids + pos = np.vstack([self.get_best_channel_position(int(clu)) + for clu in clusters]) + assert pos.shape == (len(clusters), n) + # Distance of all clusters to the current cluster. + dist = (pos - pos0) ** 2 + assert dist.shape == (len(clusters), n) + dist = np.sum(dist, axis=1) ** .5 + assert dist.shape == (len(clusters),) + # Closest clusters. + ind = np.argsort(dist) + ind = ind[:self.n_spikes_close_clusters] + return [(int(clusters[i]), float(dist[i])) for i in ind] + + def spikes_per_cluster(self, cluster_id): + return np.nonzero(self.spike_clusters == cluster_id)[0] + + # View methods + # ------------------------------------------------------------------------- + + def _add_view(self, gui, view): + view.attach(gui) + self.emit('add_view', gui, view) + return view + + def add_waveform_view(self, gui): + v = WaveformView(waveforms=self.get_waveforms, + channel_positions=self.channel_positions, + waveform_lims=self.get_waveform_lims(), + best_channels=self.get_best_channels, + ) + return self._add_view(gui, v) + + def add_trace_view(self, gui): + v = TraceView(traces=self.get_traces, + spikes=self.get_spikes_traces, + sample_rate=self.sample_rate, + duration=self.duration, + n_channels=self.n_channels, + ) + return self._add_view(gui, v) + + def add_feature_view(self, gui): + v = FeatureView(features=self.get_features, + background_features=self.get_background_features(), + spike_times=self.spike_times, + n_channels=self.n_channels, + n_features_per_channel=self.n_features_per_channel, + feature_lim=self.get_feature_lim(), + best_channels=self.get_channels_by_amplitude, + ) + return self._add_view(gui, v) + + def add_correlogram_view(self, gui): + v = CorrelogramView(spike_times=self.spike_times, + spike_clusters=self.spike_clusters, + sample_rate=self.sample_rate, + ) + return self._add_view(gui, v) + + # GUI methods + # ------------------------------------------------------------------------- + + def similarity(self, cluster_id): + return self.get_close_clusters(cluster_id) + + def create_gui(self, name=None, subtitle=None, + plugins=None, config_dir=None, + add_default_views=True, + **kwargs): + """Create a manual clustering GUI.""" + config_dir = config_dir or self.config_dir + gui = GUI(name=name, subtitle=subtitle, + config_dir=config_dir, **kwargs) + gui.controller = self + + # Attach the ManualClustering component to the GUI. + self.manual_clustering.attach(gui) + + # Add views. + if add_default_views: + self.add_correlogram_view(gui) + if self.all_features is not None: + self.add_feature_view(gui) + if self.all_waveforms is not None: + self.add_waveform_view(gui) + if self.all_traces is not None: + self.add_trace_view(gui) + + self.emit('create_gui', gui) + + return gui diff --git a/phy/cluster/manual/default_settings.py b/phy/cluster/manual/default_settings.py deleted file mode 100644 index 42cd55add..000000000 --- a/phy/cluster/manual/default_settings.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for manual sorting.""" - - -# ----------------------------------------------------------------------------- -# Correlograms -# ----------------------------------------------------------------------------- - -# Number of time samples in a bin. -correlograms_binsize = 20 - -# Number of bins (odd number). -correlograms_winsize_bins = 2 * 25 + 1 - -# Maximum number of spikes for the correlograms. -# Use `None` to specify an infinite value. -correlograms_n_spikes_max = 1000000 - -# Contiguous chunks of spikes for computing the CCGs. -# Use `None` to have a regular (strided) subselection instead of a chunked -# subselection. -correlograms_excerpt_size = 100000 - - -# ----------------------------------------------------------------------------- -# Views -# ----------------------------------------------------------------------------- - -# Maximum number of spikes to display in the waveform view. -waveforms_n_spikes_max = 100 - -# Load regularly-spaced waveforms. -waveforms_excerpt_size = None - -# Maximum number of spikes to display in the feature view. -features_n_spikes_max = 2500 - -# Load a regular subselection of spikes from the cluster store. -features_excerpt_size = None - -# Maximum number of background spikes to display in the feature view. -features_n_spikes_max_bg = features_n_spikes_max - -features_grid_n_spikes_max = features_n_spikes_max -features_grid_excerpt_size = features_excerpt_size -features_grid_n_spikes_max_bg = features_n_spikes_max_bg - - -# ----------------------------------------------------------------------------- -# Clustering GUI -# ----------------------------------------------------------------------------- - -cluster_manual_shortcuts = { - 'reset_gui': 'alt+r', - 'show_shortcuts': 'ctrl+h', - 'save': 'ctrl+s', - 'exit': 'ctrl+q', - # Wizard actions. - 'reset_wizard': 'ctrl+w', - 'next': 'space', - 'previous': 'shift+space', - 'reset_wizard': 'ctrl+alt+space', - 'first': 'home', - 'last': 'end', - 'pin': 'return', - 'unpin': 'backspace', - # Clustering actions. - 'merge': 'g', - 'split': 'k', - 'undo': 'ctrl+z', - 'redo': ('ctrl+shift+z', 'ctrl+y'), - 'move_best_to_noise': 'alt+n', - 'move_best_to_mua': 'alt+m', - 'move_best_to_good': 'alt+g', - 'move_match_to_noise': 'ctrl+n', - 'move_match_to_mua': 'ctrl+m', - 'move_match_to_good': 'ctrl+g', - 'move_both_to_noise': 'ctrl+alt+n', - 'move_both_to_mua': 'ctrl+alt+m', - 'move_both_to_good': 'ctrl+alt+g', - # Views. - 'show_view_shortcuts': 'h', - 'toggle_correlogram_normalization': 'n', - 'toggle_waveforms_overlap': 'o', - 'toggle_waveforms_mean': 'm', - 'show_features_time': 't', -} - - -cluster_manual_config = [ - # The wizard panel is less useful now that there's the stats panel. - # ('wizard', {'position': 'right'}), - ('stats', {'position': 'right'}), - ('features_grid', {'position': 'left'}), - ('features', {'position': 'left'}), - ('correlograms', {'position': 'left'}), - ('waveforms', {'position': 'right'}), - ('traces', {'position': 'right'}), -] - - -def _select_clusters(gui, args): - # Range: '5-12' - if '-' in args: - m, M = map(int, args.split('-')) - # The second one should be included. - M += 1 - clusters = list(range(m, M)) - # List of ids: '5 6 9 12' - else: - clusters = list(map(int, args.split(' '))) - gui.select(clusters) - - -cluster_manual_snippets = { - 'c': _select_clusters, -} - - -# Whether to ask the user if they want to save when the GUI is closed. -prompt_save_on_exit = True - - -# ----------------------------------------------------------------------------- -# Internal settings -# ----------------------------------------------------------------------------- - -waveforms_scale_factor = .01 -features_scale_factor = .01 -features_grid_scale_factor = features_scale_factor -traces_scale_factor = .01 diff --git a/phy/cluster/manual/gui.py b/phy/cluster/manual/gui.py deleted file mode 100644 index 385ca4978..000000000 --- a/phy/cluster/manual/gui.py +++ /dev/null @@ -1,619 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""GUI creator.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -import phy -from ...gui.base import BaseGUI -from ...gui.qt import _prompt -from .view_models import (WaveformViewModel, - FeatureGridViewModel, - FeatureViewModel, - CorrelogramViewModel, - TraceViewModel, - StatsViewModel, - ) -from ...utils.logging import debug, info, warn -from ...io.kwik.model import cluster_group_id -from ._history import GlobalHistory -from ._utils import ClusterMetadataUpdater -from .clustering import Clustering -from .wizard import Wizard, WizardViewModel - - -#------------------------------------------------------------------------------ -# Manual clustering window -#------------------------------------------------------------------------------ - -def _check_list_argument(arg, name='clusters'): - if not isinstance(arg, (list, tuple, np.ndarray)): - raise ValueError("The argument should be a list or an array.") - if len(name) == 0: - raise ValueError("No {0} were selected.".format(name)) - - -def _to_wizard_group(group_id): - """Return the group name required by the wizard, as a function - of the Kwik cluster group.""" - if hasattr(group_id, '__len__'): - group_id = group_id[0] - return { - 0: 'ignored', - 1: 'ignored', - 2: 'good', - 3: None, - None: None, - }.get(group_id, 'good') - - -def _process_ups(ups): - """This function processes the UpdateInfo instances of the two - undo stacks (clustering and cluster metadata) and concatenates them - into a single UpdateInfo instance.""" - if len(ups) == 0: - return - elif len(ups) == 1: - return ups[0] - elif len(ups) == 2: - up = ups[0] - up.update(ups[1]) - return up - else: - raise NotImplementedError() - - -class ClusterManualGUI(BaseGUI): - """Manual clustering GUI. - - This object represents a main window with: - - * multiple views - * high-level clustering methods - * global keyboard shortcuts - - Events - ------ - - cluster - select - request_save - - """ - - _vm_classes = { - 'waveforms': WaveformViewModel, - 'features': FeatureViewModel, - 'features_grid': FeatureGridViewModel, - 'correlograms': CorrelogramViewModel, - 'traces': TraceViewModel, - 'wizard': WizardViewModel, - 'stats': StatsViewModel, - } - - def __init__(self, model=None, store=None, cluster_ids=None, **kwargs): - self.store = store - self.wizard = Wizard() - self._is_dirty = False - self._cluster_ids = cluster_ids - super(ClusterManualGUI, self).__init__(model=model, - vm_classes=self._vm_classes, - **kwargs) - - def _initialize_views(self): - # The wizard needs to be started *before* the views are created, - # so that the first cluster selection is already set for the views - # when they're created. - self.connect(self._connect_view, event='add_view') - - @self.main_window.connect_ - def on_close_gui(): - # Return False if the close event should be discarded, so that - # the close Qt event is discarded. - return self._prompt_save() - - self.on_open() - self.wizard.start() - if self._cluster_ids is None: - self._cluster_ids = self.wizard.selection - super(ClusterManualGUI, self)._initialize_views() - - # View methods - # --------------------------------------------------------------------- - - @property - def title(self): - """Title of the main window.""" - name = self.__class__.__name__ - filename = getattr(self.model, 'kwik_path', 'mock') - clustering = self.model.clustering - channel_group = self.model.channel_group - template = ("{filename} (shank {channel_group}, " - "{clustering} clustering) " - "- {name} - phy {version}") - return template.format(name=name, - version=phy.__version_git__, - filename=filename, - channel_group=channel_group, - clustering=clustering, - ) - - def _connect_view(self, view): - """Connect a view to the GUI's events (select and cluster).""" - @self.connect - def on_select(cluster_ids, auto_update=True): - view.select(cluster_ids, auto_update=auto_update) - - @self.connect - def on_cluster(up): - view.on_cluster(up) - - def _connect_store(self): - @self.connect - def on_cluster(up=None): - self.store.update_spikes_per_cluster(self.model.spikes_per_cluster) - # No need to delete the old clusters from the store, we can keep - # them for possible undo, and regularly clean up the store. - for item in self.store.items.values(): - item.on_cluster(up) - - def _set_default_view_connections(self): - """Set view connections.""" - - # Select feature dimension from waveform view. - @self.connect_views('waveforms', 'features') - def channel_click(waveforms, features): - - @waveforms.view.connect - def on_channel_click(e): - # The box id is set when the features grid view is to be - # updated. - if e.box_idx is not None: - return - dimension = (e.channel_idx, 0) - features.set_dimension(e.ax, dimension) - features.update() - - # Select feature grid dimension from waveform view. - @self.connect_views('waveforms', 'features_grid') - def channel_click_grid(waveforms, features_grid): - - @waveforms.view.connect - def on_channel_click(e): - # The box id is set when the features grid view is to be - # updated. - if e.box_idx is None: - return - if not (1 <= e.box_idx <= features_grid.n_rows - 1): - return - dimension = (e.channel_idx, 0) - box = (e.box_idx, e.box_idx) - features_grid.set_dimension(e.ax, box, dimension) - features_grid.update() - - # Enlarge feature subplot. - @self.connect_views('features_grid', 'features') - def enlarge(grid, features): - - @grid.view.connect - def on_enlarge(e): - features.set_dimension('x', e.x_dim, smart=False) - features.set_dimension('y', e.y_dim, smart=False) - features.update() - - def _view_model_kwargs(self, name): - kwargs = {'model': self.model, - 'store': self.store, - 'wizard': self.wizard, - 'cluster_ids': self._cluster_ids, - } - return kwargs - - # Creation methods - # --------------------------------------------------------------------- - - def _get_clusters(self, which): - # Move best/match/both to noise/mua/good. - return { - 'best': [self.wizard.best], - 'match': [self.wizard.match], - 'both': [self.wizard.best, self.wizard.match], - }[which] - - def _create_actions(self): - for action in ['reset_gui', - 'save', - 'exit', - 'show_shortcuts', - 'select', - # Wizard. - 'reset_wizard', - 'first', - 'last', - 'next', - 'previous', - 'pin', - 'unpin', - # Actions. - 'merge', - 'split', - 'undo', - 'redo', - # Views. - 'toggle_correlogram_normalization', - 'toggle_waveforms_mean', - 'toggle_waveforms_overlap', - 'show_features_time', - ]: - self._add_gui_shortcut(action) - - def _make_func(which, group): - """Return a function that moves best/match/both clusters to - a group.""" - - def func(): - clusters = self._get_clusters(which) - if None in clusters: - return - self.move(clusters, group) - - name = 'move_{}_to_{}'.format(which, group) - func.__name__ = name - setattr(self, name, func) - return name - - for which in ('best', 'match', 'both'): - for group in ('noise', 'mua', 'good'): - self._add_gui_shortcut(_make_func(which, group)) - - def _create_cluster_metadata(self): - self._cluster_metadata_updater = ClusterMetadataUpdater( - self.model.cluster_metadata) - - @self.connect - def on_cluster(up): - for cluster in up.metadata_changed: - group_0 = self._cluster_metadata_updater.group(cluster) - group_1 = self.model.cluster_metadata.group(cluster) - assert group_0 == group_1 - - def _create_clustering(self): - self.clustering = Clustering(self.model.spike_clusters) - - @self.connect - def on_cluster(up): - spc = self.clustering.spikes_per_cluster - self.model.update_spikes_per_cluster(spc) - - def _create_global_history(self): - self._global_history = GlobalHistory(process_ups=_process_ups) - - def _create_wizard(self): - - # Initialize the groups for the wizard. - def _group(cluster): - group_id = self._cluster_metadata_updater.group(cluster) - return _to_wizard_group(group_id) - - groups = {cluster: _group(cluster) - for cluster in self.clustering.cluster_ids} - self.wizard.cluster_groups = groups - - self.wizard.reset() - - # Set the similarity and quality functions for the wizard. - @self.wizard.set_similarity_function - def similarity(target, candidate): - """Compute the distance between the mean masked features.""" - - mu_0 = self.store.mean_features(target).ravel() - mu_1 = self.store.mean_features(candidate).ravel() - - omeg_0 = self.store.mean_masks(target) - omeg_1 = self.store.mean_masks(candidate) - - omeg_0 = np.repeat(omeg_0, self.model.n_features_per_channel) - omeg_1 = np.repeat(omeg_1, self.model.n_features_per_channel) - - d_0 = mu_0 * omeg_0 - d_1 = mu_1 * omeg_1 - - # WARNING: "-" because this is a distance, not a score. - return -np.linalg.norm(d_0 - d_1) - - @self.wizard.set_quality_function - def quality(cluster): - """Return the maximum mean_masks across all channels - for a given cluster.""" - return self.store.mean_masks(cluster).max() - - @self.connect - def on_cluster(up): - # HACK: get the current group as it is not available in `up` - # currently. - if up.description.startswith('metadata'): - up = up.copy() - cluster = up.metadata_changed[0] - group = self.model.cluster_metadata.group(cluster) - up.metadata_value = _to_wizard_group(group) - - # This called for both regular and history actions. - # Save the wizard selection and update the wizard. - self.wizard.on_cluster(up) - - # Update the wizard selection after a clustering action. - self._wizard_select_after_clustering(up) - - def _wizard_select_after_clustering(self, up): - # Make as few updates as possible in the views after clustering - # actions. This allows for better before/after comparisons. - if up.added: - self.select(up.added, auto_update=False) - elif up.description == 'metadata_group': - # Select the last selected clusters after undo/redo. - if up.history and up.selection: - self.select(up.selection, auto_update=False) - return - cluster = up.metadata_changed[0] - if cluster == self.wizard.best: - self.wizard.next_best() - elif cluster == self.wizard.match: - self.wizard.next_match() - self._wizard_select() - elif up.selection: - self.select(up.selection, auto_update=False) - - # Open data - # ------------------------------------------------------------------------- - - def on_open(self): - """Reinitialize the GUI after new data has been loaded.""" - self._create_global_history() - # This connects the callback that updates the model spikes_per_cluster. - self._create_clustering() - self._create_cluster_metadata() - # This connects the callback that updates the store. - self._connect_store() - self._create_wizard() - self._is_dirty = False - - @self.connect - def on_cluster(up): - self._is_dirty = True - - def save(self): - """Save the changes.""" - # The session saves the model when this event is emitted. - self.emit('request_save') - self._is_dirty = False - - def _prompt_save(self): - """Display a prompt for saving and return whether the GUI should be - closed.""" - if (self.settings.get('prompt_save_on_exit', False) and - self._is_dirty): - res = _prompt(self.main_window, - "Do you want to save your changes?", - ('save', 'cancel', 'close')) - if res == 'save': - self.save() - return True - elif res == 'cancel': - return False - elif res == 'close': - return True - return True - - # General actions - # --------------------------------------------------------------------- - - def start(self): - """Start the wizard.""" - self.wizard.start() - self._cluster_ids = self.wizard.selection - - @property - def cluster_ids(self): - """Array of all cluster ids used in the current clustering.""" - return self.clustering.cluster_ids - - @property - def n_clusters(self): - """Number of clusters in the current clustering.""" - return self.clustering.n_clusters - - # View-related actions - # --------------------------------------------------------------------- - - def toggle_correlogram_normalization(self): - """Toggle CCG normalization in the correlograms views.""" - for vm in self.get_views('correlograms'): - vm.toggle_normalization() - - def toggle_waveforms_mean(self): - """Toggle mean mode in the waveform views.""" - for vm in self.get_views('waveforms'): - vm.show_mean = not(vm.show_mean) - - def toggle_waveforms_overlap(self): - """Toggle cluster overlap in the waveform views.""" - for vm in self.get_views('waveforms'): - vm.overlap = not(vm.overlap) - - def show_features_time(self): - """Set the x dimension to time in all feature views.""" - for vm in self.get_views('features'): - vm.set_dimension('x', 'time') - vm.update() - - # Selection - # --------------------------------------------------------------------- - - def select(self, cluster_ids, **kwargs): - """Select clusters.""" - cluster_ids = list(cluster_ids) - assert len(cluster_ids) == len(set(cluster_ids)) - # Do not re-select an already-selected list of clusters. - if cluster_ids == self._cluster_ids: - return - if not set(cluster_ids) <= set(self.clustering.cluster_ids): - n_selected = len(cluster_ids) - cluster_ids = [cl for cl in cluster_ids - if cl in self.clustering.cluster_ids] - n_kept = len(cluster_ids) - warn("{} of the {} selected clusters do not exist.".format( - n_selected - n_kept, n_selected)) - if len(cluster_ids) >= 14: - warn("You cannot select more than 13 clusters in the GUI.") - return - debug("Select clusters {0:s}.".format(str(cluster_ids))) - self._cluster_ids = cluster_ids - self.emit('select', cluster_ids, **kwargs) - - @property - def selected_clusters(self): - """The list of selected clusters.""" - return self._cluster_ids - - # Wizard list - # --------------------------------------------------------------------- - - def _wizard_select(self, **kwargs): - self.select(self.wizard.selection, **kwargs) - - def reset_wizard(self): - """Restart the wizard.""" - self.wizard.start() - self._wizard_select() - - def first(self): - """Go to the first cluster proposed by the wizard.""" - self.wizard.first() - self._wizard_select() - - def last(self): - """Go to the last cluster proposed by the wizard.""" - self.wizard.last() - self._wizard_select() - - def next(self): - """Go to the next cluster proposed by the wizard.""" - self.wizard.next() - self._wizard_select() - - def previous(self): - """Go to the previous cluster proposed by the wizard.""" - self.wizard.previous() - self._wizard_select() - - def pin(self): - """Pin the current best cluster.""" - cluster = (self.selected_clusters[0] - if len(self.selected_clusters) else None) - self.wizard.pin(cluster) - self._wizard_select() - - def unpin(self): - """Unpin the current best cluster.""" - self.wizard.unpin() - self._wizard_select() - - # Cluster actions - # --------------------------------------------------------------------- - - def merge(self, clusters=None): - """Merge some clusters.""" - if clusters is None: - clusters = self.selected_clusters - clusters = list(clusters) - if len(clusters) <= 1: - return - up = self.clustering.merge(clusters) - up.selection = self.selected_clusters - info("Merge clusters {} to {}.".format(str(clusters), - str(up.added[0]))) - self._global_history.action(self.clustering) - self.emit('cluster', up=up) - return up - - def _spikes_to_split(self): - """Find the spikes lasso selected in a feature view for split.""" - for features in self.get_views('features', 'features_grid'): - spikes = features.spikes_in_lasso() - if spikes is not None: - features.lasso.clear() - return spikes - - def split(self, spikes=None): - """Make a new cluster out of some spikes. - - Notes - ----- - - Spikes belonging to affected clusters, but not part of the `spikes` - array, will move to brand new cluster ids. This is because a new - cluster id must be used as soon as a cluster changes. - - """ - if spikes is None: - spikes = self._spikes_to_split() - if spikes is None: - return - if not len(spikes): - info("No spikes to split.") - return - _check_list_argument(spikes, 'spikes') - info("Split {0:d} spikes.".format(len(spikes))) - up = self.clustering.split(spikes) - up.selection = self.selected_clusters - self._global_history.action(self.clustering) - self.emit('cluster', up=up) - return up - - def move(self, clusters, group): - """Move some clusters to a cluster group. - - Here is the list of cluster groups: - - * 0=Noise - * 1=MUA - * 2=Good - * 3=Unsorted - - """ - _check_list_argument(clusters) - info("Move clusters {0} to {1}.".format(str(clusters), group)) - group_id = cluster_group_id(group) - up = self._cluster_metadata_updater.set_group(clusters, group_id) - up.selection = self.selected_clusters - self._global_history.action(self._cluster_metadata_updater) - self.emit('cluster', up=up) - return up - - def _undo_redo(self, up): - if up: - info("{} {}.".format(up.history.title(), - up.description, - )) - self.emit('cluster', up=up) - - def undo(self): - """Undo the last clustering action.""" - up = self._global_history.undo() - self._undo_redo(up) - return up - - def redo(self): - """Redo the last undone action.""" - # debug("The saved selection before the undo is {}.".format(clusters)) - up = self._global_history.redo() - if up: - up.selection = self.selected_clusters - self._undo_redo(up) - return up diff --git a/phy/cluster/manual/gui_component.py b/phy/cluster/manual/gui_component.py new file mode 100644 index 000000000..f9fdee400 --- /dev/null +++ b/phy/cluster/manual/gui_component.py @@ -0,0 +1,563 @@ +# -*- coding: utf-8 -*- + +"""Manual clustering GUI component.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +from collections import OrderedDict +from functools import partial +import logging + +import numpy as np + +from ._history import GlobalHistory +from ._utils import create_cluster_meta +from .clustering import Clustering +from phy.gui.actions import Actions +from phy.gui.widgets import Table + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Utility functions +# ----------------------------------------------------------------------------- + +def _process_ups(ups): # pragma: no cover + """This function processes the UpdateInfo instances of the two + undo stacks (clustering and cluster metadata) and concatenates them + into a single UpdateInfo instance.""" + if len(ups) == 0: + return + elif len(ups) == 1: + return ups[0] + elif len(ups) == 2: + up = ups[0] + up.update(ups[1]) + return up + else: + raise NotImplementedError() + + +# ----------------------------------------------------------------------------- +# Clustering GUI component +# ----------------------------------------------------------------------------- + +class ClusterView(Table): + def __init__(self): + super(ClusterView, self).__init__() + self.add_styles(''' + table tr[data-good='true'] { + color: #86D16D; + } + ''') + + @property + def state(self): + return {'sort_by': self.current_sort} + + def set_state(self, state): + sort_by = state.get('sort_by', None) + if sort_by: + self.sort_by(*sort_by) + + +class ManualClustering(object): + """Component that brings manual clustering facilities to a GUI: + + * Clustering instance: merge, split, undo, redo + * ClusterMeta instance: change cluster metadata (e.g. group) + * Selection + * Many manual clustering-related actions, snippets, shortcuts, etc. + + Parameters + ---------- + + spike_clusters : ndarray + spikes_per_clusters : function `cluster_id -> spike_ids` + cluster_groups : dictionary + shortcuts : dict + quality: func + similarity: func + + GUI events + ---------- + + When this component is attached to a GUI, the GUI emits the following + events: + + select(cluster_ids) + when clusters are selected + cluster(up) + when a merge or split happens + request_save(spike_clusters, cluster_groups) + when a save is requested by the user + + """ + + default_shortcuts = { + # Clustering. + 'merge': 'g', + 'split': 'k', + + # Move. + 'move_best_to_noise': 'alt+n', + 'move_best_to_mua': 'alt+m', + 'move_best_to_good': 'alt+g', + + 'move_similar_to_noise': 'ctrl+n', + 'move_similar_to_mua': 'ctrl+m', + 'move_similar_to_good': 'ctrl+g', + + 'move_all_to_noise': 'ctrl+alt+n', + 'move_all_to_mua': 'ctrl+alt+m', + 'move_all_to_good': 'ctrl+alt+g', + + # Wizard. + 'reset': 'ctrl+alt+space', + 'next': 'space', + 'previous': 'shift+space', + 'next_best': 'down', + 'previous_best': 'up', + + # Misc. + 'save': 'Save', + 'show_shortcuts': 'Save', + 'undo': 'Undo', + 'redo': ('ctrl+shift+z', 'ctrl+y'), + } + + def __init__(self, + spike_clusters, + spikes_per_cluster, + cluster_groups=None, + shortcuts=None, + quality=None, + similarity=None, + new_cluster_id=None, + ): + + self.gui = None + self.quality = quality # function cluster => quality + self.similarity = similarity # function cluster => [(cl, sim), ...] + + assert hasattr(spikes_per_cluster, '__call__') + self.spikes_per_cluster = spikes_per_cluster + + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + + # Create Clustering and ClusterMeta. + self.clustering = Clustering(spike_clusters, + new_cluster_id=new_cluster_id) + self.cluster_meta = create_cluster_meta(cluster_groups) + self._global_history = GlobalHistory(process_ups=_process_ups) + self._register_logging() + + # Create the cluster views. + self._create_cluster_views() + self._add_default_columns() + + self._best = None + self._current_similarity_values = {} + + # Internal methods + # ------------------------------------------------------------------------- + + def _register_logging(self): + # Log the actions. + @self.clustering.connect + def on_cluster(up): + if up.history: + logger.info(up.history.title() + " cluster assign.") + elif up.description == 'merge': + logger.info("Merge clusters %s to %s.", + ', '.join(map(str, up.deleted)), + up.added[0]) + else: + logger.info("Assigned %s spikes.", len(up.spike_ids)) + + if self.gui: + self.gui.emit('cluster', up) + + @self.cluster_meta.connect # noqa + def on_cluster(up): + if up.history: + logger.info(up.history.title() + " move.") + else: + logger.info("Move clusters %s to %s.", + ', '.join(map(str, up.metadata_changed)), + up.metadata_value) + + if self.gui: + self.gui.emit('cluster', up) + + def _add_default_columns(self): + # Default columns. + @self.add_column(name='n_spikes') + def n_spikes(cluster_id): + return len(self.spikes_per_cluster(cluster_id)) + + @self.add_column(show=False) + def skip(cluster_id): + """Whether to skip that cluster.""" + return (self.cluster_meta.get('group', cluster_id) + in ('noise', 'mua')) + + @self.add_column(show=False) + def good(cluster_id): + """Good column for color.""" + return self.cluster_meta.get('group', cluster_id) == 'good' + + def similarity(cluster_id): + # NOTE: there is a dictionary with the similarity to the current + # best cluster. It is updated when the selection changes in the + # cluster view. This is a bit of a hack: the HTML table expects + # a function that returns a value for every row, but here we + # cache all similarity view rows in self._current_similarity_values + return self._current_similarity_values.get(cluster_id, 0) + if self.similarity: + self.similarity_view.add_column(similarity, + name=self.similarity.__name__) + + def _create_actions(self, gui): + self.actions = Actions(gui, + name='Clustering', + menu='&Clustering', + default_shortcuts=self.shortcuts) + + # Selection. + self.actions.add(self.select, alias='c') + self.actions.separator() + + # Clustering. + self.actions.add(self.merge, alias='g') + self.actions.add(self.split, alias='k') + self.actions.separator() + + # Move. + self.actions.add(self.move) + + for group in ('noise', 'mua', 'good'): + self.actions.add(partial(self.move_best, group), + name='move_best_to_' + group, + docstring='Move the best clusters to %s.' % group) + self.actions.add(partial(self.move_similar, group), + name='move_similar_to_' + group, + docstring='Move the similar clusters to %s.' % + group) + self.actions.add(partial(self.move_all, group), + name='move_all_to_' + group, + docstring='Move all selected clusters to %s.' % + group) + self.actions.separator() + + # Others. + self.actions.add(self.undo) + self.actions.add(self.redo) + self.actions.add(self.save) + + # Wizard. + self.actions.add(self.reset, menu='&Wizard') + self.actions.add(self.next, menu='&Wizard') + self.actions.add(self.previous, menu='&Wizard') + self.actions.add(self.next_best, menu='&Wizard') + self.actions.add(self.previous_best, menu='&Wizard') + self.actions.separator() + + def _create_cluster_views(self): + # Create the cluster view. + self.cluster_view = ClusterView() + self.cluster_view.build() + + # Create the similarity view. + self.similarity_view = ClusterView() + self.similarity_view.build() + + # Selection in the cluster view. + @self.cluster_view.connect_ + def on_select(cluster_ids): + # Emit GUI.select when the selection changes in the cluster view. + self._emit_select(cluster_ids) + # Pin the clusters and update the similarity view. + self._update_similarity_view() + + # Selection in the similarity view. + @self.similarity_view.connect_ # noqa + def on_select(cluster_ids): + # Select the clusters from both views. + cluster_ids = self.cluster_view.selected + cluster_ids + self._emit_select(cluster_ids) + + # Save the current selection when an action occurs. + def on_request_undo_state(up): + return {'selection': (self.cluster_view.selected, + self.similarity_view.selected)} + + self.clustering.connect(on_request_undo_state) + self.cluster_meta.connect(on_request_undo_state) + + self._update_cluster_view() + + def _update_cluster_view(self): + """Initialize the cluster view with cluster data.""" + logger.log(5, "Update the cluster view.") + cluster_ids = [int(c) for c in self.clustering.cluster_ids] + self.cluster_view.set_rows(cluster_ids) + + def _update_similarity_view(self): + """Update the similarity view with matches for the specified + clusters.""" + if not self.similarity: + return + selection = self.cluster_view.selected + if not len(selection): + return + cluster_id = selection[0] + cluster_ids = self.clustering.cluster_ids + self._best = cluster_id + logger.log(5, "Update the similarity view.") + # This is a list of pairs (closest_cluster, similarity). + similarities = self.similarity(cluster_id) + # We save the similarity values wrt the currently-selected clusters. + # Note that we keep the order of the output of the self.similary() + # function. + clusters_sim = OrderedDict([(int(cl), s) for (cl, s) in similarities]) + # List of similar clusters, remove non-existing ones. + clusters = [c for c in clusters_sim.keys() + if c in cluster_ids] + # The similarity view will use these values. + self._current_similarity_values = clusters_sim + # Set the rows of the similarity view. + # TODO: instead of the self._current_similarity_values hack, + # give the possibility to specify the values here (?). + self.similarity_view.set_rows([c for c in clusters + if c not in selection]) + + def _emit_select(self, cluster_ids): + """Choose spikes from the specified clusters and emit the + `select` event on the GUI.""" + logger.debug("Select clusters: %s.", ', '.join(map(str, cluster_ids))) + if self.gui: + self.gui.emit('select', cluster_ids) + + # Public methods + # ------------------------------------------------------------------------- + + def add_column(self, func=None, name=None, show=True, default=False): + if func is None: + return lambda f: self.add_column(f, name=name, show=show, + default=default) + name = name or func.__name__ + assert name + self.cluster_view.add_column(func, name=name, show=show) + self.similarity_view.add_column(func, name=name, show=show) + if default: + self.set_default_sort(name) + + def set_default_sort(self, name, sort_dir='desc'): + assert name + logger.debug("Set default sort `%s` %s.", name, sort_dir) + # Set the default sort. + self.cluster_view.set_default_sort(name, sort_dir) + # Reset the cluster view. + self._update_cluster_view() + # Sort by the default sort. + self.cluster_view.sort_by(name, sort_dir) + + def on_cluster(self, up): + """Update the cluster views after clustering actions.""" + + similar = self.similarity_view.selected + + # Reinitialize the cluster view if clusters have changed. + if up.added: + self._update_cluster_view() + + # Select all new clusters in view 1. + if up.history == 'undo': + # Select the clusters that were selected before the undone + # action. + clusters_0, clusters_1 = up.undo_state[0]['selection'] + self.cluster_view.select(clusters_0) + self.similarity_view.select(clusters_1) + elif up.added: + if up.description == 'assign': + # NOTE: we reverse the order such that the last selected + # cluster (with a new color) is the split cluster. + added = up.added[::-1] + else: + added = up.added + self.select(added) + if similar: + self.similarity_view.next() + elif up.metadata_changed: + # Select next in similarity view if all moved are in that view. + if set(up.metadata_changed) <= set(similar): + + # Update the cluster view, and select the clusters that + # were selected before the action. + selected = self.similarity_view.selected + self._update_similarity_view() + self.similarity_view.select(selected, do_emit=False) + self.similarity_view.next() + # Otherwise, select next in cluster view. + else: + # Update the cluster view, and select the clusters that + # were selected before the action. + selected = self.cluster_view.selected + self._update_cluster_view() + self.cluster_view.select(selected, do_emit=False) + self.cluster_view.next() + if similar: + self.similarity_view.next() + + def attach(self, gui): + self.gui = gui + + # Create the actions. + self._create_actions(gui) + + # Add the cluster views. + gui.add_view(self.cluster_view, name='ClusterView') + + # Add the quality column in the cluster view. + if self.quality: + self.cluster_view.add_column(self.quality, + name=self.quality.__name__, + ) + + # Update the cluster view and sort by n_spikes at the beginning. + self._update_cluster_view() + # if not self.quality: + # self.cluster_view.sort_by('n_spikes', 'desc') + + # Add the similarity view if there is a similarity function. + if self.similarity: + gui.add_view(self.similarity_view, name='SimilarityView') + + # Set the view state. + cv = self.cluster_view + cv.set_state(gui.state.get_view_state(cv)) + + # Save the view state in the GUI state. + @gui.connect_ + def on_close(): + gui.state.update_view_state(cv, cv.state) + # NOTE: create_gui() already saves the state, but the event + # is registered *before* we add all views. + gui.state.save() + + # Update the cluster views and selection when a cluster event occurs. + self.gui.connect_(self.on_cluster) + return self + + # Selection actions + # ------------------------------------------------------------------------- + + def select(self, *cluster_ids): + """Select a list of clusters.""" + # HACK: allow for `select(1, 2, 3)` in addition to `select([1, 2, 3])` + # This makes it more convenient to select multiple clusters with + # the snippet: `:c 1 2 3` instead of `:c 1,2,3`. + if cluster_ids and isinstance(cluster_ids[0], (tuple, list)): + cluster_ids = list(cluster_ids[0]) + list(cluster_ids[1:]) + # Update the cluster view selection. + self.cluster_view.select(cluster_ids) + + @property + def selected(self): + return self.cluster_view.selected + self.similarity_view.selected + + # Clustering actions + # ------------------------------------------------------------------------- + + def merge(self, cluster_ids=None): + """Merge the selected clusters.""" + if cluster_ids is None: + cluster_ids = self.selected + if len(cluster_ids or []) <= 1: + return + self.clustering.merge(cluster_ids) + self._global_history.action(self.clustering) + + def split(self, spike_ids=None): + """Split the selected spikes.""" + if spike_ids is None: + spike_ids = self.gui.emit('request_split') + spike_ids = np.concatenate(spike_ids).astype(np.int64) + if len(spike_ids) == 0: + return + self.clustering.split(spike_ids) + self._global_history.action(self.clustering) + + # Move actions + # ------------------------------------------------------------------------- + + def move(self, cluster_ids, group): + """Move clusters to a group.""" + if len(cluster_ids) == 0: + return + self.cluster_meta.set('group', cluster_ids, group) + self._global_history.action(self.cluster_meta) + + def move_best(self, group): + """Move all selected best clusters to a group.""" + self.move(self.cluster_view.selected, group) + + def move_similar(self, group): + """Move all selected similar clusters to a group.""" + self.move(self.similarity_view.selected, group) + + def move_all(self, group): + """Move all selected clusters to a group.""" + self.move(self.selected, group) + + # Wizard actions + # ------------------------------------------------------------------------- + + def reset(self): + """Reset the wizard.""" + self._update_cluster_view() + self.cluster_view.next() + + def next_best(self): + """Select the next best cluster.""" + self.cluster_view.next() + + def previous_best(self): + """Select the previous best cluster.""" + self.cluster_view.previous() + + def next(self): + """Select the next cluster.""" + if not self.selected: + self.cluster_view.next() + else: + self.similarity_view.next() + + def previous(self): + """Select the previous cluster.""" + self.similarity_view.previous() + + # Other actions + # ------------------------------------------------------------------------- + + def undo(self): + """Undo the last action.""" + self._global_history.undo() + + def redo(self): + """Undo the last undone action.""" + self._global_history.redo() + + def save(self): + """Save the manual clustering back to disk.""" + spike_clusters = self.clustering.spike_clusters + groups = {c: self.cluster_meta.get('group', c) or 'unsorted' + for c in self.clustering.cluster_ids} + self.gui.emit('request_save', spike_clusters, groups) diff --git a/phy/cluster/manual/static/__init__.py b/phy/cluster/manual/static/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/cluster/manual/static/styles.css b/phy/cluster/manual/static/styles.css deleted file mode 100644 index 66ae9081d..000000000 --- a/phy/cluster/manual/static/styles.css +++ /dev/null @@ -1,56 +0,0 @@ -/* Wizard panel */ - -.control-panel { - margin: 0; - font-weight: bold; - font-size: 24pt; - padding: 10px; - text-align: center -} - -.control-panel > div { - display: inline-block; - margin: 0 auto; -} - -.control-panel .best { - margin-right: 20px; -} - -.control-panel .match { -} - -.control-panel > div .id { - margin: 10px 0 20px 0; - height: 40px; - text-align: center; - vertical-align: middle; - padding: 5px 0 10px 0; -} - -/* Progress bar */ -.control-panel progress[value] { - width: 200px; -} - -/* Cluster group */ -.control-panel .unsorted { -} - -.control-panel .good { - background-color: #243318; -} - -.control-panel .ignored { - background-color: #331d12; -} - -/* Stats panel */ - -.stats td, .stats th { - padding: 0 20px 0 0; -} - -.stats td { - text-align: right; -} diff --git a/phy/cluster/manual/static/wizard.html b/phy/cluster/manual/static/wizard.html deleted file mode 100644 index 6a38170d7..000000000 --- a/phy/cluster/manual/static/wizard.html +++ /dev/null @@ -1,14 +0,0 @@ -
-
-
{best}
-
- -
-
-
-
{match}
-
- -
-
-
diff --git a/phy/cluster/manual/tests/conftest.py b/phy/cluster/manual/tests/conftest.py new file mode 100644 index 000000000..01eca9764 --- /dev/null +++ b/phy/cluster/manual/tests/conftest.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +"""Test fixtures.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import fixture + +import numpy as np + +from phy.cluster.manual.controller import Controller +from phy.electrode.mea import staggered_positions +from phy.io.array import (get_closest_clusters, + _spikes_in_clusters, + ) +from phy.io.mock import (artificial_waveforms, + artificial_features, + artificial_masks, + artificial_traces, + ) + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@fixture +def cluster_ids(): + return [0, 1, 2, 10, 11, 20, 30] + # i, g, N, i, g, N, N + + +@fixture +def cluster_groups(): + return {0: 'noise', 1: 'good', 10: 'mua', 11: 'good'} + + +@fixture +def quality(): + def quality(c): + return c + return quality + + +@fixture +def similarity(cluster_ids): + sim = lambda c, d: (c * 1.01 + d) + + def similarity(c): + return get_closest_clusters(c, cluster_ids, sim) + return similarity + + +class MockController(Controller): + def _init_data(self): + self.cache_dir = self.config_dir + self.n_samples_waveforms = 31 + self.n_samples_t = 20000 + self.n_channels = 11 + self.n_clusters = 4 + self.n_spikes_per_cluster = 50 + n_spikes_total = self.n_clusters * self.n_spikes_per_cluster + n_features_per_channel = 4 + + self.n_channels = self.n_channels + self.n_spikes = n_spikes_total + self.sample_rate = 20000. + self.duration = self.n_samples_t / float(self.sample_rate) + self.spike_times = np.arange(0, self.duration, 100. / self.sample_rate) + self.spike_clusters = np.repeat(np.arange(self.n_clusters), + self.n_spikes_per_cluster) + assert len(self.spike_times) == len(self.spike_clusters) + self.cluster_ids = np.unique(self.spike_clusters) + self.channel_positions = staggered_positions(self.n_channels) + + sc = self.spike_clusters + self.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c]) + self.spike_count = lambda c: len(self.spikes_per_cluster(c)) + self.n_features_per_channel = n_features_per_channel + self.cluster_groups = {c: None for c in range(self.n_clusters)} + + self.all_traces = artificial_traces(self.n_samples_t, self.n_channels) + self.all_masks = artificial_masks(n_spikes_total, self.n_channels) + self.all_waveforms = artificial_waveforms(n_spikes_total, + self.n_samples_waveforms, + self.n_channels) + self.all_features = artificial_features(n_spikes_total, + self.n_channels, + self.n_features_per_channel) diff --git a/phy/cluster/manual/tests/test_clustering.py b/phy/cluster/manual/tests/test_clustering.py index 59a9b34c2..86124dc67 100644 --- a/phy/cluster/manual/tests/test_clustering.py +++ b/phy/cluster/manual/tests/test_clustering.py @@ -9,20 +9,17 @@ import numpy as np from numpy.testing import assert_array_equal as ae from pytest import raises -from six import itervalues -from ....io.mock import artificial_spike_clusters -from ....utils.array import (_spikes_in_clusters, - _flatten_spikes_per_cluster, - ) +from phy.io.mock import artificial_spike_clusters +from phy.io.array import (_spikes_in_clusters,) from ..clustering import (_extend_spikes, _concatenate_spike_clusters, - _extend_assignement, + _extend_assignment, Clustering) #------------------------------------------------------------------------------ -# Test assignements +# Test assignments #------------------------------------------------------------------------------ def test_extend_spikes_simple(): @@ -72,7 +69,7 @@ def test_concatenate_spike_clusters(): ae(clusters, np.arange(0, 60 + 1, 10)) -def test_extend_assignement(): +def test_extend_assignment(): spike_clusters = np.array([3, 5, 2, 9, 5, 5, 2]) spike_ids = np.array([0, 2]) @@ -84,17 +81,21 @@ def test_extend_assignement(): # This should not depend on the index chosen. for to in (123, 0, 1, 2, 3): clusters_rel = [123] * len(spike_ids) - new_spike_ids, new_cluster_ids = _extend_assignement(spike_ids, - spike_clusters, - clusters_rel) + new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, + spike_clusters, + clusters_rel, + 10, + ) ae(new_spike_ids, [0, 2, 6]) ae(new_cluster_ids, [10, 10, 11]) # Second case: we assign the spikes to different clusters. clusters_rel = [0, 1] - new_spike_ids, new_cluster_ids = _extend_assignement(spike_ids, - spike_clusters, - clusters_rel) + new_spike_ids, new_cluster_ids = _extend_assignment(spike_ids, + spike_clusters, + clusters_rel, + 10, + ) ae(new_spike_ids, [0, 2, 6]) ae(new_cluster_ids, [10, 11, 12]) @@ -103,17 +104,15 @@ def test_extend_assignement(): # Test clustering #------------------------------------------------------------------------------ -def _check_spikes_per_cluster(clustering): - ae(_flatten_spikes_per_cluster(clustering.spikes_per_cluster), - clustering.spike_clusters) - - def test_clustering_split(): spike_clusters = np.array([2, 5, 3, 2, 7, 5, 2]) # Instantiate a Clustering instance. clustering = Clustering(spike_clusters) ae(clustering.spike_clusters, spike_clusters) + n_spikes = len(spike_clusters) + assert clustering.n_spikes == n_spikes + ae(clustering.spike_ids, np.arange(n_spikes)) splits = [[0], [1], @@ -136,13 +135,11 @@ def test_clustering_split(): for to_split in splits: clustering.reset() clustering.split(to_split) - _check_spikes_per_cluster(clustering) # Test many splits, without reset this time. clustering.reset() for to_split in splits: clustering.split(to_split) - _check_spikes_per_cluster(clustering) def test_clustering_descendants_merge(): @@ -159,7 +156,6 @@ def test_clustering_descendants_merge(): new = up.added[0] assert new == 8 assert up.descendants == [(2, 8), (3, 8)] - _check_spikes_per_cluster(clustering) with raises(ValueError): up = clustering.merge([2, 8]) @@ -168,7 +164,6 @@ def test_clustering_descendants_merge(): new = up.added[0] assert new == 9 assert up.descendants == [(5, 9), (8, 9)] - _check_spikes_per_cluster(clustering) def test_clustering_descendants_split(): @@ -188,7 +183,6 @@ def test_clustering_descendants_split(): assert up.added == [8, 9] assert up.descendants == [(2, 8), (2, 9)] ae(clustering.spike_clusters, [8, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) # Undo. up = clustering.undo() @@ -196,7 +190,6 @@ def test_clustering_descendants_split(): assert up.added == [2] assert set(up.descendants) == set([(8, 2), (9, 2)]) ae(clustering.spike_clusters, spike_clusters) - _check_spikes_per_cluster(clustering) # Redo. up = clustering.redo() @@ -204,7 +197,6 @@ def test_clustering_descendants_split(): assert up.added == [8, 9] assert up.descendants == [(2, 8), (2, 9)] ae(clustering.spike_clusters, [8, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) # Second split: just replace cluster 8 by 10 (1 spike in it). up = clustering.split([0]) @@ -212,7 +204,6 @@ def test_clustering_descendants_split(): assert up.added == [10] assert up.descendants == [(8, 10)] ae(clustering.spike_clusters, [10, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) # Undo again. up = clustering.undo() @@ -220,7 +211,6 @@ def test_clustering_descendants_split(): assert up.added == [8] assert up.descendants == [(10, 8)] ae(clustering.spike_clusters, [8, 5, 3, 9, 7, 5, 9]) - _check_spikes_per_cluster(clustering) def test_clustering_merge(): @@ -242,6 +232,10 @@ def _assert_is_checkpoint(index): def _assert_spikes(clusters): ae(info.spike_ids, _spikes_in_clusters(spike_clusters, clusters)) + @clustering.connect + def on_request_undo_state(up): + return 'hello' + # Checkpoint 0. _checkpoint() _assert_is_checkpoint(0) @@ -261,6 +255,7 @@ def _assert_spikes(clusters): assert info.added == [12] assert info.deleted == [2, 3] assert info.history is None + assert info.undo_state is None # undo_state is only returned when undoing. _assert_is_checkpoint(2) # Undo once. @@ -268,6 +263,7 @@ def _assert_spikes(clusters): assert info.added == [2, 3] assert info.deleted == [12] assert info.history == 'undo' + assert info.undo_state == ['hello'] _assert_is_checkpoint(1) # Redo. @@ -276,6 +272,7 @@ def _assert_spikes(clusters): assert info.added == [12] assert info.deleted == [2, 3] assert info.history == 'redo' + assert info.undo_state is None _assert_is_checkpoint(2) # No redo. @@ -308,7 +305,9 @@ def _assert_spikes(clusters): _assert_is_checkpoint(3) # We merge again. - assert clustering.new_cluster_id() == 14 + # NOTE: 14 has been wasted, move to 15: necessary to avoid explicit cache + # invalidation when caching clusterid-based functions. + assert clustering.new_cluster_id() == 15 assert any(clustering.spike_clusters == 13) assert all(clustering.spike_clusters != 14) info = clustering.merge([8, 7], 15) @@ -351,8 +350,9 @@ def _checkpoint(index=None): def _assert_is_checkpoint(index): ae(clustering.spike_clusters, checkpoints[index]) - def _assert_spikes(spikes): - ae(info.spike_ids, spikes) + @clustering.connect + def on_request_undo_state(up): + return 'hello' # Checkpoint 0. _checkpoint() @@ -363,33 +363,45 @@ def _assert_spikes(spikes): my_spikes_3 = np.unique(np.random.randint(low=0, high=n_spikes, size=1000)) my_spikes_4 = np.arange(n_spikes - 5) + # Edge cases. + clustering.assign([]) + with raises(ValueError): + clustering.merge([], 1) + # Checkpoint 1. info = clustering.split(my_spikes_1) _checkpoint() + assert info.description == 'assign' assert 10 in info.added assert info.history is None _assert_is_checkpoint(1) # Checkpoint 2. info = clustering.split(my_spikes_2) + assert info.description == 'assign' assert info.history is None _checkpoint() _assert_is_checkpoint(2) # Checkpoint 3. info = clustering.assign(my_spikes_3) + assert info.description == 'assign' assert info.history is None + assert info.undo_state is None _checkpoint() _assert_is_checkpoint(3) # Undo checkpoint 3. info = clustering.undo() + assert info.description == 'assign' assert info.history == 'undo' + assert info.undo_state == ['hello'] _checkpoint() _assert_is_checkpoint(2) # Checkpoint 4. info = clustering.assign(my_spikes_4) + assert info.description == 'assign' assert info.history is None _checkpoint(4) assert len(info.deleted) >= 2 @@ -415,67 +427,53 @@ def test_clustering_long(): assert clustering.new_cluster_id() == n_clusters assert clustering.n_clusters == n_clusters - assert len(clustering.cluster_counts) == n_clusters - assert sum(itervalues(clustering.cluster_counts)) == n_spikes - _check_spikes_per_cluster(clustering) - # Updating a cluster, method 1. spike_clusters_new = spike_clusters.copy() spike_clusters_new[:10] = 100 clustering.spike_clusters[:] = spike_clusters_new[:] # Need to update explicitely. - clustering._update_all_spikes_per_cluster() + clustering._new_cluster_id = 101 ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100]) # Updating a cluster, method 2. clustering.spike_clusters[:] = spike_clusters_base[:] clustering.spike_clusters[:10] = 100 - # Need to update manually. - clustering._update_all_spikes_per_cluster() + # HACK: need to update manually here. + clustering._new_cluster_id = 101 ae(clustering.cluster_ids, np.r_[np.arange(n_clusters), 100]) # Assign. new_cluster = 101 clustering.assign(np.arange(0, 10), new_cluster) assert new_cluster in clustering.cluster_ids - assert clustering.cluster_counts[new_cluster] == 10 assert np.all(clustering.spike_clusters[:10] == new_cluster) - _check_spikes_per_cluster(clustering) # Merge. - count = clustering.cluster_counts.copy() my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [2, 3]))[0] info = clustering.merge([2, 3]) my_spikes = info.spike_ids ae(my_spikes, my_spikes_0) assert (new_cluster + 1) in clustering.cluster_ids - assert clustering.cluster_counts[new_cluster + 1] == count[2] + count[3] assert np.all(clustering.spike_clusters[my_spikes] == (new_cluster + 1)) - _check_spikes_per_cluster(clustering) # Merge to a given cluster. clustering.spike_clusters[:] = spike_clusters_base[:] - clustering._update_all_spikes_per_cluster() + clustering._new_cluster_id = 11 + my_spikes_0 = np.nonzero(np.in1d(clustering.spike_clusters, [4, 6]))[0] - count = clustering.cluster_counts - count4, count6 = count[4], count[6] info = clustering.merge([4, 6], 11) my_spikes = info.spike_ids ae(my_spikes, my_spikes_0) assert 11 in clustering.cluster_ids - assert clustering.cluster_counts[11] == count4 + count6 assert np.all(clustering.spike_clusters[my_spikes] == 11) - _check_spikes_per_cluster(clustering) # Split. my_spikes = [1, 3, 5] clustering.split(my_spikes) assert np.all(clustering.spike_clusters[my_spikes] == 12) - _check_spikes_per_cluster(clustering) # Assign. clusters = [0, 1, 2] clustering.assign(my_spikes, clusters) clu = clustering.spike_clusters[my_spikes] ae(clu - clu[0], clusters) - _check_spikes_per_cluster(clustering) diff --git a/phy/cluster/manual/tests/test_controller.py b/phy/cluster/manual/tests/test_controller.py new file mode 100644 index 000000000..888b59121 --- /dev/null +++ b/phy/cluster/manual/tests/test_controller.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +"""Test controller.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op +from textwrap import dedent + +from .conftest import MockController + + +#------------------------------------------------------------------------------ +# Test controller +#------------------------------------------------------------------------------ + +def test_controller_1(qtbot, tempdir): + + plugin = dedent(''' + from phy import IPlugin + + class MockControllerPlugin(IPlugin): + def attach_to_controller(self, controller): + controller.hello = 'world' + + c = get_config() + c.MockController.plugins = ['MockControllerPlugin'] + + ''') + with open(op.join(tempdir, 'phy_config.py'), 'w') as f: + f.write(plugin) + + controller = MockController(config_dir=tempdir) + gui = controller.create_gui() + gui.show() + + # Ensure that the plugin has been loaded. + assert controller.hello == 'world' + + controller.manual_clustering.select([2, 3]) + + # qtbot.stop() + gui.close() diff --git a/phy/cluster/manual/tests/test_gui.py b/phy/cluster/manual/tests/test_gui.py deleted file mode 100644 index 31e5d6abe..000000000 --- a/phy/cluster/manual/tests/test_gui.py +++ /dev/null @@ -1,227 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of manual clustering GUI.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark -import numpy as np -from numpy.testing import assert_allclose as ac -from numpy.testing import assert_array_equal as ae - -from ..gui import ClusterManualGUI -from ....utils.settings import _load_default_settings -from ....utils.logging import set_level -from ....utils.array import _spikes_in_clusters -from ....io.mock import MockModel -from ....io.kwik.store_items import create_store - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Kwik tests -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -def _start_manual_clustering(config='none', shortcuts=None): - if config is 'none': - config = [] - model = MockModel() - spc = model.spikes_per_cluster - store = create_store(model, spikes_per_cluster=spc) - gui = ClusterManualGUI(model=model, store=store, - config=config, shortcuts=shortcuts) - return gui - - -def test_gui_clustering(qtbot): - - gui = _start_manual_clustering() - gui.show() - qtbot.addWidget(gui.main_window) - - cs = gui.store - spike_clusters = gui.model.spike_clusters.copy() - - f = gui.model.features - m = gui.model.masks - - def _check_arrays(cluster, clusters_for_sc=None, spikes=None): - """Check the features and masks in the cluster store - of a given custer.""" - if spikes is None: - if clusters_for_sc is None: - clusters_for_sc = [cluster] - spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc) - shape = (len(spikes), - len(gui.model.channel_order), - gui.model.n_features_per_channel) - ac(cs.features(cluster), f[spikes, :].reshape(shape)) - ac(cs.masks(cluster), m[spikes]) - - _check_arrays(0) - _check_arrays(2) - - # Merge two clusters. - clusters = [0, 2] - up = gui.merge(clusters) - new = up.added[0] - _check_arrays(new, clusters) - - # Split some spikes. - spikes = [2, 3, 5, 7, 11, 13] - # clusters = np.unique(spike_clusters[spikes]) - up = gui.split(spikes) - _check_arrays(new + 1, spikes=spikes) - - # Undo. - gui.undo() - _check_arrays(new, clusters) - - # Undo. - gui.undo() - _check_arrays(0) - _check_arrays(2) - - # Redo. - gui.redo() - _check_arrays(new, clusters) - - # Split some spikes. - spikes = [5, 7, 11, 13, 17, 19] - # clusters = np.unique(spike_clusters[spikes]) - gui.split(spikes) - _check_arrays(new + 1, spikes=spikes) - - # Test merge-undo-different-merge combo. - spc = gui.clustering.spikes_per_cluster.copy() - clusters = gui.cluster_ids[:3] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - # Undo. - gui.undo() - for cluster in clusters: - _check_arrays(cluster, spikes=spc[cluster]) - # Another merge. - clusters = gui.cluster_ids[1:5] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - - # Move a cluster to a group. - cluster = gui.cluster_ids[0] - gui.move([cluster], 2) - assert len(gui.store.mean_probe_position(cluster)) == 2 - - spike_clusters_new = gui.model.spike_clusters.copy() - # Check that the spike clusters have changed. - assert not np.all(spike_clusters_new == spike_clusters) - ac(gui.model.spike_clusters, gui.clustering.spike_clusters) - - gui.close() - - -def test_gui_move_wizard(qtbot): - gui = _start_manual_clustering() - qtbot.addWidget(gui.main_window) - gui.show() - - gui.next() - gui.pin() - gui.next() - best = gui.wizard.best - assert gui.selected_clusters[0] == best - match = gui.selected_clusters[1] - gui.move([gui.wizard.match], 'mua') - assert gui.selected_clusters[0] == best - assert gui.selected_clusters[1] != match - - gui.close() - - -def test_gui_wizard(qtbot): - gui = _start_manual_clustering() - n = gui.n_clusters - qtbot.addWidget(gui.main_window) - gui.show() - - clusters = np.arange(gui.n_clusters) - best_clusters = gui.wizard.best_clusters() - - # assert gui.wizard.best_clusters(1)[0] == best_clusters[0] - ae(np.unique(best_clusters), clusters) - assert len(gui.wizard.most_similar_clusters()) == n - 1 - - assert len(gui.wizard.most_similar_clusters(0, n_max=3)) == 3 - - clusters = gui.cluster_ids[:2] - up = gui.merge(clusters) - new = up.added[0] - assert np.all(np.in1d(gui.wizard.best_clusters(), - np.arange(clusters[-1] + 1, new + 1))) - assert np.all(np.in1d(gui.wizard.most_similar_clusters(new), - np.arange(clusters[-1] + 1, new))) - - gui.close() - - -@mark.long -def test_gui_history(qtbot): - - gui = _start_manual_clustering() - qtbot.addWidget(gui.main_window) - gui.show() - - gui.wizard.start() - - spikes = _spikes_in_clusters(gui.model.spike_clusters, - gui.wizard.selection) - gui.split(spikes[::3]) - gui.undo() - gui.wizard.next() - gui.redo() - gui.undo() - - for _ in range(10): - gui.merge(gui.wizard.selection) - gui.wizard.next() - gui.undo() - gui.wizard.next() - gui.redo() - gui.wizard.next() - - spikes = _spikes_in_clusters(gui.model.spike_clusters, - gui.wizard.selection) - if len(spikes): - gui.split(spikes[::10]) - gui.wizard.next() - gui.undo() - gui.merge(gui.wizard.selection) - gui.wizard.next() - gui.wizard.next() - - gui.wizard.next_best() - ae(gui.model.spike_clusters, gui.clustering.spike_clusters) - - gui.close() - - -@mark.long -def test_gui_gui(qtbot): - settings = _load_default_settings() - config = settings['cluster_manual_config'] - shortcuts = settings['cluster_manual_shortcuts'] - - gui = _start_manual_clustering(config=config, - shortcuts=shortcuts, - ) - qtbot.addWidget(gui.main_window) - gui.show() - gui.close() diff --git a/phy/cluster/manual/tests/test_gui_component.py b/phy/cluster/manual/tests/test_gui_component.py new file mode 100644 index 000000000..4f186a469 --- /dev/null +++ b/phy/cluster/manual/tests/test_gui_component.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- + +"""Test GUI component.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture, fixture +import numpy as np +from numpy.testing import assert_array_equal as ae +from vispy.util import keys + +from ..gui_component import (ManualClustering, + ) +from phy.io.array import _spikes_in_clusters +from phy.gui import GUI +from .conftest import MockController + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def gui(tempdir, qtbot): + gui = GUI(position=(200, 100), size=(500, 500), config_dir=tempdir) + gui.show() + qtbot.waitForWindowShown(gui) + yield gui + qtbot.wait(5) + gui.close() + del gui + qtbot.wait(5) + + +@fixture +def manual_clustering(qtbot, gui, cluster_ids, cluster_groups, + quality, similarity): + spike_clusters = np.array(cluster_ids) + spikes_per_cluster = lambda c: [c] + + mc = ManualClustering(spike_clusters, + spikes_per_cluster, + cluster_groups=cluster_groups, + shortcuts={'undo': 'ctrl+z'}, + quality=quality, + similarity=similarity, + ) + mc.attach(gui) + mc.set_default_sort(quality.__name__) + + return mc + + +#------------------------------------------------------------------------------ +# Test GUI component +#------------------------------------------------------------------------------ + +def test_manual_clustering_edge_cases(manual_clustering): + mc = manual_clustering + + # Empty selection at first. + ae(mc.clustering.cluster_ids, [0, 1, 2, 10, 11, 20, 30]) + + mc.select([0]) + assert mc.selected == [0] + + mc.undo() + mc.redo() + + # Merge. + mc.merge() + assert mc.selected == [0] + + mc.merge([]) + assert mc.selected == [0] + + mc.merge([10]) + assert mc.selected == [0] + + # Split. + mc.split([]) + assert mc.selected == [0] + + # Move. + mc.move([], 'ignored') + + mc.save() + + +def test_manual_clustering_skip(qtbot, gui, manual_clustering): + mc = manual_clustering + + # yield [0, 1, 2, 10, 11, 20, 30] + # # i, g, N, i, g, N, N + expected = [30, 20, 11, 2, 1] + + for clu in expected: + mc.cluster_view.next() + assert mc.selected == [clu] + + +def test_manual_clustering_merge(manual_clustering): + mc = manual_clustering + + mc.cluster_view.select([30]) + mc.similarity_view.select([20]) + assert mc.selected == [30, 20] + + mc.merge() + assert mc.selected == [31, 11] + + mc.undo() + assert mc.selected == [30, 20] + + mc.redo() + assert mc.selected == [31, 11] + + +def test_manual_clustering_split(manual_clustering): + mc = manual_clustering + + mc.select([1, 2]) + mc.split([1, 2]) + assert mc.selected == [31] + + mc.undo() + assert mc.selected == [1, 2] + + mc.redo() + assert mc.selected == [31] + + +def test_manual_clustering_split_2(gui, quality, similarity): + spike_clusters = np.array([0, 0, 1]) + + mc = ManualClustering(spike_clusters, + lambda c: _spikes_in_clusters(spike_clusters, [c]), + similarity=similarity, + ) + mc.attach(gui) + + mc.add_column(quality, name='quality', default=True) + mc.set_default_sort('quality', 'desc') + + mc.split([0]) + assert mc.selected == [3, 2] + + +def test_manual_clustering_state(tempdir, qtbot, gui, manual_clustering): + mc = manual_clustering + cv = mc.cluster_view + cv.sort_by('id') + gui.close() + assert cv.state['sort_by'] == ('id', 'asc') + cv.set_state(cv.state) + assert cv.state['sort_by'] == ('id', 'asc') + + +def test_manual_clustering_split_lasso(tempdir, qtbot): + controller = MockController(config_dir=tempdir) + gui = controller.create_gui() + mc = controller.manual_clustering + view = gui.list_views('FeatureView', is_visible=False)[0] + + gui.show() + + # Select one cluster. + mc.select(0) + + # Simulate a lasso. + ev = view.events + ev.mouse_press(pos=(210, 1), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(320, 1), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(320, 30), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(210, 30), button=1, modifiers=(keys.CONTROL,)) + + ups = [] + + @mc.clustering.connect + def on_cluster(up): + ups.append(up) + + mc.split() + up = ups[0] + assert up.description == 'assign' + assert up.added == [4, 5] + assert up.deleted == [0] + + # qtbot.stop() + gui.close() + + +def test_manual_clustering_move_1(manual_clustering): + mc = manual_clustering + + mc.select([20]) + assert mc.selected == [20] + + mc.move([20], 'noise') + assert mc.selected == [11] + + mc.undo() + assert mc.selected == [20] + + mc.redo() + assert mc.selected == [11] + + +def test_manual_clustering_move_2(manual_clustering): + mc = manual_clustering + + mc.select([20]) + mc.similarity_view.select([10]) + + assert mc.selected == [20, 10] + + mc.move([10], 'noise') + assert mc.selected == [20, 2] + + mc.undo() + assert mc.selected == [20, 10] + + mc.redo() + assert mc.selected == [20, 2] + + +#------------------------------------------------------------------------------ +# Test shortcuts +#------------------------------------------------------------------------------ + +def test_manual_clustering_action_reset(qtbot, manual_clustering): + mc = manual_clustering + + mc.actions.select([10, 11]) + + mc.actions.reset() + assert mc.selected == [30] + + mc.actions.next() + assert mc.selected == [30, 20] + + mc.actions.next() + assert mc.selected == [30, 11] + + mc.actions.previous() + assert mc.selected == [30, 20] + + +def test_manual_clustering_action_nav(qtbot, manual_clustering): + mc = manual_clustering + + mc.actions.reset() + assert mc.selected == [30] + + mc.actions.next_best() + assert mc.selected == [20] + + mc.actions.previous_best() + assert mc.selected == [30] + + +def test_manual_clustering_action_move_1(qtbot, manual_clustering): + mc = manual_clustering + + mc.actions.next() + + assert mc.selected == [30] + mc.actions.move_best_to_noise() + + assert mc.selected == [20] + mc.actions.move_best_to_mua() + + assert mc.selected == [11] + mc.actions.move_best_to_good() + + assert mc.selected == [2] + + mc.cluster_meta.get('group', 30) == 'noise' + mc.cluster_meta.get('group', 20) == 'mua' + mc.cluster_meta.get('group', 11) == 'good' + + # qtbot.stop() + + +def test_manual_clustering_action_move_2(manual_clustering): + mc = manual_clustering + + mc.select([30]) + mc.similarity_view.select([20]) + + assert mc.selected == [30, 20] + mc.actions.move_similar_to_noise() + + assert mc.selected == [30, 11] + mc.actions.move_similar_to_mua() + + assert mc.selected == [30, 2] + mc.actions.move_similar_to_good() + + assert mc.selected == [30, 1] + + mc.cluster_meta.get('group', 20) == 'noise' + mc.cluster_meta.get('group', 11) == 'mua' + mc.cluster_meta.get('group', 2) == 'good' + + +def test_manual_clustering_action_move_3(manual_clustering): + mc = manual_clustering + + mc.select([30]) + mc.similarity_view.select([20]) + + assert mc.selected == [30, 20] + mc.actions.move_all_to_noise() + + assert mc.selected == [11, 2] + mc.actions.move_all_to_mua() + + assert mc.selected == [1] + mc.actions.move_all_to_good() + + assert mc.selected == [1] + + mc.cluster_meta.get('group', 30) == 'noise' + mc.cluster_meta.get('group', 20) == 'noise' + + mc.cluster_meta.get('group', 11) == 'mua' + mc.cluster_meta.get('group', 10) == 'mua' + + mc.cluster_meta.get('group', 2) == 'good' + mc.cluster_meta.get('group', 1) == 'good' diff --git a/phy/cluster/manual/tests/test_history.py b/phy/cluster/manual/tests/test_history.py index 5fadc274d..4ccb04044 100644 --- a/phy/cluster/manual/tests/test_history.py +++ b/phy/cluster/manual/tests/test_history.py @@ -26,6 +26,9 @@ def _assert_current(item): item1 = np.ones(4) item2 = 2 * np.ones(5) + assert not history.is_first() + assert history.is_last() + history.add(item0) _assert_current(item0) diff --git a/phy/cluster/manual/tests/test_utils.py b/phy/cluster/manual/tests/test_utils.py index 4b707080b..651ffe3b0 100644 --- a/phy/cluster/manual/tests/test_utils.py +++ b/phy/cluster/manual/tests/test_utils.py @@ -6,39 +6,66 @@ # Imports #------------------------------------------------------------------------------ -from ....utils.logging import set_level, debug -from .._utils import ClusterMetadataUpdater, UpdateInfo -from ....io.kwik.model import ClusterMetadata +import logging + +from pytest import raises + +from .._utils import (ClusterMeta, UpdateInfo, + _update_cluster_selection, create_cluster_meta) + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ -def setup(): - set_level('debug') +def test_create_cluster_meta(): + cluster_groups = {2: 3, + 3: 3, + 5: 1, + 7: 2, + } + meta = create_cluster_meta(cluster_groups) + assert meta.group(2) == 3 + assert meta.group(3) == 3 + assert meta.group(5) == 1 + assert meta.group(7) == 2 + assert meta.group(8) is None -def teardown(): - set_level('info') +def test_metadata_history_simple(): + """Test ClusterMeta history.""" + meta = ClusterMeta() + meta.add_field('group') -def test_metadata_history(): - """Test ClusterMetadataUpdater history.""" + meta.set('group', 2, 2) + assert meta.get('group', 2) == 2 - data = {2: {'group': 2, 'color': 7}, 4: {'group': 5}} + meta.undo() + assert meta.get('group', 2) is None - base_meta = ClusterMetadata(data=data) + meta.redo() + assert meta.get('group', 2) == 2 - @base_meta.default - def group(cluster): - return 3 + with raises(AssertionError): + assert meta.to_dict('grou') is None + assert meta.to_dict('group') == {2: 2} - @base_meta.default - def color(cluster): - return 0 - meta = ClusterMetadataUpdater(base_meta) +def test_metadata_history_complex(): + """Test ClusterMeta history.""" + + meta = ClusterMeta() + meta.add_field('group', 3) + meta.add_field('color', 0) + + data = {2: {'group': 2, 'color': 7}, 4: {'group': 5}} + meta.from_dict(data) + + assert meta.group(2) == 2 + assert meta.group([4, 2]) == [5, 2] # Values set in 'data'. assert meta.group(2) == 2 @@ -57,19 +84,19 @@ def color(cluster): meta.redo() # Action 1. - info = meta.set_group(2, 20) + info = meta.set('group', 2, 20) assert meta.group(2) == 20 assert info.description == 'metadata_group' assert info.metadata_changed == [2] # Action 2. - info = meta.set_color(3, 30) + info = meta.set('color', 3, 30) assert meta.color(3) == 30 assert info.description == 'metadata_color' assert info.metadata_changed == [3] # Action 3. - info = meta.set_color(2, 40) + info = meta.set('color', 2, 40) assert meta.color(2) == 40 assert info.description == 'metadata_color' assert info.metadata_changed == [2] @@ -113,8 +140,56 @@ def color(cluster): assert info is None +def test_metadata_descendants(): + """Test ClusterMeta history.""" + + data = {0: {'group': 0}, + 1: {'group': 1}, + 2: {'group': 2}, + 3: {'group': 3}, + } + + meta = ClusterMeta() + + meta.add_field('group', 3) + meta.from_dict(data) + + meta.set_from_descendants([]) + assert meta.group(4) == 3 + + meta.set_from_descendants([(0, 4)]) + assert meta.group(4) == 0 + + # Reset to default. + meta.set('group', 4, 3) + meta.set_from_descendants([(1, 4)]) + assert meta.group(4) == 1 + + meta.set_from_descendants([(1, 5), (2, 5)]) + # This is the default value because the parents have different values. + assert meta.group(5) == 3 + + meta.set('group', 3, 2) + meta.set_from_descendants([(2, 6), (3, 6), (10, 10)]) + assert meta.group(6) == 2 + + # If the value of the new cluster is non-default, it should not + # be changed by set_from_descendants. + meta.set_from_descendants([(3, 2)]) + assert meta.group(2) == 2 + + +def test_update_cluster_selection(): + clusters = [1, 2, 3] + up = UpdateInfo(deleted=[2], added=[4, 0]) + assert _update_cluster_selection(clusters, up) == [1, 3, 4, 0] + + def test_update_info(): - debug(UpdateInfo(deleted=range(5), added=[5], description='merge')) - debug(UpdateInfo(deleted=range(5), added=[5], description='assign')) - debug(UpdateInfo(deleted=range(5), added=[5], - description='assign', history='undo')) + logger.debug(UpdateInfo()) + logger.debug(UpdateInfo(description='hello')) + logger.debug(UpdateInfo(deleted=range(5), added=[5], description='merge')) + logger.debug(UpdateInfo(deleted=range(5), added=[5], description='assign')) + logger.debug(UpdateInfo(deleted=range(5), added=[5], + description='assign', history='undo')) + logger.debug(UpdateInfo(metadata_changed=[2, 3], description='metadata')) diff --git a/phy/cluster/manual/tests/test_view_models.py b/phy/cluster/manual/tests/test_view_models.py deleted file mode 100644 index f973fcc0b..000000000 --- a/phy/cluster/manual/tests/test_view_models.py +++ /dev/null @@ -1,238 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of view model.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os - -from pytest import mark - -from ....utils.array import _spikes_per_cluster -from ....utils.logging import set_level -from ....utils.testing import (show_test_start, - show_test_stop, - show_test_run, - ) -from ....io.kwik.mock import create_mock_kwik -from ....io.kwik import KwikModel, create_store -from ..view_models import (WaveformViewModel, - FeatureGridViewModel, - CorrelogramViewModel, - TraceViewModel, - ) - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Utilities -#------------------------------------------------------------------------------ - -_N_CLUSTERS = 5 -_N_SPIKES = 200 -_N_CHANNELS = 28 -_N_FETS = 3 -_N_SAMPLES_TRACES = 10000 -_N_FRAMES = int((float(os.environ.get('PHY_EVENT_LOOP_DELAY', 0)) * 60) or 2) - - -def setup(): - set_level('info') - - -def _test_empty(tempdir, view_model_class, stop=True, **kwargs): - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=1, - n_spikes=1, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - model = KwikModel(filename) - spikes_per_cluster = _spikes_per_cluster(model.spike_ids, - model.spike_clusters) - store = create_store(model, - path=tempdir, - spikes_per_cluster=spikes_per_cluster, - features_masks_chunk_size=10, - waveforms_n_spikes_max=10, - waveforms_excerpt_size=5, - ) - store.generate() - - vm = view_model_class(model=model, store=store, **kwargs) - vm.on_open() - - # Show the view. - show_test_start(vm.view) - show_test_run(vm.view, _N_FRAMES) - vm.select([0]) - show_test_run(vm.view, _N_FRAMES) - vm.select([]) - show_test_run(vm.view, _N_FRAMES) - - if stop: - show_test_stop(vm.view) - - return vm - - -def _test_view_model(tempdir, view_model_class, stop=True, **kwargs): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - model = KwikModel(filename) - spikes_per_cluster = _spikes_per_cluster(model.spike_ids, - model.spike_clusters) - store = create_store(model, - path=tempdir, - spikes_per_cluster=spikes_per_cluster, - features_masks_chunk_size=15, - waveforms_n_spikes_max=20, - waveforms_excerpt_size=5, - ) - store.generate() - - vm = view_model_class(model=model, store=store, **kwargs) - vm.on_open() - show_test_start(vm.view) - - vm.select([2]) - show_test_run(vm.view, _N_FRAMES) - - vm.select([2, 3]) - show_test_run(vm.view, _N_FRAMES) - - vm.select([3, 2]) - show_test_run(vm.view, _N_FRAMES) - - if stop: - show_test_stop(vm.view) - - return vm - - -#------------------------------------------------------------------------------ -# Waveforms -#------------------------------------------------------------------------------ - -def test_waveforms_full(tempdir): - vm = _test_view_model(tempdir, WaveformViewModel, stop=False) - vm.overlap = True - show_test_run(vm.view, _N_FRAMES) - vm.show_mean = True - show_test_run(vm.view, _N_FRAMES) - show_test_stop(vm.view) - - -def test_waveforms_empty(tempdir): - _test_empty(tempdir, WaveformViewModel) - - -#------------------------------------------------------------------------------ -# Features -#------------------------------------------------------------------------------ - -def test_features_empty(tempdir): - _test_empty(tempdir, FeatureGridViewModel) - - -def test_features_full(tempdir): - _test_view_model(tempdir, FeatureGridViewModel, - marker_size=8, n_spikes_max=20) - - -def test_features_lasso(tempdir): - vm = _test_view_model(tempdir, - FeatureGridViewModel, - marker_size=8, - stop=False, - ) - show_test_run(vm.view, _N_FRAMES) - box = (1, 2) - vm.view.lasso.box = box - x, y = 0., 1. - vm.view.lasso.add((x, x)) - vm.view.lasso.add((y, x)) - vm.view.lasso.add((y, y)) - vm.view.lasso.add((x, y)) - show_test_run(vm.view, _N_FRAMES) - # Find spikes in lasso. - spikes = vm.spikes_in_lasso() - # Change their clusters. - vm.model.spike_clusters[spikes] = 3 - sc = vm.model.spike_clusters - vm.view.visual.spike_clusters = sc[vm.view.visual.spike_ids] - show_test_run(vm.view, _N_FRAMES) - show_test_stop(vm.view) - - -#------------------------------------------------------------------------------ -# Correlograms -#------------------------------------------------------------------------------ - -def test_ccg_empty(tempdir): - _test_empty(tempdir, - CorrelogramViewModel, - binsize=20, - winsize_bins=51, - n_excerpts=100, - excerpt_size=100, - ) - - -def test_ccg_simple(tempdir): - _test_view_model(tempdir, - CorrelogramViewModel, - binsize=10, - winsize_bins=61, - n_excerpts=80, - excerpt_size=120, - ) - - -def test_ccg_full(tempdir): - vm = _test_view_model(tempdir, - CorrelogramViewModel, - binsize=20, - winsize_bins=51, - n_excerpts=100, - excerpt_size=100, - stop=False, - ) - show_test_run(vm.view, _N_FRAMES) - vm.change_bins(half_width=100., bin=1.) - show_test_run(vm.view, _N_FRAMES) - show_test_stop(vm.view) - - -#------------------------------------------------------------------------------ -# Traces -#------------------------------------------------------------------------------ - -def test_traces_empty(tempdir): - _test_empty(tempdir, TraceViewModel) - - -def test_traces_simple(tempdir): - _test_view_model(tempdir, TraceViewModel) - - -def test_traces_full(tempdir): - vm = _test_view_model(tempdir, TraceViewModel, stop=False) - vm.move_right() - show_test_run(vm.view, _N_FRAMES) - vm.move_left() - show_test_run(vm.view, _N_FRAMES) - - show_test_stop(vm.view) diff --git a/phy/cluster/manual/tests/test_views.py b/phy/cluster/manual/tests/test_views.py new file mode 100644 index 000000000..10de92878 --- /dev/null +++ b/phy/cluster/manual/tests/test_views.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- + +"""Test views.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_equal as ae +from numpy.testing import assert_allclose as ac +from vispy.util import keys +from pytest import fixture + +from phy.utils import Bunch +from .conftest import MockController +from ..views import (ScatterView, + _extract_wave, + _extend, + ) + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +@fixture +def state(tempdir): + # Save a test GUI state JSON file in the tempdir. + state = Bunch() + state.WaveformView0 = Bunch(overlap=False) + state.TraceView0 = Bunch(scaling=1.) + state.FeatureView0 = Bunch(feature_scaling=.5) + state.CorrelogramView0 = Bunch(uniform_normalization=True) + return state + + +@fixture +def gui(tempdir, state): + controller = MockController(config_dir=tempdir) + return controller.create_gui(add_default_views=False, **state) + + +def _select_clusters(gui): + gui.show() + mc = gui.controller.manual_clustering + assert mc + mc.select([]) + mc.select([0]) + mc.select([0, 2]) + + +#------------------------------------------------------------------------------ +# Test utils +#------------------------------------------------------------------------------ + +def test_extend(): + l = list(range(5)) + assert _extend(l) == l + assert _extend(l, 0) == [] + assert _extend(l, 4) == list(range(4)) + assert _extend(l, 5) == l + assert _extend(l, 6) == (l + [4]) + + +def test_extract_wave(): + traces = np.arange(30).reshape((6, 5)) + mask = np.array([0, 1, 1, .5, 0]) + wave_len = 4 + hwl = wave_len // 2 + + ae(_extract_wave(traces, 0 - hwl, mask, wave_len)[0], + [[0, 0], [0, 0], [1, 2], [6, 7]]) + + ae(_extract_wave(traces, 1 - hwl, mask, wave_len)[0], + [[0, 0], [1, 2], [6, 7], [11, 12]]) + + ae(_extract_wave(traces, 2 - hwl, mask, wave_len)[0], + [[1, 2], [6, 7], [11, 12], [16, 17]]) + + ae(_extract_wave(traces, 5 - hwl, mask, wave_len)[0], + [[16, 17], [21, 22], [0, 0], [0, 0]]) + + +#------------------------------------------------------------------------------ +# Test waveform view +#------------------------------------------------------------------------------ + +def test_waveform_view(qtbot, gui): + v = gui.controller.add_waveform_view(gui) + _select_clusters(gui) + + ac(v.boxed.box_size, (.1818, .0909), atol=1e-2) + + v.toggle_waveform_overlap() + v.toggle_waveform_overlap() + + v.toggle_zoom_on_channels() + v.toggle_zoom_on_channels() + + # Box scaling. + bs = v.boxed.box_size + v.increase() + v.decrease() + ac(v.boxed.box_size, bs) + + bs = v.boxed.box_size + v.widen() + v.narrow() + ac(v.boxed.box_size, bs) + + # Probe scaling. + bp = v.boxed.box_pos + v.extend_horizontally() + v.shrink_horizontally() + ac(v.boxed.box_pos, bp) + + bp = v.boxed.box_pos + v.extend_vertically() + v.shrink_vertically() + ac(v.boxed.box_pos, bp) + + a, b = v.probe_scaling + v.probe_scaling = (a, b * 2) + ac(v.probe_scaling, (a, b * 2)) + + a, b = v.box_scaling + v.box_scaling = (a * 2, b) + ac(v.box_scaling, (a * 2, b)) + + v.zoom_on_channels([0, 2, 4]) + + # Simulate channel selection. + _clicked = [] + + @v.gui.connect_ + def on_channel_click(channel_idx=None, button=None, key=None): + _clicked.append((channel_idx, button, key)) + + v.events.key_press(key=keys.Key('2')) + v.events.mouse_press(pos=(0., 0.), button=1) + v.events.key_release(key=keys.Key('2')) + + assert _clicked == [(0, 1, 2)] + + v.next_data() + + # qtbot.stop() + gui.close() + + +#------------------------------------------------------------------------------ +# Test trace view +#------------------------------------------------------------------------------ + +def test_trace_view(qtbot, gui): + v = gui.controller.add_trace_view(gui) + + _select_clusters(gui) + + ac(v.stacked.box_size, (1., .08181), atol=1e-3) + assert v.time == .5 + + v.go_to(.25) + assert v.time == .25 + + v.go_to(-.5) + assert v.time == .125 + + v.go_left() + assert v.time == .125 + + v.go_right() + assert v.time == .175 + + # Change interval size. + v.interval = (.25, .75) + ac(v.interval, (.25, .75)) + v.widen() + ac(v.interval, (.225, .775)) + v.narrow() + ac(v.interval, (.25, .75)) + + # Widen the max interval. + v.set_interval((0, gui.controller.duration)) + v.widen() + + # Change channel scaling. + bs = v.stacked.box_size + v.increase() + v.decrease() + ac(v.stacked.box_size, bs, atol=1e-3) + + v.origin = 'upper' + assert v.origin == 'upper' + + # qtbot.stop() + gui.close() + + +#------------------------------------------------------------------------------ +# Test feature view +#------------------------------------------------------------------------------ + +def test_feature_view(qtbot, gui): + v = gui.controller.add_feature_view(gui) + _select_clusters(gui) + + assert v.feature_scaling == .5 + v.add_attribute('sine', + np.sin(np.linspace(-10., 10., gui.controller.n_spikes))) + + v.increase() + v.decrease() + + v.on_channel_click(channel_idx=3, button=1, key=2) + v.clear_channels() + v.toggle_automatic_channel_selection() + + # qtbot.stop() + gui.close() + + +#------------------------------------------------------------------------------ +# Test scatter view +#------------------------------------------------------------------------------ + +def test_scatter_view(qtbot, gui): + n = 1000 + v = ScatterView(coords=lambda c: Bunch(x=np.random.randn(n), + y=np.random.randn(n), + spike_ids=np.arange(n), + spike_clusters=np.ones(n). + astype(np.int32) * c[0], + ) if 2 not in c else None, + data_bounds=[-3, -3, 3, 3], + ) + v.attach(gui) + + _select_clusters(gui) + + # qtbot.stop() + gui.close() + + +#------------------------------------------------------------------------------ +# Test correlogram view +#------------------------------------------------------------------------------ + +def test_correlogram_view(qtbot, gui): + v = gui.controller.add_correlogram_view(gui) + _select_clusters(gui) + + v.toggle_normalization() + + v.set_bin(1) + v.set_window(100) + + # qtbot.stop() + gui.close() diff --git a/phy/cluster/manual/tests/test_wizard.py b/phy/cluster/manual/tests/test_wizard.py deleted file mode 100644 index 4a7943637..000000000 --- a/phy/cluster/manual/tests/test_wizard.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test wizard.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import raises - -from ..wizard import (_previous, - _next, - Wizard, - ) - - -#------------------------------------------------------------------------------ -# Test wizard -#------------------------------------------------------------------------------ - -def test_utils(): - l = [2, 3, 5, 7, 11] - - def func(x): - return x in (2, 5) - - with raises(RuntimeError): - _previous(l, 1, func) - with raises(RuntimeError): - _previous(l, 15, func) - - assert _previous(l, 2, func) == 2 - assert _previous(l, 3, func) == 2 - assert _previous(l, 5, func) == 2 - assert _previous(l, 7, func) == 5 - assert _previous(l, 11, func) == 5 - - with raises(RuntimeError): - _next(l, 1, func) - with raises(RuntimeError): - _next(l, 15, func) - - assert _next(l, 2, func) == 5 - assert _next(l, 3, func) == 5 - assert _next(l, 5, func) == 5 - assert _next(l, 7, func) == 7 - assert _next(l, 11, func) == 11 - - -def test_wizard_core(): - - wizard = Wizard([2, 3, 5]) - - @wizard.set_quality_function - def quality(cluster): - return {2: .9, - 3: .3, - 5: .6, - }[cluster] - - @wizard.set_similarity_function - def similarity(cluster, other): - cluster, other = min((cluster, other)), max((cluster, other)) - return {(2, 3): 1, - (2, 5): 2, - (3, 5): 3}[cluster, other] - - assert wizard.best_clusters() == [2, 5, 3] - assert wizard.best_clusters(n_max=0) == [2, 5, 3] - assert wizard.best_clusters(n_max=None) == [2, 5, 3] - assert wizard.best_clusters(n_max=2) == [2, 5] - - assert wizard.best_clusters(n_max=1) == [2] - - assert wizard.most_similar_clusters() == [5, 3] - assert wizard.most_similar_clusters(2) == [5, 3] - - assert wizard.most_similar_clusters(n_max=0) == [5, 3] - assert wizard.most_similar_clusters(n_max=None) == [5, 3] - assert wizard.most_similar_clusters(n_max=1) == [5] - - -def test_wizard_nav(): - - groups = {2: None, 3: None, 5: 'ignored', 7: 'good'} - wizard = Wizard(groups) - - @wizard.set_quality_function - def quality(cluster): - return {2: .2, - 3: .3, - 5: .5, - 7: .7, - }[cluster] - - @wizard.set_similarity_function - def similarity(cluster, other): - return 1. + quality(cluster) - quality(other) - - # Loop over the best clusters. - wizard.start() - assert wizard.best == 3 - assert wizard.match is None - - wizard.next() - assert wizard.best == 2 - - wizard.previous() - assert wizard.best == 3 - - wizard.previous_best() - assert wizard.best == 3 - - wizard.next() - assert wizard.best == 2 - - wizard.next() - assert wizard.best == 7 - - wizard.next_best() - assert wizard.best == 5 - - wizard.next() - assert wizard.best == 5 - - # Now we start again. - wizard.start() - assert wizard.best == 3 - assert wizard.match is None - - # The match are sorted by group first (unsorted, good, and ignored), - # and similarity second. - wizard.pin() - assert wizard.best == 3 - assert wizard.match == 2 - - wizard.next() - assert wizard.best == 3 - assert wizard.match == 7 - - wizard.next_match() - assert wizard.best == 3 - assert wizard.match == 5 - - wizard.previous_match() - assert wizard.best == 3 - assert wizard.match == 7 - - wizard.previous() - assert wizard.best == 3 - assert wizard.match == 2 - - wizard.unpin() - assert wizard.best == 3 - assert wizard.match is None diff --git a/phy/cluster/manual/view_models.py b/phy/cluster/manual/view_models.py deleted file mode 100644 index 0f9554108..000000000 --- a/phy/cluster/manual/view_models.py +++ /dev/null @@ -1,1098 +0,0 @@ -# -*- coding: utf-8 -*- - -"""View model for clustered data.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from six import string_types - -from ...io.kwik.model import _DEFAULT_GROUPS -from ...utils.array import _unique, _spikes_in_clusters, _as_array -from ...utils.selector import Selector -from ...utils._misc import _show_shortcuts -from ...utils._types import _is_integer, _is_float -from ...utils._color import _selected_clusters_colors -from ...utils import _as_list -from ...stats.ccg import correlograms, _symmetrize_correlograms -from ...plot.ccg import CorrelogramView -from ...plot.features import FeatureView -from ...plot.waveforms import WaveformView -from ...plot.traces import TraceView -from ...gui.base import BaseViewModel, HTMLViewModel -from ...gui._utils import _read - - -#------------------------------------------------------------------------------ -# Misc -#------------------------------------------------------------------------------ - -def _create_view(cls, backend=None, **kwargs): - if backend in ('pyqt4', None): - kwargs.update({'always_on_top': True}) - return cls(**kwargs) - - -def _oddify(x): - return x if x % 2 == 1 else x + 1 - - -#------------------------------------------------------------------------------ -# Base view models -#------------------------------------------------------------------------------ - -class BaseClusterViewModel(BaseViewModel): - """Interface between a view and a model.""" - _view_class = None - - def __init__(self, model=None, - store=None, wizard=None, - cluster_ids=None, **kwargs): - assert store is not None - self._store = store - self._wizard = wizard - - super(BaseClusterViewModel, self).__init__(model=model, - **kwargs) - - self._cluster_ids = None - if cluster_ids is not None: - self.select(_as_list(cluster_ids)) - - @property - def store(self): - """The cluster store.""" - return self._store - - @property - def wizard(self): - """The wizard.""" - return self._wizard - - @property - def cluster_ids(self): - """Selected clusters.""" - return self._cluster_ids - - @property - def n_clusters(self): - """Number of selected clusters.""" - return len(self._cluster_ids) - - # Public methods - #-------------------------------------------------------------------------- - - def select(self, cluster_ids, **kwargs): - """Select a list of clusters.""" - cluster_ids = _as_list(cluster_ids) - self._cluster_ids = cluster_ids - self.on_select(cluster_ids, **kwargs) - - # Callback methods - #-------------------------------------------------------------------------- - - def on_select(self, cluster_ids, **kwargs): - """Update the view after a new selection has been made. - - Must be overriden.""" - - def on_cluster(self, up): - """Called when a clustering action occurs. - - May be overriden.""" - - -def _css_cluster_colors(): - colors = _selected_clusters_colors() - # HACK: this is the maximum number of clusters that can be displayed - # in an HTML view. If this number is exceeded, cluster colors will be - # wrong for the extra clusters. - n = 32 - - def _color(i): - i = i % len(colors) - c = colors[i] - c = (255 * c).astype(np.int32) - return 'rgb({}, {}, {})'.format(*c) - - return ''.join(""".cluster_{i} {{ - color: {color}; - }}\n""".format(i=i, color=_color(i)) - for i in range(n)) - - -class HTMLClusterViewModel(BaseClusterViewModel, HTMLViewModel): - """HTML view model that displays per-cluster information.""" - - def get_css(self, **kwargs): - # TODO: improve this - # Currently, child classes *must* append some CSS to this parent's - # method. - return _css_cluster_colors() - - def on_select(self, cluster_ids, **kwargs): - """Update the view after a new selection has been made.""" - self.update(cluster_ids=cluster_ids) - - def on_cluster(self, up): - """Update the view after a clustering action.""" - self.update(cluster_ids=self._cluster_ids, up=up) - - -class VispyViewModel(BaseClusterViewModel): - """Create a VisPy view from a model. - - This object uses an internal `Selector` instance to manage spike and - cluster selection. - - """ - _imported_params = ('n_spikes_max', 'excerpt_size') - keyboard_shortcuts = {} - scale_factor = 1. - - def __init__(self, **kwargs): - super(VispyViewModel, self).__init__(**kwargs) - - # Call on_close() when the view is closed. - @self._view.connect - def on_close(e): - self.on_close() - - def _create_view(self, **kwargs): - n_spikes_max = kwargs.get('n_spikes_max', None) - excerpt_size = kwargs.get('excerpt_size', None) - backend = kwargs.get('backend', None) - position = kwargs.get('position', None) - size = kwargs.get('size', None) - - # Create the spike/cluster selector. - self._selector = Selector(self._model.spike_clusters, - n_spikes_max=n_spikes_max, - excerpt_size=excerpt_size, - ) - - # Create the VisPy canvas. - view = _create_view(self._view_class, - backend=backend, - position=position or (200, 200), - size=size or (600, 600), - ) - view.connect(self.on_key_press) - return view - - @property - def selector(self): - """A Selector instance managing the selected spikes and clusters.""" - return self._selector - - @property - def cluster_ids(self): - """Selected clusters.""" - return self._selector.selected_clusters - - @property - def spike_ids(self): - """Selected spikes.""" - return self._selector.selected_spikes - - @property - def n_spikes(self): - """Number of selected spikes.""" - return self._selector.n_spikes - - def update_spike_clusters(self, spikes=None, spike_clusters=None): - """Update the spike clusters and cluster colors.""" - if spikes is None: - spikes = self.spike_ids - if spike_clusters is None: - spike_clusters = self.model.spike_clusters[spikes] - n_clusters = len(_unique(spike_clusters)) - visual = self._view.visual - # This updates the list of unique clusters in the view. - visual.spike_clusters = spike_clusters - visual.cluster_colors = _selected_clusters_colors(n_clusters) - - def select(self, cluster_ids, **kwargs): - """Select a set of clusters.""" - self._selector.selected_clusters = cluster_ids - self.on_select(cluster_ids, **kwargs) - - def on_select(self, cluster_ids, **kwargs): - """Update the view after a new selection has been made. - - Must be overriden. - - """ - self.update_spike_clusters() - self._view.update() - - def on_close(self): - """Clear the view when the model is closed.""" - self._view.visual.spike_clusters = [] - self._view.update() - - def on_key_press(self, event): - """Called when a key is pressed.""" - if event.key == 'h' and 'control' not in event.modifiers: - shortcuts = self._view.keyboard_shortcuts - shortcuts.update(self.keyboard_shortcuts) - _show_shortcuts(shortcuts, name=self.name) - - def update(self): - """Update the view.""" - self.view.update() - - -#------------------------------------------------------------------------------ -# Stats panel -#------------------------------------------------------------------------------ - -class StatsViewModel(HTMLClusterViewModel): - """Display cluster statistics.""" - - def get_html(self, cluster_ids=None, up=None): - """Return the HTML table with the cluster statistics.""" - stats = self.store.items['statistics'] - names = stats.fields - if cluster_ids is None: - return '' - # Only keep scalar stats. - _arrays = {name: isinstance(getattr(self.store, name)(cluster_ids[0]), - np.ndarray) for name in names} - names = sorted([name for name in _arrays if not _arrays[name]]) - # Add the cluster group as a first statistic. - names = ['cluster_group'] + names - # Generate the table. - html = '' - for i, cluster in enumerate(cluster_ids): - html += '{cluster}'.format( - cluster=cluster, style='cluster_{}'.format(i)) - html += '' - cluster_groups = self.model.cluster_groups - group_names = dict(_DEFAULT_GROUPS) - for name in names: - html += '' - html += '{name}'.format(name=name) - for i, cluster in enumerate(cluster_ids): - if name == 'cluster_group': - # Get the cluster group. - group_id = cluster_groups.get(cluster, -1) - value = group_names.get(group_id, 'unknown') - else: - value = getattr(self.store, name)(cluster) - if _is_float(value): - value = '{:.3f}'.format(value) - elif _is_integer(value): - value = '{:d}'.format(value) - else: - value = str(value) - html += '{value}'.format( - value=value, style='cluster_{}'.format(i)) - html += '' - return '
\n' + html + '
' - - def get_css(self, cluster_ids=None, up=None): - css = super(StatsViewModel, self).get_css(cluster_ids=cluster_ids, - up=up) - static_path = op.join(op.dirname(op.realpath(__file__)), - 'static/') - css += _read('styles.css', static_path=static_path) - return css - - -#------------------------------------------------------------------------------ -# Kwik view models -#------------------------------------------------------------------------------ - -class WaveformViewModel(VispyViewModel): - """Waveforms.""" - _view_class = WaveformView - _view_name = 'waveforms' - _imported_params = ('scale_factor', 'box_scale', 'probe_scale', - 'overlap', 'show_mean') - - def on_open(self): - """Initialize the view when the model is opened.""" - super(WaveformViewModel, self).on_open() - # Waveforms. - self.view.visual.channel_positions = self.model.probe.positions - self.view.visual.channel_order = self.model.channel_order - # Mean waveforms. - self.view.mean.channel_positions = self.model.probe.positions - self.view.mean.channel_order = self.model.channel_order - if self.scale_factor is None: - self.scale_factor = 1. - - def _load_waveforms(self): - # NOTE: we load all spikes from the store. - # The waveforms store item is responsible for making a subselection - # of spikes both on disk and in the view. - waveforms = self.store.load('waveforms', - clusters=self.cluster_ids, - ) - return waveforms - - def _load_mean_waveforms(self): - mean_waveforms = self.store.load('mean_waveforms', - clusters=self.cluster_ids, - ) - mean_masks = self.store.load('mean_masks', - clusters=self.cluster_ids, - ) - return mean_waveforms, mean_masks - - def update_spike_clusters(self, spikes=None): - """Update the view's spike clusters.""" - super(WaveformViewModel, self).update_spike_clusters(spikes=spikes) - self._view.mean.spike_clusters = np.sort(self.cluster_ids) - self._view.mean.cluster_colors = self._view.visual.cluster_colors - - def on_select(self, clusters, **kwargs): - """Update the view when the selection changes.""" - # Get the spikes of the stored waveforms. - n_clusters = len(clusters) - waveforms = self._load_waveforms() - spikes = self.store.items['waveforms'].spikes_in_clusters(clusters) - n_spikes = len(spikes) - _, self._n_samples, self._n_channels = waveforms.shape - mean_waveforms, mean_masks = self._load_mean_waveforms() - - self.update_spike_clusters(spikes) - - # Cluster display order. - self.view.visual.cluster_order = clusters - self.view.mean.cluster_order = clusters - - # Waveforms. - assert waveforms.shape[0] == n_spikes - self.view.visual.waveforms = waveforms * self.scale_factor - - assert mean_waveforms.shape == (n_clusters, - self._n_samples, - self._n_channels) - self.view.mean.waveforms = mean_waveforms * self.scale_factor - - # Masks. - masks = self.store.load('masks', clusters=clusters, spikes=spikes) - assert masks.shape == (n_spikes, self._n_channels) - self.view.visual.masks = masks - - assert mean_masks.shape == (n_clusters, self._n_channels) - self.view.mean.masks = mean_masks - - # Spikes. - self.view.visual.spike_ids = spikes - self.view.mean.spike_ids = np.arange(len(clusters)) - - self.view.update() - - def on_close(self): - """Clear the view when the model is closed.""" - self.view.visual.channel_positions = [] - self.view.mean.channel_positions = [] - super(WaveformViewModel, self).on_close() - - @property - def box_scale(self): - """Scale of the waveforms. - - This is a pair of scalars. - - """ - return self.view.box_scale - - @box_scale.setter - def box_scale(self, value): - self.view.box_scale = value - - @property - def probe_scale(self): - """Scale of the probe. - - This is a pair of scalars. - - """ - return self.view.probe_scale - - @probe_scale.setter - def probe_scale(self, value): - self.view.probe_scale = value - - @property - def overlap(self): - """Whether to overlap waveforms.""" - return self.view.overlap - - @overlap.setter - def overlap(self, value): - self.view.overlap = value - - @property - def show_mean(self): - """Whether to show mean waveforms.""" - return self.view.show_mean - - @show_mean.setter - def show_mean(self, value): - self.view.show_mean = value - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(WaveformViewModel, self).exported_params(save_size_pos) - params.update({ - 'scale_factor': self.scale_factor, - 'box_scale': self.view.box_scale, - 'probe_scale': self.view.probe_scale, - 'overlap': self.view.overlap, - 'show_mean': self.view.show_mean, - }) - return params - - -class CorrelogramViewModel(VispyViewModel): - """Correlograms.""" - _view_class = CorrelogramView - _view_name = 'correlograms' - binsize = 20 # in number of samples - winsize_bins = 41 # in number of bins - _imported_params = ('binsize', 'winsize_bins', 'lines', 'normalization') - _normalization = 'equal' # or 'independent' - _ccgs = None - - def change_bins(self, bin=None, half_width=None): - """Change the parameters of the correlograms. - - Parameters - ---------- - bin : float (ms) - Bin size. - half_width : float (ms) - Half window size. - - """ - sr = float(self.model.sample_rate) - - # Default values. - if bin is None: - bin = 1000. * self.binsize / sr # in ms - if half_width is None: - half_width = 1000. * (float((self.winsize_bins // 2) * - self.binsize / sr)) - - bin = np.clip(bin, .1, 1e3) # in ms - self.binsize = int(sr * bin * .001) # in s - - half_width = np.clip(half_width, .1, 1e3) # in ms - self.winsize_bins = 2 * int(half_width / bin) + 1 - - self.select(self.cluster_ids) - - def on_select(self, clusters, **kwargs): - """Update the view when the selection changes.""" - super(CorrelogramViewModel, self).on_select(clusters) - spikes = self.spike_ids - self.view.cluster_ids = clusters - - # Compute the correlograms. - spike_samples = self.model.spike_samples[spikes] - spike_clusters = self.view.visual.spike_clusters - - ccgs = correlograms(spike_samples, - spike_clusters, - cluster_order=clusters, - binsize=self.binsize, - # NOTE: this must be an odd number, for symmetry - winsize_bins=_oddify(self.winsize_bins), - ) - self._ccgs = _symmetrize_correlograms(ccgs) - # Normalize the CCGs. - self.view.correlograms = self._normalize(self._ccgs) - - # Take the cluster order into account. - self.view.visual.cluster_order = clusters - self.view.update() - - def _normalize(self, ccgs): - if not len(ccgs): - return ccgs - if self._normalization == 'equal': - return ccgs * (1. / max(1., ccgs.max())) - elif self._normalization == 'independent': - return ccgs * (1. / np.maximum(1., ccgs.max(axis=2)[:, :, None])) - - @property - def normalization(self): - """Correlogram normalization: `equal` or `independent`.""" - return self._normalization - - @normalization.setter - def normalization(self, value): - self._normalization = value - if self._ccgs is not None: - self.view.visual.correlograms = self._normalize(self._ccgs) - self.view.update() - - @property - def lines(self): - return self.view.lines - - @lines.setter - def lines(self, value): - self.view.lines = value - - def toggle_normalization(self): - """Change the correlogram normalization.""" - self.normalization = ('equal' if self._normalization == 'independent' - else 'independent') - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(CorrelogramViewModel, self).exported_params( - save_size_pos) - params.update({ - 'normalization': self.normalization, - }) - return params - - -class TraceViewModel(VispyViewModel): - """Traces.""" - _view_class = TraceView - _view_name = 'traces' - _imported_params = ('scale_factor', 'channel_scale', 'interval_size') - interval_size = .25 - - def __init__(self, **kwargs): - self._interval = None - super(TraceViewModel, self).__init__(**kwargs) - - def _load_traces(self, interval): - start, end = interval - spikes = self.spike_ids - - # Load the traces. - # debug("Loading traces...") - # Using channel_order ensures that we get rid of the dead channels. - # We also keep the channel order as specified by the PRM file. - # WARNING: HDF5 does not support out-of-order indexing (...!!) - traces = self.model.traces[start:end, :][:, self.model.channel_order] - - # Normalize and set the traces. - traces_f = np.empty_like(traces, dtype=np.float32) - traces_f[...] = traces * self.scale_factor - # Detrend the traces. - m = np.mean(traces_f[::10, :], axis=0) - traces_f -= m - self.view.visual.traces = traces_f - - # Keep the spikes in the interval. - spike_samples = self.model.spike_samples[spikes] - a, b = spike_samples.searchsorted(interval) - spikes = spikes[a:b] - self.view.visual.n_spikes = len(spikes) - self.view.visual.spike_ids = spikes - - if len(spikes) == 0: - return - - # We update the spike clusters according to the subselection of spikes. - # We don't update the list of unique clusters, which only change - # when selecting or clustering, not when changing the interval. - # self.update_spike_clusters(spikes) - self.view.visual.spike_clusters = self.model.spike_clusters[spikes] - - # Set the spike samples. - spike_samples = self.model.spike_samples[spikes] - # This is in unit of samples relative to the start of the interval. - spike_samples = spike_samples - start - self.view.visual.spike_samples = spike_samples - self.view.visual.offset = start - - # Load the masks. - # TODO: ensure model.masks is always 2D, even with 1 spike - masks = np.atleast_2d(self._model.masks[spikes]) - self.view.visual.masks = masks - - @property - def interval(self): - """The interval of the view, in unit of sample.""" - return self._interval - - @interval.setter - def interval(self, value): - if self.model.traces is None: - return - if not isinstance(value, tuple) or len(value) != 2: - raise ValueError("The interval should be a (start, end) tuple.") - # Restrict the interval to the boundaries of the traces. - start, end = value - start, end = int(start), int(end) - n = self.model.traces.shape[0] - if start < 0: - end += (-start) - start = 0 - elif end >= n: - start -= (end - n) - end = n - start = np.clip(start, 0, end) - end = np.clip(end, start, n) - assert 0 <= start < end <= n - self._interval = (start, end) - self._load_traces((start, end)) - self.view.update() - - @property - def channel_scale(self): - """Vertical scale of the traces.""" - return self.view.channel_scale - - @channel_scale.setter - def channel_scale(self, value): - self.view.channel_scale = value - - def move(self, amount): - """Move the current interval by a given amount (in samples).""" - amount = int(amount) - start, end = self.interval - self.interval = start + amount, end + amount - - def move_right(self, fraction=.05): - """Move the current interval to the right.""" - start, end = self.interval - self.move(int(+(end - start) * fraction)) - - def move_left(self, fraction=.05): - """Move the current interval to the left.""" - start, end = self.interval - self.move(int(-(end - start) * fraction)) - - keyboard_shortcuts = { - 'scroll_left': 'ctrl+left', - 'scroll_right': 'ctrl+right', - 'fast_scroll_left': 'shift+left', - 'fast_scroll_right': 'shift+right', - } - - def on_key_press(self, event): - """Called when a key is pressed.""" - super(TraceViewModel, self).on_key_press(event) - key = event.key - if 'Control' in event.modifiers: - if key == 'Left': - self.move_left() - elif key == 'Right': - self.move_right() - if 'Shift' in event.modifiers: - if key == 'Left': - self.move_left(1) - elif key == 'Right': - self.move_right(1) - - def on_open(self): - """Initialize the view when the model is opened.""" - super(TraceViewModel, self).on_open() - self.view.visual.n_samples_per_spike = self.model.n_samples_waveforms - self.view.visual.sample_rate = self.model.sample_rate - if self.scale_factor is None: - self.scale_factor = 1. - if self.interval_size is None: - self.interval_size = .25 - self.select([]) - - def on_select(self, clusters, **kwargs): - """Update the view when the selection changes.""" - # Get the spikes in the selected clusters. - spikes = self.spike_ids - n_clusters = len(clusters) - spike_clusters = self.model.spike_clusters[spikes] - - # Update the clusters of the trace view. - visual = self._view.visual - visual.spike_clusters = spike_clusters - visual.cluster_ids = clusters - visual.cluster_order = clusters - visual.cluster_colors = _selected_clusters_colors(n_clusters) - - # Select the default interval. - half_size = int(self.interval_size * self.model.sample_rate / 2.) - if len(spikes) > 0: - # Center the default interval around the first spike. - sample = self._model.spike_samples[spikes[0]] - else: - sample = half_size - # Load traces by setting the interval. - visual._update_clusters_automatically = False - self.interval = sample - half_size, sample + half_size - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(TraceViewModel, self).exported_params(save_size_pos) - params.update({ - 'scale_factor': self.scale_factor, - 'channel_scale': self.channel_scale, - }) - return params - - -#------------------------------------------------------------------------------ -# Feature view models -#------------------------------------------------------------------------------ - -def _best_channels(cluster, model=None, store=None): - """Return the channels with the largest mean features.""" - n_fet = model.n_features_per_channel - score = store.mean_features(cluster) - score = score.reshape((-1, n_fet)).mean(axis=1) - assert len(score) == len(model.channel_order) - channels = np.argsort(score)[::-1] - return channels - - -def _dimensions(x_channels, y_channels): - """Default dimensions matrix.""" - # time, depth time, (x, 0) time, (y, 0) time, (z, 0) - # time, (x', 0) (x', 0), (x, 0) (x', 1), (y, 0) (x', 2), (z, 0) - # time, (y', 0) (y', 0), (x, 1) (y', 1), (y, 1) (y', 2), (z, 1) - # time, (z', 0) (z', 0), (x, 2) (z', 1), (y, 2) (z', 2), (z, 2) - - n = len(x_channels) - assert len(y_channels) == n - y_dim = {} - x_dim = {} - # TODO: depth - x_dim[0, 0] = 'time' - y_dim[0, 0] = 'time' - - # Time in first column and first row. - for i in range(1, n + 1): - x_dim[0, i] = 'time' - y_dim[0, i] = (x_channels[i - 1], 0) - x_dim[i, 0] = 'time' - y_dim[i, 0] = (y_channels[i - 1], 0) - - for i in range(1, n + 1): - for j in range(1, n + 1): - x_dim[i, j] = (x_channels[i - 1], j - 1) - y_dim[i, j] = (y_channels[j - 1], i - 1) - - return x_dim, y_dim - - -class BaseFeatureViewModel(VispyViewModel): - """Features.""" - _view_class = FeatureView - _view_name = 'base_features' - _imported_params = ('scale_factor', 'n_spikes_max_bg', 'marker_size') - n_spikes_max_bg = 10000 - - def __init__(self, *args, **kwargs): - self._extra_features = {} - super(BaseFeatureViewModel, self).__init__(*args, **kwargs) - - def _rescale_features(self, features): - # WARNING: convert features to a 3D array - # (n_spikes, n_channels, n_features) - # because that's what the FeatureView expects currently. - n_fet = self.model.n_features_per_channel - n_channels = len(self.model.channel_order) - shape = (-1, n_channels, n_fet) - features = features[:, :n_fet * n_channels].reshape(shape) - # Scale factor. - return features * self.scale_factor - - @property - def lasso(self): - """The spike lasso visual.""" - return self.view.lasso - - def spikes_in_lasso(self): - """Return the spike ids from the selected clusters within the lasso.""" - if not len(self.cluster_ids) or self.view.lasso.n_points <= 2: - return - clusters = self.cluster_ids - # Load *all* features from the selected clusters. - features = self.store.load('features', clusters=clusters) - # Find the corresponding spike_ids. - spike_ids = _spikes_in_clusters(self.model.spike_clusters, clusters) - assert features.shape[0] == len(spike_ids) - # Extract the extra features for all the spikes in the clusters. - extra_features = self._subset_extra_features(spike_ids) - # Rescale the features (and not the extra features!). - features = features * self.scale_factor - # Extract the two relevant dimensions. - points = self.view.visual.project(self.view.lasso.box, - features=features, - extra_features=extra_features, - ) - # Find the points within the lasso. - in_lasso = self.view.lasso.in_lasso(points) - return spike_ids[in_lasso] - - @property - def marker_size(self): - """Marker size, in pixels.""" - return self.view.marker_size - - @marker_size.setter - def marker_size(self, value): - self.view.marker_size = value - - @property - def n_features(self): - """Number of features.""" - return self.view.background.n_features - - @property - def n_rows(self): - """Number of rows in the view. - - To be overriden. - - """ - return 1 - - def dimensions_for_clusters(self, cluster_ids): - """Return the x and y dimensions most appropriate for the set of - selected clusters. - - To be overriden. - - TODO: make this customizable. - - """ - return {}, {} - - def set_dimension(self, axis, box, dim, smart=True): - """Set a dimension. - - Parameters - ---------- - - axis : str - `'x'` or `'y'` - box : tuple - A `(i, j)` pair. - dim : str or tuple - A feature name, or a tuple `(channel_id, feature_id)`. - smart : bool - Whether to ensure the two axes in the subplot are different, - to avoid a useless `x=y` diagonal situation. - - """ - if smart: - dim = self.view.smart_dimension(axis, box, dim) - self.view.set_dimensions(axis, {box: dim}) - - def add_extra_feature(self, name, array): - """Add an extra feature. - - Parameters - ---------- - - name : str - The feature's name. - array : ndarray - A `(n_spikes,)` array with the feature's value for every spike. - - """ - assert isinstance(name, string_types) - array = _as_array(array) - n_spikes = self.model.n_spikes - if array.shape != (n_spikes,): - raise ValueError("The extra feature needs to be a 1D vector with " - "`n_spikes={}` values.".format(n_spikes)) - self._extra_features[name] = (array, array.min(), array.max()) - - def _subset_extra_features(self, spikes): - return {name: (array[spikes], m, M) - for name, (array, m, M) in self._extra_features.items()} - - def _add_extra_features_in_view(self, spikes): - """Add the extra features in the view, by selecting only the - displayed spikes.""" - subset_extra = self._subset_extra_features(spikes) - for name, (sub_array, m, M) in subset_extra.items(): - # Make the extraction for the background spikes. - array, _, _ = self._extra_features[name] - sub_array_bg = array[self.view.background.spike_ids] - self.view.add_extra_feature(name, sub_array, m, M, - array_bg=sub_array_bg) - - @property - def x_dim(self): - """List of x dimensions in every row/column.""" - return self.view.x_dim - - @property - def y_dim(self): - """List of y dimensions in every row/column.""" - return self.view.y_dim - - def on_open(self): - """Initialize the view when the model is opened.""" - # Get background features. - # TODO OPTIM: precompute this once for all and store in the cluster - # store. But might be unnecessary. - if self.n_spikes_max_bg is not None: - k = max(1, self.model.n_spikes // self.n_spikes_max_bg) - else: - k = 1 - if self.model.features is not None: - # Background features. - features_bg = self.store.load('features', - spikes=slice(None, None, k)) - self.view.background.features = self._rescale_features(features_bg) - self.view.background.spike_ids = self.model.spike_ids[::k] - # Register the time dimension. - self.add_extra_feature('time', self.model.spike_samples) - # Add the subset extra features to the visuals. - self._add_extra_features_in_view(slice(None, None, k)) - # Number of rows: number of features + 1 for - self.view.init_grid(self.n_rows) - - def on_select(self, clusters, auto_update=True): - """Update the view when the selection changes.""" - super(BaseFeatureViewModel, self).on_select(clusters) - spikes = self.spike_ids - - features = self.store.load('features', - clusters=clusters, - spikes=spikes) - masks = self.store.load('masks', - clusters=clusters, - spikes=spikes) - - nc = len(self.model.channel_order) - nf = self.model.n_features_per_channel - features = features.reshape((len(spikes), nc, nf)) - self.view.visual.features = self._rescale_features(features) - self.view.visual.masks = masks - - # Spikes. - self.view.visual.spike_ids = spikes - - # Extra features (including time). - self._add_extra_features_in_view(spikes) - - # Cluster display order. - self.view.visual.cluster_order = clusters - - # Set default dimensions. - if auto_update: - x_dim, y_dim = self.dimensions_for_clusters(clusters) - self.view.set_dimensions('x', x_dim) - self.view.set_dimensions('y', y_dim) - - keyboard_shortcuts = { - 'increase_scale': 'ctrl+', - 'decrease_scale': 'ctrl-', - } - - def on_key_press(self, event): - """Handle key press events.""" - super(BaseFeatureViewModel, self).on_key_press(event) - key = event.key - ctrl = 'Control' in event.modifiers - if ctrl and key in ('+', '-'): - k = 1.1 if key == '+' else .9 - self.scale_factor *= k - self.view.visual.features *= k - self.view.background.features *= k - self.view.update() - - def exported_params(self, save_size_pos=True): - """Parameters to save automatically when the view is closed.""" - params = super(BaseFeatureViewModel, - self).exported_params(save_size_pos) - params.update({ - 'scale_factor': self.scale_factor, - 'marker_size': self.marker_size, - }) - return params - - -class FeatureGridViewModel(BaseFeatureViewModel): - """Features grid""" - _view_name = 'features_grid' - - keyboard_shortcuts = { - 'enlarge_subplot': 'ctrl+click', - 'increase_scale': 'ctrl+', - 'decrease_scale': 'ctrl-', - } - - @property - def n_rows(self): - """Number of rows in the grid view.""" - return self.n_features + 1 - - def dimensions_for_clusters(self, cluster_ids): - """Return the x and y dimensions most appropriate for the set of - selected clusters. - - TODO: make this customizable. - - """ - n = len(cluster_ids) - if not n: - return {}, {} - x_channels = _best_channels(cluster_ids[min(1, n - 1)], - model=self.model, - store=self.store, - ) - y_channels = _best_channels(cluster_ids[0], - model=self.model, - store=self.store, - ) - y_channels = y_channels[:self.n_rows - 1] - # For the x axis, remove the channels that already are in - # the y axis. - x_channels = [c for c in x_channels if c not in y_channels] - # Now, select the right number of channels in the x axis. - x_channels = x_channels[:self.n_rows - 1] - return _dimensions(x_channels, y_channels) - - -class FeatureViewModel(BaseFeatureViewModel): - """Feature view with a single subplot.""" - _view_name = 'features' - _x_dim = 'time' - _y_dim = (0, 0) - - keyboard_shortcuts = { - 'increase_scale': 'ctrl+', - 'decrease_scale': 'ctrl-', - } - - @property - def n_rows(self): - """Number of rows.""" - return 1 - - @property - def x_dim(self): - """x dimension.""" - return self._x_dim - - @property - def y_dim(self): - """y dimension.""" - return self._y_dim - - def set_dimension(self, axis, dim, smart=True): - """Set a (smart) dimension. - - "smart" means that the dimension may be changed if it is the same - than the other dimension, to avoid x=y. - - """ - super(FeatureViewModel, self).set_dimension(axis, (0, 0), dim, - smart=smart) - - def dimensions_for_clusters(self, cluster_ids): - """Current dimensions.""" - return self._x_dim, self._y_dim diff --git a/phy/cluster/manual/views.py b/phy/cluster/manual/views.py new file mode 100644 index 000000000..dfa59efde --- /dev/null +++ b/phy/cluster/manual/views.py @@ -0,0 +1,1445 @@ +# -*- coding: utf-8 -*- + +"""Manual clustering views.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +import inspect +from itertools import product +import logging +import re + +import numpy as np +from vispy.util.event import Event + +from phy.io.array import _index_of, _get_padded, get_excerpts +from phy.gui import Actions +from phy.plot import View, _get_linear_x +from phy.plot.transform import Range +from phy.plot.utils import _get_boxes +from phy.stats import correlograms +from phy.utils import Bunch +from phy.utils._color import _spike_colors, ColorSelector + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Utils +# ----------------------------------------------------------------------------- + +def _extract_wave(traces, start, mask, wave_len=None, mask_threshold=.5): + n_samples, n_channels = traces.shape + assert mask.shape == (n_channels,) + channels = np.nonzero(mask > mask_threshold)[0] + # There should be at least one non-masked channel. + if not len(channels): + return # pragma: no cover + i, j = start, start + wave_len + a, b = max(0, i), min(j, n_samples - 1) + data = traces[a:b, channels] + data = _get_padded(data, i - a, i - a + wave_len) + assert data.shape == (wave_len, len(channels)) + return data, channels + + +def _get_depth(masks, spike_clusters_rel=None, n_clusters=None): + """Return the OpenGL z-depth of vertices as a function of the + mask and cluster index.""" + n_spikes = len(masks) + assert masks.shape == (n_spikes,) + # Fixed depth for background spikes. + if spike_clusters_rel is None: + depth = .5 * np.ones(n_spikes) + else: + depth = (-0.1 - (spike_clusters_rel + masks) / + float(n_clusters + 10.)) + depth[masks <= 0.25] = 0 + assert depth.shape == (n_spikes,) + return depth + + +def _extend(channels, n=None): + channels = list(channels) + if n is None: + return channels + if not len(channels): # pragma: no cover + channels = [0] + if len(channels) < n: + channels.extend([channels[-1]] * (n - len(channels))) + channels = channels[:n] + assert len(channels) == n + return channels + + +# ----------------------------------------------------------------------------- +# Manual clustering view +# ----------------------------------------------------------------------------- + +class StatusEvent(Event): + def __init__(self, type, message=None): + super(StatusEvent, self).__init__(type) + self.message = message + + +class ManualClusteringView(View): + """Base class for clustering views. + + The views take their data with functions `cluster_ids: spike_ids, data`. + + """ + default_shortcuts = { + } + + def __init__(self, shortcuts=None, **kwargs): + + # Load default shortcuts, and override with any user shortcuts. + self.shortcuts = self.default_shortcuts.copy() + self.shortcuts.update(shortcuts or {}) + + # Message to show in the status bar. + self.status = None + + # Attached GUI. + self.gui = None + + # Keep track of the selected clusters and spikes. + self.cluster_ids = None + + super(ManualClusteringView, self).__init__(**kwargs) + self.events.add(status=StatusEvent) + + def on_select(self, cluster_ids=None): + cluster_ids = (cluster_ids if cluster_ids is not None + else self.cluster_ids) + self.cluster_ids = list(cluster_ids) if cluster_ids is not None else [] + self.cluster_ids = [int(c) for c in self.cluster_ids] + + def attach(self, gui): + """Attach the view to the GUI.""" + + # Disable keyboard pan so that we can use arrows as global shortcuts + # in the GUI. + self.panzoom.enable_keyboard_pan = False + + gui.add_view(self) + self.gui = gui + + # Set the view state. + self.set_state(gui.state.get_view_state(self)) + + gui.connect_(self.on_select) + self.actions = Actions(gui, + name=self.__class__.__name__, + menu=self.__class__.__name__, + default_shortcuts=self.shortcuts) + + # Update the GUI status message when the `self.set_status()` method + # is called, i.e. when the `status` event is raised by the VisPy + # view. + @self.connect + def on_status(e): + gui.status_message = e.message + + # Save the view state in the GUI state. + @gui.connect_ + def on_close(): + gui.state.update_view_state(self, self.state) + # NOTE: create_gui() already saves the state, but the event + # is registered *before* we add all views. + gui.state.save() + + self.show() + + @property + def state(self): + """View state. + + This Bunch will be automatically persisted in the GUI state when the + GUI is closed. + + To be overriden. + + """ + return Bunch() + + def set_state(self, state): + """Set the view state. + + The passed object is the persisted `self.state` bunch. + + May be overriden. + + """ + for k, v in state.items(): + setattr(self, k, v) + + def set_status(self, message=None): + message = message or self.status + if not message: + return + self.status = message + self.events.status(message=message) + + def on_mouse_move(self, e): # pragma: no cover + self.set_status() + + +# ----------------------------------------------------------------------------- +# Waveform view +# ----------------------------------------------------------------------------- + +class ChannelClick(Event): + def __init__(self, type, channel_idx=None, key=None, button=None): + super(ChannelClick, self).__init__(type) + self.channel_idx = channel_idx + self.key = key + self.button = button + + +class WaveformView(ManualClusteringView): + scaling_coeff = 1.1 + + default_shortcuts = { + 'toggle_waveform_overlap': 'o', + 'toggle_zoom_on_channels': 'z', + 'next_data': 'w', + + # Box scaling. + 'widen': 'ctrl+right', + 'narrow': 'ctrl+left', + 'increase': 'ctrl+up', + 'decrease': 'ctrl+down', + + # Probe scaling. + 'extend_horizontally': 'shift+right', + 'shrink_horizontally': 'shift+left', + 'extend_vertically': 'shift+up', + 'shrink_vertically': 'shift+down', + } + + def __init__(self, + waveforms=None, + channel_positions=None, + waveform_lims=None, + best_channels=None, + **kwargs): + self._key_pressed = None + self._overlap = False + self.do_zoom_on_channels = True + self.data_index = 0 + + self.best_channels = best_channels or (lambda clusters: []) + + # Channel positions and n_channels. + assert channel_positions is not None + self.channel_positions = np.asarray(channel_positions) + self.n_channels = self.channel_positions.shape[0] + + # Initialize the view. + box_bounds = _get_boxes(channel_positions) + super(WaveformView, self).__init__(layout='boxed', + box_bounds=box_bounds, + **kwargs) + + self.events.add(channel_click=ChannelClick) + + # Box and probe scaling. + self._box_scaling = np.ones(2) + self._probe_scaling = np.ones(2) + + # Make a copy of the initial box pos and size. We'll apply the scaling + # to these quantities. + self.box_pos = np.array(self.boxed.box_pos) + self.box_size = np.array(self.boxed.box_size) + self._update_boxes() + + # Data: functions cluster_id => waveforms. + self.waveforms = waveforms + + # Waveform normalization. + assert len(waveform_lims) == 2 + self.data_bounds = [-1, waveform_lims[0], +1, waveform_lims[1]] + + # Channel positions. + assert channel_positions.shape == (self.n_channels, 2) + self.channel_positions = channel_positions + + def on_select(self, cluster_ids=None): + super(WaveformView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids + n_clusters = len(cluster_ids) + if n_clusters == 0: + return + + # Load the waveform subset. + data = self.waveforms(cluster_ids) + # Take one element in the list. + data = data[self.data_index % len(data)] + alpha = data.get('alpha', .5) + spike_ids = data.spike_ids + spike_clusters = data.spike_clusters + w = data.data + masks = data.masks + n_spikes = len(spike_ids) + assert w.ndim == 3 + n_samples = w.shape[1] + assert w.shape == (n_spikes, n_samples, self.n_channels) + assert masks.shape == (n_spikes, self.n_channels) + + # Relative spike clusters. + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + assert spike_clusters_rel.shape == (n_spikes,) + + # Fetch the waveforms. + t = _get_linear_x(n_spikes, n_samples) + # Overlap. + if not self.overlap: + t = t + 2.5 * (spike_clusters_rel[:, np.newaxis] - + (n_clusters - 1) / 2.) + # The total width should not depend on the number of clusters. + t /= n_clusters + + # Plot all waveforms. + # OPTIM: avoid the loop. + with self.building(): + for ch in range(self.n_channels): + m = masks[:, ch] + depth = _get_depth(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + color = _spike_colors(spike_clusters_rel, + masks=m, + alpha=alpha, + ) + self[ch].plot(x=t, y=w[:, :, ch], + color=color, + depth=depth, + data_bounds=self.data_bounds, + ) + # Add channel labels. + self[ch].text(pos=[[t[0, 0], 0.]], + text=str(ch), + anchor=[-1.01, -.25], + data_bounds=self.data_bounds, + ) + + # Zoom on the best channels when selecting clusters. + channels = self.best_channels(cluster_ids) + if channels is not None and self.do_zoom_on_channels: + self.zoom_on_channels(channels) + + @property + def state(self): + return Bunch(box_scaling=tuple(self.box_scaling), + probe_scaling=tuple(self.probe_scaling), + overlap=self.overlap, + do_zoom_on_channels=self.do_zoom_on_channels, + ) + + def attach(self, gui): + """Attach the view to the GUI.""" + super(WaveformView, self).attach(gui) + self.actions.add(self.toggle_waveform_overlap) + self.actions.add(self.toggle_zoom_on_channels) + + # Box scaling. + self.actions.add(self.widen) + self.actions.add(self.narrow) + self.actions.add(self.increase) + self.actions.add(self.decrease) + + # Probe scaling. + self.actions.add(self.extend_horizontally) + self.actions.add(self.shrink_horizontally) + self.actions.add(self.extend_vertically) + self.actions.add(self.shrink_vertically) + + self.actions.add(self.next_data) + + # We forward the event from VisPy to the phy GUI. + @self.connect + def on_channel_click(e): + gui.emit('channel_click', + channel_idx=e.channel_idx, + key=e.key, + button=e.button, + ) + + # Overlap + # ------------------------------------------------------------------------- + + @property + def overlap(self): + return self._overlap + + @overlap.setter + def overlap(self, value): + self._overlap = value + # HACK: temporarily disable automatic zoom on channels when + # changing the overlap. + tmp = self.do_zoom_on_channels + self.do_zoom_on_channels = False + self.on_select() + self.do_zoom_on_channels = tmp + + def toggle_waveform_overlap(self): + """Toggle the overlap of the waveforms.""" + self.overlap = not self.overlap + + # Box scaling + # ------------------------------------------------------------------------- + + def _update_boxes(self): + self.boxed.update_boxes(self.box_pos * self.probe_scaling, + self.box_size * self.box_scaling) + + @property + def box_scaling(self): + return self._box_scaling + + @box_scaling.setter + def box_scaling(self, value): + assert len(value) == 2 + self._box_scaling = np.array(value) + self._update_boxes() + + def widen(self): + """Increase the horizontal scaling of the waveforms.""" + self._box_scaling[0] *= self.scaling_coeff + self._update_boxes() + + def narrow(self): + """Decrease the horizontal scaling of the waveforms.""" + self._box_scaling[0] /= self.scaling_coeff + self._update_boxes() + + def increase(self): + """Increase the vertical scaling of the waveforms.""" + self._box_scaling[1] *= self.scaling_coeff + self._update_boxes() + + def decrease(self): + """Decrease the vertical scaling of the waveforms.""" + self._box_scaling[1] /= self.scaling_coeff + self._update_boxes() + + # Probe scaling + # ------------------------------------------------------------------------- + + @property + def probe_scaling(self): + return self._probe_scaling + + @probe_scaling.setter + def probe_scaling(self, value): + assert len(value) == 2 + self._probe_scaling = np.array(value) + self._update_boxes() + + def extend_horizontally(self): + """Increase the horizontal scaling of the probe.""" + self._probe_scaling[0] *= self.scaling_coeff + self._update_boxes() + + def shrink_horizontally(self): + """Decrease the horizontal scaling of the waveforms.""" + self._probe_scaling[0] /= self.scaling_coeff + self._update_boxes() + + def extend_vertically(self): + """Increase the vertical scaling of the waveforms.""" + self._probe_scaling[1] *= self.scaling_coeff + self._update_boxes() + + def shrink_vertically(self): + """Decrease the vertical scaling of the waveforms.""" + self._probe_scaling[1] /= self.scaling_coeff + self._update_boxes() + + # Navigation + # ------------------------------------------------------------------------- + + def next_data(self): + """Show the next set of waveforms (if any).""" + self.data_index += 1 + self.on_select() + + def toggle_zoom_on_channels(self): + self.do_zoom_on_channels = not self.do_zoom_on_channels + + def zoom_on_channels(self, channels_rel): + """Zoom on some channels.""" + if channels_rel is None or not len(channels_rel): + return + channels_rel = np.asarray(channels_rel, dtype=np.int32) + assert 0 <= channels_rel.min() <= channels_rel.max() < self.n_channels + # Bounds of the channels. + b = self.boxed.box_bounds[channels_rel] + x0, y0 = b[:, :2].min(axis=0) + x1, y1 = b[:, 2:].max(axis=0) + self.panzoom.set_range((x0, y0, x1, y1), keep_aspect=True) + + def on_key_press(self, event): + """Handle key press events.""" + key = event.key + self._key_pressed = key + + def on_mouse_press(self, e): + key = self._key_pressed + if 'Control' in e.modifiers or key in map(str, range(10)): + key = int(key.name) if key in map(str, range(10)) else None + # Get mouse position in NDC. + mouse_pos = self.panzoom.get_mouse_pos(e.pos) + channel_idx = self.boxed.get_closest_box(mouse_pos) + self.events.channel_click(channel_idx=channel_idx, + key=key, + button=e.button) + + def on_key_release(self, event): + self._key_pressed = None + + +# ----------------------------------------------------------------------------- +# Trace view +# ----------------------------------------------------------------------------- + +def select_traces(traces, interval, sample_rate=None): + """Load traces in an interval (in seconds).""" + start, end = interval + i, j = round(sample_rate * start), round(sample_rate * end) + i, j = int(i), int(j) + traces = traces[i:j, :] + traces = traces - np.mean(traces, axis=0) + return traces + + +def extract_spikes(traces, interval, sample_rate=None, + spike_times=None, spike_clusters=None, + all_masks=None, + n_samples_waveforms=None): + sr = sample_rate + ns = n_samples_waveforms + if not isinstance(ns, tuple): + ns = (ns // 2, ns // 2) + offset_samples = ns[0] + wave_len = ns[0] + ns[1] + + # Find spikes. + a, b = spike_times.searchsorted(interval) + st = spike_times[a:b] + sc = spike_clusters[a:b] + m = all_masks[a:b] + n = len(st) + assert len(sc) == n + assert m.shape[0] == n + + # Extract waveforms. + spikes = [] + for i in range(n): + b = Bunch() + # Find the start of the waveform in the extracted traces. + sample_start = int(round((st[i] - interval[0]) * sr)) + sample_start -= offset_samples + o = _extract_wave(traces, sample_start, m[i], wave_len) + if o is None: # pragma: no cover + logger.debug("Unable to extract spike %d.", i) + continue + b.waveforms, b.channels = o + # Masks on unmasked channels. + b.masks = m[i, b.channels] + b.spike_time = st[i] + b.spike_cluster = sc[i] + b.offset_samples = offset_samples + + spikes.append(b) + return spikes + + +class TraceView(ManualClusteringView): + interval_duration = .25 # default duration of the interval + shift_amount = .1 + scaling_coeff = 1.1 + default_trace_color = (.3, .3, .3, 1.) + default_shortcuts = { + 'go_left': 'alt+left', + 'go_right': 'alt+right', + 'decrease': 'alt+down', + 'increase': 'alt+up', + 'widen': 'ctrl+alt+left', + 'narrow': 'ctrl+alt+right', + } + + def __init__(self, + traces=None, + spikes=None, + sample_rate=None, + duration=None, + n_channels=None, + **kwargs): + + # traces is a function interval => [traces] + # spikes is a function interval => [Bunch(...)] + + # Sample rate. + assert sample_rate > 0 + self.sample_rate = float(sample_rate) + self.dt = 1. / self.sample_rate + + # Traces and spikes. + assert hasattr(traces, '__call__') + self.traces = traces + assert hasattr(spikes, '__call__') + self.spikes = spikes + + assert duration >= 0 + self.duration = duration + + assert n_channels >= 0 + self.n_channels = n_channels + + # Box and probe scaling. + self._scaling = 1. + self._origin = None + + self._color_selector = ColorSelector() + + # Initialize the view. + super(TraceView, self).__init__(layout='stacked', + origin=self.origin, + n_plots=self.n_channels, + **kwargs) + + # Make a copy of the initial box pos and size. We'll apply the scaling + # to these quantities. + self.box_size = np.array(self.stacked.box_size) + self._update_boxes() + + # Initial interval. + self._interval = None + self.go_to(duration / 2.) + + # Internal methods + # ------------------------------------------------------------------------- + + def _plot_traces(self, traces=None, color=None, show_labels=True): + assert traces.shape[1] == self.n_channels + t = self._interval[0] + np.arange(traces.shape[0]) * self.dt + color = color or self.default_trace_color + channels = np.arange(self.n_channels) + for ch in channels: + self[ch].plot(t, traces[:, ch], + color=color, + data_bounds=self.data_bounds, + ) + if show_labels: + # Add channel labels. + self[ch].text(pos=[t[0], traces[0, ch]], + text=str(ch), + anchor=[+1., -.1], + data_bounds=self.data_bounds, + ) + + def _plot_spike(self, waveforms=None, channels=None, masks=None, + spike_time=None, spike_cluster=None, offset_samples=0, + color=None): + + n_samples, n_channels = waveforms.shape + assert len(channels) == n_channels + assert len(masks) == n_channels + sr = float(self.sample_rate) + + t0 = spike_time - offset_samples / sr + + # Generate the x coordinates of the waveform. + t = t0 + self.dt * np.arange(n_samples) + t = np.tile(t, (n_channels, 1)) # (n_unmasked_channels, n_samples) + + # The box index depends on the channel. + box_index = np.repeat(channels[:, np.newaxis], n_samples, axis=0) + self.plot(t, waveforms.T, color=color, box_index=box_index, + data_bounds=self.data_bounds) + + def _restrict_interval(self, interval): + start, end = interval + # Round the times to full samples to avoid subsampling shifts + # in the traces. + start = int(round(start * self.sample_rate)) / self.sample_rate + end = int(round(end * self.sample_rate)) / self.sample_rate + # Restrict the interval to the boundaries of the traces. + if start < 0: + end += (-start) + start = 0 + elif end >= self.duration: + start -= (end - self.duration) + end = self.duration + start = np.clip(start, 0, end) + end = np.clip(end, start, self.duration) + assert 0 <= start < end <= self.duration + return start, end + + # Public methods + # ------------------------------------------------------------------------- + + def set_interval(self, interval, change_status=True): + """Display the traces and spikes in a given interval.""" + self.clear() + interval = self._restrict_interval(interval) + self._interval = interval + start, end = interval + # Set the status message. + if change_status: + self.set_status('Interval: {:.3f} s - {:.3f} s'.format(start, end)) + + # Load the traces. + all_traces = self.traces(interval) + assert isinstance(all_traces, (tuple, list)) + # Default data bounds. + m, M = all_traces[0].traces.min(), all_traces[0].traces.max() + self.data_bounds = np.array([start, m, end, M]) + + # Plot the traces. + for i, traces in enumerate(all_traces): + # Only show labels for the first set of traces. + self._plot_traces(show_labels=(i == 0), **traces) + + # Plot the spikes. + spikes = self.spikes(interval, all_traces) + assert isinstance(spikes, (tuple, list)) + + for spike in spikes: + color = self._color_selector.get(spike.spike_cluster, + self.cluster_ids) + self._plot_spike(color=color, **spike) + + self.build() + self.update() + + def on_select(self, cluster_ids=None): + super(TraceView, self).on_select(cluster_ids) + self.set_interval(self._interval, change_status=False) + + def attach(self, gui): + """Attach the view to the GUI.""" + super(TraceView, self).attach(gui) + self.actions.add(self.go_to, alias='tg') + self.actions.add(self.shift, alias='ts') + self.actions.add(self.go_right) + self.actions.add(self.go_left) + self.actions.add(self.increase) + self.actions.add(self.decrease) + self.actions.add(self.widen) + self.actions.add(self.narrow) + + @property + def state(self): + return Bunch(scaling=self.scaling, + origin=self.origin, + interval=self._interval, + ) + + # Scaling + # ------------------------------------------------------------------------- + + @property + def scaling(self): + return self._scaling + + @scaling.setter + def scaling(self, value): + self._scaling = value + self._update_boxes() + + # Origin + # ------------------------------------------------------------------------- + + @property + def origin(self): + return self._origin + + @origin.setter + def origin(self, value): + self._origin = value + self._update_boxes() + + # Navigation + # ------------------------------------------------------------------------- + + @property + def time(self): + """Time at the center of the window.""" + return sum(self._interval) * .5 + + @property + def interval(self): + return self._interval + + @interval.setter + def interval(self, value): + self.set_interval(value) + + @property + def half_duration(self): + """Half of the duration of the current interval.""" + if self._interval is not None: + a, b = self._interval + return (b - a) * .5 + else: + return self.interval_duration * .5 + + def go_to(self, time): + """Go to a specific time (in seconds).""" + half_dur = self.half_duration + self.set_interval((time - half_dur, time + half_dur)) + + def shift(self, delay): + """Shift the interval by a given delay (in seconds).""" + self.go_to(self.time + delay) + + def go_right(self): + """Go to right.""" + start, end = self._interval + delay = (end - start) * .2 + self.shift(delay) + + def go_left(self): + """Go to left.""" + start, end = self._interval + delay = (end - start) * .2 + self.shift(-delay) + + def widen(self): + """Increase the interval size.""" + t, h = self.time, self.half_duration + h *= self.scaling_coeff + self.set_interval((t - h, t + h)) + + def narrow(self): + """Decrease the interval size.""" + t, h = self.time, self.half_duration + h /= self.scaling_coeff + self.set_interval((t - h, t + h)) + + # Channel scaling + # ------------------------------------------------------------------------- + + def _update_boxes(self): + self.stacked.box_size = self.box_size * self.scaling + + def increase(self): + """Increase the scaling of the traces.""" + self.scaling *= self.scaling_coeff + self._update_boxes() + + def decrease(self): + """Decrease the scaling of the traces.""" + self.scaling /= self.scaling_coeff + self._update_boxes() + + +# ----------------------------------------------------------------------------- +# Feature view +# ----------------------------------------------------------------------------- + +def _dimensions_matrix(channels, n_cols=None, top_left_attribute=None): + """ + time,x0 y0,x0 x1,x0 y1,x0 + x0,y0 time,y0 x1,y0 y1,y0 + x0,x1 y0,x1 time,x1 y1,x1 + x0,y1 y0,y1 x1,y1 time,y1 + """ + # Generate the dimensions matrix from the docstring. + ds = inspect.getdoc(_dimensions_matrix).strip() + x, y = channels[:2] + + def _get_dim(d): + if d == 'time': + return d + assert re.match(r'[xy][01]', d) + c = x if d[0] == 'x' else y + f = int(d[1]) + return c, f + + dims = [[_.split(',') for _ in re.split(r' +', line.strip())] + for line in ds.splitlines()] + x_dim = {(i, j): _get_dim(dims[i][j][0]) + for i, j in product(range(4), range(4))} + y_dim = {(i, j): _get_dim(dims[i][j][1]) + for i, j in product(range(4), range(4))} + return x_dim, y_dim + + +def _project_mask_depth(dim, masks, spike_clusters_rel=None, n_clusters=None): + """Return the mask and depth vectors for a given dimension.""" + n_spikes = masks.shape[0] + if isinstance(dim, tuple): + ch, fet = dim + m = masks[:, ch] + d = _get_depth(m, + spike_clusters_rel=spike_clusters_rel, + n_clusters=n_clusters) + else: + m = np.ones(n_spikes) + d = np.zeros(n_spikes) + return m, d + + +class FeatureView(ManualClusteringView): + _default_marker_size = 3. + + default_shortcuts = { + 'increase': 'ctrl++', + 'decrease': 'ctrl+-', + 'toggle_automatic_channel_selection': 'c', + } + + def __init__(self, + features=None, + background_features=None, + spike_times=None, + n_channels=None, + n_features_per_channel=None, + feature_lim=None, + best_channels=None, + **kwargs): + """ + features is a function : + `cluster_ids: Bunch(spike_ids, + features, + masks, + spike_clusters, + spike_times)` + background_features is a Bunch(...) like above. + + """ + self._scaling = 1. + + self.best_channels = best_channels or (lambda clusters=None: []) + + assert features + self.features = features + + # This is a tuple (spikes, features, masks). + self.background_features = background_features + + self.n_features_per_channel = n_features_per_channel + assert n_channels > 0 + self.n_channels = n_channels + + # Spike times. + self.n_spikes = spike_times.shape[0] + assert spike_times.shape == (self.n_spikes,) + assert self.n_spikes >= 0 + + self.n_cols = 4 + self.shape = (self.n_cols, self.n_cols) + + # Initialize the view. + super(FeatureView, self).__init__(layout='grid', + shape=self.shape, + enable_lasso=True, + **kwargs) + + # Feature normalization. + self.data_bounds = [-1, -feature_lim, +1, +feature_lim] + + # If this is True, the channels won't be automatically chosen + # when new clusters are selected. + self.fixed_channels = False + + # Channels to show. + self.channels = None + + # Attributes: extra features. This is a dictionary + # {name: (array, data_bounds)} + # where each array is a `(n_spikes,)` array. + self.attributes = {} + self.add_attribute('time', spike_times) + + # Internal methods + # ------------------------------------------------------------------------- + + def _get_feature(self, dim, spike_ids, f): + if dim in self.attributes: + # Extra features like time. + values, _ = self.attributes[dim] + values = values[spike_ids] + # assert values.shape == (f.shape[0],) + return values + else: + assert len(dim) == 2 + ch, fet = dim + assert fet < f.shape[2] + return f[:, ch, fet] * self._scaling + + def _get_dim_bounds_single(self, dim): + """Return the min and max of the bounds for a single dimension.""" + if dim in self.attributes: + # Attribute: the data bounds were computed in add_attribute(). + y0, y1 = self.attributes[dim][1] + else: + # Features: the data bounds were computed in the constructor. + _, y0, _, y1 = self.data_bounds + return y0, y1 + + def _get_dim_bounds(self, x_dim, y_dim): + """Return the data bounds of a subplot, as a function of the + two x-y dimensions.""" + x0, x1 = self._get_dim_bounds_single(x_dim) + y0, y1 = self._get_dim_bounds_single(y_dim) + return [x0, y0, x1, y1] + + def _plot_features(self, i, j, x_dim, y_dim, x, y, + masks=None, spike_clusters_rel=None): + """Plot the features in a subplot.""" + assert x.shape == y.shape + n_spikes = x.shape[0] + + sc = spike_clusters_rel + if sc is not None: + assert sc.shape == (n_spikes,) + n_clusters = len(self.cluster_ids) + + # Retrieve the data bounds. + data_bounds = self._get_dim_bounds(x_dim[i, j], y_dim[i, j]) + + # Retrieve the masks and depth. + mx, dx = _project_mask_depth(x_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + my, dy = _project_mask_depth(y_dim[i, j], masks, + spike_clusters_rel=sc, + n_clusters=n_clusters) + assert mx.shape == my.shape == dx.shape == dy.shape == (n_spikes,) + + d = np.maximum(dx, dy) + m = np.maximum(mx, my) + + # Get the color of the markers. + color = _spike_colors(sc, masks=m) + assert color.shape == (n_spikes, 4) + + # Create the scatter plot for the current subplot. + # The marker size is smaller for background spikes. + ms = (self._default_marker_size + if spike_clusters_rel is not None else 1.) + self[i, j].scatter(x=x, y=y, + color=color, + depth=d, + data_bounds=data_bounds, + size=ms * np.ones(n_spikes), + ) + if i == (self.n_cols - 1): + dim = x_dim[i, j] if j < (self.n_cols - 1) else x_dim[i, 0] + self[i, j].text(pos=[0., -1.], + text=str(dim), + anchor=[0., -1.04], + ) + if j == 0: + self[i, j].text(pos=[-1., 0.], + text=str(y_dim[i, j]), + anchor=[-1.03, 0.], + ) + + def _get_channel_dims(self, cluster_ids): + """Select the channels to show by default.""" + n = 2 + channels = self.best_channels(cluster_ids) + channels = (channels if channels is not None + else list(range(self.n_channels))) + channels = _extend(channels, n) + assert len(channels) == n + return channels + + # Public methods + # ------------------------------------------------------------------------- + + def add_attribute(self, name, values, top_left=True): + """Add an attribute (aka extra feature). + + The values should be a 1D array with `n_spikes` elements. + + By default, there is the `time` attribute. + + """ + assert values.shape == (self.n_spikes,) + lims = values.min(), values.max() + self.attributes[name] = (values, lims) + # Register the attribute to use in the top-left subplot. + if top_left: + self.top_left_attribute = name + + def clear_channels(self): + """Reset the dimensions.""" + self.channels = None + self.on_select() + + def on_select(self, cluster_ids=None): + super(FeatureView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids + n_clusters = len(cluster_ids) + if n_clusters == 0: + return + + # Get the spikes, features, masks. + data = self.features(cluster_ids) + spike_ids = data.spike_ids + spike_clusters = data.spike_clusters + f = data.data + masks = data.masks + assert f.ndim == 3 + assert masks.ndim == 2 + assert spike_ids.shape[0] == f.shape[0] == masks.shape[0] + + # Get the spike clusters. + sc = _index_of(spike_clusters, cluster_ids) + + # Get the background features. + data_bg = self.background_features + if data_bg is not None: + spike_ids_bg = data_bg.spike_ids + features_bg = data_bg.data + masks_bg = data_bg.masks + + # Select the dimensions. + # Choose the channels automatically unless fixed_channels is set. + if (not self.fixed_channels or self.channels is None): + self.channels = self._get_channel_dims(cluster_ids) + tla = self.top_left_attribute + assert self.channels + x_dim, y_dim = _dimensions_matrix(self.channels, + n_cols=self.n_cols, + top_left_attribute=tla) + + # Set the status message. + ch = ', '.join(map(str, self.channels)) + self.set_status('Channels: {}'.format(ch)) + + # Set a non-time attribute as y coordinate in the top-left subplot. + attrs = sorted(self.attributes) + attrs.remove('time') + if attrs: + y_dim[0, 0] = attrs[0] + + # Plot all features. + with self.building(): + for i in range(self.n_cols): + for j in range(self.n_cols): + + # Retrieve the x and y values for the subplot. + x = self._get_feature(x_dim[i, j], spike_ids, f) + y = self._get_feature(y_dim[i, j], spike_ids, f) + + if data_bg is not None: + # Retrieve the x and y values for the background + # spikes. + x_bg = self._get_feature(x_dim[i, j], spike_ids_bg, + features_bg) + y_bg = self._get_feature(y_dim[i, j], spike_ids_bg, + features_bg) + + # Background features. + self._plot_features(i, j, x_dim, y_dim, x_bg, y_bg, + masks=masks_bg) + + # Cluster features. + self._plot_features(i, j, x_dim, y_dim, x, y, + masks=masks, + spike_clusters_rel=sc) + + # Add axes. + self[i, j].lines(pos=[[-1., 0., +1., 0.], + [0., -1., 0., +1.]], + color=(.25, .25, .25, .5)) + + # Add the boxes. + self.grid.add_boxes(self, self.shape) + + def attach(self, gui): + """Attach the view to the GUI.""" + super(FeatureView, self).attach(gui) + self.actions.add(self.increase) + self.actions.add(self.decrease) + self.actions.add(self.clear_channels) + self.actions.add(self.toggle_automatic_channel_selection) + + gui.connect_(self.on_channel_click) + gui.connect_(self.on_request_split) + + @property + def state(self): + return Bunch(scaling=self.scaling) + + def on_channel_click(self, channel_idx=None, key=None, button=None): + """Respond to the click on a channel.""" + channels = self.channels + if channels is None: + return + assert len(channels) == 2 + assert 0 <= channel_idx < self.n_channels + # Get the axis from the pressed button (1, 2, etc.) + # axis = 'x' if button == 1 else 'y' + channels[0 if button == 1 else 1] = channel_idx + self.fixed_channels = True + self.on_select() + + def on_request_split(self): + """Return the spikes enclosed by the lasso.""" + if self.lasso.count < 3: + return [] + tla = self.top_left_attribute + assert self.channels + x_dim, y_dim = _dimensions_matrix(self.channels, + n_cols=self.n_cols, + top_left_attribute=tla) + data = self.features(self.cluster_ids, load_all=True) + spike_ids = data.spike_ids + f = data.data + i, j = self.lasso.box + + x = self._get_feature(x_dim[i, j], spike_ids, f) + y = self._get_feature(y_dim[i, j], spike_ids, f) + pos = np.c_[x, y].astype(np.float64) + + # Retrieve the data bounds. + data_bounds = self._get_dim_bounds(x_dim[i, j], y_dim[i, j]) + pos = Range(from_bounds=data_bounds).apply(pos) + + ind = self.lasso.in_polygon(pos) + self.lasso.clear() + return spike_ids[ind] + + def toggle_automatic_channel_selection(self): + """Toggle the automatic selection of channels when the cluster + selection changes.""" + self.fixed_channels = not self.fixed_channels + + def increase(self): + """Increase the scaling of the features.""" + self.scaling *= 1.2 + self.on_select() + + def decrease(self): + """Decrease the scaling of the features.""" + self.scaling /= 1.2 + self.on_select() + + # Feature scaling + # ------------------------------------------------------------------------- + + @property + def scaling(self): + return self._scaling + + @scaling.setter + def scaling(self, value): + self._scaling = value + + +# ----------------------------------------------------------------------------- +# Correlogram view +# ----------------------------------------------------------------------------- + +class CorrelogramView(ManualClusteringView): + excerpt_size = 10000 + n_excerpts = 100 + bin_size = 1e-3 + window_size = 50e-3 + uniform_normalization = False + + default_shortcuts = { + 'go_left': 'alt+left', + 'go_right': 'alt+right', + } + + def __init__(self, + spike_times=None, + spike_clusters=None, + sample_rate=None, + **kwargs): + + assert sample_rate > 0 + self.sample_rate = float(sample_rate) + + self.spike_times = np.asarray(spike_times) + self.n_spikes, = self.spike_times.shape + + # Initialize the view. + super(CorrelogramView, self).__init__(layout='grid', + shape=(1, 1), + **kwargs) + + # Spike clusters. + assert spike_clusters.shape == (self.n_spikes,) + self.spike_clusters = spike_clusters + + # Set the default bin and window size. + self.set_bin_window(bin_size=self.bin_size, + window_size=self.window_size) + + def set_bin_window(self, bin_size=None, window_size=None): + """Set the bin and window sizes.""" + bin_size = bin_size or self.bin_size + window_size = window_size or self.window_size + assert 1e-6 < bin_size < 1e3 + assert 1e-6 < window_size < 1e3 + assert bin_size < window_size + self.bin_size = bin_size + self.window_size = window_size + # Set the status message. + b, w = self.bin_size * 1000, self.window_size * 1000 + self.set_status('Bin: {:.1f} ms. Window: {:.1f} ms.'.format(b, w)) + + def _compute_correlograms(self, cluster_ids): + + # Keep spikes belonging to the selected clusters. + ind = np.in1d(self.spike_clusters, cluster_ids) + st = self.spike_times[ind] + sc = self.spike_clusters[ind] + + # Take excerpts of the spikes. + n_spikes_total = len(st) + st = get_excerpts(st, excerpt_size=self.excerpt_size, + n_excerpts=self.n_excerpts) + sc = get_excerpts(sc, excerpt_size=self.excerpt_size, + n_excerpts=self.n_excerpts) + n_spikes_exerpts = len(st) + logger.log(5, "Computing correlograms for clusters %s (%d/%d spikes).", + ', '.join(map(str, cluster_ids)), + n_spikes_exerpts, n_spikes_total, + ) + + # Compute all pairwise correlograms. + ccg = correlograms(st, sc, + cluster_ids=cluster_ids, + sample_rate=self.sample_rate, + bin_size=self.bin_size, + window_size=self.window_size, + ) + + return ccg + + def on_select(self, cluster_ids=None): + super(CorrelogramView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids + n_clusters = len(cluster_ids) + if n_clusters == 0: + return + + ccg = self._compute_correlograms(cluster_ids) + ylim = [ccg.max()] if not self.uniform_normalization else None + + colors = _spike_colors(np.arange(n_clusters), alpha=1.) + + self.grid.shape = (n_clusters, n_clusters) + with self.building(): + for i in range(n_clusters): + for j in range(n_clusters): + hist = ccg[i, j, :] + color = colors[i] if i == j else np.ones(4) + self[i, j].hist(hist, + color=color, + ylim=ylim, + ) + # Cluster labels. + if i == (n_clusters - 1): + self[i, j].text(pos=[0., -1.], + text=str(cluster_ids[j]), + anchor=[0., -1.04], + ) + + def toggle_normalization(self): + """Change the normalization of the correlograms.""" + self.uniform_normalization = not self.uniform_normalization + self.on_select() + + def attach(self, gui): + """Attach the view to the GUI.""" + super(CorrelogramView, self).attach(gui) + self.actions.add(self.toggle_normalization, shortcut='n') + self.actions.add(self.set_bin, alias='cb') + self.actions.add(self.set_window, alias='cw') + + @property + def state(self): + return Bunch(bin_size=self.bin_size, + window_size=self.window_size, + excerpt_size=self.excerpt_size, + n_excerpts=self.n_excerpts, + uniform_normalization=self.uniform_normalization, + ) + + def set_bin(self, bin_size): + """Set the correlogram bin size (in milliseconds).""" + self.set_bin_window(bin_size=bin_size * 1e-3) + self.on_select() + + def set_window(self, window_size): + """Set the correlogram window size (in milliseconds).""" + self.set_bin_window(window_size=window_size * 1e-3) + self.on_select() + + +# ----------------------------------------------------------------------------- +# Scatter view +# ----------------------------------------------------------------------------- + +class ScatterView(ManualClusteringView): + _default_marker_size = 3. + + def __init__(self, + coords=None, # function clusters: Bunch(x, y) + data_bounds=None, + **kwargs): + + assert coords + self.coords = coords + + # Initialize the view. + super(ScatterView, self).__init__(**kwargs) + + # Feature normalization. + self.data_bounds = data_bounds + + def on_select(self, cluster_ids=None): + super(ScatterView, self).on_select(cluster_ids) + cluster_ids = self.cluster_ids + n_clusters = len(cluster_ids) + if n_clusters == 0: + return + + # Get the spike times and amplitudes + data = self.coords(cluster_ids) + if data is None: + self.clear() + return + spike_ids = data.spike_ids + spike_clusters = data.spike_clusters + x = data.x + y = data.y + n_spikes = len(spike_ids) + assert n_spikes > 0 + assert spike_clusters.shape == (n_spikes,) + assert x.shape == (n_spikes,) + assert y.shape == (n_spikes,) + + # Get the spike clusters. + sc = _index_of(spike_clusters, cluster_ids) + + # Plot the amplitudes. + with self.building(): + m = np.ones(n_spikes) + # Get the color of the markers. + color = _spike_colors(sc, masks=m) + assert color.shape == (n_spikes, 4) + ms = (self._default_marker_size if sc is not None else 1.) + + self.scatter(x=x, + y=y, + color=color, + data_bounds=self.data_bounds, + size=ms * np.ones(n_spikes), + ) diff --git a/phy/cluster/manual/wizard.py b/phy/cluster/manual/wizard.py deleted file mode 100644 index b0c3c05d0..000000000 --- a/phy/cluster/manual/wizard.py +++ /dev/null @@ -1,502 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Wizard.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from operator import itemgetter - -from ...utils import _is_array_like -from .view_models import HTMLClusterViewModel -from ...gui._utils import _read - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _argsort(seq, reverse=True, n_max=None): - """Return the list of clusters in decreasing order of value from - a list of tuples (cluster, value).""" - out = [cl for (cl, v) in sorted(seq, - key=itemgetter(1), - reverse=reverse)] - if n_max in (None, 0): - return out - else: - return out[:n_max] - - -def _best_clusters(clusters, quality, n_max=None): - return _argsort([(cluster, quality(cluster)) - for cluster in clusters], n_max=n_max) - - -def _find_first(items, filter=None): - if not items: - return None - if filter is None: - return items[0] - return next(item for item in items if filter(item)) - - -def _previous(items, current, filter=None): - if current not in items: - raise RuntimeError("{0} is not in {1}.".format(current, items)) - i = items.index(current) - if i == 0: - return current - try: - return _find_first(items[:i][::-1], filter) - except StopIteration: - return current - - -def _next(items, current, filter=None): - if not items: - return current - if current not in items: - raise RuntimeError("{0} is not in {1}.".format(current, items)) - i = items.index(current) - if i == len(items) - 1: - return current - try: - return _find_first(items[i + 1:], filter) - except StopIteration: - return current - - -def _progress(value, maximum): - if maximum <= 1: - return 1 - return int(100 * value / float(maximum - 1)) - - -#------------------------------------------------------------------------------ -# Wizard -#------------------------------------------------------------------------------ - -class Wizard(object): - """Propose a selection of high-quality clusters and merge candidates.""" - def __init__(self, cluster_groups=None): - self.cluster_groups = cluster_groups - self.reset() - - def reset(self): - self._best_list = [] # This list is fixed (modulo clustering actions). - self._match_list = [] # This list may often change. - self._similarity = None - self._quality = None - self._best = None - self._match = None - - @property - def has_started(self): - return len(self._best_list) > 0 - - # Quality functions - #-------------------------------------------------------------------------- - - def set_similarity_function(self, func): - """Register a function returning the similarity between two clusters. - - Can be used as a decorator. - - """ - self._similarity = func - return func - - def set_quality_function(self, func): - """Register a function returning the quality of a cluster. - - Can be used as a decorator. - - """ - self._quality = func - return func - - # Internal methods - #-------------------------------------------------------------------------- - - def _group(self, cluster): - return self._cluster_groups.get(cluster, None) - - def _in_groups(self, items, groups): - """Filter out ignored clusters or pairs of clusters.""" - if not isinstance(groups, (list, tuple)): - groups = [groups] - return [item for item in items if self._group(item) in groups] - - def _is_not_ignored(self, cluster): - return self._in_groups([cluster], (None, 'good')) - - def _check(self): - clusters = set(self.cluster_ids) - assert set(self._best_list) <= clusters - assert set(self._match_list) <= clusters - if self._best is not None and len(self._best_list) >= 1: - assert self._best in self._best_list - if self._match is not None and len(self._match_list) >= 1: - assert self._match in self._match_list - if None not in (self.best, self.match): - assert self.best != self.match - - def _sort(self, items, mix_good_unsorted=False): - """Sort clusters according to their groups: - unsorted, good, and ignored.""" - if mix_good_unsorted: - return (self._in_groups(items, (None, 'good')) + - self._in_groups(items, 'ignored')) - else: - return (self._in_groups(items, None) + - self._in_groups(items, 'good') + - self._in_groups(items, 'ignored')) - - # Properties - #-------------------------------------------------------------------------- - - @property - def cluster_ids(self): - """Array of cluster ids in the current clustering.""" - return sorted(self._cluster_groups) - - @property - def cluster_groups(self): - """Dictionary with the groups of each cluster. - - The groups are: `None` (corresponds to unsorted), `good`, or `ignored`. - - """ - return self._cluster_groups - - @cluster_groups.setter - def cluster_groups(self, cluster_groups): - # cluster_groups is a dictionary or is converted to one. - if _is_array_like(cluster_groups): - # A group can be None (unsorted), `good`, or `ignored`. - cluster_groups = {clu: None for clu in cluster_groups} - self._cluster_groups = cluster_groups - - # Core methods - #-------------------------------------------------------------------------- - - def best_clusters(self, n_max=None, quality=None): - """Return the list of best clusters sorted by decreasing quality. - - The default quality function is the registered one. - - """ - if quality is None: - quality = self._quality - best = _best_clusters(self.cluster_ids, quality, n_max=n_max) - return self._sort(best) - - def most_similar_clusters(self, cluster=None, n_max=None, similarity=None): - """Return the `n_max` most similar clusters to a given cluster. - - The default similarity function is the registered one. - - """ - if cluster is None: - cluster = self.best - if cluster is None: - cluster = self.best_clusters(1)[0] - if similarity is None: - similarity = self._similarity - s = [(other, similarity(cluster, other)) - for other in self.cluster_ids - if other != cluster] - clusters = _argsort(s, n_max=n_max) - return self._sort(clusters, mix_good_unsorted=True) - - # List methods - #-------------------------------------------------------------------------- - - def _set_best_list(self, cluster=None, clusters=None): - if cluster is None: - cluster = self.best - if clusters is None: - clusters = self.best_clusters() - self._best_list = clusters - if clusters: - self.best = clusters[0] - - def _set_match_list(self, cluster=None, clusters=None): - if cluster is None: - cluster = self.best - if clusters is None: - clusters = self.most_similar_clusters(cluster) - self._match_list = clusters - if clusters: - self.match = clusters[0] - - @property - def best(self): - """Currently-selected best cluster.""" - return self._best - - @best.setter - def best(self, value): - assert value in self._best_list - self._best = value - - @property - def match(self): - """Currently-selected closest match.""" - return self._match - - @property - def selection(self): - """Return the current best/match cluster selection.""" - b, m = self.best, self.match - if b is None: - return [] - elif m is None: - return [b] - else: - if b == m: - return [b] - else: - return [b, m] - - @match.setter - def match(self, value): - if value is not None: - assert value in self._match_list - self._match = value - - @property - def best_list(self): - """Current list of best clusters, by decreasing quality.""" - return self._best_list - - @property - def match_list(self): - """Current list of closest matches, by decreasing similarity.""" - return self._match_list - - @property - def n_processed(self): - """Numbered of processed clusters so far. - - A cluster is considered processed if its group is not `None`. - - """ - return len(self._in_groups(self._best_list, ('good', 'ignored'))) - - @property - def n_clusters(self): - """Total number of clusters.""" - return len(self.cluster_ids) - - # Navigation - #-------------------------------------------------------------------------- - - @property - def _has_finished(self): - return self.best is not None and len(self._best_list) <= 1 - - def next_best(self): - """Select the next best cluster.""" - if self._has_finished: - return - self.best = _next(self._best_list, - self._best, - ) - if self.match is not None: - self._set_match_list() - - def previous_best(self): - """Select the previous best in cluster.""" - if self._has_finished: - return - self.best = _previous(self._best_list, - self._best, - ) - if self.match is not None: - self._set_match_list() - - def next_match(self): - """Select the next match.""" - # Handle the case where we arrive at the end of the match list. - if self.match is not None and len(self._match_list) <= 1: - self.next_best() - else: - self.match = _next(self._match_list, - self._match, - ) - - def previous_match(self): - """Select the previous match.""" - self.match = _previous(self._match_list, - self._match, - ) - - def next(self): - """Next cluster proposition.""" - if self.match is None: - return self.next_best() - else: - return self.next_match() - - def previous(self): - """Previous cluster proposition.""" - if self.match is None: - return self.previous_best() - else: - return self.previous_match() - - def first(self): - """First match or first best.""" - if self.match is None: - self.best = self._best_list[0] - else: - self.match = self._match_list[0] - - def last(self): - """Last match or last best.""" - if self.match is None: - self.best = self._best_list[-1] - else: - self.match = self._match_list[-1] - - # Control - #-------------------------------------------------------------------------- - - def start(self): - """Start the wizard by setting the list of best clusters.""" - self._set_best_list() - - def pin(self, cluster=None): - """Pin the current best cluster and set the list of closest matches.""" - if self._has_finished: - return - if cluster is None: - cluster = self.best - if self.match is not None and self.best == cluster: - return - self.best = cluster - self._set_match_list(cluster) - self._check() - - def unpin(self): - """Unpin the current cluster.""" - if self.match is not None: - self.match = None - self._match_list = [] - - # Actions - #-------------------------------------------------------------------------- - - def _delete(self, clusters): - for clu in clusters: - if clu in self._cluster_groups: - del self._cluster_groups[clu] - if clu in self._best_list: - self._best_list.remove(clu) - if clu in self._match_list: - self._match_list.remove(clu) - if clu == self._best: - self._best = self._best_list[0] if self._best_list else None - if clu == self._match: - self._match = None - - def _add(self, clusters, group, position=None): - for clu in clusters: - assert clu not in self._cluster_groups - assert clu not in self._best_list - assert clu not in self._match_list - self._cluster_groups[clu] = group - if self.best is not None: - if position is not None: - self._best_list.insert(position, clu) - else: - self._best_list.append(clu) - if self.match is not None: - self._match_list.append(clu) - - def _update_state(self, up): - # Update the cluster group. - if up.description == 'metadata_group': - cluster = up.metadata_changed[0] - group = up.metadata_value - self._cluster_groups[cluster] = group - # Reorder the best list, so that the clusters moved in different - # groups go to their right place in the best list. - if (self._best is not None and self._best_list and - cluster == self._best): - # Find the next best after the cluster has been moved. - next_best = _next(self._best_list, self._best) - # Reorder the list. - self._best_list = self._sort(self._best_list) - # Select the next best. - self._best = next_best - # Update the wizard with new and old clusters. - for clu in up.added: - # Add the child at the parent's position. - parents = [x for (x, y) in up.descendants if y == clu] - parent = parents[0] - group = self._group(parent) - position = (self._best_list.index(parent) - if self._best_list else None) - self._add([clu], group, position) - # Delete old clusters. - self._delete(up.deleted) - # Select the last added cluster. - if self.best is not None and up.added: - self.best = up.added[-1] - - def on_cluster(self, up): - if self._has_finished: - return - if self._best_list or self._match_list: - self._update_state(up) - - # Panel - #-------------------------------------------------------------------------- - - @property - def _best_progress(self): - """Progress in the best clusters.""" - value = (self.best_list.index(self.best) - if self.best in self.best_list else 0) - maximum = len(self.best_list) - return _progress(value, maximum) - - @property - def _match_progress(self): - """Progress in the processed clusters.""" - value = self.n_processed - maximum = self.n_clusters - return _progress(value, maximum) - - def get_panel_params(self): - """Return the parameters for the HTML panel.""" - return dict(best=self.best if self.best is not None else '', - match=self.match if self.match is not None else '', - best_progress=self._best_progress, - match_progress=self._match_progress, - best_group=self._group(self.best) or 'unsorted', - match_group=self._group(self.match) or 'unsorted', - ) - - -#------------------------------------------------------------------------------ -# Wizard view model -#------------------------------------------------------------------------------ - -class WizardViewModel(HTMLClusterViewModel): - def get_html(self, **kwargs): - static_path = op.join(op.dirname(op.realpath(__file__)), 'static') - params = self._wizard.get_panel_params() - html = _read('wizard.html', static_path=static_path) - return html.format(**params) - - def get_css(self, **kwargs): - css = super(WizardViewModel, self).get_css(**kwargs) - static_path = op.join(op.dirname(op.realpath(__file__)), 'static') - css += _read('styles.css', static_path=static_path) - return css diff --git a/phy/conftest.py b/phy/conftest.py new file mode 100644 index 000000000..869c5fb27 --- /dev/null +++ b/phy/conftest.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +"""py.test utilities.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging +import numpy as np +import os + +from pytest import yield_fixture + +from phy import add_default_handler +from phy.utils.tempdir import TemporaryDirectory + + +#------------------------------------------------------------------------------ +# Common fixtures +#------------------------------------------------------------------------------ + +logging.getLogger().setLevel(logging.DEBUG) +add_default_handler('DEBUG') + +# Fix the random seed in the tests. +np.random.seed(2015) + + +@yield_fixture +def tempdir(): + with TemporaryDirectory() as tempdir: + yield tempdir + + +@yield_fixture +def chdir_tempdir(): + curdir = os.getcwd() + with TemporaryDirectory() as tempdir: + os.chdir(tempdir) + yield tempdir + os.chdir(curdir) diff --git a/phy/detect/__init__.py b/phy/detect/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/detect/default_settings.py b/phy/detect/default_settings.py deleted file mode 100644 index 6b6794eda..000000000 --- a/phy/detect/default_settings.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for spike detection.""" - - -# ----------------------------------------------------------------------------- -# Spike detection -# ----------------------------------------------------------------------------- - -spikedetekt = { - 'filter_low': 500., - 'filter_high_factor': 0.95 * .5, # will be multiplied by the sample rate - 'filter_butter_order': 3, - - # Data chunks. - 'chunk_size_seconds': 1., - 'chunk_overlap_seconds': .015, - - # Threshold. - 'n_excerpts': 50, - 'excerpt_size_seconds': 1., - 'use_single_threshold': True, - 'threshold_strong_std_factor': 4.5, - 'threshold_weak_std_factor': 2., - 'detect_spikes': 'negative', - - # Connected components. - 'connected_component_join_size': 1, - - # Spike extractions. - 'extract_s_before': 10, - 'extract_s_after': 10, - 'weight_power': 2, - - # Features. - 'n_features_per_channel': 3, - 'pca_n_waveforms_max': 10000, - - # Waveform filtering in GUI. - 'waveform_filter': True, - 'waveform_dc_offset': None, - 'waveform_scale_factor': None, - -} diff --git a/phy/detect/spikedetekt.py b/phy/detect/spikedetekt.py deleted file mode 100644 index 3615a4823..000000000 --- a/phy/detect/spikedetekt.py +++ /dev/null @@ -1,610 +0,0 @@ -# -*- coding: utf-8 -*- - -"""SpikeDetekt algorithm.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from collections import defaultdict - -import numpy as np - -from ..utils.array import (get_excerpts, - chunk_bounds, - data_chunk, - _as_array, - _concatenate, - ) -from ..utils._types import Bunch -from ..utils.event import EventEmitter, ProgressReporter -from ..utils.logging import debug, info -from ..electrode.mea import (_channels_per_group, - _probe_adjacency_list, - ) -from ..traces import (Filter, Thresholder, compute_threshold, - FloodFillDetector, WaveformExtractor, PCA, - ) -from .store import SpikeDetektStore - - -#------------------------------------------------------------------------------ -# Spike detection class -#------------------------------------------------------------------------------ - -def _find_dead_channels(channels_per_group, n_channels): - all_channels = sorted([item for sublist in channels_per_group.values() - for item in sublist]) - dead = np.setdiff1d(np.arange(n_channels), all_channels) - debug("Using dead channels: {}.".format(dead)) - return dead - - -def _keep_spikes(samples, bounds): - """Only keep spikes within the bounds `bounds=(start, end)`.""" - start, end = bounds - return (start <= samples) & (samples <= end) - - -def _split_spikes(groups, idx=None, **arrs): - """Split spike data according to the channel group.""" - # split: {group: {'spike_samples': ..., 'waveforms':, 'masks':}} - dtypes = {'spike_samples': np.float64, - 'waveforms': np.float32, - 'masks': np.float32, - } - groups = _as_array(groups) - if idx is not None: - n_spikes_chunk = np.sum(idx) - # First, remove the overlapping bands. - groups = groups[idx] - arrs_bis = arrs.copy() - for key, arr in arrs.items(): - arrs_bis[key] = arr[idx] - assert len(arrs_bis[key]) == n_spikes_chunk - # Then, split along the group. - groups_u = np.unique(groups) - out = {} - for group in groups_u: - i = (groups == group) - out[group] = {} - for key, arr in arrs_bis.items(): - out[group][key] = _concat(arr[i], dtypes.get(key, None)) - return out - - -def _array_list(arrs): - out = np.empty((len(arrs),), dtype=np.object) - out[:] = arrs - return out - - -def _concat(arr, dtype=None): - out = np.array([_[...] for _ in arr], dtype=dtype) - return out - - -def _cut_traces(traces, interval_samples): - n_samples, n_channels = traces.shape - # Take a subset if necessary. - if interval_samples is not None: - start, end = interval_samples - assert start <= end - traces = traces[start:end, ...] - n_samples = traces.shape[0] - else: - start, end = 0, n_samples - assert 0 <= start < end - if start > 0: - # TODO: add offset to the spike samples... - raise NotImplementedError("Need to add `start` to the " - "spike samples") - return traces, start - - -#------------------------------------------------------------------------------ -# Spike detection class -#------------------------------------------------------------------------------ - -_spikes_message = "{n_spikes:d} spikes in chunk {value:d}/{value_max:d}." - - -class SpikeDetektProgress(ProgressReporter): - _progress_messages = { - 'detect': ("Detecting spikes: {progress:.2f}%. " + _spikes_message, - "Spike detection complete: {n_spikes_total:d} " + - "spikes detected."), - - 'excerpt': ("Extracting waveforms subset for PCs: " + - "{progress:.2f}%. " + _spikes_message, - "Waveform subset extraction complete: " + - "{n_spikes_total} spikes."), - - 'pca': ("Performing PCA: {progress:.2f}%.", - "Principal waveform components computed."), - - 'extract': ("Extracting spikes: {progress:.2f}%. " + _spikes_message, - "Spike extraction complete: {n_spikes_total:d} " + - "spikes extracted."), - - } - - def __init__(self, n_chunks=None): - super(SpikeDetektProgress, self).__init__() - self.n_chunks = n_chunks - - def start_step(self, name, value_max): - self._iter = 0 - self.reset(value_max) - self.set_progress_message(self._progress_messages[name][0], - line_break=True) - self.set_complete_message(self._progress_messages[name][1]) - - -class SpikeDetekt(EventEmitter): - """Spike detection class. - - Parameters - ---------- - - tempdir : str - Path to the temporary directory used by the algorithm. It should be - on a SSD for best performance. - probe : dict - The probe dictionary. - **kwargs : dict - Spike detection parameters. - - """ - def __init__(self, tempdir=None, probe=None, **kwargs): - super(SpikeDetekt, self).__init__() - self._tempdir = tempdir - self._dead_channels = None - # Load a probe. - if probe is not None: - kwargs['probe_channels'] = _channels_per_group(probe) - kwargs['probe_adjacency_list'] = _probe_adjacency_list(probe) - self._kwargs = kwargs - self._n_channels_per_group = { - group: len(channels) - for group, channels in self._kwargs['probe_channels'].items() - } - self._groups = sorted(self._n_channels_per_group) - self._n_features = self._kwargs['n_features_per_channel'] - before = self._kwargs['extract_s_before'] - after = self._kwargs['extract_s_after'] - self._n_samples_waveforms = before + after - - # Processing objects creation - # ------------------------------------------------------------------------- - - def _create_filter(self): - rate = self._kwargs['sample_rate'] - low = self._kwargs['filter_low'] - high = self._kwargs['filter_high_factor'] * rate - order = self._kwargs['filter_butter_order'] - return Filter(rate=rate, - low=low, - high=high, - order=order, - ) - - def _create_thresholder(self, thresholds=None): - mode = self._kwargs['detect_spikes'] - return Thresholder(mode=mode, thresholds=thresholds) - - def _create_detector(self): - graph = self._kwargs['probe_adjacency_list'] - probe_channels = self._kwargs['probe_channels'] - join_size = self._kwargs['connected_component_join_size'] - return FloodFillDetector(probe_adjacency_list=graph, - join_size=join_size, - channels_per_group=probe_channels, - ) - - def _create_extractor(self, thresholds): - before = self._kwargs['extract_s_before'] - after = self._kwargs['extract_s_after'] - weight_power = self._kwargs['weight_power'] - probe_channels = self._kwargs['probe_channels'] - return WaveformExtractor(extract_before=before, - extract_after=after, - weight_power=weight_power, - channels_per_group=probe_channels, - thresholds=thresholds, - ) - - def _create_pca(self): - n_pcs = self._kwargs['n_features_per_channel'] - return PCA(n_pcs=n_pcs) - - # Misc functions - # ------------------------------------------------------------------------- - - def update_params(self, **kwargs): - self._kwargs.update(kwargs) - - # Processing functions - # ------------------------------------------------------------------------- - - def apply_filter(self, data): - """Filter the traces.""" - filter = self._create_filter() - return filter(data).astype(np.float32) - - def find_thresholds(self, traces): - """Find weak and strong thresholds in filtered traces.""" - rate = self._kwargs['sample_rate'] - n_excerpts = self._kwargs['n_excerpts'] - excerpt_size = int(self._kwargs['excerpt_size_seconds'] * rate) - single = bool(self._kwargs['use_single_threshold']) - strong_f = self._kwargs['threshold_strong_std_factor'] - weak_f = self._kwargs['threshold_weak_std_factor'] - - info("Finding the thresholds...") - excerpt = get_excerpts(traces, - n_excerpts=n_excerpts, - excerpt_size=excerpt_size) - excerpt_f = self.apply_filter(excerpt) - thresholds = compute_threshold(excerpt_f, - single_threshold=single, - std_factor=(weak_f, strong_f)) - debug("Thresholds: {}.".format(thresholds)) - return {'weak': thresholds[0], - 'strong': thresholds[1]} - - def detect(self, traces_f, thresholds=None, dead_channels=None): - """Detect connected waveform components in filtered traces. - - Parameters - ---------- - - traces_f : array - An `(n_samples, n_channels)` array with the filtered data. - thresholds : dict - The weak and strong thresholds. - dead_channels : array-like - Array of dead channels. - - Returns - ------- - - components : list - A list of `(n, 2)` arrays with `sample, channel` pairs. - - """ - # Threshold the data following the weak and strong thresholds. - thresholder = self._create_thresholder(thresholds) - # Transform the filtered data according to the detection mode. - traces_t = thresholder.transform(traces_f) - # Compute the threshold crossings. - weak = thresholder.detect(traces_t, 'weak') - strong = thresholder.detect(traces_t, 'strong') - # Force crossings to be False on dead channels. - if dead_channels is not None and len(dead_channels): - assert dead_channels.max() < traces_f.shape[1] - weak[:, dead_channels] = 0 - strong[:, dead_channels] = 0 - else: - debug("No dead channels specified.") - # Run the detection. - detector = self._create_detector() - return detector(weak_crossings=weak, - strong_crossings=strong) - - def extract_spikes(self, components, traces_f, - thresholds=None, keep_bounds=None): - """Extract spikes from connected components. - - Returns a split object. - - Parameters - ---------- - components : list - List of connected components. - traces_f : array - Filtered data. - thresholds : dict - The weak and strong thresholds. - keep_bounds : tuple - (keep_start, keep_end). - - """ - n_spikes = len(components) - if n_spikes == 0: - return {} - - # Transform the filtered data according to the detection mode. - thresholder = self._create_thresholder() - traces_t = thresholder.transform(traces_f) - # Extract all waveforms. - extractor = self._create_extractor(thresholds) - groups, samples, waveforms, masks = zip(*[extractor(component, - data=traces_f, - data_t=traces_t, - ) - for component in components]) - - # Create the return arrays. - groups = np.array(groups, dtype=np.int32) - assert groups.shape == (n_spikes,) - assert groups.dtype == np.int32 - - samples = np.array(samples, dtype=np.float64) - assert samples.shape == (n_spikes,) - assert samples.dtype == np.float64 - - # These are lists of arrays of various shapes (because of various - # groups). - waveforms = _array_list(waveforms) - assert waveforms.shape == (n_spikes,) - assert waveforms.dtype == np.object - - masks = _array_list(masks) - assert masks.dtype == np.object - assert masks.shape == (n_spikes,) - - # Reorder the spikes. - idx = np.argsort(samples) - groups = groups[idx] - samples = samples[idx] - waveforms = waveforms[idx] - masks = masks[idx] - - # Remove spikes in the overlapping bands. - # WARNING: add keep_start to spike_samples, because spike_samples - # is relative to the start of the chunk. - (keep_start, keep_end) = keep_bounds - idx = _keep_spikes(samples + keep_start, (keep_start, keep_end)) - - # Split the data according to the channel groups. - split = _split_spikes(groups, idx=idx, spike_samples=samples, - waveforms=waveforms, masks=masks) - # split: {group: {'spike_samples': ..., 'waveforms':, 'masks':}} - return split - - def waveform_pcs(self, waveforms, masks): - """Compute waveform principal components. - - Returns - ------- - - pcs : array - An `(n_features, n_samples, n_channels)` array. - - """ - pca = self._create_pca() - if waveforms is None or not len(waveforms): - return - assert (waveforms.shape[0], waveforms.shape[2]) == masks.shape - return pca.fit(waveforms, masks) - - def features(self, waveforms, pcs): - """Extract features from waveforms. - - Returns - ------- - - features : array - An `(n_spikes, n_channels, n_features)` array. - - """ - pca = self._create_pca() - out = pca.transform(waveforms, pcs=pcs) - assert out.dtype == np.float32 - return out - - # Chunking - # ------------------------------------------------------------------------- - - def iter_chunks(self, n_samples): - """Iterate over chunks.""" - rate = self._kwargs['sample_rate'] - chunk_size = int(self._kwargs['chunk_size_seconds'] * rate) - overlap = int(self._kwargs['chunk_overlap_seconds'] * rate) - for chunk_idx, bounds in enumerate(chunk_bounds(n_samples, chunk_size, - overlap=overlap)): - yield Bunch(bounds=bounds, - s_start=bounds[0], - s_end=bounds[1], - keep_start=bounds[2], - keep_end=bounds[3], - keep_bounds=(bounds[2:4]), - key=bounds[2], - chunk_idx=chunk_idx, - ) - - def n_chunks(self, n_samples): - """Number of chunks.""" - return len(list(self.iter_chunks(n_samples))) - - def chunk_keys(self, n_samples): - return [chunk.key for chunk in self.iter_chunks(n_samples)] - - # Output data - # ------------------------------------------------------------------------- - - def output_data(self): - """Bunch of values to be returned by the algorithm.""" - sc = self._store.spike_counts - chunk_keys = self._store.chunk_keys - output = Bunch(groups=self._groups, - n_chunks=len(chunk_keys), - chunk_keys=chunk_keys, - spike_samples=self._store.spike_samples(), - masks=self._store.masks(), - features=self._store.features(), - spike_counts=sc, - n_spikes_total=sc(), - n_spikes_per_group={group: sc(group=group) - for group in self._groups}, - n_spikes_per_chunk={chunk_key: sc(chunk_key=chunk_key) - for chunk_key in chunk_keys}, - ) - return output - - # Main loop - # ------------------------------------------------------------------------- - - def _iter_spikes(self, n_samples, step_spikes=1, thresholds=None): - """Iterate over extracted spikes (possibly subset). - - Yield a split dictionary `{group: {'waveforms': ..., ...}}`. - - """ - for chunk in self.iter_chunks(n_samples): - - # Extract a few components. - components = self._store.load(name='components', - chunk_key=chunk.key) - if components is None or not len(components): - yield chunk, {} - continue - - k = np.clip(step_spikes, 1, len(components)) - components = components[::k] - - # Get the filtered chunk. - chunk_f = self._store.load(name='filtered', - chunk_key=chunk.key) - - # Extract the spikes from the chunk. - split = self.extract_spikes(components, chunk_f, - keep_bounds=chunk.keep_bounds, - thresholds=thresholds) - - yield chunk, split - - def step_detect(self, traces=None, thresholds=None): - n_samples, n_channels = traces.shape - n_chunks = self.n_chunks(n_samples) - - # Pass 1: find the connected components and count the spikes. - self._pr.start_step('detect', n_chunks) - - # Dictionary {chunk_key: components}. - # Every chunk has a unique key: the `keep_start` integer. - n_spikes_total = 0 - for chunk in self.iter_chunks(n_samples): - chunk_data = data_chunk(traces, chunk.bounds, with_overlap=True) - - # Apply the filter. - data_f = self.apply_filter(chunk_data) - assert data_f.dtype == np.float32 - assert data_f.shape == chunk_data.shape - - # Save the filtered chunk. - self._store.store(name='filtered', chunk_key=chunk.key, - data=data_f) - - # Detect spikes in the filtered chunk. - components = self.detect(data_f, thresholds=thresholds, - dead_channels=self._dead_channels) - self._store.store(name='components', chunk_key=chunk.key, - data=components) - - # Report progress. - n_spikes_chunk = len(components) - n_spikes_total += n_spikes_chunk - self._pr.increment(n_spikes=n_spikes_chunk, - n_spikes_total=n_spikes_total) - - return n_spikes_total - - def step_excerpt(self, n_samples=None, - n_spikes_total=None, thresholds=None): - self._pr.start_step('excerpt', self.n_chunks(n_samples)) - - k = int(n_spikes_total / float(self._kwargs['pca_n_waveforms_max'])) - w_subset = defaultdict(list) - m_subset = defaultdict(list) - n_spikes_total = 0 - for chunk, split in self._iter_spikes(n_samples, step_spikes=k, - thresholds=thresholds): - n_spikes_chunk = 0 - for group, out in split.items(): - w_subset[group].append(out['waveforms']) - m_subset[group].append(out['masks']) - assert len(out['masks']) == len(out['waveforms']) - n_spikes_chunk += len(out['masks']) - - n_spikes_total += n_spikes_chunk - self._pr.increment(n_spikes=n_spikes_chunk, - n_spikes_total=n_spikes_total) - for group in self._groups: - w_subset[group] = _concatenate(w_subset[group]) - m_subset[group] = _concatenate(m_subset[group]) - - return w_subset, m_subset - - def step_pcs(self, w_subset=None, m_subset=None): - self._pr.start_step('pca', len(self._groups)) - pcs = {} - for group in self._groups: - # Perform PCA and return the components. - pcs[group] = self.waveform_pcs(w_subset[group], - m_subset[group]) - self._pr.increment() - return pcs - - def step_extract(self, n_samples=None, - pcs=None, thresholds=None): - self._pr.start_step('extract', self.n_chunks(n_samples)) - # chunk_counts = defaultdict(dict) # {group: {key: n_spikes}}. - n_spikes_total = 0 - for chunk, split in self._iter_spikes(n_samples, - thresholds=thresholds): - # Delete filtered and components cache files. - self._store.delete(name='filtered', chunk_key=chunk.key) - self._store.delete(name='components', chunk_key=chunk.key) - # split: {group: {'spike_samples': ..., 'waveforms':, 'masks':}} - for group, out in split.items(): - out['features'] = self.features(out['waveforms'], pcs[group]) - self._store.append(group=group, - chunk_key=chunk.key, - spike_samples=out['spike_samples'], - features=out['features'], - masks=out['masks'], - spike_offset=chunk.s_start, - ) - n_spikes_total = self._store.spike_counts() - n_spikes_chunk = self._store.spike_counts(chunk_key=chunk.key) - self._pr.increment(n_spikes_total=n_spikes_total, - n_spikes=n_spikes_chunk) - - def run_serial(self, traces, interval_samples=None): - """Run SpikeDetekt using one CPU.""" - traces, offset = _cut_traces(traces, interval_samples) - n_samples, n_channels = traces.shape - - # Initialize the main loop. - chunk_keys = self.chunk_keys(n_samples) - n_chunks = len(chunk_keys) - self._pr = SpikeDetektProgress(n_chunks=n_chunks) - self._store = SpikeDetektStore(self._tempdir, - groups=self._groups, - chunk_keys=chunk_keys) - - # Find the weak and strong thresholds. - thresholds = self.find_thresholds(traces) - - # Find dead channels. - probe_channels = self._kwargs['probe_channels'] - self._dead_channels = _find_dead_channels(probe_channels, n_channels) - - # Spike detection. - n_spikes_total = self.step_detect(traces=traces, - thresholds=thresholds) - - # Excerpt waveforms. - w_subset, m_subset = self.step_excerpt(n_samples=n_samples, - n_spikes_total=n_spikes_total, - thresholds=thresholds) - - # Compute the PCs. - pcs = self.step_pcs(w_subset=w_subset, m_subset=m_subset) - - # Compute all features. - self.step_extract(n_samples=n_samples, pcs=pcs, thresholds=thresholds) - - return self.output_data() diff --git a/phy/detect/store.py b/phy/detect/store.py deleted file mode 100644 index 179f511af..000000000 --- a/phy/detect/store.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Spike detection store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op -from collections import defaultdict - -import numpy as np -from six import string_types - -from ..utils.array import (_save_arrays, - _load_arrays, - _concatenate, - ) -from ..utils.logging import debug -from ..utils.settings import _ensure_dir_exists - - -#------------------------------------------------------------------------------ -# Spike counts -#------------------------------------------------------------------------------ - -class SpikeCounts(object): - """Count spikes in chunks and channel groups.""" - def __init__(self, counts=None, groups=None, chunk_keys=None): - self._groups = groups - self._chunk_keys = chunk_keys - self._counts = counts or defaultdict(lambda: defaultdict(int)) - - def append(self, group=None, chunk_key=None, count=None): - self._counts[group][chunk_key] += count - - @property - def counts(self): - return self._counts - - def per_group(self, group): - return sum(self._counts.get(group, {}).values()) - - def per_chunk(self, chunk_key): - return sum(self._counts[group].get(chunk_key, 0) - for group in self._groups) - - def __call__(self, group=None, chunk_key=None): - if group is not None and chunk_key is not None: - return self._counts.get(group, {}).get(chunk_key, 0) - elif group is not None: - return self.per_group(group) - elif chunk_key is not None: - return self.per_chunk(chunk_key) - elif group is None and chunk_key is None: - return sum(self.per_group(group) for group in self._groups) - - -#------------------------------------------------------------------------------ -# Spike detection store -#------------------------------------------------------------------------------ - -class ArrayStore(object): - def __init__(self, root_dir): - self._root_dir = op.realpath(root_dir) - _ensure_dir_exists(self._root_dir) - - def _rel_path(self, **kwargs): - """Relative to the root.""" - raise NotImplementedError() - - def _path(self, **kwargs): - """Absolute path of a data file.""" - path = op.realpath(op.join(self._root_dir, self._rel_path(**kwargs))) - _ensure_dir_exists(op.dirname(path)) - assert path.endswith('.npy') - return path - - def _offsets_path(self, path): - assert path.endswith('.npy') - return op.splitext(path)[0] + '.offsets.npy' - - def _contains_multiple_arrays(self, path): - return op.exists(path) and op.exists(self._offsets_path(path)) - - def store(self, data=None, **kwargs): - """Store an array or list of arrays.""" - path = self._path(**kwargs) - if isinstance(data, list): - if not data: - return - _save_arrays(path, data) - elif isinstance(data, np.ndarray): - dtype = data.dtype - if not data.size: - return - assert dtype != np.object - np.save(path, data) - # debug("Store {}.".format(path)) - - def load(self, **kwargs): - path = self._path(**kwargs) - if not op.exists(path): - debug("File `{}` doesn't exist.".format(path)) - return - # Multiple arrays: - # debug("Load {}.".format(path)) - if self._contains_multiple_arrays(path): - return _load_arrays(path) - else: - return np.load(path) - - def delete(self, **kwargs): - path = self._path(**kwargs) - if op.exists(path): - os.remove(path) - # debug("Deleted `{}`.".format(path)) - offsets_path = self._offsets_path(path) - if op.exists(offsets_path): - os.remove(offsets_path) - # debug("Deleted `{}`.".format(offsets_path)) - - -class SpikeDetektStore(ArrayStore): - """Store the following items: - - * filtered - * components - * spike_samples - * features - * masks - - """ - def __init__(self, root_dir, groups=None, chunk_keys=None): - super(SpikeDetektStore, self).__init__(root_dir) - self._groups = groups - self._chunk_keys = chunk_keys - self._spike_counts = SpikeCounts(groups=groups, chunk_keys=chunk_keys) - - def _rel_path(self, name=None, chunk_key=None, group=None): - assert chunk_key >= 0 - assert group is None or group >= 0 - assert isinstance(name, string_types) - group = group if group is not None else 'all' - return 'group_{group}/{name}/chunk_{chunk:d}.npy'.format( - chunk=chunk_key, name=name, group=group) - - @property - def groups(self): - return self._groups - - @property - def chunk_keys(self): - return self._chunk_keys - - def _iter(self, group=None, name=None): - for chunk_key in self.chunk_keys: - yield self.load(group=group, chunk_key=chunk_key, name=name) - - def spike_samples(self, group=None): - if group is None: - return {group: self.spike_samples(group) for group in self._groups} - return self.concatenate(self._iter(group=group, name='spike_samples')) - - def features(self, group=None): - """Yield chunk features.""" - if group is None: - return {group: self.features(group) for group in self._groups} - return self._iter(group=group, name='features') - - def masks(self, group=None): - """Yield chunk masks.""" - if group is None: - return {group: self.masks(group) for group in self._groups} - return self._iter(group=group, name='masks') - - @property - def spike_counts(self): - return self._spike_counts - - def append(self, group=None, chunk_key=None, - spike_samples=None, features=None, masks=None, - spike_offset=0): - if spike_samples is None or len(spike_samples) == 0: - return - n = len(spike_samples) - assert features.shape[0] == n - assert masks.shape[0] == n - spike_samples = spike_samples + spike_offset - - self.store(group=group, chunk_key=chunk_key, - name='features', data=features) - self.store(group=group, chunk_key=chunk_key, - name='masks', data=masks) - self.store(group=group, chunk_key=chunk_key, - name='spike_samples', data=spike_samples) - self._spike_counts.append(group=group, chunk_key=chunk_key, count=n) - - def concatenate(self, arrays): - return _concatenate(arrays) - - def delete_all(self, name): - """Delete all files for a given data name.""" - for group in self._groups: - for chunk_key in self._chunk_keys: - super(SpikeDetektStore, self).delete(name=name, group=group, - chunk_key=chunk_key) diff --git a/phy/detect/tests/__init__.py b/phy/detect/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/detect/tests/test_spikedetekt.py b/phy/detect/tests/test_spikedetekt.py deleted file mode 100644 index 1f49e267a..000000000 --- a/phy/detect/tests/test_spikedetekt.py +++ /dev/null @@ -1,157 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of clustering algorithms.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_equal as ae -from pytest import mark - -from ...utils.logging import set_level -from ...utils.testing import show_test -from ..spikedetekt import (SpikeDetekt, _split_spikes, _concat, _concatenate) - - -#------------------------------------------------------------------------------ -# Tests spike detection -#------------------------------------------------------------------------------ - -def setup(): - set_level('info') - - -def test_split_spikes(): - groups = np.zeros(10, dtype=np.int) - groups[1::2] = 1 - - idx = np.ones(10, dtype=np.bool) - idx[0] = False - idx[-1] = False - - a = np.random.rand(10, 2) - b = np.random.rand(10, 3, 2) - - out = _split_spikes(groups, idx, a=a, b=b) - - assert sorted(out) == [0, 1] - assert sorted(out[0]) == ['a', 'b'] - assert sorted(out[1]) == ['a', 'b'] - - ae(out[0]['a'], a[1:-1][1::2]) - ae(out[0]['b'], b[1:-1][1::2]) - - ae(out[1]['a'], a[1:-1][::2]) - ae(out[1]['b'], b[1:-1][::2]) - - -def test_spike_detect_methods(tempdir, raw_dataset): - params = raw_dataset.params - probe = raw_dataset.probe - sample_rate = raw_dataset.sample_rate - sd = SpikeDetekt(tempdir=tempdir, - probe=raw_dataset.probe, - sample_rate=sample_rate, - **params) - traces = raw_dataset.traces - n_samples = raw_dataset.n_samples - n_channels = raw_dataset.n_channels - - # Filter the data. - traces_f = sd.apply_filter(traces) - assert traces_f.shape == traces.shape - assert not np.any(np.isnan(traces_f)) - - # Thresholds. - thresholds = sd.find_thresholds(traces) - assert np.all(0 <= thresholds['weak']) - assert np.all(thresholds['weak'] <= thresholds['strong']) - - # Spike detection. - traces_f[1000:1010, :3] *= 5 - traces_f[2000:2010, [0, 2]] *= 5 - traces_f[3000:3020, :] *= 5 - components = sd.detect(traces_f, thresholds) - assert isinstance(components, list) - # n_spikes = len(components) - n_samples_waveforms = (params['extract_s_before'] + - params['extract_s_after']) - - # Spike extraction. - split = sd.extract_spikes(components, traces_f, thresholds, - keep_bounds=(0, n_samples)) - - if not split: - return - samples = _concat(split[0]['spike_samples'], np.float64) - waveforms = _concat(split[0]['waveforms'], np.float32) - masks = _concat(split[0]['masks'], np.float32) - - n_spikes = len(samples) - n_channels = len(probe['channel_groups'][0]['channels']) - - assert samples.dtype == np.float64 - assert samples.shape == (n_spikes,) - assert waveforms.shape == (n_spikes, n_samples_waveforms, n_channels) - assert masks.shape == (n_spikes, n_channels) - assert 0. <= masks.min() < masks.max() <= 1. - assert not np.any(np.isnan(samples)) - assert not np.any(np.isnan(waveforms)) - assert not np.any(np.isnan(masks)) - - # PCA. - pcs = sd.waveform_pcs(waveforms, masks) - n_pcs = params['n_features_per_channel'] - assert pcs.shape == (n_pcs, n_samples_waveforms, n_channels) - assert not np.any(np.isnan(pcs)) - - # Features. - features = sd.features(waveforms, pcs) - assert features.shape == (n_spikes, n_channels, n_pcs) - assert not np.any(np.isnan(features)) - - -@mark.long -def test_spike_detect_real_data(tempdir, raw_dataset): - - params = raw_dataset.params - probe = raw_dataset.probe - sample_rate = raw_dataset.sample_rate - sd = SpikeDetekt(tempdir=tempdir, - probe=probe, - sample_rate=sample_rate, - **params) - traces = raw_dataset.traces - n_samples = raw_dataset.n_samples - npc = params['n_features_per_channel'] - n_samples_w = params['extract_s_before'] + params['extract_s_after'] - - # Run the detection. - out = sd.run_serial(traces, interval_samples=(0, n_samples)) - - channels = probe['channel_groups'][0]['channels'] - n_channels = len(channels) - - spike_samples = _concatenate(out.spike_samples[0]) - masks = _concatenate(out.masks[0]) - features = _concatenate(out.features[0]) - n_spikes = out.n_spikes_per_group[0] - - if n_spikes: - assert spike_samples.shape == (n_spikes,) - assert masks.shape == (n_spikes, n_channels) - assert features.shape == (n_spikes, n_channels, npc) - - # There should not be any spike with only masked channels. - assert np.all(masks.max(axis=1) > 0) - - # Plot... - from phy.plot.traces import plot_traces - c = plot_traces(traces[:30000, channels], - spike_samples=spike_samples, - masks=masks, - n_samples_per_spike=n_samples_w, - show=False) - show_test(c) diff --git a/phy/detect/tests/test_store.py b/phy/detect/tests/test_store.py deleted file mode 100644 index 5ada61b66..000000000 --- a/phy/detect/tests/test_store.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of spike detection store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import yield_fixture -import numpy as np -from numpy.testing import assert_equal as ae - -from ...utils.logging import set_level -from ..store import SpikeCounts, ArrayStore, SpikeDetektStore - - -#------------------------------------------------------------------------------ -# Tests spike detection store -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -@yield_fixture(params=['from_dict', 'append']) -def spike_counts(request): - groups = [0, 2] - chunk_keys = [10, 20, 30] - if request.param == 'from_dict': - c = {0: {10: 100, 20: 200}, - 2: {10: 1, 30: 300}, - } - sc = SpikeCounts(c, groups=groups, chunk_keys=chunk_keys) - elif request.param == 'append': - sc = SpikeCounts(groups=groups, chunk_keys=chunk_keys) - sc.append(group=0, chunk_key=10, count=100) - sc.append(group=0, chunk_key=20, count=200) - sc.append(group=2, chunk_key=10, count=1) - sc.append(group=2, chunk_key=30, count=300) - yield sc - - -def test_spike_counts(spike_counts): - assert spike_counts() == 601 - - assert spike_counts(group=0) == 300 - assert spike_counts(group=1) == 0 - assert spike_counts(group=2) == 301 - - assert spike_counts(chunk_key=10) == 101 - assert spike_counts(chunk_key=20) == 200 - assert spike_counts(chunk_key=30) == 300 - - -class TestArrayStore(ArrayStore): - def _rel_path(self, a=None, b=None): - return '{}/{}.npy'.format(a or 'none_a', b or 'none_b') - - -def test_array_store(tempdir): - store = TestArrayStore(tempdir) - - store.store(a=1, b=1, data=np.arange(11)) - store.store(a=1, b=2, data=np.arange(12)) - store.store(b=2, data=np.arange(2)) - store.store(b=3, data=None) - store.store(a=1, data=[np.arange(3), np.arange(3, 8)]) - - ae(store.load(a=0, b=1), None) - ae(store.load(a=1, b=1), np.arange(11)) - ae(store.load(a=1, b=2), np.arange(12)) - ae(store.load(b=2), np.arange(2)) - ae(store.load(b=3), None) - ae(store.load(a=1)[0], np.arange(3)) - ae(store.load(a=1)[1], np.arange(3, 8)) - - store.delete(b=2) - store.delete(a=1, b=2) - - -def test_spikedetekt_store(tempdir): - groups = [0, 2] - chunk_keys = [10, 20, 30] - - n_channels = 4 - npc = 2 - - _keys = [(0, 10), (0, 20), (2, 10), (2, 30)] - _counts = [100, 200, 1, 300] - - # Generate random spike samples, features, and masks. - s = {k: np.arange(c) for k, c in zip(_keys, _counts)} - f = {k: np.random.rand(c, n_channels, npc) - for k, c in zip(_keys, _counts)} - m = {k: np.random.rand(c, n_channels) - for k, c in zip(_keys, _counts)} - - store = SpikeDetektStore(tempdir, groups=groups, chunk_keys=chunk_keys) - - # Save data. - for group, chunk_key in _keys: - spike_samples = s.get((group, chunk_key), None) - features = f.get((group, chunk_key), None) - masks = m.get((group, chunk_key), None) - - store.append(group=group, chunk_key=chunk_key, - spike_samples=spike_samples, - features=features, - masks=masks, - spike_offset=chunk_key, - ) - - # Load data. - for group in groups: - # Check spike samples. - ae(store.spike_samples(group), - np.hstack([s[key] + key[1] for key in _keys if key[0] == group])) - - # Check features and masks. - for name in ('features', 'masks'): - # Actual data. - data_dict = f if name == 'features' else m - # Stored data (generator). - data_gen = getattr(store, name)(group) - # Go through all chunks. - for chunk_key, data in zip(chunk_keys, data_gen): - if (group, chunk_key) in _keys: - ae(data_dict[group, chunk_key], data) - else: - assert data is None - - # Test spike counts. - test_spike_counts(store.spike_counts) diff --git a/phy/electrode/mea.py b/phy/electrode/mea.py index 714cf7322..18909b18b 100644 --- a/phy/electrode/mea.py +++ b/phy/electrode/mea.py @@ -25,7 +25,7 @@ def _edges_to_adjacency_list(edges): """Convert a list of edges into an adjacency list.""" adj = {} for i, j in edges: - if i in adj: + if i in adj: # pragma: no cover ni = adj[i] else: ni = adj[i] = set() @@ -38,6 +38,18 @@ def _edges_to_adjacency_list(edges): return adj +def _adjacency_subset(adjacency, subset): + return {c: [v for v in vals if v in subset] + for (c, vals) in adjacency.items() if c in subset} + + +def _remap_adjacency(adjacency, mapping): + remapped = {} + for key, vals in adjacency.items(): + remapped[mapping[key]] = [mapping[i] for i in vals] + return remapped + + def _probe_positions(probe, group): """Return the positions of a probe channel group.""" positions = probe['channel_groups'][group]['geometry'] @@ -54,13 +66,6 @@ def _probe_channels(probe, group): return probe['channel_groups'][group]['channels'] -def _probe_all_channels(probe): - """Return the list of channels in the probe.""" - cgs = probe['channel_groups'].values() - cg_channels = [cg['channels'] for cg in cgs] - return sorted(set(itertools.chain(*cg_channels))) - - def _probe_adjacency_list(probe): """Return an adjacency list of a whole probe.""" cgs = probe['channel_groups'].values() @@ -87,7 +92,7 @@ def load_probe(name): path = op.join(curdir, 'probes/{}.prb'.format(name)) if not op.exists(path): raise IOError("The probe `{}` cannot be found.".format(name)) - return _read_python(path) + return MEA(probe=_read_python(path)) def list_probes(): @@ -120,21 +125,22 @@ def __init__(self, ): self._probe = probe self._channels = channels - if positions is not None: - assert self.n_channels == positions.shape[0] + self._check_positions(positions) self._positions = positions # This is a mapping {channel: list of neighbors}. if adjacency is None and probe is not None: adjacency = _probe_adjacency_list(probe) self.channels_per_group = _channels_per_group(probe) self._adjacency = adjacency + if probe: + # Select the first channel group. + cg = sorted(self._probe['channel_groups'].keys())[0] + self.change_channel_group(cg) def _check_positions(self, positions): if positions is None: return positions = _as_array(positions) - if self.n_channels is None: - self.n_channels = positions.shape[0] if positions.shape[0] != self.n_channels: raise ValueError("'positions' " "(shape {0:s})".format(str(positions.shape)) + @@ -147,11 +153,6 @@ def positions(self): """Channel positions in the current channel group.""" return self._positions - @positions.setter - def positions(self, value): - self._check_positions(value) - self._positions = value - @property def channels(self): """Channel ids in the current channel group.""" @@ -167,10 +168,6 @@ def adjacency(self): """Adjacency graph in the current channel group.""" return self._adjacency - @adjacency.setter - def adjacency(self, value): - self._adjacency = value - def change_channel_group(self, group): """Change the current channel group.""" assert self._probe is not None diff --git a/phy/electrode/tests/test_mea.py b/phy/electrode/tests/test_mea.py index b974d6e3c..4bd580a0e 100644 --- a/phy/electrode/tests/test_mea.py +++ b/phy/electrode/tests/test_mea.py @@ -6,11 +6,14 @@ # Imports #------------------------------------------------------------------------------ +import os.path as op + from pytest import raises import numpy as np from numpy.testing import assert_array_equal as ae -from ..mea import (_probe_channels, _probe_positions, _probe_adjacency_list, +from ..mea import (_probe_channels, _remap_adjacency, _adjacency_subset, + _probe_positions, _probe_adjacency_list, MEA, linear_positions, staggered_positions, load_probe, list_probes ) @@ -20,6 +23,24 @@ # Tests #------------------------------------------------------------------------------ +def test_remap(): + adjacency = {1: [2, 3, 7], 3: [5, 11]} + mapping = {1: 3, 2: 20, 3: 30, 5: 50, 7: 70, 11: 1} + remapped = _remap_adjacency(adjacency, mapping) + assert sorted(remapped.keys()) == [3, 30] + assert remapped[3] == [20, 30, 70] + assert remapped[30] == [50, 1] + + +def test_adjacency_subset(): + adjacency = {1: [2, 3, 7], 3: [5, 11], 5: [1, 2, 11]} + subset = [1, 5, 32] + adjsub = _adjacency_subset(adjacency, subset) + assert sorted(adjsub.keys()) == [1, 5] + assert adjsub[1] == [] + assert adjsub[5] == [1] + + def test_probe(): probe = {'channel_groups': { 0: {'channels': [0, 3, 1], @@ -39,16 +60,12 @@ def test_probe(): assert _probe_adjacency_list(probe) == adjacency mea = MEA(probe=probe) - assert mea.positions is None - assert mea.channels is None - assert mea.n_channels == 0 + assert mea.adjacency == adjacency assert mea.channels_per_group == {0: [0, 3, 1], 1: [7]} - - mea.change_channel_group(0) - ae(mea.positions, [(10, 10), (20, 30), (10, 20)]) assert mea.channels == [0, 3, 1] assert mea.n_channels == 3 + ae(mea.positions, [(10, 10), (20, 30), (10, 20)]) def test_mea(): @@ -57,8 +74,7 @@ def test_mea(): channels = np.arange(n_channels) positions = np.random.randn(n_channels, 2) - mea = MEA(channels) - mea.positions = positions + mea = MEA(channels, positions=positions) ae(mea.positions, positions) assert mea.adjacency is None @@ -68,17 +84,11 @@ def test_mea(): mea = MEA(channels, positions=positions) assert mea.n_channels == n_channels - with raises(AssertionError): + with raises(ValueError): MEA(channels=np.arange(n_channels + 1), positions=positions) - with raises(AssertionError): - MEA(channels=channels, positions=positions[:-1, :]) - - mea = MEA(channels=channels) - assert mea.n_channels == n_channels - mea.positions = positions with raises(ValueError): - mea.positions = positions[:-1, :] + MEA(channels=channels, positions=positions[:-1, :]) def test_positions(): @@ -90,9 +100,18 @@ def test_positions(): assert probe.shape == (29, 2) -def test_library(): +def test_library(tempdir): + assert '1x32_buzsaki' in list_probes() + probe = load_probe('1x32_buzsaki') assert probe - assert probe['channel_groups'][0]['channels'] == list(range(32)) + assert probe.channels == list(range(32)) - assert '1x32_buzsaki' in list_probes() + path = op.join(tempdir, 'test.prb') + with raises(IOError): + load_probe(path) + + with open(path, 'w') as f: + f.write('') + with raises(KeyError): + load_probe(path) diff --git a/phy/gui/__init__.py b/phy/gui/__init__.py index 0ed1b6ebb..bb3202e71 100644 --- a/phy/gui/__init__.py +++ b/phy/gui/__init__.py @@ -3,11 +3,7 @@ """GUI routines.""" -from .qt import start_qt_app, run_qt_app, qt_app, enable_qt -from .dock import DockWindow - -from .base import (BaseViewModel, - HTMLViewModel, - WidgetCreator, - BaseGUI, - ) +from .qt import require_qt, create_app, run_app +from .gui import GUI, GUIState +from .actions import Actions +from .widgets import HTMLWidget, Table diff --git a/phy/gui/_utils.py b/phy/gui/_utils.py deleted file mode 100644 index ca7ae0c64..000000000 --- a/phy/gui/_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- - -"""HTML/CSS utilities.""" - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -import os.path as op - - -# ----------------------------------------------------------------------------- -# Utilities -# ----------------------------------------------------------------------------- - -def _read(fn, static_path=None): - """Read a file in a static directory. - - By default, this is `./static/`.""" - if static_path is None: - static_path = op.join(op.dirname(op.realpath(__file__)), 'static') - with open(op.join(static_path, fn), 'r') as f: - return f.read() diff --git a/phy/gui/actions.py b/phy/gui/actions.py new file mode 100644 index 000000000..55b56898e --- /dev/null +++ b/phy/gui/actions.py @@ -0,0 +1,447 @@ +# -*- coding: utf-8 -*- + +"""Actions and snippets.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +from functools import partial +import logging +import re +import sys +import traceback + +from six import string_types, PY3 + +from .qt import QKeySequence, QAction, require_qt +from phy.utils import Bunch + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Snippet parsing utilities +# ----------------------------------------------------------------------------- + +def _parse_arg(s): + """Parse a number or string.""" + try: + return int(s) + except ValueError: + pass + try: + return float(s) + except ValueError: + pass + return s + + +def _parse_list(s): + """Parse a comma-separated list of values (strings or numbers).""" + # Range: 'x-y' + if '-' in s: + m, M = map(_parse_arg, s.split('-')) + return list(range(m, M + 1)) + # List of ids: 'x,y,z' + elif ',' in s: + return list(map(_parse_arg, s.split(','))) + else: + return _parse_arg(s) + + +def _parse_snippet(s): + """Parse an entire snippet command.""" + return list(map(_parse_list, s.split(' '))) + + +# ----------------------------------------------------------------------------- +# Show shortcut utility functions +# ----------------------------------------------------------------------------- + +def _get_shortcut_string(shortcut): + """Return a string representation of a shortcut.""" + if shortcut is None: + return '' + if isinstance(shortcut, (tuple, list)): + return ', '.join([_get_shortcut_string(s) for s in shortcut]) + if isinstance(shortcut, string_types): + if hasattr(QKeySequence, shortcut): + shortcut = QKeySequence(getattr(QKeySequence, shortcut)) + else: + return shortcut.lower() + assert isinstance(shortcut, QKeySequence) + s = shortcut.toString() or '' + return str(s).lower() + + +def _get_qkeysequence(shortcut): + """Return a QKeySequence or list of QKeySequence from a shortcut string.""" + if shortcut is None: + return [] + if isinstance(shortcut, (tuple, list)): + return [_get_qkeysequence(s) for s in shortcut] + assert isinstance(shortcut, string_types) + if hasattr(QKeySequence, shortcut): + return QKeySequence(getattr(QKeySequence, shortcut)) + sequence = QKeySequence.fromString(shortcut) + assert not sequence.isEmpty() + return sequence + + +def _show_shortcuts(shortcuts, name=None): + """Display shortcuts.""" + name = name or '' + print('') + if name: + name = ' for ' + name + print('Keyboard shortcuts' + name) + for name in sorted(shortcuts): + shortcut = _get_shortcut_string(shortcuts[name]) + if not name.startswith('_'): + print('- {0:<40}: {1:s}'.format(name, shortcut)) + + +# ----------------------------------------------------------------------------- +# Actions +# ----------------------------------------------------------------------------- + +def _alias(name): + # Get the alias from the character after & if it exists. + alias = name[name.index('&') + 1] if '&' in name else name + return alias + + +@require_qt +def _create_qaction(gui, name, callback, shortcut, docstring=None, alias=''): + # Create the QAction instance. + action = QAction(name.capitalize().replace('_', ' '), gui) + + def wrapped(checked, *args, **kwargs): # pragma: no cover + return callback(*args, **kwargs) + + action.triggered.connect(wrapped) + sequence = _get_qkeysequence(shortcut) + if not isinstance(sequence, (tuple, list)): + sequence = [sequence] + action.setShortcuts(sequence) + assert docstring + docstring += ' (alias: {})'.format(alias) + action.setStatusTip(docstring) + action.setWhatsThis(docstring) + return action + + +class Actions(object): + """Handle GUI actions. + + This class attaches to a GUI and implements the following features: + + * Add and remove actions + * Keyboard shortcuts for the actions + * Display all shortcuts + + """ + def __init__(self, gui, name=None, menu=None, default_shortcuts=None): + self._actions_dict = {} + self._aliases = {} + self._default_shortcuts = default_shortcuts or {} + self.name = name + self.menu = menu + self.gui = gui + gui.actions.append(self) + + def add(self, callback=None, name=None, shortcut=None, alias=None, + docstring=None, menu=None, verbose=True): + """Add an action with a keyboard shortcut.""" + # TODO: add menu_name option and create menu bar + if callback is None: + # Allow to use either add(func) or @add or @add(...). + return partial(self.add, name=name, shortcut=shortcut, + alias=alias, menu=menu) + assert callback + + # Get the name from the callback function if needed. + name = name or callback.__name__ + alias = alias or _alias(name) + name = name.replace('&', '') + shortcut = shortcut or self._default_shortcuts.get(name, None) + + # Skip existing action. + if name in self._actions_dict: + return + + # Set the status tip from the function's docstring. + docstring = docstring or callback.__doc__ or name + docstring = re.sub(r'[\s]{2,}', ' ', docstring) + + # Create and register the action. + action = _create_qaction(self.gui, name, callback, + shortcut, + docstring=docstring, + alias=alias, + ) + action_obj = Bunch(qaction=action, name=name, alias=alias, + shortcut=shortcut, callback=callback, menu=menu) + if verbose and not name.startswith('_'): + logger.log(5, "Add action `%s` (%s).", name, + _get_shortcut_string(action.shortcut())) + self.gui.addAction(action) + # Add the action to the menu. + menu = menu or self.menu + # Do not show private actions in the menu. + if menu and not name.startswith('_'): + self.gui.get_menu(menu).addAction(action) + self._actions_dict[name] = action_obj + # Register the alias -> name mapping. + self._aliases[alias] = name + + # Set the callback method. + if callback: + setattr(self, name, callback) + + def separator(self, menu=None): + """Add a separator""" + self.gui.get_menu(menu or self.menu).addSeparator() + + def disable(self, name=None): + """Disable one or all actions.""" + if name is None: + for name in self._actions_dict: + self.disable(name) + return + self._actions_dict[name].qaction.setEnabled(False) + + def enable(self, name=None): + """Enable one or all actions.""" + if name is None: + for name in self._actions_dict: + self.enable(name) + return + self._actions_dict[name].qaction.setEnabled(True) + + def get(self, name): + """Get a QAction instance from its name.""" + return self._actions_dict[name].qaction + + def run(self, name, *args): + """Run an action as specified by its name.""" + assert isinstance(name, string_types) + # Resolve the alias if it is an alias. + name = self._aliases.get(name, name) + # Get the action. + action = self._actions_dict.get(name, None) + if not action: + raise ValueError("Action `{}` doesn't exist.".format(name)) + if not name.startswith('_'): + logger.debug("Execute action `%s`.", name) + return action.callback(*args) + + def remove(self, name): + """Remove an action.""" + self.gui.removeAction(self._actions_dict[name].qaction) + del self._actions_dict[name] + delattr(self, name) + + def remove_all(self): + """Remove all actions.""" + names = sorted(self._actions_dict.keys()) + for name in names: + self.remove(name) + + @property + def shortcuts(self): + """A dictionary of action shortcuts.""" + return {name: action.shortcut + for name, action in self._actions_dict.items()} + + def show_shortcuts(self): + """Print all shortcuts.""" + gui_name = self.gui.name + actions_name = self.name + name = ('{} - {}'.format(gui_name, actions_name) + if actions_name else gui_name) + _show_shortcuts(self.shortcuts, name) + + def __contains__(self, name): + return name in self._actions_dict + + def __repr__(self): + return ''.format(sorted(self._actions_dict)) + + +# ----------------------------------------------------------------------------- +# Snippets +# ----------------------------------------------------------------------------- + +class Snippets(object): + """Provide keyboard snippets to quickly execute actions from a GUI. + + This class attaches to a GUI and an `Actions` instance. To every command + is associated a snippet with the same name, or with an alias as indicated + in the action. The arguments of the action's callback functions can be + provided in the snippet's command with a simple syntax. For example, the + following command: + + ``` + :my_action string 3-6 + ``` + + corresponds to: + + ```python + my_action('string', (3, 4, 5, 6)) + ``` + + The snippet mode is activated with the `:` keyboard shortcut. A snippet + command is activated with `Enter`, and one can leave the snippet mode + with `Escape`. + + """ + + # HACK: Unicode characters do not seem to work on Python 2 + cursor = '\u200A\u258C' if PY3 else '' + + # Allowed characters in snippet mode. + # A Qt shortcut will be created for every character. + _snippet_chars = ("abcdefghijklmnopqrstuvwxyz0123456789" + " ,.;?!_-+~=*/\(){}[]") + + def __init__(self, gui): + self.gui = gui + self._status_message = gui.status_message + + self.actions = Actions(gui, name='Snippets', menu='Snippets') + + # Register snippet mode shortcut. + @self.actions.add(shortcut=':') + def enable_snippet_mode(): + """Enable the snippet mode (type action alias in the status + bar).""" + self.mode_on() + + self._create_snippet_actions() + self.mode_off() + + @property + def command(self): + """This is used to write a snippet message in the status bar. + + A cursor is appended at the end. + + """ + msg = self.gui.status_message + n = len(msg) + n_cur = len(self.cursor) + return msg[:n - n_cur] + + @command.setter + def command(self, value): + value += self.cursor + self.gui.unlock_status() + self.gui.status_message = value + self.gui.lock_status() + + def _backspace(self): + """Erase the last character in the snippet command.""" + if self.command == ':': + return + logger.log(5, "Snippet keystroke `Backspace`.") + self.command = self.command[:-1] + + def _enter(self): + """Disable the snippet mode and execute the command.""" + command = self.command + logger.log(5, "Snippet keystroke `Enter`.") + # NOTE: we need to set back the actions (mode_off) before running + # the command. + self.mode_off() + self.run(command) + + def _create_snippet_actions(self): + """Add mock Qt actions for snippet keystrokes. + + Used to enable snippet mode. + + """ + # One action per allowed character. + for i, char in enumerate(self._snippet_chars): + + def _make_func(char): + def callback(): + logger.log(5, "Snippet keystroke `%s`.", char) + self.command += char + return callback + + self.actions.add(name='_snippet_{}'.format(i), + shortcut=char, + callback=_make_func(char)) + + self.actions.add(name='_snippet_backspace', + shortcut='backspace', + callback=self._backspace) + self.actions.add(name='_snippet_activate', + shortcut=('enter', 'return'), + callback=self._enter) + self.actions.add(name='_snippet_disable', + shortcut='escape', + callback=self.mode_off) + + def run(self, snippet): + """Executes a snippet command. + + May be overridden. + + """ + assert snippet[0] == ':' + snippet = snippet[1:] + snippet_args = _parse_snippet(snippet) + name = snippet_args[0] + + logger.info("Processing snippet `%s`.", snippet) + try: + # Try to run the snippet on all attached Actions instances. + for actions in self.gui.actions: + try: + actions.run(name, *snippet_args[1:]) + return + except ValueError: + # This Actions instance doesn't contain the requested + # snippet, trying the next attached Actions instance. + pass + logger.warn("Couldn't find action `%s`.", name) + except Exception as e: + logger.warn("Error when executing snippet: \"%s\".", str(e)) + logger.debug(''.join(traceback.format_exception(*sys.exc_info()))) + + def is_mode_on(self): + return self.command.startswith(':') + + def mode_on(self): + logger.info("Snippet mode enabled, press `escape` to leave this mode.") + # Save the current status message. + self._status_message = self.gui.status_message + self.gui.lock_status() + + # Silent all actions except the Snippets actions. + for actions in self.gui.actions: + if actions != self.actions: + actions.disable() + self.actions.enable() + + self.command = ':' + + def mode_off(self): + self.gui.unlock_status() + # Reset the GUI status message that was set before the mode was + # activated. + self.gui.status_message = self._status_message + + # Re-enable all actions except the Snippets actions. + self.actions.disable() + for actions in self.gui.actions: + if actions != self.actions: + actions.enable() + # The `:` shortcut should always be enabled. + self.actions.enable('enable_snippet_mode') diff --git a/phy/gui/base.py b/phy/gui/base.py deleted file mode 100644 index 4df3d8cbe..000000000 --- a/phy/gui/base.py +++ /dev/null @@ -1,672 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Base classes for GUIs.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from collections import Counter -import inspect - -from six import string_types, PY3 - -from ..utils._misc import _show_shortcuts -from ..utils import debug, info, warn, EventEmitter -from ._utils import _read -from .dock import DockWindow - - -#------------------------------------------------------------------------------ -# BaseViewModel -#------------------------------------------------------------------------------ - -class BaseViewModel(object): - """Interface between a view and a model. - - Events - ------ - - show_view - close_view - - """ - _view_name = '' - _imported_params = ('position', 'size',) - - def __init__(self, model=None, **kwargs): - self._model = model - self._event = EventEmitter() - - # Instantiate the underlying view. - self._view = self._create_view(**kwargs) - - # Set passed keyword arguments as attributes. - for param in self.imported_params(): - value = kwargs.get(param, None) - if value is not None: - setattr(self, param, value) - - self.on_open() - - def emit(self, *args, **kwargs): - """Emit an event.""" - return self._event.emit(*args, **kwargs) - - def connect(self, *args, **kwargs): - """Connect a callback function.""" - self._event.connect(*args, **kwargs) - - # Methods to override - #-------------------------------------------------------------------------- - - def _create_view(self, **kwargs): - """Create the view with the parameters passed to the constructor. - - Must be overriden.""" - return None - - def on_open(self): - """Initialize the view after the model has been loaded. - - May be overriden.""" - - def on_close(self): - """Called when the model is closed. - - May be overriden.""" - - # Parameters - #-------------------------------------------------------------------------- - - @classmethod - def imported_params(cls): - """All parameter names to be imported on object creation.""" - out = () - for base_class in inspect.getmro(cls): - if base_class == object: - continue - out += base_class._imported_params - return out - - def exported_params(self, save_size_pos=True): - """Return a dictionary of variables to save when the view is closed.""" - if save_size_pos and hasattr(self._view, 'pos'): - return { - 'position': (self._view.x(), self._view.y()), - 'size': (self._view.width(), self._view.height()), - } - else: - return {} - - @classmethod - def get_params(cls, settings): - """Return the parameter values for the creation of the view.""" - name = cls._view_name - param_names = cls.imported_params() - params = {key: settings[name + '_' + key] - for key in param_names - if (name + '_' + key) in settings} - return params - - # Properties - #-------------------------------------------------------------------------- - - @property - def model(self): - """The model.""" - return self._model - - @property - def name(self): - """The view model's name.""" - return self._view_name - - @property - def view(self): - """The underlying view.""" - return self._view - - # Public methods - #-------------------------------------------------------------------------- - - def close(self): - """Close the view.""" - self._view.close() - self.emit('close_view') - - def show(self): - """Show the view.""" - self._view.show() - self.emit('show_view') - - -#------------------------------------------------------------------------------ -# HTMLViewModel -#------------------------------------------------------------------------------ - -class HTMLViewModel(BaseViewModel): - """Widget with custom HTML code. - - To create a new HTML view, derive from `HTMLViewModel`, and implement - `get_html()` which returns HTML code. - - """ - - def _update(self, view, **kwargs): - html = self.get_html(**kwargs) - css = self.get_css(**kwargs) - wrapped = _read('wrap_qt.html') - html_wrapped = wrapped.replace('%CSS%', css).replace('%HTML%', html) - view.setHtml(html_wrapped) - - def _create_view(self, **kwargs): - from PyQt4.QtWebKit import QWebView - view = QWebView() - self._update(view, **kwargs) - return view - - def get_html(self, **kwargs): - """Return the non-formatted HTML contents of the view.""" - return '' - - def get_css(self, **kwargs): - """Return the view's CSS styles.""" - return '' - - def update(self, **kwargs): - """Update the widget's HTML contents.""" - self._update(self._view, **kwargs) - - def isVisible(self): - return self._view.isVisible() - - -#------------------------------------------------------------------------------ -# Widget creator (used to create views and GUIs) -#------------------------------------------------------------------------------ - -class WidgetCreator(EventEmitter): - """Manage the creation of widgets. - - A widget must implement: - - * `name` - * `show()` - * `connect` (for `close` event) - - Events - ------ - - add(widget): when a widget is added. - close(widget): when a widget is closed. - - """ - def __init__(self, widget_classes=None): - super(WidgetCreator, self).__init__() - self._widget_classes = widget_classes or {} - self._widgets = [] - - def _create_widget(self, widget_class, **kwargs): - """Create a new widget of a given class. - - May be overriden. - - """ - return widget_class(**kwargs) - - @property - def widget_classes(self): - """The registered widget classes that can be created.""" - return self._widget_classes - - def _widget_name(self, widget): - if widget.name: - return widget.name - # Fallback to the name given in widget_classes. - for name, cls in self._widget_classes.items(): - if cls == widget.__class__: - return name - - def get(self, *names): - """Return the list of widgets of a given type.""" - if not names: - return self._widgets - return [widget for widget in self._widgets - if self._widget_name(widget) in names] - - def add(self, widget_class, show=False, **kwargs): - """Add a new widget.""" - # widget_class can also be a name, but in this case it must be - # registered in self._widget_classes. - if isinstance(widget_class, string_types): - if widget_class not in self.widget_classes: - raise ValueError("Unknown widget class " - "`{}`.".format(widget_class)) - widget_class = self.widget_classes[widget_class] - widget = self._create_widget(widget_class, **kwargs) - if widget not in self._widgets: - self._widgets.append(widget) - self.emit('add', widget) - - @widget.connect - def on_close(e=None): - self.emit('close', widget) - self.remove(widget) - - if show: - widget.show() - - return widget - - def remove(self, widget): - """Remove a widget.""" - if widget in self._widgets: - debug("Remove widget {}.".format(widget)) - self._widgets.remove(widget) - else: - debug("Unable to remove widget {}.".format(widget)) - - -#------------------------------------------------------------------------------ -# Base GUI -#------------------------------------------------------------------------------ - -def _title(item): - """Default view model title.""" - if hasattr(item, 'name'): - name = item.name.capitalize() - else: - name = item.__class__.__name__.capitalize() - return name - - -def _assert_counters_equal(c_0, c_1): - c_0 = {(k, v) for (k, v) in c_0.items() if v > 0} - c_1 = {(k, v) for (k, v) in c_1.items() if v > 0} - assert c_0 == c_1 - - -class BaseGUI(EventEmitter): - """Base GUI. - - This object represents a main window with: - - * multiple dockable views - * user-exposed actions - * keyboard shortcuts - - Parameters - ---------- - - config : list - List of pairs `(name, kwargs)` to create default views. - vm_classes : dict - Dictionary `{name: view_model_class}`. - state : object - Default Qt GUI state. - shortcuts : dict - Dictionary `{function_name: keyboard_shortcut}`. - - Events - ------ - - add_view - close_view - reset_gui - - """ - - _default_shortcuts = { - 'exit': 'ctrl+q', - 'enable_snippet_mode': ':', - } - - def __init__(self, - model=None, - vm_classes=None, - state=None, - shortcuts=None, - snippets=None, - config=None, - settings=None, - ): - super(BaseGUI, self).__init__() - self.settings = settings or {} - if state is None: - state = {} - self.model = model - # Shortcuts. - s = self._default_shortcuts.copy() - s.update(shortcuts or {}) - self._shortcuts = s - self._snippets = snippets or {} - # GUI state and config. - self._state = state - if config is None: - config = [(name, {}) for name in (vm_classes or {})] - self._config = config - # Create the dock window. - self._dock = DockWindow(title=self.title) - self._view_creator = WidgetCreator(widget_classes=vm_classes) - self._initialize_views() - self._load_geometry_state(state) - self._set_default_shortcuts() - self._create_actions() - self._set_default_view_connections() - - def _initialize_views(self): - self._load_config(self._config, - requested_count=self._state.get('view_count', None), - ) - - #-------------------------------------------------------------------------- - # Methods to override - #-------------------------------------------------------------------------- - - @property - def title(self): - """Title of the main window. - - May be overriden. - - """ - return 'Base GUI' - - def _set_default_view_connections(self): - """Set view connections. - - May be overriden. - - Example: - - ```python - @self.main_window.connect_views('view_1', 'view_2') - def f(view_1, view_2): - # Called for every pair of views of type view_1 and view_2. - pass - ``` - - """ - pass - - def _create_actions(self): - """Create default actions in the GUI. - - The `_add_gui_shortcut()` method can be used. - - Must be overriden. - - """ - pass - - def _view_model_kwargs(self, name): - return {} - - def on_open(self): - """Callback function when the model is opened. - - Must be overriden. - - """ - pass - - #-------------------------------------------------------------------------- - # Internal methods - #-------------------------------------------------------------------------- - - def _load_config(self, config=None, - current_count=None, - requested_count=None): - """Load a GUI configuration dictionary.""" - config = config or [] - current_count = current_count or {} - requested_count = requested_count or Counter([name - for name, _ in config]) - for name, kwargs in config: - # Add the right number of views of each type. - if current_count.get(name, 0) >= requested_count.get(name, 0): - continue - debug("Adding {} view in GUI.".format(name)) - # GUI-specific keyword arguments position, size, maximized - self.add_view(name, **kwargs) - if name not in current_count: - current_count[name] = 0 - current_count[name] += 1 - _assert_counters_equal(current_count, requested_count) - - def _load_geometry_state(self, gui_state): - if gui_state: - self._dock.restore_geometry_state(gui_state) - - def _remove_actions(self): - self._dock.remove_actions() - - def _set_default_shortcuts(self): - for name, shortcut in self._default_shortcuts.items(): - self._add_gui_shortcut(name) - - def _add_gui_shortcut(self, method_name): - """Helper function to add a GUI action with a keyboard shortcut.""" - # Get the keyboard shortcut for this method. - shortcut = self._shortcuts.get(method_name, None) - - def callback(): - return getattr(self, method_name)() - - # Bind the shortcut to the method. - self._dock.add_action(method_name, - callback, - shortcut=shortcut, - ) - - #-------------------------------------------------------------------------- - # Snippet methods - #-------------------------------------------------------------------------- - - @property - def status_message(self): - """Message in the status bar.""" - return str(self._dock.status_message) - - @status_message.setter - def status_message(self, value): - self._dock.status_message = str(value) - - # HACK: Unicode characters do not appear to work on Python 2 - _snippet_message_cursor = '\u200A\u258C' if PY3 else '' - - @property - def _snippet_message(self): - """This is used to write a snippet message in the status bar. - - A cursor is appended at the end. - - """ - n = len(self.status_message) - n_cur = len(self._snippet_message_cursor) - return self.status_message[:n - n_cur] - - @_snippet_message.setter - def _snippet_message(self, value): - self.status_message = value + self._snippet_message_cursor - - def process_snippet(self, snippet): - """Processes a snippet. - - May be overriden. - - """ - assert snippet[0] == ':' - snippet = snippet[1:] - split = snippet.split(' ') - cmd = split[0] - snippet = snippet[len(cmd):].strip() - func = self._snippets.get(cmd, None) - if func is None: - info("The snippet `{}` could not be found.".format(cmd)) - return - try: - info("Processing snippet `{}`.".format(cmd)) - func(self, snippet) - except Exception as e: - warn("Error when executing snippet `{}`: {}.".format( - cmd, str(e))) - - def _snippet_action_name(self, char): - return self._snippet_chars.index(char) - - _snippet_chars = 'abcdefghijklmnopqrstuvwxyz0123456789 ._,+*-=:()' - - def _create_snippet_actions(self): - # One action per allowed character. - for i, char in enumerate(self._snippet_chars): - - def _make_func(char): - def callback(): - self._snippet_message += char - return callback - - self._dock.add_action('snippet_{}'.format(i), - shortcut=char, - callback=_make_func(char), - ) - - def backspace(): - if self._snippet_message == ':': - return - self._snippet_message = self._snippet_message[:-1] - - def enter(): - self.process_snippet(self._snippet_message) - self.disable_snippet_mode() - - self._dock.add_action('snippet_backspace', - shortcut='backspace', - callback=backspace, - ) - self._dock.add_action('snippet_activate', - shortcut=('enter', 'return'), - callback=enter, - ) - self._dock.add_action('snippet_disable', - shortcut='escape', - callback=self.disable_snippet_mode, - ) - - def enable_snippet_mode(self): - info("Snippet mode enabled, press `escape` to leave this mode.") - self._remove_actions() - self._create_snippet_actions() - self._snippet_message = ':' - - def disable_snippet_mode(self): - self.status_message = '' - # Reestablishes the shortcuts. - self._remove_actions() - self._set_default_shortcuts() - self._create_actions() - info("Snippet mode disabled.") - - #-------------------------------------------------------------------------- - # Public methods - #-------------------------------------------------------------------------- - - def show(self): - """Show the GUI""" - self._dock.show() - - @property - def main_window(self): - """Main Qt window.""" - return self._dock - - def add_view(self, item, title=None, **kwargs): - """Add a new view instance to the GUI.""" - position = kwargs.pop('position', None) - # Item may be a string. - if isinstance(item, string_types): - name = item - # Default view model kwargs. - kwargs.update(self._view_model_kwargs(name)) - # View model parameters from settings. - vm_class = self._view_creator._widget_classes[name] - kwargs.update(vm_class.get_params(self.settings)) - # debug("Create {} with {}.".format(name, kwargs)) - item = self._view_creator.add(item, **kwargs) - # Set the view name if necessary. - if not item._view_name: - item._view_name = name - # Default dock title. - if title is None: - title = _title(item) - # Get the underlying view. - view = item.view if isinstance(item, BaseViewModel) else item - # Add the view to the main window. - dw = self._dock.add_view(view, title=title, position=position) - - # Dock widget close event. - @dw.connect_ - def on_close_widget(): - self._view_creator.remove(item) - self.emit('close_view', item) - - # Make sure the callback above is called when the dock widget - # is closed directly. - # View model close event. - @item.connect - def on_close_view(e=None): - dw.close() - - # Call the user callback function. - if 'on_view_open' in self.settings: - self.settings['on_view_open'](self, item) - - self.emit('add_view', item) - - def get_views(self, *names): - """Return the list of views of a given type.""" - return self._view_creator.get(*names) - - def connect_views(self, name_0, name_1): - """Decorator for a function called on every pair of views of a - given type.""" - def wrap(func): - for view_0 in self.get_views(name_0): - for view_1 in self.get_views(name_1): - func(view_0, view_1) - return wrap - - @property - def views(self): - """List of all open views.""" - return self.get_views() - - def view_count(self): - """Number of views of each type.""" - return {name: len(self.get_views(name)) - for name in self._view_creator.widget_classes.keys()} - - def reset_gui(self): - """Reset the GUI configuration.""" - count = self.view_count() - self._load_config(self._config, - current_count=count, - ) - self.emit('reset_gui') - - def show_shortcuts(self): - """Show the list of all keyboard shortcuts.""" - _show_shortcuts(self._shortcuts, name=self.__class__.__name__) - - def isVisible(self): - return self._dock.isVisible() - - def close(self): - """Close the GUI.""" - self._dock.close() - - def exit(self): - """Close the GUI.""" - self.close() diff --git a/phy/gui/dock.py b/phy/gui/dock.py deleted file mode 100644 index 99a95b8ff..000000000 --- a/phy/gui/dock.py +++ /dev/null @@ -1,282 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Qt dock window.""" - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -from collections import defaultdict - -from .qt import QtCore, QtGui -from ..utils.event import EventEmitter - - -# ----------------------------------------------------------------------------- -# Qt utilities -# ----------------------------------------------------------------------------- - -def _title(widget): - return str(widget.windowTitle()).lower() - - -def _widget(dock_widget): - """Return a Qt or VisPy widget from a dock widget.""" - widget = dock_widget.widget() - if hasattr(widget, '_vispy_canvas'): - return widget._vispy_canvas - else: - return widget - - -# ----------------------------------------------------------------------------- -# Qt windows -# ----------------------------------------------------------------------------- - -class DockWindow(QtGui.QMainWindow): - """A Qt main window holding docking Qt or VisPy widgets. - - Events - ------ - - close_gui - show_gui - keystroke - - Note - ---- - - Use `connect_()`, not `connect()`, because of a name conflict with Qt. - - """ - def __init__(self, - position=None, - size=None, - title=None, - ): - super(DockWindow, self).__init__() - self._actions = {} - if title is None: - title = 'phy' - self.setWindowTitle(title) - if position is not None: - self.move(position[0], position[1]) - if size is not None: - self.resize(QtCore.QSize(size[0], size[1])) - self.setObjectName(title) - QtCore.QMetaObject.connectSlotsByName(self) - self.setDockOptions(QtGui.QMainWindow.AllowTabbedDocks | - QtGui.QMainWindow.AllowNestedDocks | - QtGui.QMainWindow.AnimatedDocks - ) - # We can derive from EventEmitter because of a conflict with connect. - self._event = EventEmitter() - - self._status_bar = QtGui.QStatusBar() - self.setStatusBar(self._status_bar) - - def keyReleaseEvent(self, e): - self.emit('keystroke', e.key(), e.text()) - return super(DockWindow, self).keyReleaseEvent(e) - - # Events - # ------------------------------------------------------------------------- - - def emit(self, *args, **kwargs): - return self._event.emit(*args, **kwargs) - - def connect_(self, *args, **kwargs): - self._event.connect(*args, **kwargs) - - def unconnect_(self, *args, **kwargs): - self._event.unconnect(*args, **kwargs) - - def closeEvent(self, e): - """Qt slot when the window is closed.""" - res = self.emit('close_gui') - # Discard the close event if False is returned by one of the callback - # functions. - if False in res: - e.ignore() - return - super(DockWindow, self).closeEvent(e) - - def show(self): - """Show the window.""" - self.emit('show_gui') - super(DockWindow, self).show() - - # Actions - # ------------------------------------------------------------------------- - - def add_action(self, - name, - callback=None, - shortcut=None, - checkable=False, - checked=False, - ): - """Add an action with a keyboard shortcut.""" - if name in self._actions: - return - action = QtGui.QAction(name, self) - action.triggered.connect(callback) - action.setCheckable(checkable) - action.setChecked(checked) - if shortcut: - if not isinstance(shortcut, (tuple, list)): - shortcut = [shortcut] - for key in shortcut: - action.setShortcut(key) - self.addAction(action) - self._actions[name] = action - if callback: - setattr(self, name, callback) - return action - - def remove_action(self, name): - """Remove an action.""" - self.removeAction(self._actions[name]) - del self._actions[name] - delattr(self, name) - - def remove_actions(self): - """Remove all actions.""" - names = sorted(self._actions.keys()) - for name in names: - self.remove_action(name) - - def shortcut(self, name, key=None): - """Decorator to add a global keyboard shortcut.""" - def wrap(func): - self.add_action(name, shortcut=key, callback=func) - return wrap - - # Views - # ------------------------------------------------------------------------- - - def add_view(self, - view, - title='view', - position=None, - closable=True, - floatable=True, - floating=None, - # parent=None, # object to pass in the raised events - **kwargs): - """Add a widget to the main window.""" - - try: - from vispy.app import Canvas - if isinstance(view, Canvas): - view = view.native - except ImportError: - pass - - class DockWidget(QtGui.QDockWidget): - def __init__(self, *args, **kwargs): - super(DockWidget, self).__init__(*args, **kwargs) - self._event = EventEmitter() - - def emit(self, *args, **kwargs): - return self._event.emit(*args, **kwargs) - - def connect_(self, *args, **kwargs): - self._event.connect(*args, **kwargs) - - def closeEvent(self, e): - """Qt slot when the window is closed.""" - self.emit('close_widget') - super(DockWidget, self).closeEvent(e) - - # Create the dock widget. - dockwidget = DockWidget(self) - dockwidget.setObjectName(title) - dockwidget.setWindowTitle(title) - dockwidget.setWidget(view) - - # Set dock widget options. - options = QtGui.QDockWidget.DockWidgetMovable - if closable: - options = options | QtGui.QDockWidget.DockWidgetClosable - if floatable: - options = options | QtGui.QDockWidget.DockWidgetFloatable - - dockwidget.setFeatures(options) - dockwidget.setAllowedAreas(QtCore.Qt.LeftDockWidgetArea | - QtCore.Qt.RightDockWidgetArea | - QtCore.Qt.TopDockWidgetArea | - QtCore.Qt.BottomDockWidgetArea - ) - - q_position = { - 'left': QtCore.Qt.LeftDockWidgetArea, - 'right': QtCore.Qt.RightDockWidgetArea, - 'top': QtCore.Qt.TopDockWidgetArea, - 'bottom': QtCore.Qt.BottomDockWidgetArea, - }[position or 'right'] - self.addDockWidget(q_position, dockwidget) - if floating is not None: - dockwidget.setFloating(floating) - dockwidget.show() - return dockwidget - - def list_views(self, title='', is_visible=True): - """List all views which title start with a given string.""" - title = title.lower() - children = self.findChildren(QtGui.QWidget) - return [child for child in children - if isinstance(child, QtGui.QDockWidget) and - _title(child).startswith(title) and - (child.isVisible() if is_visible else True) and - child.width() >= 10 and - child.height() >= 10 - ] - - def view_count(self, is_visible=True): - """Return the number of opened views.""" - views = self.list_views() - counts = defaultdict(lambda: 0) - for view in views: - counts[_title(view)] += 1 - return dict(counts) - - # Status bar - # ------------------------------------------------------------------------- - - @property - def status_message(self): - """The message in the status bar.""" - return self._status_bar.currentMessage() - - @status_message.setter - def status_message(self, value): - self._status_bar.showMessage(value) - - # State - # ------------------------------------------------------------------------- - - def save_geometry_state(self): - """Return picklable geometry and state of the window and docks. - - This function can be called in `on_close()`. - - """ - return { - 'geometry': self.saveGeometry(), - 'state': self.saveState(), - 'view_count': self.view_count(), - } - - def restore_geometry_state(self, gs): - """Restore the position of the main window and the docks. - - The dock widgets need to be recreated first. - - This function can be called in `on_show()`. - - """ - if gs.get('geometry', None): - self.restoreGeometry((gs['geometry'])) - if gs.get('state', None): - self.restoreState((gs['state'])) diff --git a/phy/gui/gui.py b/phy/gui/gui.py new file mode 100644 index 000000000..b154ebc2b --- /dev/null +++ b/phy/gui/gui.py @@ -0,0 +1,412 @@ +# -*- coding: utf-8 -*- + +"""Qt dock window.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +from collections import defaultdict +import logging +import os.path as op + +from .qt import (QApplication, QWidget, QDockWidget, QStatusBar, QMainWindow, + Qt, QSize, QMetaObject) +from .actions import Actions, Snippets +from phy.utils.event import EventEmitter +from phy.utils import (Bunch, _bunchify, + _load_json, _save_json, + _ensure_dir_exists, phy_config_dir,) + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# GUI main window +# ----------------------------------------------------------------------------- + +def _try_get_vispy_canvas(view): + # Get the Qt widget from a VisPy canvas. + try: + from vispy.app import Canvas + if isinstance(view, Canvas): + view = view.native + except ImportError: # pragma: no cover + pass + return view + + +def _try_get_matplotlib_canvas(view): + # Get the Qt widget from a matplotlib figure. + try: + from matplotlib.pyplot import Figure + from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg + if isinstance(view, Figure): + view = FigureCanvasQTAgg(view) + except ImportError: # pragma: no cover + pass + return view + + +class DockWidget(QDockWidget): + """A QDockWidget that can emit events.""" + def __init__(self, *args, **kwargs): + super(DockWidget, self).__init__(*args, **kwargs) + self._event = EventEmitter() + + def emit(self, *args, **kwargs): + return self._event.emit(*args, **kwargs) + + def connect_(self, *args, **kwargs): + self._event.connect(*args, **kwargs) + + def closeEvent(self, e): + """Qt slot when the window is closed.""" + self.emit('close_widget') + super(DockWidget, self).closeEvent(e) + + +def _create_dock_widget(widget, name, closable=True, floatable=True): + # Create the gui widget. + dock_widget = DockWidget() + dock_widget.setObjectName(name) + dock_widget.setWindowTitle(name) + dock_widget.setWidget(widget) + + # Set gui widget options. + options = QDockWidget.DockWidgetMovable + if closable: + options = options | QDockWidget.DockWidgetClosable + if floatable: + options = options | QDockWidget.DockWidgetFloatable + + dock_widget.setFeatures(options) + dock_widget.setAllowedAreas(Qt.LeftDockWidgetArea | + Qt.RightDockWidgetArea | + Qt.TopDockWidgetArea | + Qt.BottomDockWidgetArea + ) + + return dock_widget + + +def _get_dock_position(position): + return {'left': Qt.LeftDockWidgetArea, + 'right': Qt.RightDockWidgetArea, + 'top': Qt.TopDockWidgetArea, + 'bottom': Qt.BottomDockWidgetArea, + }[position or 'right'] + + +class GUI(QMainWindow): + """A Qt main window holding docking Qt or VisPy widgets. + + `GUI` derives from `QMainWindow`. + + Events + ------ + + close + show + add_view + close_view + + Note + ---- + + Use `connect_()`, not `connect()`, because of a name conflict with Qt. + + """ + def __init__(self, + position=None, + size=None, + name=None, + subtitle=None, + **kwargs + ): + # HACK to ensure that closeEvent is called only twice (seems like a + # Qt bug). + self._closed = False + if not QApplication.instance(): # pragma: no cover + raise RuntimeError("A Qt application must be created.") + super(GUI, self).__init__() + QMetaObject.connectSlotsByName(self) + self.setDockOptions(QMainWindow.AllowTabbedDocks | + QMainWindow.AllowNestedDocks | + QMainWindow.AnimatedDocks + ) + + self._set_name(name, subtitle) + self._set_pos_size(position, size) + + # Registered functions. + self._registered = {} + + # Mapping {name: menuBar}. + self._menus = {} + + # We can derive from EventEmitter because of a conflict with connect. + self._event = EventEmitter() + + # Status bar. + self._lock_status = False + self._status_bar = QStatusBar() + self.setStatusBar(self._status_bar) + + # List of attached Actions instances. + self.actions = [] + + # Default actions. + self._set_default_actions() + + # Create and attach snippets. + self.snippets = Snippets(self) + + # Create the state. + self.state = GUIState(self.name, **kwargs) + + @self.connect_ + def on_show(): + logger.debug("Load the geometry state.") + gs = self.state.get('geometry_state', None) + self.restore_geometry_state(gs) + + @self.connect_ + def on_close(): + logger.debug("Save the geometry state.") + gs = self.save_geometry_state() + self.state['geometry_state'] = gs + # Save the state to disk when closing the GUI. + self.state.save() + + def _set_name(self, name, subtitle): + if name is None: + name = self.__class__.__name__ + title = name if not subtitle else name + ' - ' + subtitle + self.setWindowTitle(title) + self.setObjectName(name) + # Set the name in the GUI. + self.name = name + + def _set_pos_size(self, position, size): + if position is not None: + self.move(position[0], position[1]) + if size is not None: + self.resize(QSize(size[0], size[1])) + + def _set_default_actions(self): + self.default_actions = Actions(self, name='Default', menu='&File') + + @self.default_actions.add(shortcut=('HelpContents', 'h')) + def show_all_shortcuts(): + """Show the shortcuts of all actions.""" + for actions in self.actions: + actions.show_shortcuts() + + @self.default_actions.add(shortcut='ctrl+q') + def exit(): + """Close the GUI.""" + self.close() + + # Events + # ------------------------------------------------------------------------- + + def emit(self, *args, **kwargs): + return self._event.emit(*args, **kwargs) + + def connect_(self, *args, **kwargs): + self._event.connect(*args, **kwargs) + + def unconnect_(self, *args, **kwargs): + self._event.unconnect(*args, **kwargs) + + def closeEvent(self, e): + """Qt slot when the window is closed.""" + if self._closed: + return + self._closed = True + res = self.emit('close') + # Discard the close event if False is returned by one of the callback + # functions. + if False in res: # pragma: no cover + e.ignore() + return + super(GUI, self).closeEvent(e) + + def show(self): + """Show the window.""" + self.emit('show') + super(GUI, self).show() + + # Views + # ------------------------------------------------------------------------- + + def _get_view_index(self, view): + """Index of a view in the GUI: 0 for the first view of a given + class, 1 for the next, and so on.""" + name = view.__class__.__name__ + return len(self.list_views(name)) + + def add_view(self, + view, + name=None, + position=None, + closable=False, + floatable=True, + floating=None): + """Add a widget to the main window.""" + + # Set the name in the view. + view.view_index = self._get_view_index(view) + # The view name is ``, e.g. `MyView0`. + view.name = name or view.__class__.__name__ + str(view.view_index) + + # Get the Qt canvas for VisPy and matplotlib views. + widget = _try_get_vispy_canvas(view) + widget = _try_get_matplotlib_canvas(widget) + + dock_widget = _create_dock_widget(widget, view.name, + closable=closable, + floatable=floatable, + ) + self.addDockWidget(_get_dock_position(position), dock_widget) + if floating is not None: + dock_widget.setFloating(floating) + dock_widget.view = view + + # Emit the close_view event when the dock widget is closed. + @dock_widget.connect_ + def on_close_widget(): + self.emit('close_view', view) + + dock_widget.show() + self.emit('add_view', view) + logger.log(5, "Add %s to GUI.", view.name) + return dock_widget + + def list_views(self, name='', is_visible=True): + """List all views which name start with a given string.""" + children = self.findChildren(QWidget) + return [child.view for child in children + if isinstance(child, QDockWidget) and + child.view.name.startswith(name) and + (child.isVisible() if is_visible else True) and + child.width() >= 10 and + child.height() >= 10 + ] + + def view_count(self): + """Return the number of opened views.""" + views = self.list_views() + counts = defaultdict(lambda: 0) + for view in views: + counts[view.name] += 1 + return dict(counts) + + # Menu bar + # ------------------------------------------------------------------------- + + def get_menu(self, name): + """Return or create a menu.""" + if name not in self._menus: + self._menus[name] = self.menuBar().addMenu(name) + return self._menus[name] + + # Status bar + # ------------------------------------------------------------------------- + + @property + def status_message(self): + """The message in the status bar.""" + return str(self._status_bar.currentMessage()) + + @status_message.setter + def status_message(self, value): + if self._lock_status: + return + self._status_bar.showMessage(str(value)) + + def lock_status(self): + self._lock_status = True + + def unlock_status(self): + self._lock_status = False + + # State + # ------------------------------------------------------------------------- + + def save_geometry_state(self): + """Return picklable geometry and state of the window and docks. + + This function can be called in `on_close()`. + + """ + return { + 'geometry': self.saveGeometry(), + 'state': self.saveState(), + } + + def restore_geometry_state(self, gs): + """Restore the position of the main window and the docks. + + The gui widgets need to be recreated first. + + This function can be called in `on_show()`. + + """ + if not gs: + return + if gs.get('geometry', None): + self.restoreGeometry((gs['geometry'])) + if gs.get('state', None): + self.restoreState((gs['state'])) + + +# ----------------------------------------------------------------------------- +# GUI state, creator +# ----------------------------------------------------------------------------- + +class GUIState(Bunch): + """Represent the state of the GUI: positions of the views and + all parameters associated to the GUI and views. + + This is automatically loaded from the configuration directory. + + """ + def __init__(self, name='GUI', config_dir=None, **kwargs): + super(GUIState, self).__init__(**kwargs) + self.name = name + self.config_dir = config_dir or phy_config_dir() + _ensure_dir_exists(op.join(self.config_dir, self.name)) + self.load() + + def get_view_state(self, view): + """Return the state of a view.""" + return self.get(view.name, Bunch()) + + def update_view_state(self, view, state): + """Update the state of a view.""" + if view.name not in self: + self[view.name] = Bunch() + self[view.name].update(state) + + @property + def path(self): + return op.join(self.config_dir, self.name, 'state.json') + + def load(self): + """Load the state from the JSON file in the config dir.""" + if not op.exists(self.path): + logger.debug("The GUI state file `%s` doesn't exist.", self.path) + # TODO: create the default state. + return + assert op.exists(self.path) + logger.debug("Load the GUI state from `%s`.", self.path) + self.update(_bunchify(_load_json(self.path))) + + def save(self): + """Save the state to the JSON file in the config dir.""" + logger.debug("Save the GUI state to `%s`.", self.path) + _save_json(self.path, {k: v for k, v in self.items() + if k not in ('config_dir', 'name')}) diff --git a/phy/gui/qt.py b/phy/gui/qt.py index dfb46e0d7..6fd3045b7 100644 --- a/phy/gui/qt.py +++ b/phy/gui/qt.py @@ -6,194 +6,127 @@ # Imports # ----------------------------------------------------------------------------- -import os +from contextlib import contextmanager +from functools import wraps +import logging import sys -import contextlib -from ..utils._misc import _is_interactive -from ..utils.logging import info, warn +logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # PyQt import # ----------------------------------------------------------------------------- -_PYQT = False +from PyQt4.QtCore import (Qt, QByteArray, QMetaObject, QObject, # noqa + QVariant, QEventLoop, QTimer, + pyqtSignal, pyqtSlot, QSize, QUrl) try: - from PyQt4 import QtCore, QtGui, QtWebKit # noqa - from PyQt4.QtGui import QMainWindow - Qt = QtCore.Qt - _PYQT = True -except ImportError: - try: - from PyQt5 import QtCore, QtGui, QtWebKit # noqa - from PyQt5.QtGui import QMainWindow - _PYQT = True - except ImportError: - pass + from PyQt4.QtCore import QPyNullVariant # noqa +except: # pragma: no cover + QPyNullVariant = None +try: + from PyQt4.QtCore import QString # noqa +except: # pragma: no cover + QString = None +from PyQt4.QtGui import (QKeySequence, QAction, QStatusBar, # noqa + QMainWindow, QDockWidget, QWidget, + QMessageBox, QApplication, QMenuBar, + ) +from PyQt4.QtWebKit import QWebView, QWebPage, QWebSettings # noqa -def _check_qt(): - if not _PYQT: - warn("PyQt is not available.") - return False - return True +# ----------------------------------------------------------------------------- +# Utility functions +# ----------------------------------------------------------------------------- +def _button_enum_from_name(name): + return getattr(QMessageBox, name.capitalize()) -if not _check_qt(): - QMainWindow = object # noqa +def _button_name_from_enum(enum): + names = dir(QMessageBox) + for name in names: + if getattr(QMessageBox, name) == enum: + return name.lower() -# ----------------------------------------------------------------------------- -# Utility functions -# ----------------------------------------------------------------------------- -def _prompt(parent, message, buttons=('yes', 'no'), title='Question'): - buttons = [(button, getattr(QtGui.QMessageBox, button.capitalize())) - for button in buttons] +def _prompt(message, buttons=('yes', 'no'), title='Question'): + buttons = [(button, _button_enum_from_name(button)) for button in buttons] arg_buttons = 0 for (_, button) in buttons: arg_buttons |= button - reply = QtGui.QMessageBox.question(parent, - title, - message, - arg_buttons, - buttons[0][1], - ) - for name, button in buttons: - if reply == button: - return name + box = QMessageBox() + box.setWindowTitle(title) + box.setText(message) + box.setStandardButtons(arg_buttons) + box.setDefaultButton(buttons[0][1]) + return box -def _set_qt_widget_position_size(widget, position=None, size=None): - if position is not None: - widget.moveTo(*position) - if size is not None: - widget.resize(*size) +def _show_box(box): # pragma: no cover + return _button_name_from_enum(box.exec_()) -# ----------------------------------------------------------------------------- -# Event loop integration with IPython -# ----------------------------------------------------------------------------- +@contextmanager +def _wait_signal(signal, timeout=None): + """Block loop until signal emitted, or timeout (ms) elapses.""" + # http://jdreaver.com/posts/2014-07-03-waiting-for-signals-pyside-pyqt.html + loop = QEventLoop() + signal.connect(loop.quit) -_APP = None -_APP_RUNNING = False + yield + if timeout is not None: + QTimer.singleShot(timeout, loop.quit) + loop.exec_() -def _try_enable_ipython_qt(): - """Try to enable IPython Qt event loop integration. - Returns True in the following cases: - - * python -i test.py - * ipython -i test.py - * ipython and %run test.py +# ----------------------------------------------------------------------------- +# Qt app +# ----------------------------------------------------------------------------- - Returns False in the following cases: +def require_qt(func): + """Specify that a function requires a Qt application. - * python test.py - * ipython test.py + Use this decorator to specify that a function needs a running + Qt application before it can run. An error is raised if that is not + the case. """ - try: - from IPython import get_ipython - ip = get_ipython() - except ImportError: - return False - if not _is_interactive(): - return False - if ip: - ip.enable_gui('qt') - global _APP_RUNNING - _APP_RUNNING = True - return True - return False - - -def enable_qt(): - if not _check_qt(): - return - try: - from IPython import get_ipython - ip = get_ipython() - ip.enable_gui('qt') - global _APP_RUNNING - _APP_RUNNING = True - info("Qt event loop activated.") - except: - warn("Qt event loop not activated.") + @wraps(func) + def wrapped(*args, **kwargs): + if not QApplication.instance(): # pragma: no cover + raise RuntimeError("A Qt application must be created.") + return func(*args, **kwargs) + return wrapped -# ----------------------------------------------------------------------------- -# Qt app -# ----------------------------------------------------------------------------- +# Global variable with the current Qt application. +QT_APP = None -def start_qt_app(): - """Start a Qt application if necessary. - If a new Qt application is created, this function returns it. - If no new application is created, the function returns None. +def create_app(): + """Create a Qt application.""" + global QT_APP + QT_APP = QApplication.instance() + if QT_APP is None: # pragma: no cover + QT_APP = QApplication(sys.argv) + return QT_APP - """ - # Only start a Qt application if there is no - # IPython event loop integration. - if not _check_qt(): - return - global _APP - if _try_enable_ipython_qt(): - return - try: - from vispy import app - app.use_app("pyqt4") - except ImportError: - pass - if QtGui.QApplication.instance(): - _APP = QtGui.QApplication.instance() - return - if _APP: - return - _APP = QtGui.QApplication(sys.argv) - return _APP - - -def run_qt_app(): - """Start the Qt application's event loop.""" - global _APP_RUNNING - if not _check_qt(): - return - if _APP is not None and not _APP_RUNNING: - _APP_RUNNING = True - _APP.exec_() - if not _is_interactive(): - _APP_RUNNING = False - - -@contextlib.contextmanager -def qt_app(): - """Context manager to ensure that a Qt app is running.""" - if not _check_qt(): - return - app = start_qt_app() - yield app - run_qt_app() + +@require_qt +def run_app(): # pragma: no cover + """Run the Qt application.""" + global QT_APP + return QT_APP.exit(QT_APP.exec_()) # ----------------------------------------------------------------------------- # Testing utilities # ----------------------------------------------------------------------------- -def _close_qt_after(window, duration): - """Close a Qt window after a given duration.""" - def callback(): - window.close() - QtCore.QTimer.singleShot(int(1000 * duration), callback) - - -_MAX_ITER = 100 -_DELAY = max(0, float(os.environ.get('PHY_EVENT_LOOP_DELAY', .1))) - - -def _debug_trace(): +def _debug_trace(): # pragma: no cover """Set a tracepoint in the Python debugger that works with Qt.""" from PyQt4.QtCore import pyqtRemoveInputHook from pdb import set_trace diff --git a/phy/gui/static/__init__.py b/phy/gui/static/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/gui/static/styles.css b/phy/gui/static/styles.css deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/gui/static/table.css b/phy/gui/static/table.css new file mode 100644 index 000000000..7e15b1c4d --- /dev/null +++ b/phy/gui/static/table.css @@ -0,0 +1,66 @@ + +th.sort-header::-moz-selection { background:transparent; } +th.sort-header::selection { background:transparent; } +th.sort-header { cursor:pointer; } + +table { + border-collapse: collapse; +} + +table td { + padding: 5px 10px; + margin: 0; +} + +table th.sort-header:after { + content: "\25B2"; + margin-left: 5px; + margin-right: 15px; + visibility: hidden; +} + +table th.sort-header:hover:after { + visibility: visible; +} + +table th.sort-up:after { + content: "\25BC"; +} +table th.sort-down:after { + content: "\25B2"; +} + +table th.sort-up:after, +table th.sort-down:after, +table th.sort-down:hover:after { + visibility: visible; +} + +table tr { cursor:pointer; } + +table tr:hover { + background-color: #222; +} + +table tr:hover th { + background-color: #000; +} + +table tr.selected td { + background-color: #444; +} + +table tr.pinned { + background-color: #888; +} + +table tr[data-skip='true'] { + color: #888; +} + +table tr, table td { + user-select: none; /* CSS3 (little to no support) */ + -ms-user-select: none; /* IE 10+ */ + -moz-user-select: none; /* Gecko (Firefox) */ + -webkit-user-select: none; /* Webkit (Safari, Chrome) */ +} diff --git a/phy/gui/static/table.js b/phy/gui/static/table.js new file mode 100644 index 000000000..f2654bdbb --- /dev/null +++ b/phy/gui/static/table.js @@ -0,0 +1,267 @@ +// Utils. + +function uniq(a) { + var seen = {}; + return a.filter(function(item) { + return seen.hasOwnProperty(item) ? false : (seen[item] = true); + }); +} + +function isFloat(n) { + return n === Number(n) && n % 1 !== 0; +} + +function clear(e) { + while (e.firstChild) { + e.removeChild(e.firstChild); + } +} + + +// Table class. +var Table = function (el) { + this.el = el; + this.selected = []; + this.headers = {}; // {name: th} mapping + this.rows = {}; // {id: tr} mapping + this.cols = []; + + var thead = document.createElement("thead"); + this.el.appendChild(thead); + + var tbody = document.createElement("tbody"); + this.el.appendChild(tbody); +}; + +Table.prototype.setHeaders = function(data) { + this.rows = {}; + + var that = this; + var keys = data.cols; + this.cols = data.cols; + + var thead = this.el.getElementsByTagName("thead")[0]; + clear(thead); + + // Header. + var tr = document.createElement("tr"); + for (var j = 0; j < keys.length; j++) { + var key = keys[j]; + var th = document.createElement("th"); + th.appendChild(document.createTextNode(key)); + tr.appendChild(th); + this.headers[key] = th; + } + thead.appendChild(tr); + + // Enable the tablesort plugin. + this.tablesort = new Tablesort(this.el); +} + +Table.prototype.setData = function(data) { + /* + data.cols: list of column names + data.items: list of rows (each row is an object {col: value}) + */ + // if (data.items.length == 0) return; + + // Reinitialize the state. + this.selected = []; + this.rows = {}; + var keys = data.cols; + var that = this; + + // Clear the table body. + var tbody = this.el.getElementsByTagName("tbody")[0]; + clear(tbody); + this.nrows = data.items.length; + + // Data rows. + for (var i = 0; i < data.items.length; i++) { + tr = document.createElement("tr"); + var row = data.items[i]; + for (var j = 0; j < keys.length; j++) { + var key = keys[j]; + var value = row[key]; + // Format numbers. + if (isFloat(value)) + value = value.toPrecision(3); + var td = document.createElement("td"); + td.appendChild(document.createTextNode(value)); + tr.appendChild(td); + } + + // Set the data values on the row. + for (var key in row) { + tr.dataset[key] = row[key]; + } + + tr.onclick = function(e) { + var id = parseInt(String(this.dataset.id)); + var evt = e ? e:window.event; + // Control pressed: toggle selected. + if (evt.ctrlKey || evt.metaKey) { + var index = that.selected.indexOf(id); + // If this item is already selected, deselect it. + if (index != -1) { + var selected = that.selected.slice(); + selected.splice(index, 1); + that.select(selected); + } + // Otherwise, select it. + else { + that.select(that.selected.concat([id])); + } + } + else if (evt.shiftKey) { + var clicked_idx = that.rows[id].rowIndex; + var sel_idx = that.rows[that.selected[0]].rowIndex; + if (sel_idx == undefined) return; + var i0 = Math.min(clicked_idx, sel_idx); + var i1 = Math.max(clicked_idx, sel_idx); + var sel = []; + for (var i = i0; i <= i1; i++) { + sel.push(that.el.rows[i].dataset.id); + } + that.select(sel); + } + // Otherwise, select just that item. + else { + that.select([id]); + } + } + + tbody.appendChild(tr); + this.rows[data.items[i].id] = tr; + } +}; + +Table.prototype.rowId = function(i) { + return this.el.rows[i].dataset.id; +}; + +Table.prototype.isRowSkipped = function(i) { + return this.el.rows[i].dataset.skip == 'true'; +}; + +Table.prototype.sortBy = function(header, dir) { + dir = typeof dir !== 'undefined' ? dir : 'asc'; + if (this.headers[header] == undefined) + throw "The column `" + header + "` doesn't exist." + + // Remove all sort classes. + for (var i = 0; i < this.cols.length; i++) { + var name = this.cols[i]; + this.headers[name].classList.remove("sort-up"); + this.headers[name].classList.remove("sort-down"); + } + + var order = (dir == 'asc') ? "sort-up" : "sort-down"; + this.headers[header].classList.add(order); + + // Add sort. + this.tablesort.sortTable(this.headers[header]); +}; + +Table.prototype.currentSort = function() { + for (var header in this.headers) { + if (this.headers[header].classList.contains('sort-up')) { + return [header, 'desc']; + } + if (this.headers[header].classList.contains('sort-down')) { + return [header, 'asc']; + } + } + return [null, null]; +}; + +Table.prototype.select = function(ids, do_emit) { + do_emit = typeof do_emit !== 'undefined' ? do_emit : true; + + ids = uniq(ids); + + // Remove the class on all rows. + for (var i = 0; i < this.selected.length; i++) { + var id = this.selected[i]; + var row = this.rows[id]; + row.classList.remove('selected'); + } + + // Add the class. + for (var i = 0; i < ids.length; i++) { + ids[i] = parseInt(String(ids[i])); + this.rows[ids[i]].classList.add('selected'); + } + + this.selected = ids; + + if (do_emit) + emit("select", ids); +}; + +Table.prototype.clear = function() { + this.selected = []; +}; + +Table.prototype.firstRow = function() { + return this.el.rows[1]; +}; + +Table.prototype.lastRow = function() { + return this.el.rows[this.el.rows.length - 1]; +}; + +Table.prototype.rowIterator = function(id, doSkip) { + doSkip = typeof doSkip !== 'undefined' ? doSkip : true; + // TODO: what to do when doing next() while several items are selected. + var i0 = undefined; + if (id !== undefined) { + i0 = this.rows[id].rowIndex; + } + var that = this; + return { + i: i0, + n: that.el.rows.length, + row: function () { return that.el.rows[this.i]; }, + previous: function () { + if (this.i == undefined) this.i = this.n; + for (var i = this.i - 1; i >= 1; i--) { + if (!doSkip || !that.isRowSkipped(i)) { + this.i = i; + return this.row(); + } + } + return this.row(); + }, + next: function () { + if (this.i == undefined) this.i = 0; + for (var i = this.i + 1; i < this.n; i++) { + if (!doSkip || !that.isRowSkipped(i)) { + this.i = i; + return this.row(); + } + } + return this.row(); + } + }; +}; + +Table.prototype.next = function() { + // TODO: what to do when doing next() while several items are selected. + var id = this.selected[0]; + var iterator = this.rowIterator(id); + var row = iterator.next(); + this.select([row.dataset.id]); + row.scrollIntoView(false); + return; +}; + +Table.prototype.previous = function() { + // TODO: what to do when doing previous() while several items are selected. + var id = this.selected[0]; + var iterator = this.rowIterator(id); + var row = iterator.previous(); + this.select([row.dataset.id]); + row.scrollIntoView(false); + return; +}; diff --git a/phy/gui/static/tablesort.min.js b/phy/gui/static/tablesort.min.js new file mode 100644 index 000000000..0d1fdb157 --- /dev/null +++ b/phy/gui/static/tablesort.min.js @@ -0,0 +1,5 @@ +/*! + * tablesort v4.0.0 (2015-12-17) + * http://tristen.ca/tablesort/demo/ + * Copyright (c) 2015 ; Licensed MIT +*/!function(){function a(b,c){if(!(this instanceof a))return new a(b,c);if(!b||"TABLE"!==b.tagName)throw new Error("Element must be a table");this.init(b,c||{})}var b=[],c=function(a){var b;return window.CustomEvent&&"function"==typeof window.CustomEvent?b=new CustomEvent(a):(b=document.createEvent("CustomEvent"),b.initCustomEvent(a,!1,!1,void 0)),b},d=function(a){return a.getAttribute("data-sort")||a.textContent||a.innerText||""},e=function(a,b){return a=a.toLowerCase(),b=b.toLowerCase(),a===b?0:b>a?1:-1},f=function(a,b){return function(c,d){var e=a(c.td,d.td);return 0===e?b?d.index-c.index:c.index-d.index:e}};a.extend=function(a,c,d){if("function"!=typeof c||"function"!=typeof d)throw new Error("Pattern and sort must be a function");b.push({name:a,pattern:c,sort:d})},a.prototype={init:function(a,b){var c,d,e,f,g=this;if(g.table=a,g.thead=!1,g.options=b,a.rows&&a.rows.length>0&&(a.tHead&&a.tHead.rows.length>0?(c=a.tHead.rows[a.tHead.rows.length-1],g.thead=!0):c=a.rows[0]),c){var h=function(){g.current&&g.current!==this&&(g.current.classList.remove("sort-up"),g.current.classList.remove("sort-down")),g.current=this,g.sortTable(this)};for(e=0;e0&&m.push(l),n++;if(!m)return}for(n=0;nn;n++)s[n]?(l=s[n],u++):l=r[n-u].tr,i.table.tBodies[0].appendChild(l);i.table.dispatchEvent(c("afterSort"))}},refresh:function(){void 0!==this.current&&this.sortTable(this.current,!0)}},"undefined"!=typeof module&&module.exports?module.exports=a:window.Tablesort=a}(); \ No newline at end of file diff --git a/phy/gui/static/tablesort.number.js b/phy/gui/static/tablesort.number.js new file mode 100644 index 000000000..d43405fb8 --- /dev/null +++ b/phy/gui/static/tablesort.number.js @@ -0,0 +1,26 @@ +(function(){ + var cleanNumber = function(i) { + return i.replace(/[^\-\+eE\,?0-9\.]/g, ''); + }, + + compareNumber = function(a, b) { + a = parseFloat(a); + b = parseFloat(b); + + a = isNaN(a) ? 0 : a; + b = isNaN(b) ? 0 : b; + + return a - b; + }; + + Tablesort.extend('number', function(item) { + return item.match(/^-?[£\x24Û¢´€]?\d+\s*([,\.]\d{0,2})/) || // Prefixed currency + item.match(/^-?\d+\s*([,\.]\d{0,2})?[£\x24Û¢´€]/) || // Suffixed currency + item.match(/^-?(\d)*-?([,\.]){0,1}-?(\d)+([E,e][\-+][\d]+)?%?$/); // Number + }, function(a, b) { + a = cleanNumber(a); + b = cleanNumber(b); + + return compareNumber(b, a); + }); +}()); diff --git a/phy/gui/static/wrap_qt.html b/phy/gui/static/wrap_qt.html deleted file mode 100644 index fcad69ee7..000000000 --- a/phy/gui/static/wrap_qt.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - - -%HTML% - - - diff --git a/phy/gui/tests/conftest.py b/phy/gui/tests/conftest.py new file mode 100644 index 000000000..0ad5a9d9f --- /dev/null +++ b/phy/gui/tests/conftest.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +"""Test gui.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture + +from ..actions import Actions, Snippets +from ..gui import GUI + + +#------------------------------------------------------------------------------ +# Utilities and fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def gui(tempdir, qapp): + gui = GUI(position=(200, 100), size=(100, 100), config_dir=tempdir) + yield gui + gui.close() + + +@yield_fixture +def actions(gui): + yield Actions(gui) + + +@yield_fixture +def snippets(gui): + yield Snippets(gui) diff --git a/phy/gui/tests/test_actions.py b/phy/gui/tests/test_actions.py new file mode 100644 index 000000000..4077e6595 --- /dev/null +++ b/phy/gui/tests/test_actions.py @@ -0,0 +1,327 @@ +# -*- coding: utf-8 -*- + +"""Test dock.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import raises + +from ..actions import (_show_shortcuts, + _get_shortcut_string, + _get_qkeysequence, + _parse_snippet, + Actions, + ) +from phy.utils.testing import captured_output, captured_logging + + +#------------------------------------------------------------------------------ +# Test actions +#------------------------------------------------------------------------------ + +def test_shortcuts(qapp): + assert 'z' in _get_shortcut_string('Undo') + + def _assert_shortcut(name, key=None): + shortcut = _get_qkeysequence(name) + s = _get_shortcut_string(shortcut) + if key is None: + assert s == s + else: + assert key in s + + _assert_shortcut('Undo', 'z') + _assert_shortcut('Save', 's') + _assert_shortcut('q') + _assert_shortcut('ctrl+q') + _assert_shortcut(':') + _assert_shortcut(['ctrl+a', 'shift+b']) + + +def test_show_shortcuts(qapp): + # NOTE: a Qt application needs to be running so that we can use the + # KeySequence. + shortcuts = { + 'test_1': 'ctrl+t', + 'test_2': ('ctrl+a', 'shift+b'), + 'test_3': 'ctrl+z', + } + with captured_output() as (stdout, stderr): + _show_shortcuts(shortcuts, 'test') + assert 'ctrl+a, shift+b' in stdout.getvalue() + assert 'ctrl+z' in stdout.getvalue() + + +def test_actions_default_shortcuts(gui): + actions = Actions(gui, default_shortcuts={'my_action': 'a'}) + actions.add(lambda: None, name='my_action') + assert actions.shortcuts['my_action'] == 'a' + + +def test_actions_simple(actions): + + _res = [] + + def _action(*args): + _res.append(args) + + actions.add(_action, 'tes&t') + # Adding an action twice has no effect. + actions.add(_action, 'test') + + # Create a shortcut and display it. + _captured = [] + + @actions.add(shortcut='h') + def show_my_shortcuts(): + with captured_output() as (stdout, stderr): + actions.show_shortcuts() + _captured.append(stdout.getvalue()) + + actions.show_my_shortcuts() + assert 'show_my_shortcuts' in _captured[0] + assert ': h' in _captured[0] + + actions.run('t', 1) + assert _res == [(1,)] + + assert 'show_my_shortcuts' in actions + assert 'unknown' not in actions + + actions.remove_all() + + assert '' + snippets.mode_on() # ':' + snippets.actions._snippet_backspace() + _run('t3 hello') + snippets.actions._snippet_activate() # 'Enter' + assert _actions[-1] == (3, ('hello',)) + snippets.mode_off() + + +def test_snippets_actions_2(actions, snippets): + + _actions = [] + + @actions.add + def test(arg): + _actions.append(arg) + + actions.test(1) + assert _actions == [1] + + snippets.mode_on() + snippets.mode_off() + + actions.test(2) + assert _actions == [1, 2] diff --git a/phy/gui/tests/test_base.py b/phy/gui/tests/test_base.py deleted file mode 100644 index ea606eedc..000000000 --- a/phy/gui/tests/test_base.py +++ /dev/null @@ -1,289 +0,0 @@ -# -*- coding: utf-8 -*-1 -from __future__ import print_function - -"""Tests of base classes.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -from pytest import raises, mark - -from ..base import (BaseViewModel, - HTMLViewModel, - WidgetCreator, - BaseGUI, - ) -from ..qt import (QtGui, - _set_qt_widget_position_size, - ) -from ...utils.event import EventEmitter -from ...utils.logging import set_level -from ...io.base import BaseModel, BaseSession - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Base tests -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - -def test_base_view_model(qtbot): - class MyViewModel(BaseViewModel): - _view_name = 'main_window' - _imported_params = ('text',) - - def _create_view(self, text='', position=None, size=None): - view = QtGui.QMainWindow() - view.setWindowTitle(text) - _set_qt_widget_position_size(view, - position=position, - size=size, - ) - return view - - size = (400, 100) - text = 'hello' - - vm = MyViewModel(text=text, size=size) - qtbot.addWidget(vm.view) - vm.show() - - assert vm.view.windowTitle() == text - assert vm.text == text - assert vm.size == size - assert (vm.view.width(), vm.view.height()) == size - - vm.close() - - -def test_html_view_model(qtbot): - - class MyHTMLViewModel(HTMLViewModel): - def get_html(self, **kwargs): - return 'hello world!' - - vm = MyHTMLViewModel() - vm.show() - qtbot.addWidget(vm.view) - vm.close() - - -def test_widget_creator(): - - class MyWidget(EventEmitter): - """Mock widget.""" - def __init__(self, param=None): - super(MyWidget, self).__init__() - self.name = 'my_widget' - self._shown = False - self.param = param - - @property - def shown(self): - return self._shown - - def close(self, e=None): - self.emit('close', e) - self._shown = False - - def show(self): - self._shown = True - - widget_classes = {'my_widget': MyWidget} - - wc = WidgetCreator(widget_classes=widget_classes) - assert not wc.get() - assert not wc.get('my_widget') - - with raises(ValueError): - wc.add('my_widget_bis') - - for show in (False, True): - w = wc.add('my_widget', show=show, param=show) - assert len(wc.get()) == 1 - assert len(wc.get('my_widget')) == 1 - - assert w.shown is show - assert w.param is show - w.show() - assert w.shown - - w.close() - assert not wc.get() - assert not wc.get('my_widget') - - -def test_base_gui(qtbot): - - class V1(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 1' - - class V2(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 2' - - class V3(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 3' - - vm_classes = {'v1': V1, 'v2': V2, 'v3': V3} - - config = [('v1', {'position': 'right'}), - ('v2', {'position': 'left'}), - ('v2', {'position': 'bottom'}), - ('v3', {'position': 'left'}), - ] - - shortcuts = {'test': 't'} - - _message = [] - - def _snippet(gui, args): - _message.append(args) - - snippets = {'hello': _snippet} - - class TestGUI(BaseGUI): - def __init__(self): - super(TestGUI, self).__init__(vm_classes=vm_classes, - config=config, - shortcuts=shortcuts, - snippets=snippets, - ) - - def _create_actions(self): - self._add_gui_shortcut('test') - - def test(self): - self.show_shortcuts() - self.reset_gui() - - gui = TestGUI() - qtbot.addWidget(gui.main_window) - gui.show() - - # Test snippet mode. - gui.enable_snippet_mode() - - def _keystroke(char=None): - """Simulate a keystroke.""" - i = gui._snippet_action_name(char) - getattr(gui.main_window, 'snippet_{}'.format(i))() - - gui.enable_snippet_mode() - for c in 'hello world': - _keystroke(c) - assert gui.status_message == ':hello world' + gui._snippet_message_cursor - gui.main_window.snippet_activate() - assert _message == ['world'] - - # Test views. - v2 = gui.get_views('v2') - assert len(v2) == 2 - v2[1].close() - - v3 = gui.get_views('v3') - v3[0].close() - - gui.reset_gui() - - gui.close() - - -def test_base_session(tempdir, qtbot): - - phy_dir = op.join(tempdir, 'test.phy') - - model = BaseModel() - - class V1(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 1' - - class V2(HTMLViewModel): - def get_html(self, **kwargs): - return 'view 2' - - vm_classes = {'v1': V1, 'v2': V2} - - config = [('v1', {'position': 'right'}), - ('v2', {'position': 'left'}), - ('v2', {'position': 'bottom'}), - ] - - shortcuts = {'test': 't', 'exit': 'ctrl+q'} - - class TestGUI(BaseGUI): - def __init__(self, **kwargs): - super(TestGUI, self).__init__(vm_classes=vm_classes, - **kwargs) - self.on_open() - - def _create_actions(self): - self._add_gui_shortcut('test') - self._add_gui_shortcut('exit') - - def test(self): - self.show_shortcuts() - self.reset_gui() - - gui_classes = {'gui': TestGUI} - - default_settings_path = op.join(tempdir, 'default_settings.py') - - with open(default_settings_path, 'w') as f: - f.write("gui_config = {}\n".format(str(config)) + - "gui_shortcuts = {}".format(str(shortcuts))) - - session = BaseSession(model=model, - phy_user_dir=phy_dir, - default_settings_paths=[default_settings_path], - vm_classes=vm_classes, - gui_classes=gui_classes, - ) - - # New GUI. - gui = session.show_gui('gui') - qtbot.addWidget(gui.main_window) - qtbot.waitForWindowShown(gui.main_window) - - # Remove a v2 view. - v2 = gui.get_views('v2') - assert len(v2) == 2 - v2[0].close() - gui.close() - - # Reopen and check that the v2 is gone. - gui = session.show_gui('gui') - qtbot.addWidget(gui.main_window) - qtbot.waitForWindowShown(gui.main_window) - - v2 = gui.get_views('v2') - assert len(v2) == 1 - - gui.reset_gui() - v2 = gui.get_views('v2') - assert len(v2) == 2 - gui.close() - - gui = session.show_gui('gui') - qtbot.addWidget(gui.main_window) - qtbot.waitForWindowShown(gui.main_window) - - v2 = gui.get_views('v2') - assert len(v2) == 2 - gui.close() diff --git a/phy/gui/tests/test_dock.py b/phy/gui/tests/test_dock.py deleted file mode 100644 index 8ee2ce02f..000000000 --- a/phy/gui/tests/test_dock.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test dock.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark - -from vispy import app - -from ..dock import DockWindow -from ...utils._color import _random_color -from ...utils.logging import set_level - - -# Skip these tests in "make test-quick". -pytestmark = mark.long - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - -def _create_canvas(): - """Create a VisPy canvas with a color background.""" - c = app.Canvas() - c.color = _random_color() - - @c.connect - def on_draw(e): - c.context.clear(c.color) - - @c.connect - def on_key_press(e): - c.color = _random_color() - c.update() - - return c - - -def test_dock_1(qtbot): - - gui = DockWindow() - qtbot.addWidget(gui) - - @gui.shortcut('quit', 'ctrl+q') - def quit(): - gui.close() - - gui.add_view(_create_canvas(), 'view1') - gui.add_view(_create_canvas(), 'view2') - gui.show() - - assert len(gui.list_views('view')) == 2 - gui.close() - - -def test_dock_status_message(qtbot): - gui = DockWindow() - qtbot.addWidget(gui) - assert gui.status_message == '' - gui.status_message = ':hello world!' - assert gui.status_message == ':hello world!' - - -def test_dock_state(qtbot): - _gs = None - gui = DockWindow() - qtbot.addWidget(gui) - - @gui.shortcut('press', 'ctrl+g') - def press(): - pass - - gui.add_view(_create_canvas(), 'view1') - gui.add_view(_create_canvas(), 'view2') - gui.add_view(_create_canvas(), 'view2') - - @gui.connect_ - def on_close_gui(): - global _gs - _gs = gui.save_geometry_state() - - gui.show() - - assert len(gui.list_views('view')) == 3 - assert gui.view_count() == { - 'view1': 1, - 'view2': 2, - } - - gui.close() - - # Recreate the GUI with the saved state. - gui = DockWindow() - - gui.add_view(_create_canvas(), 'view1') - gui.add_view(_create_canvas(), 'view2') - gui.add_view(_create_canvas(), 'view2') - - @gui.connect_ - def on_show(): - print(_gs) - gui.restore_geometry_state(_gs) - - gui.show() - gui.close() diff --git a/phy/gui/tests/test_gui.py b/phy/gui/tests/test_gui.py new file mode 100644 index 000000000..f41b8773f --- /dev/null +++ b/phy/gui/tests/test_gui.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- + +"""Test gui.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import raises + +from ..qt import Qt, QApplication, QWidget +from ..gui import (GUI, GUIState, + _try_get_matplotlib_canvas, + _try_get_vispy_canvas, + ) +from phy.utils import Bunch +from phy.utils._color import _random_color + + +#------------------------------------------------------------------------------ +# Utilities and fixtures +#------------------------------------------------------------------------------ + +def _create_canvas(): + """Create a VisPy canvas with a color background.""" + from vispy import app + c = app.Canvas() + c.color = _random_color() + + @c.connect + def on_draw(e): # pragma: no cover + c.context.clear(c.color) + + return c + + +#------------------------------------------------------------------------------ +# Test views +#------------------------------------------------------------------------------ + +def test_vispy_view(): + from vispy.app import Canvas + assert isinstance(_try_get_vispy_canvas(Canvas()), QWidget) + + +def test_matplotlib_view(): + from matplotlib.pyplot import Figure + assert isinstance(_try_get_matplotlib_canvas(Figure()), QWidget) + + +#------------------------------------------------------------------------------ +# Test GUI +#------------------------------------------------------------------------------ + +def test_gui_noapp(tempdir): + if not QApplication.instance(): + with raises(RuntimeError): # pragma: no cover + GUI(config_dir=tempdir) + + +def test_gui_1(tempdir, qtbot): + + gui = GUI(position=(200, 100), size=(100, 100), config_dir=tempdir) + qtbot.addWidget(gui) + + assert gui.name == 'GUI' + + # Increase coverage. + @gui.connect_ + def on_show(): + pass + gui.unconnect_(on_show) + qtbot.keyPress(gui, Qt.Key_Control) + qtbot.keyRelease(gui, Qt.Key_Control) + + view = gui.add_view(_create_canvas(), floating=True, closable=True) + gui.add_view(_create_canvas()) + view.setFloating(False) + gui.show() + + assert len(gui.list_views('Canvas')) == 2 + + # Check that the close_widget event is fired when the gui widget is + # closed. + _close = [] + + @view.connect_ + def on_close_widget(): + _close.append(0) + + @gui.connect_ + def on_close_view(view): + _close.append(1) + + view.close() + assert _close == [1, 0] + + gui.close() + + assert gui.state.geometry_state['geometry'] + assert gui.state.geometry_state['state'] + + gui.default_actions.exit() + + +def test_gui_status_message(gui): + assert gui.status_message == '' + gui.status_message = ':hello world!' + assert gui.status_message == ':hello world!' + + gui.lock_status() + gui.status_message = '' + assert gui.status_message == ':hello world!' + gui.unlock_status() + gui.status_message = '' + assert gui.status_message == '' + + +def test_gui_geometry_state(tempdir, qtbot): + _gs = [] + gui = GUI(size=(100, 100), config_dir=tempdir) + qtbot.addWidget(gui) + + gui.add_view(_create_canvas(), 'view1') + gui.add_view(_create_canvas(), 'view2') + gui.add_view(_create_canvas(), 'view2') + + @gui.connect_ + def on_close(): + _gs.append(gui.save_geometry_state()) + + gui.show() + qtbot.waitForWindowShown(gui) + + assert len(gui.list_views('view')) == 3 + assert gui.view_count() == { + 'view1': 1, + 'view2': 2, + } + + gui.close() + + # Recreate the GUI with the saved state. + gui = GUI(config_dir=tempdir) + + gui.add_view(_create_canvas(), 'view1') + gui.add_view(_create_canvas(), 'view2') + gui.add_view(_create_canvas(), 'view2') + + @gui.connect_ + def on_show(): + gui.restore_geometry_state(_gs[0]) + + assert gui.restore_geometry_state(None) is None + + qtbot.addWidget(gui) + gui.show() + + assert len(gui.list_views('view')) == 3 + assert gui.view_count() == { + 'view1': 1, + 'view2': 2, + } + + gui.close() + + +#------------------------------------------------------------------------------ +# Test GUI state +#------------------------------------------------------------------------------ + +def test_gui_state_view(tempdir): + view = Bunch(name='MyView0') + state = GUIState(config_dir=tempdir) + state.update_view_state(view, dict(hello='world')) + assert not state.get_view_state(Bunch(name='MyView')) + assert not state.get_view_state(Bunch(name='MyView1')) + assert state.get_view_state(view) == Bunch(hello='world') diff --git a/phy/gui/tests/test_qt.py b/phy/gui/tests/test_qt.py index d147a610e..3e12053ad 100644 --- a/phy/gui/tests/test_qt.py +++ b/phy/gui/tests/test_qt.py @@ -6,41 +6,79 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark +from pytest import raises -from ..qt import (QtWebKit, QtGui, - qt_app, - _set_qt_widget_position_size, +from ..qt import (QMessageBox, Qt, QWebView, QTimer, + _button_name_from_enum, + _button_enum_from_name, _prompt, + _wait_signal, + require_qt, + create_app, + QApplication, ) -from ...utils.logging import set_level - - -# Skip these tests in "make test-quick". -pytestmark = mark.long #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ -def setup(): - set_level('debug') +def test_require_qt_with_app(): + + @require_qt + def f(): + pass + + if not QApplication.instance(): + with raises(RuntimeError): # pragma: no cover + f() + +def test_require_qt_without_app(qapp): -def teardown(): - set_level('info') + @require_qt + def f(): + pass + + # This should not raise an error. + f() + + +def test_qt_app(qtbot): + create_app() + view = QWebView() + qtbot.addWidget(view) + view.close() -def test_wrap(qtbot): +def test_wait_signal(qtbot): + x = [] - view = QtWebKit.QWebView() + def f(): + x.append(0) + + timer = QTimer() + timer.setInterval(100) + timer.setSingleShot(True) + timer.timeout.connect(f) + timer.start() + + assert x == [] + + with _wait_signal(timer.timeout): + pass + assert x == [0] + + +def test_web_view(qtbot): + + view = QWebView() def _assert(text): html = view.page().mainFrame().toHtml() assert html == '' + text + '' - _set_qt_widget_position_size(view, size=(100, 100)) + view.resize(100, 100) view.setHtml("hello") qtbot.addWidget(view) qtbot.waitForWindowShown(view) @@ -51,8 +89,8 @@ def _assert(text): _assert('world') view.close() - view = QtWebKit.QWebView() - _set_qt_widget_position_size(view, size=(100, 100)) + view = QWebView() + view.resize(100, 100) view.show() qtbot.addWidget(view) @@ -61,13 +99,13 @@ def _assert(text): view.close() -@mark.skipif() -def test_prompt(): - with qt_app(): - w = QtGui.QWidget() - w.show() - result = _prompt(w, - "How are you doing?", - buttons=['save', 'cancel', 'close'], - ) - print(result) +def test_prompt(qtbot): + + assert _button_name_from_enum(QMessageBox.Save) == 'save' + assert _button_enum_from_name('save') == QMessageBox.Save + + box = _prompt("How are you doing?", + buttons=['save', 'cancel', 'close'], + ) + qtbot.mouseClick(box.buttons()[0], Qt.LeftButton) + assert 'save' in str(box.clickedButton().text()).lower() diff --git a/phy/gui/tests/test_widgets.py b/phy/gui/tests/test_widgets.py new file mode 100644 index 000000000..fc3601241 --- /dev/null +++ b/phy/gui/tests/test_widgets.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- + +"""Test widgets.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture, raises + +from ..widgets import HTMLWidget, Table + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def table(qtbot): + table = Table() + table.show() + # qtbot.waitForWindowShown(table) + + def count(id): + return 10000.5 - 10 * id + table.add_column(count, show=True) + + def skip(id): + return id == 4 + table.add_column(skip) + + table.set_rows(range(10)) + + yield table + + table.close() + + +#------------------------------------------------------------------------------ +# Test widgets +#------------------------------------------------------------------------------ + +def test_widget_empty(qtbot): + widget = HTMLWidget() + widget.show() + # qtbot.waitForWindowShown(widget) + # qtbot.stop() + + +def test_widget_html(qtbot): + widget = HTMLWidget() + widget.add_styles('html, body, p {background-color: purple;}') + widget.add_header('') + widget.set_body('Hello world!') + widget.show() + # qtbot.waitForWindowShown(widget) + assert 'Hello world!' in widget.html() + + +def test_widget_javascript_1(qtbot): + widget = HTMLWidget() + widget.eval_js('number = 1;') + widget.show() + # qtbot.waitForWindowShown(widget) + + assert widget.eval_js('number') == 1 + + +def test_widget_javascript_2(qtbot): + widget = HTMLWidget() + widget.show() + # qtbot.waitForWindowShown(widget) + _out = [] + + @widget.connect_ + def on_test(arg): + _out.append(arg) + + widget.eval_js('emit("test", [1, 2]);') + assert _out == [[1, 2]] + + widget.unconnect_(on_test) + + # qtbot.stop() + + +#------------------------------------------------------------------------------ +# Test table +#------------------------------------------------------------------------------ + +def test_table_current_sort(): + assert Table().current_sort == (None, None) + + +def test_table_default_sort(qtbot): + table = Table() + table.show() + # qtbot.waitForWindowShown(table) + + with raises(ValueError): + table.add_column(lambda _: _) + + def count(id): + return 10000.5 - 10 * id + table.add_column(count) + table.set_default_sort('count', 'asc') + table.set_rows(range(10)) + + assert table.default_sort == ('count', 'asc') + table.next() + assert table.selected == [9] + + table.sort_by('id', 'desc') + table.set_rows(range(11)) + table.next() + assert table.selected == [10] + + table.close() + + +def test_table_duplicates(qtbot, table): + assert table.default_sort == (None, None) + + table.select([1, 1]) + assert table.selected == [1] + # qtbot.stop() + + +def test_table_nav_first(qtbot, table): + table.next() + assert table.selected == [0] + + +def test_table_nav_last(qtbot, table): + table.previous() + assert table.selected == [9] + + +def test_table_nav_edge_0(qtbot, table): + # The first item is skipped. + table.set_rows([4, 5]) + table.next() + assert table.selected == [5] + + +def test_table_nav_edge_1(qtbot, table): + # The last item is skipped. + table.set_rows([3, 4]) + table.previous() + assert table.selected == [3] + + +def test_table_nav_0(qtbot, table): + table.select([4]) + + table.next() + assert table.selected == [5] + + table.previous() + assert table.selected == [3] + + _sel = [] + + @table.connect_ + def on_select(items): + _sel.append(items) + + table.eval_js('table.select([1]);') + assert _sel == [[1]] + + assert table.selected == [1] + + # qtbot.stop() + + +def test_table_sort(qtbot, table): + table.select([1]) + + # Sort by count decreasing, and check that 0 (count 100) comes before + # 1 (count 90). This checks that sorting works with number (need to + # import tablesort.number.js). + table.sort_by('count', 'desc') + + table.previous() + assert table.selected == [0] + + assert table.current_sort == ('count', 'desc') + + # qtbot.stop() diff --git a/phy/gui/widgets.py b/phy/gui/widgets.py new file mode 100644 index 000000000..0379ce1cc --- /dev/null +++ b/phy/gui/widgets.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- + +"""HTML widgets for GUIs.""" + + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- + +from collections import OrderedDict +import json +import logging +import os.path as op + +from six import text_type + +from .qt import (QWebView, QWebPage, QUrl, QWebSettings, + QVariant, QPyNullVariant, QString, + pyqtSlot, _wait_signal, + ) +from phy.utils import EventEmitter +from phy.utils._misc import _CustomEncoder + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# HTML widget +# ----------------------------------------------------------------------------- + +_DEFAULT_STYLES = """ + html, body, table { + background-color: black; + color: white; + font-family: sans-serif; + font-size: 12pt; + margin: 5px 10px; + } +""" + + +_PAGE_TEMPLATE = """ + + + {title:s} + + {header:s} + + + +{body:s} + + + +""" + + +class WebPage(QWebPage): + def javaScriptConsoleMessage(self, msg, line, source): + logger.debug("[%d] %s", line, msg) # pragma: no cover + + +def _to_py(obj): # pragma: no cover + if isinstance(obj, QVariant): + return obj.toPyObject() + elif QString and isinstance(obj, QString): + return text_type(obj) + elif isinstance(obj, QPyNullVariant): + return None + elif isinstance(obj, list): + return [_to_py(_) for _ in obj] + elif isinstance(obj, tuple): + return tuple(_to_py(_) for _ in obj) + else: + return obj + + +class HTMLWidget(QWebView): + """An HTML widget that is displayed with Qt. + + Python methods can be called from Javascript with `widget.the_method()`. + They must be decorated with `pyqtSlot(str)` or similar, depending on + the parameters. + + """ + title = 'Widget' + body = '' + + def __init__(self): + super(HTMLWidget, self).__init__() + self.settings().setAttribute( + QWebSettings.LocalContentCanAccessRemoteUrls, True) + self.settings().setAttribute( + QWebSettings.DeveloperExtrasEnabled, True) + self.setPage(WebPage()) + self._obj = None + self._styles = [_DEFAULT_STYLES] + self._header = '' + self._body = '' + self.add_to_js('widget', self) + self._event = EventEmitter() + self.add_header('''''') + self._pending_js_eval = [] + + # Events + # ------------------------------------------------------------------------- + + def emit(self, *args, **kwargs): + return self._event.emit(*args, **kwargs) + + def connect_(self, *args, **kwargs): + self._event.connect(*args, **kwargs) + + def unconnect_(self, *args, **kwargs): + self._event.unconnect(*args, **kwargs) + + # Headers + # ------------------------------------------------------------------------- + + def add_styles(self, s): + """Add CSS styles.""" + self._styles.append(s) + + def add_style_src(self, filename): + """Link a CSS file.""" + self.add_header(('').format(filename)) + + def add_script_src(self, filename): + """Link a JS script.""" + self.add_header(''.format(filename)) + + def add_header(self, h): + """Add HTML code to the header.""" + self._header += (h + '\n') + + # HTML methods + # ------------------------------------------------------------------------- + + def set_body(self, s): + """Set the HTML body.""" + self._body = s + + def add_body(self, s): + """Add HTML code to the body.""" + self._body += '\n' + s + '\n' + + def html(self): + """Return the full HTML source of the widget.""" + return self.page().mainFrame().toHtml() + + def build(self): + """Build the full HTML source.""" + if self.is_built(): # pragma: no cover + return + with _wait_signal(self.loadFinished, 20): + styles = '\n\n'.join(self._styles) + html = _PAGE_TEMPLATE.format(title=self.title, + styles=styles, + header=self._header, + body=self._body, + ) + logger.log(5, "Set HTML: %s", html) + static_dir = op.join(op.realpath(op.dirname(__file__)), 'static/') + base_url = QUrl().fromLocalFile(static_dir) + self.setHtml(html, base_url) + + def is_built(self): + return self.html() != '' + + # Javascript methods + # ------------------------------------------------------------------------- + + def add_to_js(self, name, var): + """Add an object to Javascript.""" + frame = self.page().mainFrame() + frame.addToJavaScriptWindowObject(name, var) + + def eval_js(self, expr): + """Evaluate a Javascript expression.""" + if not self.is_built(): + self._pending_js_eval.append(expr) + return + logger.log(5, "Evaluate Javascript: `%s`.", expr) + out = self.page().mainFrame().evaluateJavaScript(expr) + return _to_py(out) + + @pyqtSlot(str, str) + def _emit_from_js(self, name, arg_json): + self.emit(text_type(name), json.loads(text_type(arg_json))) + + def show(self): + self.build() + super(HTMLWidget, self).show() + # Call the pending JS eval calls after the page has been built. + assert self.is_built() + for expr in self._pending_js_eval: + self.eval_js(expr) + self._pending_js_eval = [] + + +# ----------------------------------------------------------------------------- +# HTML table +# ----------------------------------------------------------------------------- + +def dumps(o): + return json.dumps(o, cls=_CustomEncoder) + + +def _create_json_dict(**kwargs): + d = {} + # Remove None elements. + for k, v in kwargs.items(): + if v is not None: + d[k] = v + # The custom encoder serves for NumPy scalars that are non + # JSON-serializable (!!). + return dumps(d) + + +class Table(HTMLWidget): + """A sortable table with support for selection.""" + + _table_id = 'the-table' + + def __init__(self): + super(Table, self).__init__() + self.add_style_src('table.css') + self.add_script_src('tablesort.min.js') + self.add_script_src('tablesort.number.js') + self.add_script_src('table.js') + self.set_body('
'.format( + self._table_id)) + self.add_body(''''''.format(self._table_id)) + self._columns = OrderedDict() + self._default_sort = (None, None) + self.add_column(lambda _: _, name='id') + + def add_column(self, func, name=None, show=True): + """Add a column function which takes an id as argument and + returns a value.""" + assert func + name = name or func.__name__ + if name == '': + raise ValueError("Please provide a valid name for " + name) + d = {'func': func, + 'show': show, + } + self._columns[name] = d + + # Update the headers in the widget. + data = _create_json_dict(cols=self.column_names, + ) + self.eval_js('table.setHeaders({});'.format(data)) + + return func + + @property + def column_names(self): + """List of column names.""" + return [name for (name, d) in self._columns.items() + if d.get('show', True)] + + def _get_row(self, id): + """Create a row dictionary for a given object id.""" + return {name: d['func'](id) for (name, d) in self._columns.items()} + + def set_rows(self, ids): + """Set the rows of the table.""" + # NOTE: make sure we have integers and not np.generic objects. + assert all(isinstance(i, int) for i in ids) + + # Determine the sort column and dir to set after the rows. + sort_col, sort_dir = self.current_sort + default_sort_col, default_sort_dir = self.default_sort + + sort_col = sort_col or default_sort_col + sort_dir = sort_dir or default_sort_dir or 'desc' + + # Set the rows. + logger.log(5, "Set %d rows in the table.", len(ids)) + items = [self._get_row(id) for id in ids] + # Sort the rows before passing them to the widget. + # if sort_col: + # items = sorted(items, key=itemgetter(sort_col), + # reverse=(sort_dir == 'desc')) + data = _create_json_dict(items=items, + cols=self.column_names, + ) + self.eval_js('table.setData({});'.format(data)) + + # Sort. + if sort_col: + self.sort_by(sort_col, sort_dir) + + def sort_by(self, name, sort_dir='asc'): + """Sort by a given variable.""" + logger.log(5, "Sort by `%s` %s.", name, sort_dir) + self.eval_js('table.sortBy("{}", "{}");'.format(name, sort_dir)) + + def next(self): + """Select the next non-skipped row.""" + self.eval_js('table.next();') + + def previous(self): + """Select the previous non-skipped row.""" + self.eval_js('table.previous();') + + def select(self, ids, do_emit=True): + """Select some rows.""" + do_emit = str(do_emit).lower() + self.eval_js('table.select({}, {});'.format(dumps(ids), do_emit)) + + @property + def default_sort(self): + """Default sort as a pair `(name, dir)`.""" + return self._default_sort + + def set_default_sort(self, name, sort_dir='desc'): + """Set the default sort column.""" + self._default_sort = name, sort_dir + + @property + def selected(self): + """Currently selected rows.""" + return [int(_) for _ in self.eval_js('table.selected')] + + @property + def current_sort(self): + """Current sort: a tuple `(name, dir)`.""" + return tuple(self.eval_js('table.currentSort()') or (None, None)) diff --git a/phy/io/__init__.py b/phy/io/__init__.py index 7c7cb105b..ce3214b80 100644 --- a/phy/io/__init__.py +++ b/phy/io/__init__.py @@ -3,9 +3,5 @@ """Input/output.""" -from .base import BaseModel, BaseSession -from .h5 import File, open_h5 -from .store import ClusterStore, StoreItem -from .traces import read_dat, read_kwd -from .kwik.creator import KwikCreator, create_kwik -from .kwik.model import KwikModel +from .context import Context +from .array import Selector, select_spikes diff --git a/phy/io/array.py b/phy/io/array.py new file mode 100644 index 000000000..c79ba1561 --- /dev/null +++ b/phy/io/array.py @@ -0,0 +1,648 @@ +# -*- coding: utf-8 -*- + +"""Utility functions for NumPy arrays.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from collections import defaultdict +from functools import wraps +import logging +import math +from math import floor, exp +from operator import itemgetter +import os.path as op + +import numpy as np + +from phy.utils import Bunch, _as_scalar, _as_scalars +from phy.utils._types import _as_array, _is_array_like + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Utility functions +#------------------------------------------------------------------------------ + +def _range_from_slice(myslice, start=None, stop=None, step=None, length=None): + """Convert a slice to an array of integers.""" + assert isinstance(myslice, slice) + # Find 'step'. + step = step if step is not None else myslice.step + if step is None: + step = 1 + # Find 'start'. + start = start if start is not None else myslice.start + if start is None: + start = 0 + # Find 'stop' as a function of length if 'stop' is unspecified. + stop = stop if stop is not None else myslice.stop + if length is not None: + stop_inferred = floor(start + step * length) + if stop is not None and stop < stop_inferred: + raise ValueError("'stop' ({stop}) and ".format(stop=stop) + + "'length' ({length}) ".format(length=length) + + "are not compatible.") + stop = stop_inferred + if stop is None and length is None: + raise ValueError("'stop' and 'length' cannot be both unspecified.") + myrange = np.arange(start, stop, step) + # Check the length if it was specified. + if length is not None: + assert len(myrange) == length + return myrange + + +def _unique(x): + """Faster version of np.unique(). + + This version is restricted to 1D arrays of non-negative integers. + + It is only faster if len(x) >> len(unique(x)). + + """ + if x is None or len(x) == 0: + return np.array([], dtype=np.int64) + # WARNING: only keep positive values. + # cluster=-1 means "unclustered". + x = _as_array(x) + x = x[x >= 0] + bc = np.bincount(x) + return np.nonzero(bc)[0] + + +def _normalize(arr, keep_ratio=False): + """Normalize an array into [0, 1].""" + (x_min, y_min), (x_max, y_max) = arr.min(axis=0), arr.max(axis=0) + + if keep_ratio: + a = 1. / max(x_max - x_min, y_max - y_min) + ax = ay = a + bx = .5 - .5 * a * (x_max + x_min) + by = .5 - .5 * a * (y_max + y_min) + else: + ax = 1. / (x_max - x_min) + ay = 1. / (y_max - y_min) + bx = -x_min / (x_max - x_min) + by = -y_min / (y_max - y_min) + + arr_n = arr.copy() + arr_n[:, 0] *= ax + arr_n[:, 0] += bx + arr_n[:, 1] *= ay + arr_n[:, 1] += by + + return arr_n + + +def _index_of(arr, lookup): + """Replace scalars in an array by their indices in a lookup table. + + Implicitely assume that: + + * All elements of arr and lookup are non-negative integers. + * All elements or arr belong to lookup. + + This is not checked for performance reasons. + + """ + # Equivalent of np.digitize(arr, lookup) - 1, but much faster. + # TODO: assertions to disable in production for performance reasons. + # TODO: np.searchsorted(lookup, arr) is faster on small arrays with large + # values + lookup = np.asarray(lookup, dtype=np.int32) + m = (lookup.max() if len(lookup) else 0) + 1 + tmp = np.zeros(m + 1, dtype=np.int) + # Ensure that -1 values are kept. + tmp[-1] = -1 + if len(lookup): + tmp[lookup] = np.arange(len(lookup)) + return tmp[arr] + + +def _pad(arr, n, dir='right'): + """Pad an array with zeros along the first axis. + + Parameters + ---------- + + n : int + Size of the returned array in the first axis. + dir : str + Direction of the padding. Must be one 'left' or 'right'. + + """ + assert dir in ('left', 'right') + if n < 0: + raise ValueError("'n' must be positive: {0}.".format(n)) + elif n == 0: + return np.zeros((0,) + arr.shape[1:], dtype=arr.dtype) + n_arr = arr.shape[0] + shape = (n,) + arr.shape[1:] + if n_arr == n: + assert arr.shape == shape + return arr + elif n_arr < n: + out = np.zeros(shape, dtype=arr.dtype) + if dir == 'left': + out[-n_arr:, ...] = arr + elif dir == 'right': + out[:n_arr, ...] = arr + assert out.shape == shape + return out + else: + if dir == 'left': + out = arr[-n:, ...] + elif dir == 'right': + out = arr[:n, ...] + assert out.shape == shape + return out + + +def _get_padded(data, start, end): + """Return `data[start:end]` filling in with zeros outside array bounds + + Assumes that either `start<0` or `end>len(data)` but not both. + + """ + if start < 0 and end > data.shape[0]: + raise RuntimeError() + if start < 0: + start_zeros = np.zeros((-start, data.shape[1]), + dtype=data.dtype) + return np.vstack((start_zeros, data[:end])) + elif end > data.shape[0]: + end_zeros = np.zeros((end - data.shape[0], data.shape[1]), + dtype=data.dtype) + return np.vstack((data[start:], end_zeros)) + else: + return data[start:end] + + +def _in_polygon(points, polygon): + """Return the points that are inside a polygon.""" + from matplotlib.path import Path + points = _as_array(points) + polygon = _as_array(polygon) + assert points.ndim == 2 + assert polygon.ndim == 2 + if len(polygon): + polygon = np.vstack((polygon, polygon[0])) + path = Path(polygon, closed=True) + return path.contains_points(points) + + +def _get_data_lim(arr, n_spikes=None): + n = arr.shape[0] + k = max(1, n // n_spikes) if n_spikes else 1 + arr = np.abs(arr[::k]) + n = arr.shape[0] + arr = arr.reshape((n, -1)) + return arr.max() + + +def get_closest_clusters(cluster_id, cluster_ids, sim_func, max_n=None): + """Return a list of pairs `(cluster, similarity)` sorted by decreasing + similarity to a given cluster.""" + l = [(_as_scalar(candidate), _as_scalar(sim_func(cluster_id, candidate))) + for candidate in _as_scalars(cluster_ids)] + l = sorted(l, key=itemgetter(1), reverse=True) + return l[:max_n] + + +def concat_per_cluster(f): + """Take a function accepting a single cluster, and return a function + accepting multiple clusters.""" + @wraps(f) + def wrapped(cluster_ids, **kwargs): + # Single cluster. + if not hasattr(cluster_ids, '__len__'): + return f(cluster_ids, **kwargs) + # Concatenate the result of multiple clusters. + l = [f(c, **kwargs) for c in cluster_ids] + # Handle the case where every function returns a list of Bunch. + if l and isinstance(l[0], list): + # We assume that all items have the same length. + n = len(l[0]) + return [Bunch(_accumulate([item[i] for item in l])) + for i in range(n)] + else: + return Bunch(_accumulate(l)) + return wrapped + + +# ----------------------------------------------------------------------------- +# I/O functions +# ----------------------------------------------------------------------------- + +def read_array(path, mmap_mode=None): + """Read a .npy array.""" + file_ext = op.splitext(path)[1] + if file_ext == '.npy': + return np.load(path, mmap_mode=mmap_mode) + raise NotImplementedError("The file extension `{}` ".format(file_ext) + + "is not currently supported.") + + +def write_array(path, arr): + """Write an array to a .npy file.""" + file_ext = op.splitext(path)[1] + if file_ext == '.npy': + return np.save(path, arr) + raise NotImplementedError("The file extension `{}` ".format(file_ext) + + "is not currently supported.") + + +# ----------------------------------------------------------------------------- +# Virtual concatenation +# ----------------------------------------------------------------------------- + +def _start_stop(item): + """Find the start and stop indices of a __getitem__ item. + + This is used only by ConcatenatedArrays. + + Only two cases are supported currently: + + * Single integer. + * Contiguous slice in the first dimension only. + + """ + if isinstance(item, tuple): + item = item[0] + if isinstance(item, slice): + # Slice. + if item.step not in (None, 1): + return NotImplementedError() + return item.start, item.stop + elif isinstance(item, (list, np.ndarray)): + # List or array of indices. + return np.min(item), np.max(item) + else: + # Integer. + return item, item + 1 + + +def _fill_index(arr, item): + if isinstance(item, tuple): + item = (slice(None, None, None),) + item[1:] + return arr[item] + else: + return arr + + +class ConcatenatedArrays(object): + """This object represents a concatenation of several memory-mapped + arrays.""" + def __init__(self, arrs, cols=None): + assert isinstance(arrs, list) + self.arrs = arrs + # Reordering of the columns. + self.cols = cols + self.offsets = np.concatenate([[0], np.cumsum([arr.shape[0] + for arr in arrs])], + axis=0) + self.dtype = arrs[0].dtype if arrs else None + + @property + def shape(self): + if self.arrs[0].ndim == 1: + return (self.offsets[-1],) + ncols = (len(self.cols) if self.cols is not None + else self.arrs[0].shape[1]) + return (self.offsets[-1], ncols) + + def _get_recording(self, index): + """Return the recording that contains a given index.""" + assert index >= 0 + recs = np.nonzero((index - self.offsets[:-1]) >= 0)[0] + if len(recs) == 0: + # If the index is greater than the total size, + # return the last recording. + return len(self.arrs) - 1 + # Return the last recording such that the index is greater than + # its offset. + return recs[-1] + + def __getitem__(self, item): + cols = self.cols if self.cols is not None else slice(None, None, None) + # Get the start and stop indices of the requested item. + start, stop = _start_stop(item) + # Return the concatenation of all arrays. + if start is None and stop is None: + return np.concatenate(self.arrs, axis=0)[..., cols] + if start is None: + start = 0 + if stop is None: + stop = self.offsets[-1] + if stop < 0: + stop = self.offsets[-1] + stop + # Get the recording indices of the first and last item. + rec_start = self._get_recording(start) + rec_stop = self._get_recording(stop) + assert 0 <= rec_start <= rec_stop < len(self.arrs) + # Find the start and stop relative to the arrays. + start_rel = start - self.offsets[rec_start] + stop_rel = stop - self.offsets[rec_stop] + # Single array case. + if rec_start == rec_stop: + # Apply the rest of the index. + return _fill_index(self.arrs[rec_start][start_rel:stop_rel], + item)[..., cols] + chunk_start = self.arrs[rec_start][start_rel:] + chunk_stop = self.arrs[rec_stop][:stop_rel] + # Concatenate all chunks. + l = [chunk_start] + if rec_stop - rec_start >= 2: + logger.warn("Loading a full virtual array: this might be slow " + "and something might be wrong.") + l += [self.arrs[r][...] for r in range(rec_start + 1, + rec_stop)] + l += [chunk_stop] + # Apply the rest of the index. + return _fill_index(np.concatenate(l, axis=0), item)[..., cols] + + def __len__(self): + return self.shape[0] + + +def _concatenate_virtual_arrays(arrs, cols=None): + """Return a virtual concatenate of several NumPy arrays.""" + n = len(arrs) + if n == 0: + return None + return ConcatenatedArrays(arrs, cols) + + +# ----------------------------------------------------------------------------- +# Chunking functions +# ----------------------------------------------------------------------------- + +def _excerpt_step(n_samples, n_excerpts=None, excerpt_size=None): + """Compute the step of an excerpt set as a function of the number + of excerpts or their sizes.""" + assert n_excerpts >= 2 + step = max((n_samples - excerpt_size) // (n_excerpts - 1), + excerpt_size) + return step + + +def chunk_bounds(n_samples, chunk_size, overlap=0): + """Return chunk bounds. + + Chunks have the form: + + [ overlap/2 | chunk_size-overlap | overlap/2 ] + s_start keep_start keep_end s_end + + Except for the first and last chunks which do not have a left/right + overlap. + + This generator yields (s_start, s_end, keep_start, keep_end). + + """ + s_start = 0 + s_end = chunk_size + keep_start = s_start + keep_end = s_end - overlap // 2 + yield s_start, s_end, keep_start, keep_end + + while s_end - overlap + chunk_size < n_samples: + s_start = s_end - overlap + s_end = s_start + chunk_size + keep_start = keep_end + keep_end = s_end - overlap // 2 + if s_start < s_end: + yield s_start, s_end, keep_start, keep_end + + s_start = s_end - overlap + s_end = n_samples + keep_start = keep_end + keep_end = s_end + if s_start < s_end: + yield s_start, s_end, keep_start, keep_end + + +def excerpts(n_samples, n_excerpts=None, excerpt_size=None): + """Yield (start, end) where start is included and end is excluded.""" + assert n_excerpts >= 2 + step = _excerpt_step(n_samples, + n_excerpts=n_excerpts, + excerpt_size=excerpt_size) + for i in range(n_excerpts): + start = i * step + if start >= n_samples: + break + end = min(start + excerpt_size, n_samples) + yield start, end + + +def data_chunk(data, chunk, with_overlap=False): + """Get a data chunk.""" + assert isinstance(chunk, tuple) + if len(chunk) == 2: + i, j = chunk + elif len(chunk) == 4: + if with_overlap: + i, j = chunk[:2] + else: + i, j = chunk[2:] + else: + raise ValueError("'chunk' should have 2 or 4 elements, " + "not {0:d}".format(len(chunk))) + return data[i:j, ...] + + +def get_excerpts(data, n_excerpts=None, excerpt_size=None): + assert n_excerpts is not None + assert excerpt_size is not None + if len(data) < n_excerpts * excerpt_size: + return data + elif n_excerpts == 0: + return data[:0] + elif n_excerpts == 1: + return data[:excerpt_size] + out = np.concatenate([data_chunk(data, chunk) + for chunk in excerpts(len(data), + n_excerpts=n_excerpts, + excerpt_size=excerpt_size)]) + assert len(out) <= n_excerpts * excerpt_size + return out + + +# ----------------------------------------------------------------------------- +# Spike clusters utility functions +# ----------------------------------------------------------------------------- + +def _spikes_in_clusters(spike_clusters, clusters): + """Return the ids of all spikes belonging to the specified clusters.""" + if len(spike_clusters) == 0 or len(clusters) == 0: + return np.array([], dtype=np.int) + return np.nonzero(np.in1d(spike_clusters, clusters))[0] + + +def _spikes_per_cluster(spike_clusters, spike_ids=None): + """Return a dictionary {cluster: list_of_spikes}.""" + if spike_clusters is None or not len(spike_clusters): + return {} + if spike_ids is None: + spike_ids = np.arange(len(spike_clusters)).astype(np.int64) + rel_spikes = np.argsort(spike_clusters) + abs_spikes = spike_ids[rel_spikes] + spike_clusters = spike_clusters[rel_spikes] + + diff = np.empty_like(spike_clusters) + diff[0] = 1 + diff[1:] = np.diff(spike_clusters) + + idx = np.nonzero(diff > 0)[0] + clusters = spike_clusters[idx] + + spikes_in_clusters = {clusters[i]: np.sort(abs_spikes[idx[i]:idx[i + 1]]) + for i in range(len(clusters) - 1)} + spikes_in_clusters[clusters[-1]] = np.sort(abs_spikes[idx[-1]:]) + + return spikes_in_clusters + + +def _flatten_per_cluster(per_cluster): + """Convert a dictionary {cluster: spikes} to a spikes array.""" + return np.sort(np.concatenate(list(per_cluster.values()))).astype(np.int64) + + +def grouped_mean(arr, spike_clusters): + """Compute the mean of a spike-dependent quantity for every cluster. + + The two arguments should be 1D array with `n_spikes` elements. + + The output is a 1D array with `n_clusters` elements. The clusters are + sorted in increasing order. + + """ + arr = np.asarray(arr) + spike_clusters = np.asarray(spike_clusters) + assert arr.ndim == 1 + assert arr.shape[0] == len(spike_clusters) + cluster_ids = _unique(spike_clusters) + spike_clusters_rel = _index_of(spike_clusters, cluster_ids) + spike_counts = np.bincount(spike_clusters_rel) + assert len(spike_counts) == len(cluster_ids) + t = np.zeros(len(cluster_ids)) + # Compute the sum with possible repetitions. + np.add.at(t, spike_clusters_rel, arr) + return t / spike_counts + + +def regular_subset(spikes, n_spikes_max=None, offset=0): + """Prune the current selection to get at most n_spikes_max spikes.""" + assert spikes is not None + # Nothing to do if the selection already satisfies n_spikes_max. + if n_spikes_max is None or len(spikes) <= n_spikes_max: # pragma: no cover + return spikes + step = math.ceil(np.clip(1. / n_spikes_max * len(spikes), + 1, len(spikes))) + step = int(step) + # Note: randomly-changing selections are confusing... + my_spikes = spikes[offset::step][:n_spikes_max] + assert len(my_spikes) <= len(spikes) + assert len(my_spikes) <= n_spikes_max + return my_spikes + + +def select_spikes(cluster_ids=None, + max_n_spikes_per_cluster=None, + spikes_per_cluster=None): + """Return a selection of spikes belonging to the specified clusters.""" + assert _is_array_like(cluster_ids) + if not len(cluster_ids): + return np.array([], dtype=np.int64) + if max_n_spikes_per_cluster in (None, 0): + selection = {c: spikes_per_cluster(c) for c in cluster_ids} + else: + assert max_n_spikes_per_cluster > 0 + selection = {} + n_clusters = len(cluster_ids) + for cluster in cluster_ids: + # Decrease the number of spikes per cluster when there + # are more clusters. + n = int(max_n_spikes_per_cluster * exp(-.1 * (n_clusters - 1))) + n = max(1, n) + spikes = spikes_per_cluster(cluster) + selection[cluster] = regular_subset(spikes, n_spikes_max=n) + return _flatten_per_cluster(selection) + + +class Selector(object): + """This object is passed with the `select` event when clusters are + selected. It allows to make selections of spikes.""" + def __init__(self, spikes_per_cluster): + # NOTE: spikes_per_cluster is a function. + self.spikes_per_cluster = spikes_per_cluster + + def select_spikes(self, cluster_ids=None, + max_n_spikes_per_cluster=None): + if cluster_ids is None or not len(cluster_ids): + return None + ns = max_n_spikes_per_cluster + assert len(cluster_ids) >= 1 + # Select a subset of the spikes. + return select_spikes(cluster_ids, + spikes_per_cluster=self.spikes_per_cluster, + max_n_spikes_per_cluster=ns) + + +# ----------------------------------------------------------------------------- +# Accumulator +# ----------------------------------------------------------------------------- + +def _flatten(l): + return [item for sublist in l for item in sublist] + + +class Accumulator(object): + """Accumulate arrays for concatenation.""" + def __init__(self): + self._data = defaultdict(list) + + def add(self, name, val): + """Add an array.""" + self._data[name].append(val) + + def get(self, name): + """Return the list of arrays for a given name.""" + return _flatten(self._data[name]) + + @property + def names(self): + """List of names.""" + return set(self._data) + + def __getitem__(self, name): + """Concatenate all arrays with a given name.""" + l = self._data[name] + # Process scalars: only return the first one and don't concatenate. + if len(l) and not hasattr(l[0], '__len__'): + return l[0] + return np.concatenate(l, axis=0) + + +def _accumulate(data_list, no_concat=()): + """Concatenate a list of dicts `(name, array)`. + + You can specify some names which arrays should not be concatenated. + This is necessary with lists of plots with different sizes. + + """ + acc = Accumulator() + for data in data_list: + for name, val in data.items(): + acc.add(name, val) + out = {name: acc[name] for name in acc.names if name not in no_concat} + + # Some variables should not be concatenated but should be kept as lists. + # This is when there can be several arrays of variable length (NumPy + # doesn't support ragged arrays). + out.update({name: acc.get(name) for name in no_concat}) + return out diff --git a/phy/io/base.py b/phy/io/base.py deleted file mode 100644 index 02c7050cf..000000000 --- a/phy/io/base.py +++ /dev/null @@ -1,519 +0,0 @@ -# -*- coding: utf-8 -*- - -"""The BaseModel class holds the data from an experiment.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from collections import defaultdict - -import numpy as np - -import six -from ..utils import debug, EventEmitter -from ..utils._types import _as_list, _is_list -from ..utils.settings import (Settings, - _ensure_dir_exists, - _phy_user_dir, - ) - - -#------------------------------------------------------------------------------ -# ClusterMetadata class -#------------------------------------------------------------------------------ - -class ClusterMetadata(object): - """Hold cluster metadata. - - Features - -------- - - * New metadata fields can be easily registered - * Arbitrary functions can be used for default values - - Notes - ---- - - If a metadata field `group` is registered, then two methods are - dynamically created: - - * `group(cluster)` returns the group of a cluster, or the default value - if the cluster doesn't exist. - * `set_group(cluster, value)` sets a value for the `group` metadata field. - - """ - def __init__(self, data=None): - self._fields = {} - self._data = defaultdict(dict) - # Fill the existing values. - if data is not None: - self._data.update(data) - - @property - def data(self): - return self._data - - def _get_one(self, cluster, field): - """Return the field value for a cluster, or the default value if it - doesn't exist.""" - if cluster in self._data: - if field in self._data[cluster]: - return self._data[cluster][field] - elif field in self._fields: - # Call the default field function. - return self._fields[field](cluster) - else: - return None - else: - if field in self._fields: - return self._fields[field](cluster) - else: - return None - - def _get(self, clusters, field): - if _is_list(clusters): - return [self._get_one(cluster, field) - for cluster in _as_list(clusters)] - else: - return self._get_one(clusters, field) - - def _set_one(self, cluster, field, value): - """Set a field value for a cluster.""" - self._data[cluster][field] = value - - def _set(self, clusters, field, value): - clusters = _as_list(clusters) - for cluster in clusters: - self._set_one(cluster, field, value) - - def default(self, func): - """Register a new metadata field with a function - returning the default value of a cluster.""" - field = func.__name__ - # Register the decorated function as the default field function. - self._fields[field] = func - # Create self.(clusters). - setattr(self, field, lambda clusters: self._get(clusters, field)) - # Create self.set_(clusters, value). - setattr(self, 'set_{0:s}'.format(field), - lambda clusters, value: self._set(clusters, field, value)) - return func - - -#------------------------------------------------------------------------------ -# BaseModel class -#------------------------------------------------------------------------------ - -class BaseModel(object): - """This class holds data from an experiment. - - This base class must be derived. - - """ - def __init__(self): - self.name = 'model' - self._channel_group = None - self._clustering = None - - @property - def path(self): - return None - - # Channel groups - # ------------------------------------------------------------------------- - - @property - def channel_group(self): - return self._channel_group - - @channel_group.setter - def channel_group(self, value): - assert isinstance(value, six.integer_types) - self._channel_group = value - self._channel_group_changed(value) - - def _channel_group_changed(self, value): - """Called when the channel group changes. - - May be implemented by child classes. - - """ - pass - - @property - def channel_groups(self): - """List of channel groups. - - May be implemented by child classes. - - """ - return [] - - # Clusterings - # ------------------------------------------------------------------------- - - @property - def clustering(self): - return self._clustering - - @clustering.setter - def clustering(self, value): - # The clustering is specified by a string. - assert isinstance(value, six.string_types) - self._clustering = value - self._clustering_changed(value) - - def _clustering_changed(self, value): - """Called when the clustering changes. - - May be implemented by child classes. - - """ - pass - - @property - def clusterings(self): - """List of clusterings. - - May be implemented by child classes. - - """ - return [] - - # Data - # ------------------------------------------------------------------------- - - @property - def metadata(self): - """A dictionary holding metadata about the experiment. - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def traces(self): - """Traces (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def spike_samples(self): - """Spike times from the current channel_group. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def sample_rate(self): - pass - - @property - def spike_times(self): - """Spike times from the current channel_group. - - This is a NumPy array containing `float64` values (in seconds). - - The spike times of all recordings are concatenated. There is no gap - between consecutive recordings, currently. - - """ - return self.spike_samples.astype(np.float64) / self.sample_rate - - @property - def spike_clusters(self): - """Spike clusters from the current channel_group. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - def spike_train(self, cluster_id): - """Return the spike times of a given cluster.""" - return self.spike_times[self.spikes_per_cluster[cluster_id]] - - @property - def spikes_per_cluster(self): - """Spikes per cluster dictionary. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - def update_spikes_per_cluster(self, spc): - raise NotImplementedError() - - @property - def cluster_metadata(self): - """ClusterMetadata instance holding information about the clusters. - - Must be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def cluster_groups(self): - """Groups of all clusters in the current channel group and clustering. - - This is a regular Python dictionary. - - """ - return {cluster: self.cluster_metadata.group(cluster) - for cluster in self.cluster_ids} - - @property - def features(self): - """Features from the current channel_group (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def masks(self): - """Masks from the current channel_group (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def waveforms(self): - """Waveforms from the current channel_group (may be memory-mapped). - - May be implemented by child classes. - - """ - raise NotImplementedError() - - @property - def probe(self): - """A Probe instance. - - May be implemented by child classes. - - """ - raise NotImplementedError() - - def save(self): - """Save the data. - - May be implemented by child classes. - - """ - raise NotImplementedError() - - def close(self): - """Close the model and the underlying files. - - May be implemented by child classes. - - """ - pass - - -#------------------------------------------------------------------------------ -# Session -#------------------------------------------------------------------------------ - -class BaseSession(EventEmitter): - """Give access to the data, views, and GUIs in an interactive session. - - The model must implement: - - * `model(path)` - * `model.path` - * `model.close()` - - Events - ------ - - open - close - - """ - def __init__(self, - model=None, - path=None, - phy_user_dir=None, - default_settings_paths=None, - vm_classes=None, - gui_classes=None, - ): - super(BaseSession, self).__init__() - - self.model = None - if phy_user_dir is None: - phy_user_dir = _phy_user_dir() - _ensure_dir_exists(phy_user_dir) - self.phy_user_dir = phy_user_dir - - self._create_settings(default_settings_paths) - - if gui_classes is None: - gui_classes = self.settings['gui_classes'] - - # HACK: avoid Qt import here - try: - from ..gui.base import WidgetCreator - self._gui_creator = WidgetCreator(widget_classes=gui_classes) - except ImportError: - self._gui_creator = None - - self.connect(self.on_open) - self.connect(self.on_close) - - # Custom `on_open()` callback function. - if 'on_open' in self.settings: - @self.connect - def on_open(): - self.settings['on_open'](self) - - self._pre_open() - if model or path: - self.open(path, model=model) - - def _create_settings(self, default_settings_paths): - self.settings = Settings(phy_user_dir=self.phy_user_dir, - default_paths=default_settings_paths, - ) - - @self.connect - def on_open(): - # Initialize the settings with the model's path. - self.settings.on_open(self.experiment_path) - - # Methods to override - # ------------------------------------------------------------------------- - - def _pre_open(self): - pass - - def _create_model(self, path): - """Create a model from a path. - - Must be overriden. - - """ - pass - - def _save_model(self): - """Save a model. - - Must be overriden. - - """ - pass - - def on_open(self): - pass - - def on_close(self): - pass - - # File-related actions - # ------------------------------------------------------------------------- - - def open(self, path=None, model=None): - """Open a dataset.""" - # Close the session if it is already open. - if self.model: - self.close() - if model is None: - model = self._create_model(path) - self.model = model - self.experiment_path = (op.realpath(path) - if path else self.phy_user_dir) - self.emit('open') - - def reopen(self): - self.open(model=self.model) - - def save(self): - self._save_model() - - def close(self): - """Close the currently-open dataset.""" - self.model.close() - self.emit('close') - self.model = None - - # Views and GUIs - # ------------------------------------------------------------------------- - - def show_gui(self, name=None, show=True, **kwargs): - """Show a new GUI.""" - if name is None: - gui_classes = list(self._gui_creator.widget_classes.keys()) - if gui_classes: - name = gui_classes[0] - - # Get the default GUI config. - params = {p: self.settings.get('{}_{}'.format(name, p), None) - for p in ('config', 'shortcuts', 'snippets', 'state')} - params.update(kwargs) - - # Create the GUI. - gui = self._gui_creator.add(name, - model=self.model, - settings=self.settings, - **params) - gui._save_state = True - - # Connect the 'open' event. - self.connect(gui.on_open) - - @gui.main_window.connect_ - def on_close_gui(): - self.unconnect(gui.on_open) - # Save the params of every view in the GUI. - for vm in gui.views: - self.save_view_params(vm, save_size_pos=False) - gs = gui.main_window.save_geometry_state() - gs['view_count'] = gui.view_count() - if not gui._save_state: - gs['state'] = None - gs['geometry'] = None - self.settings['{}_state'.format(name)] = gs - self.settings.save() - - # HACK: do not save GUI state when views have been closed or reset - # in the session, otherwise Qt messes things up in the GUI. - @gui.connect - def on_close_view(view): - gui._save_state = False - - @gui.connect - def on_reset_gui(): - gui._save_state = False - - # Custom `on_gui_open()` callback. - if 'on_gui_open' in self.settings: - self.settings['on_gui_open'](self, gui) - - if show: - gui.show() - - return gui - - def save_view_params(self, vm, save_size_pos=True): - """Save the parameters exported by a view model instance.""" - to_save = vm.exported_params(save_size_pos=save_size_pos) - for key, value in to_save.items(): - assert vm.name - name = '{}_{}'.format(vm.name, key) - self.settings[name] = value - debug("Save {0}={1} for {2}.".format(name, value, vm.name)) diff --git a/phy/io/context.py b/phy/io/context.py new file mode 100644 index 000000000..d12814e21 --- /dev/null +++ b/phy/io/context.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- + +"""Execution context that handles parallel processing and cacheing.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from functools import wraps +import inspect +import logging +import os +import os.path as op + +from six.moves.cPickle import dump, load + +from phy.utils import (_save_json, _load_json, + _ensure_dir_exists, _fullname,) +from phy.utils.config import phy_config_dir + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Context +#------------------------------------------------------------------------------ + +class Context(object): + """Handle function cacheing and parallel map with ipyparallel.""" + def __init__(self, cache_dir, ipy_view=None, verbose=0): + self.verbose = verbose + # Make sure the cache directory exists. + self.cache_dir = op.realpath(op.expanduser(cache_dir)) + if not op.exists(self.cache_dir): + logger.debug("Create cache directory `%s`.", self.cache_dir) + os.makedirs(self.cache_dir) + + # Ensure the memcache directory exists. + path = op.join(self.cache_dir, 'memcache') + if not op.exists(path): + os.mkdir(path) + + self._set_memory(self.cache_dir) + self.ipy_view = ipy_view if ipy_view else None + self._memcache = {} + + def _set_memory(self, cache_dir): + # Try importing joblib. + try: + from joblib import Memory + self._memory = Memory(cachedir=self.cache_dir, + mmap_mode=None, + verbose=self.verbose, + ) + logger.debug("Initialize joblib cache dir at `%s`.", + self.cache_dir) + except ImportError: # pragma: no cover + logger.warn("Joblib is not installed. " + "Install it with `conda install joblib`.") + self._memory = None + + def cache(self, f): + """Cache a function using the context's cache directory.""" + if self._memory is None: # pragma: no cover + logger.debug("Joblib is not installed: skipping cacheing.") + return f + assert f + # NOTE: discard self in instance methods. + if 'self' in inspect.getargspec(f).args: # noqa + ignore = ['self'] + else: + ignore = None + disk_cached = self._memory.cache(f, ignore=ignore) + return disk_cached + + def load_memcache(self, name): + # Load the memcache from disk, if it exists. + path = op.join(self.cache_dir, 'memcache', name + '.pkl') + if op.exists(path): + logger.debug("Load memcache for `%s`.", name) + with open(path, 'rb') as fd: + cache = load(fd) + else: + cache = {} + self._memcache[name] = cache + return cache + + def save_memcache(self): + for name, cache in self._memcache.items(): + path = op.join(self.cache_dir, 'memcache', name + '.pkl') + logger.debug("Save memcache for `%s`.", name) + with open(path, 'wb') as fd: + dump(cache, fd) + + def memcache(self, f): + from joblib import hash + name = _fullname(f) + cache = self.load_memcache(name) + + @wraps(f) + def memcached(*args, **kwargs): + """Cache the function in memory.""" + h = hash((args, kwargs)) + if h in cache: + # logger.debug("Get %s(%s) from memcache.", name, str(args)) + return cache[h] + else: + # logger.debug("Compute %s(%s).", name, str(args)) + out = f(*args, **kwargs) + cache[h] = out + return out + return memcached + + def _get_path(self, name, location): + if location == 'local': + return op.join(self.cache_dir, name + '.json') + elif location == 'global': + return op.join(phy_config_dir(), name + '.json') + + def save(self, name, data, location='local'): + """Save a dictionary in a JSON file within the cache directory.""" + path = self._get_path(name, location) + _ensure_dir_exists(op.dirname(path)) + logger.debug("Save data to `%s`.", path) + _save_json(path, data) + + def load(self, name, location='local'): + """Load saved data from the cache directory.""" + path = self._get_path(name, location) + if not op.exists(path): + logger.debug("The file `%s` doesn't exist.", path) + return {} + return _load_json(path) + + def __getstate__(self): + """Make sure that this class is picklable.""" + state = self.__dict__.copy() + state['_memory'] = None + return state + + def __setstate__(self, state): + """Make sure that this class is picklable.""" + self.__dict__ = state + # Recreate the joblib Memory instance. + self._set_memory(state['cache_dir']) diff --git a/phy/utils/datasets.py b/phy/io/datasets.py similarity index 78% rename from phy/utils/datasets.py rename to phy/io/datasets.py index 9480af76a..9092b92f1 100644 --- a/phy/utils/datasets.py +++ b/phy/io/datasets.py @@ -7,12 +7,14 @@ #------------------------------------------------------------------------------ import hashlib +import logging import os import os.path as op -from .logging import debug, warn -from .settings import _phy_user_dir, _ensure_dir_exists -from .event import ProgressReporter +from phy.utils.event import ProgressReporter +from phy.utils.config import phy_config_dir, _ensure_dir_exists + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ @@ -28,7 +30,7 @@ def _remote_file_size(path): import requests - try: + try: # pragma: no cover response = requests.head(path) return int(response.headers.get('content-length', 0)) except Exception: @@ -52,16 +54,15 @@ def _save_stream(r, path): downloaded += len(chunk) if i % 100 == 0: pr.value = downloaded - if size: - assert size == downloaded + assert ((size == downloaded) if size else True) pr.set_complete() def _download(url, stream=None): from requests import get r = get(url, stream=stream) - if r.status_code != 200: - debug("Error while downloading `{}`.".format(url)) + if r.status_code != 200: # pragma: no cover + logger.debug("Error while downloading %s.", url) r.raise_for_status() return r @@ -84,9 +85,7 @@ def _md5(path, blocksize=2 ** 20): def _check_md5(path, checksum): - if checksum is None: - return - return _md5(path) == checksum + return (_md5(path) == checksum) if checksum else None def _check_md5_of_url(output_path, url): @@ -106,11 +105,11 @@ def _validate_output_dir(output_dir): output_dir = output_dir + '/' output_dir = op.realpath(op.dirname(output_dir)) if not op.exists(output_dir): - os.mkdir(output_dir) + os.makedirs(output_dir) return output_dir -def download_file(url, output_path=None): +def download_file(url, output_path): """Download a binary file from an URL. The checksum will be downloaded from `URL + .md5`. If this download @@ -121,43 +120,37 @@ def download_file(url, output_path=None): url : str The file's URL. - output_path : str or None - The path where the file is to be saved. - - Returns - ------- - output_path : str - The path where the file was downloaded. + The path where the file is to be saved. """ - if output_path is None: - output_path = url.split('/')[-1] + output_path = op.realpath(output_path) + assert output_path is not None if op.exists(output_path): checked = _check_md5_of_url(output_path, url) if checked is False: - debug("The file `{}` already exists ".format(output_path) + - "but is invalid: redownloading.") + logger.debug("The file `%s` already exists " + "but is invalid: redownloading.", output_path) elif checked is True: - debug("The file `{}` already exists: ".format(output_path) + - "skipping.") - return + logger.debug("The file `%s` already exists: skipping.", + output_path) + return output_path r = _download(url, stream=True) _save_stream(r, output_path) if _check_md5_of_url(output_path, url) is False: - debug("The checksum doesn't match: retrying the download.") + logger.debug("The checksum doesn't match: retrying the download.") r = _download(url, stream=True) _save_stream(r, output_path) if _check_md5_of_url(output_path, url) is False: raise RuntimeError("The checksum of the downloaded file " "doesn't match the provided checksum.") - return output_path + return -def download_test_data(name, phy_user_dir=None, force=False): +def download_test_data(name, config_dir=None, force=False): """Download a test file.""" - phy_user_dir = phy_user_dir or _phy_user_dir() - dir = op.join(phy_user_dir, 'test_data') + config_dir = config_dir or phy_config_dir() + dir = op.join(config_dir, 'test_data') _ensure_dir_exists(dir) path = op.join(dir, name) if not force and op.exists(path): @@ -187,5 +180,5 @@ def download_sample_data(filename, output_dir=None, base='cortexlab'): try: download_file(url, output_path=output_path) except Exception as e: - warn("An error occurred while downloading `{}` to `{}`: {}".format( - url, output_path, str(e))) + logger.warn("An error occurred while downloading `%s` to `%s`: %s", + url, output_path, str(e)) diff --git a/phy/io/default_settings.py b/phy/io/default_settings.py deleted file mode 100644 index 8dd5bf997..000000000 --- a/phy/io/default_settings.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for I/O.""" - - -# ----------------------------------------------------------------------------- -# Traces -# ----------------------------------------------------------------------------- - -traces = { - 'raw_data_files': [], - 'n_channels': None, - 'dtype': None, - 'sample_rate': None, -} - - -# ----------------------------------------------------------------------------- -# Store settings -# ----------------------------------------------------------------------------- - -# Number of spikes to load at once from the features_masks array -# during the cluster store generation. -features_masks_chunk_size = 100000 diff --git a/phy/io/h5.py b/phy/io/h5.py deleted file mode 100644 index cdebc29ba..000000000 --- a/phy/io/h5.py +++ /dev/null @@ -1,293 +0,0 @@ -# -*- coding: utf-8 -*- - -"""HDF5 input and output.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -import h5py -from six import string_types - -from ..utils.logging import debug, warn - - -#------------------------------------------------------------------------------ -# HDF5 utility functions -#------------------------------------------------------------------------------ - -def _split_hdf5_path(path): - """Return the group and dataset of the path.""" - # Make sure the path starts with a leading slash. - if not path.startswith('/'): - raise ValueError(("The HDF5 path '{0:s}' should start with a " - "leading slash '/'.").format(path)) - if '//' in path: - raise ValueError(("There should be no double slash in the HDF5 path " - "'{0:s}'.").format(path)) - # Handle the special case '/'. - if path == '/': - return '/', '' - # Temporarily remove the leading '/', we'll add it later (otherwise split - # and join will mess it up). - path = path[1:] - # We split the path by slash and we get the head and tail. - _split = path.split('/') - group_path = '/'.join(_split[:-1]) - name = _split[-1] - # Make some consistency checks. - assert not group_path.endswith('/') - assert '/' not in name - # Finally, we add the leading slash at the beginning of the group path. - return '/' + group_path, name - - -def _check_hdf5_path(h5_file, path): - """Check that an HDF5 path exists in a file.""" - if path not in h5_file: - raise ValueError("{path} doesn't exist.".format(path=path)) - - -#------------------------------------------------------------------------------ -# File class -#------------------------------------------------------------------------------ - -class File(object): - def __init__(self, filename, mode=None): - if mode is None: - mode = 'r' - self.filename = filename - self.mode = mode - self._h5py_file = None - - # Open and close - #-------------------------------------------------------------------------- - - @property - def h5py_file(self): - """Native h5py file handle.""" - return self._h5py_file - - def is_open(self): - return self._h5py_file is not None - - def open(self, mode=None): - if mode is not None: - self.mode = mode - if not self.is_open(): - self._h5py_file = h5py.File(self.filename, self.mode) - - def close(self): - if self.is_open(): - self._h5py_file.close() - self._h5py_file = None - - # Datasets - #-------------------------------------------------------------------------- - - def read(self, path): - """Read an HDF5 dataset, given its HDF5 path in the file.""" - _check_hdf5_path(self._h5py_file, path) - return self._h5py_file[path] - - def write(self, path, array=None, dtype=None, shape=None, overwrite=False): - """Write a NumPy array in the file. - - Parameters - ---------- - path : str - Full HDF5 path to the dataset to create. - array : ndarray - Array to write in the file. - dtype : dtype - If `array` is None, the dtype of the array. - shape : tuple - If `array` is None, the shape of the array. - overwrite : bool - If False, raise an error if the dataset already exists. Defaults - to False. - - """ - # Get the group path and the dataset name. - group_path, dset_name = _split_hdf5_path(path) - - # If the parent group doesn't already exist, create it. - if group_path not in self._h5py_file: - self._h5py_file.create_group(group_path) - - group = self._h5py_file[group_path] - - # Check that the dataset does not already exist. - if path in self._h5py_file: - if overwrite: - # Force rewriting the dataset if 'overwrite' is True. - del self._h5py_file[path] - else: - # Otherwise, raise an error. - raise ValueError(("The dataset '{0:s}' already exists." - ).format(path)) - - if array is not None: - return group.create_dataset(dset_name, data=array) - else: - assert dtype - assert shape - return group.create_dataset(dset_name, dtype=dtype, shape=shape) - - # Copy and rename - #-------------------------------------------------------------------------- - - def _check_move_copy(self, path, new_path): - if not self.exists(path): - raise ValueError("'{0}' doesn't exist.".format(path)) - if self.exists(new_path): - raise ValueError("'{0}' already exist.".format(new_path)) - - def move(self, path, new_path): - """Move a group or dataset to another location.""" - self._check_move_copy(path, new_path) - self._h5py_file.move(path, new_path) - - def copy(self, path, new_path): - """Copy a group or dataset to another location.""" - self._check_move_copy(path, new_path) - self._h5py_file.copy(path, new_path) - - def delete(self, path): - """Delete a group or dataset.""" - if not path.startswith('/'): - path = '/' + path - if not self.exists(path): - raise ValueError("'{0}' doesn't exist.".format(path)) - - path, name = _split_hdf5_path(path) - parent = self.read(path) - del parent[name] - - # Attributes - #-------------------------------------------------------------------------- - - def read_attr(self, path, attr_name): - """Read an attribute of an HDF5 group.""" - _check_hdf5_path(self._h5py_file, path) - attrs = self._h5py_file[path].attrs - if attr_name in attrs: - try: - out = attrs[attr_name] - if isinstance(out, np.ndarray) and out.dtype.kind == 'S': - if len(out) == 1: - out = out[0].decode('UTF-8') - return out - except (TypeError, IOError): - debug("Unable to read attribute `{}` at `{}`.".format( - attr_name, path)) - return - else: - raise KeyError("The attribute '{0:s}' ".format(attr_name) + - "at `{}` doesn't exist.".format(path)) - - def write_attr(self, path, attr_name, value): - """Write an attribute of an HDF5 group.""" - assert isinstance(path, string_types) - assert isinstance(attr_name, string_types) - if value is None: - value = '' - # Ensure lists of strings are converted to ASCII arrays. - if isinstance(value, list): - if not value: - value = None - if value and isinstance(value[0], string_types): - value = np.array(value, dtype='S') - # Use string arrays instead of vlen arrays (crash in h5py 2.5.0 win64). - if isinstance(value, string_types): - value = np.array([value], dtype='S') - # Idem: fix crash with boolean attributes on win64. - if isinstance(value, bool): - value = int(value) - # If the parent group doesn't already exist, create it. - if path not in self._h5py_file: - self._h5py_file.create_group(path) - try: - self._h5py_file[path].attrs[attr_name] = value - # debug("Write `{}={}` at `{}`.".format(attr_name, - # str(value), path)) - except TypeError: - warn("Unable to write attribute `{}={}` at `{}`.".format( - attr_name, value, path)) - - def attrs(self, path='/'): - """Return the list of attributes at the given path.""" - if path in self._h5py_file: - return sorted(self._h5py_file[path].attrs) - else: - return [] - - def has_attr(self, path, attr_name): - """Return whether an attribute exists at a given path.""" - if path not in self._h5py_file: - return False - else: - return attr_name in self._h5py_file[path].attrs - - # Children - #-------------------------------------------------------------------------- - - def children(self, path='/'): - """Return the list of children of a given node.""" - return sorted(self._h5py_file[path].keys()) - - def groups(self, path='/'): - """Return the list of groups under a given node.""" - return [key for key in self.children(path) - if isinstance(self._h5py_file[path + '/' + key], - h5py.Group)] - - def datasets(self, path='/'): - """Return the list of datasets under a given node.""" - return [key for key in self.children(path) - if isinstance(self._h5py_file[path + '/' + key], - h5py.Dataset)] - - # Miscellaneous properties - #-------------------------------------------------------------------------- - - def exists(self, path): - return path in self._h5py_file - - def _print_node_info(self, name, node): - """Print node information.""" - info = ('/' + name).ljust(50) - if isinstance(node, h5py.Group): - pass - elif isinstance(node, h5py.Dataset): - info += str(node.shape).ljust(20) - info += str(node.dtype).ljust(8) - print(info) - - def describe(self): - """Display the list of all groups and datasets in the file.""" - if not self.is_open(): - raise IOError("Cannot display file information because the file" - " '{0:s}' is not open.".format(self.filename)) - self._h5py_file['/'].visititems(self._print_node_info) - - # Context manager - #-------------------------------------------------------------------------- - - def __contains__(self, path): - return path in self._h5py_file - - def __enter__(self): - self.open() - return self - - def __exit__(self, type, value, tb): - self.close() - - -def open_h5(filename, mode=None): - """Open an HDF5 file and return a File instance.""" - file = File(filename, mode=mode) - file.open() - return file diff --git a/phy/io/kwik/__init__.py b/phy/io/kwik/__init__.py deleted file mode 100644 index f1aabd050..000000000 --- a/phy/io/kwik/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Kwik format.""" - -from .store_items import (FeatureMasks, - Waveforms, - ClusterStatistics, - create_store, - ) -from .creator import KwikCreator, create_kwik -from .model import KwikModel diff --git a/phy/io/kwik/creator.py b/phy/io/kwik/creator.py deleted file mode 100644 index 2a7400863..000000000 --- a/phy/io/kwik/creator.py +++ /dev/null @@ -1,463 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Kwik creator.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np -from h5py import Dataset -from six import string_types -from six.moves import zip - -import phy -from ...electrode.mea import load_probe -from ..h5 import open_h5 -from ..traces import _dat_n_samples -from ...utils._misc import _read_python -from ...utils._types import _as_array -from ...utils.array import _unique -from ...utils.logging import warn -from ...utils.settings import _load_default_settings - - -#------------------------------------------------------------------------------ -# Kwik creator -#------------------------------------------------------------------------------ - -def _write_by_chunk(dset, arrs): - # Note: arrs should be a generator for performance reasons. - assert isinstance(dset, Dataset) - # Start the data. - offset = 0 - for arr in arrs: - n = arr.shape[0] - arr = arr[...] - # Match the shape of the chunk array with the dset shape. - assert arr.shape == (n,) + dset.shape[1:] - dset[offset:offset + n, ...] = arr - offset += arr.shape[0] - # Check that the copy is complete. - assert offset == dset.shape[0] - - -def _concat(arrs): - return np.hstack([_[...] for _ in arrs]) - - -_DEFAULT_GROUPS = [(0, 'Noise'), - (1, 'MUA'), - (2, 'Good'), - (3, 'Unsorted'), - ] - - -class KwikCreator(object): - """Create and modify a `.kwik` file.""" - def __init__(self, basename=None, kwik_path=None, kwx_path=None): - # Find the .kwik filename. - if kwik_path is None: - assert basename is not None - if basename.endswith('.kwik'): - basename, _ = op.splitext(basename) - kwik_path = basename + '.kwik' - self.kwik_path = kwik_path - if basename is None: - basename, _ = op.splitext(kwik_path) - self.basename = basename - - # Find the .kwx filename. - if kwx_path is None: - basename, _ = op.splitext(kwik_path) - kwx_path = basename + '.kwx' - self.kwx_path = kwx_path - - def create_empty(self): - """Create empty `.kwik` and `.kwx` files.""" - assert not op.exists(self.kwik_path) - with open_h5(self.kwik_path, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - f.write_attr('/', 'name', self.basename) - f.write_attr('/', 'creator_version', 'phy ' - + phy.__version_git__) - - assert not op.exists(self.kwx_path) - with open_h5(self.kwx_path, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - f.write_attr('/', 'creator_version', 'phy ' - + phy.__version_git__) - - def set_metadata(self, path, **kwargs): - """Set metadata fields in a HDF5 path.""" - assert isinstance(path, string_types) - assert path - with open_h5(self.kwik_path, 'a') as f: - for key, value in kwargs.items(): - f.write_attr(path, key, value) - - def set_probe(self, probe): - """Save a probe dictionary in the file.""" - with open_h5(self.kwik_path, 'a') as f: - probe = probe['channel_groups'] - for group, d in probe.items(): - group = int(group) - channels = np.array(list(d['channels']), dtype=np.int32) - - # Write the channel order. - f.write_attr('/channel_groups/{:d}'.format(group), - 'channel_order', channels) - - # Write the probe adjacency graph. - graph = d.get('graph', []) - graph = np.array(graph, dtype=np.int32) - f.write_attr('/channel_groups/{:d}'.format(group), - 'adjacency_graph', graph) - - # Write the channel positions. - positions = d.get('geometry', {}) - for channel in channels: - channel = int(channel) - # Get the channel position. - if channel in positions: - position = positions[channel] - else: - # Default position. - position = (0, channel) - path = '/channel_groups/{:d}/channels/{:d}'.format( - group, channel) - - f.write_attr(path, 'name', str(channel)) - - position = np.array(position, dtype=np.float32) - f.write_attr(path, 'position', position) - - def add_spikes(self, - group=None, - spike_samples=None, - spike_recordings=None, - masks=None, - features=None, - n_channels=None, - n_features=None, - ): - """Add spikes in the file. - - Parameters - ---------- - - group : int - Channel group. - spike_samples : ndarray - The spike times in number of samples. - spike_recordings : ndarray (optional) - The recording indices of every spike. - masks : ndarray or list of ndarrays - The masks. - features : ndarray or list of ndarrays - The features. - - Note - ---- - - The features and masks can be passed as lists or generators of - (memmapped) data chunks in order to avoid loading the entire arrays - in RAM. - - """ - assert group >= 0 - assert n_channels >= 0 - assert n_features >= 0 - - if spike_samples is None: - return - if isinstance(spike_samples, list): - spike_samples = _concat(spike_samples) - spike_samples = _as_array(spike_samples, dtype=np.float64).ravel() - n_spikes = len(spike_samples) - if spike_recordings is None: - spike_recordings = np.zeros(n_spikes, dtype=np.int32) - spike_recordings = spike_recordings.ravel() - - # Add spikes in the .kwik file. - assert op.exists(self.kwik_path) - with open_h5(self.kwik_path, 'a') as f: - # This method can only be called once. - if '/channel_groups/{:d}/spikes/time_samples'.format(group) in f: - raise RuntimeError("Spikes have already been added to this " - "dataset.") - time_samples = spike_samples.astype(np.uint64) - frac = ((spike_samples - time_samples) * 255).astype(np.uint8) - f.write('/channel_groups/{:d}/spikes/time_samples'.format(group), - time_samples) - f.write('/channel_groups/{}/spikes/time_fractional'.format(group), - frac) - f.write('/channel_groups/{:d}/spikes/recording'.format(group), - spike_recordings) - - if masks is None and features is None: - return - # Add features and masks in the .kwx file. - assert masks is not None - assert features is not None - - # Determine the shape of the features_masks array. - shape = (n_spikes, n_channels * n_features, 2) - - def transform_f(f): - return f[...].reshape((-1, n_channels * n_features)) - - def transform_m(m): - return np.repeat(m[...], 3, axis=1) - - assert op.exists(self.kwx_path) - with open_h5(self.kwx_path, 'a') as f: - fm = f.write('/channel_groups/{:d}/features_masks'.format(group), - shape=shape, dtype=np.float32) - - # Write the features and masks in one block. - if (isinstance(features, np.ndarray) and - isinstance(masks, np.ndarray)): - fm[:, :, 0] = transform_f(features) - fm[:, :, 1] = transform_m(masks) - # Write the features and masks chunk by chunk. - else: - # Concatenate the features/masks chunks. - fm_arrs = (np.dstack((transform_f(fet), transform_m(m))) - for (fet, m) in zip(features, masks) - if fet is not None and m is not None) - _write_by_chunk(fm, fm_arrs) - - def add_recording(self, id=None, raw_path=None, - start_sample=None, sample_rate=None): - """Add a recording. - - Parameters - ---------- - - id : int - The recording id (0, 1, 2, etc.). - raw_path : str - Path to the file containing the raw data. - start_sample : int - The offset of the recording, in number of samples. - sample_rate : float - The sample rate of the recording - - """ - path = '/recordings/{:d}'.format(id) - start_sample = int(start_sample) - sample_rate = float(sample_rate) - - with open_h5(self.kwik_path, 'a') as f: - f.write_attr(path, 'name', 'recording_{:d}'.format(id)) - f.write_attr(path, 'start_sample', start_sample) - f.write_attr(path, 'sample_rate', sample_rate) - f.write_attr(path, 'start_time', start_sample / sample_rate) - if raw_path: - if op.splitext(raw_path)[1] == '.kwd': - f.write_attr(path + '/raw', 'hdf5_path', raw_path) - elif op.splitext(raw_path)[1] == '.dat': - f.write_attr(path + '/raw', 'dat_path', raw_path) - - def _add_recordings_from_dat(self, files, sample_rate=None, - n_channels=None, dtype=None): - start_sample = 0 - for i, filename in enumerate(files): - # WARNING: different sample rates in recordings is not - # supported yet. - self.add_recording(id=i, - start_sample=start_sample, - sample_rate=sample_rate, - raw_path=filename, - ) - assert op.splitext(filename)[1] == '.dat' - # Compute the offset for different recordings. - start_sample += _dat_n_samples(filename, - n_channels=n_channels, - dtype=dtype) - - def _add_recordings_from_kwd(self, file, sample_rate=None): - assert file.endswith('.kwd') - start_sample = 0 - with open_h5(file, 'r') as f: - recordings = f.children('/recordings') - for recording in recordings: - path = '/recordings/{}'.format(recording) - if f.has_attr(path, 'sample_rate'): - sample_rate = f.read_attr(path, 'sample_rate') - assert sample_rate > 0 - self.add_recording(id=int(recording), - start_sample=start_sample, - sample_rate=sample_rate, - raw_path=file, - ) - start_sample += f.read(path + '/data').shape[0] - - def add_cluster_group(self, - group=None, - id=None, - name=None, - clustering=None, - ): - """Add a cluster group. - - Parameters - ---------- - - group : int - The channel group. - id : int - The cluster group id. - name : str - The cluster group name. - clustering : str - The name of the clustering. - - """ - assert group >= 0 - cg_path = ('/channel_groups/{0:d}/' - 'cluster_groups/{1:s}/{2:d}').format(group, - clustering, - id, - ) - with open_h5(self.kwik_path, 'a') as f: - f.write_attr(cg_path, 'name', name) - - def add_clustering(self, - group=None, - name=None, - spike_clusters=None, - cluster_groups=None, - ): - """Add a clustering. - - Parameters - ---------- - - group : int - The channel group. - name : str - The clustering name. - spike_clusters : ndarray - The spike clusters assignements. This is `(n_spikes,)` array. - cluster_groups : dict - The cluster group of every cluster. - - """ - if cluster_groups is None: - cluster_groups = {} - path = '/channel_groups/{0:d}/spikes/clusters/{1:s}'.format( - group, name) - - with open_h5(self.kwik_path, 'a') as f: - assert not f.exists(path) - - # Save spike_clusters. - spike_clusters = spike_clusters.astype(np.int32).ravel() - f.write(path, spike_clusters) - cluster_ids = _unique(spike_clusters) - - # Create cluster metadata. - for cluster in cluster_ids: - cluster_path = '/channel_groups/{0:d}/clusters/{1:s}/{2:d}'. \ - format(group, name, cluster) - - # Default group: unsorted. - cluster_group = cluster_groups.get(cluster, 3) - f.write_attr(cluster_path, 'cluster_group', cluster_group) - - # Create cluster group metadata. - for group_id, cg_name in _DEFAULT_GROUPS: - self.add_cluster_group(id=group_id, - name=cg_name, - clustering=name, - group=group, - ) - - -def create_kwik(prm_file=None, kwik_path=None, overwrite=False, - probe=None, **kwargs): - """Create a new Kwik dataset from a PRM file.""" - prm = _read_python(prm_file) if prm_file else {} - - if prm: - assert 'spikedetekt' in prm - assert 'traces' in prm - sample_rate = prm['traces']['sample_rate'] - - if 'sample_rate' in kwargs: - sample_rate = kwargs['sample_rate'] - - assert sample_rate > 0 - - # Default SpikeDetekt parameters. - settings = _load_default_settings() - params = settings['spikedetekt'] - params.update(settings['traces']) - # Update with PRM and user parameters. - if prm: - params['experiment_name'] = prm['experiment_name'] - params['prb_file'] = prm['prb_file'] - params.update(prm['spikedetekt']) - params.update(prm['klustakwik2']) - params.update(prm['traces']) - params.update(kwargs) - - kwik_path = kwik_path or params['experiment_name'] + '.kwik' - kwx_path = op.splitext(kwik_path)[0] + '.kwx' - if op.exists(kwik_path): - if overwrite: - os.remove(kwik_path) - os.remove(kwx_path) - else: - raise IOError("The `.kwik` file already exists. Please use " - "the `--overwrite` option.") - - # Ensure the probe file exists if it is required. - if probe is None: - probe = load_probe(params['prb_file']) - assert probe - - # KwikCreator. - creator = KwikCreator(kwik_path) - creator.create_empty() - creator.set_probe(probe) - - # Add the recordings. - raw_data_files = params.get('raw_data_files', None) - if isinstance(raw_data_files, string_types): - if raw_data_files.endswith('.raw.kwd'): - creator._add_recordings_from_kwd(raw_data_files, - sample_rate=sample_rate, - ) - else: - raw_data_files = [raw_data_files] - if isinstance(raw_data_files, list) and len(raw_data_files): - if len(raw_data_files) > 1: - raise NotImplementedError("There is no support for " - "multiple .dat files yet.") - # The dtype must be a string so that it can be serialized in HDF5. - if not params.get('dtype', None): - warn("The `dtype` parameter is mandatory. Using a default value " - "of `int16` for now. Please update your `.prm` file.") - params['dtype'] = 'int16' - assert 'dtype' in params and isinstance(params['dtype'], string_types) - dtype = np.dtype(params['dtype']) - assert dtype is not None - - # The number of channels in the .dat file *must* be specified. - n_channels = params['n_channels'] - assert n_channels > 0 - creator._add_recordings_from_dat(raw_data_files, - sample_rate=sample_rate, - n_channels=n_channels, - dtype=dtype, - ) - - creator.set_metadata('/application_data/spikedetekt', **params) - - return kwik_path diff --git a/phy/io/kwik/mock.py b/phy/io/kwik/mock.py deleted file mode 100644 index 357286c8a..000000000 --- a/phy/io/kwik/mock.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Mock Kwik files.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np - -from ...electrode.mea import staggered_positions -from ..mock import (artificial_spike_samples, - artificial_spike_clusters, - artificial_features, - artificial_masks, - artificial_traces) -from ..h5 import open_h5 -from .model import _create_clustering - - -#------------------------------------------------------------------------------ -# Mock Kwik file -#------------------------------------------------------------------------------ - -def create_mock_kwik(dir_path, n_clusters=None, n_spikes=None, - n_channels=None, n_features_per_channel=None, - n_samples_traces=None, - with_kwx=True, - with_kwd=True, - add_original=True, - ): - """Create a test kwik file.""" - filename = op.join(dir_path, '_test.kwik') - kwx_filename = op.join(dir_path, '_test.kwx') - kwd_filename = op.join(dir_path, '_test.raw.kwd') - - # Create the kwik file. - with open_h5(filename, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - - def _write_metadata(key, value): - f.write_attr('/application_data/spikedetekt', key, value) - - _write_metadata('sample_rate', 20000.) - - # Filter parameters. - _write_metadata('filter_low', 500.) - _write_metadata('filter_high_factor', 0.95 * .5) - _write_metadata('filter_butter_order', 3) - - _write_metadata('extract_s_before', 15) - _write_metadata('extract_s_after', 25) - - _write_metadata('n_features_per_channel', n_features_per_channel) - - # Create spike times. - spike_samples = artificial_spike_samples(n_spikes).astype(np.int64) - spike_recordings = np.zeros(n_spikes, dtype=np.uint16) - # Size of the first recording. - recording_size = 2 * n_spikes // 3 - if recording_size > 0: - # Find the recording offset. - recording_offset = spike_samples[recording_size] - recording_offset += spike_samples[recording_size + 1] - recording_offset //= 2 - spike_recordings[recording_size:] = 1 - # Make sure the spike samples of the second recording start over. - spike_samples[recording_size:] -= spike_samples[recording_size] - spike_samples[recording_size:] += 10 - else: - recording_offset = 1 - - if spike_samples.max() >= n_samples_traces: - raise ValueError("There are too many spikes: decrease 'n_spikes'.") - - f.write('/channel_groups/1/spikes/time_samples', spike_samples) - f.write('/channel_groups/1/spikes/recording', spike_recordings) - f.write_attr('/channel_groups/1', 'channel_order', - np.arange(1, n_channels - 1)[::-1]) - graph = np.array([[1, 2], [2, 3]]) - f.write_attr('/channel_groups/1', 'adjacency_graph', graph) - - # Create channels. - positions = staggered_positions(n_channels) - for channel in range(n_channels): - group = '/channel_groups/1/channels/{0:d}'.format(channel) - f.write_attr(group, 'name', str(channel)) - f.write_attr(group, 'position', positions[channel]) - - # Create spike clusters. - clusterings = [('main', n_clusters)] - if add_original: - clusterings += [('original', n_clusters * 2)] - for clustering, n_clusters_rec in clusterings: - spike_clusters = artificial_spike_clusters(n_spikes, - n_clusters_rec) - groups = {0: 0, 1: 1, 2: 2} - _create_clustering(f, clustering, 1, spike_clusters, groups) - - # Create recordings. - f.write_attr('/recordings/0', 'name', 'recording_0') - f.write_attr('/recordings/1', 'name', 'recording_1') - - f.write_attr('/recordings/0/raw', 'hdf5_path', kwd_filename) - f.write_attr('/recordings/1/raw', 'hdf5_path', kwd_filename) - - # Create the kwx file. - if with_kwx: - with open_h5(kwx_filename, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - features = artificial_features(n_spikes, - (n_channels - 2) * - n_features_per_channel) - masks = artificial_masks(n_spikes, - (n_channels - 2) * - n_features_per_channel) - fm = np.dstack((features, masks)).astype(np.float32) - f.write('/channel_groups/1/features_masks', fm) - - # Create the raw kwd file. - if with_kwd: - with open_h5(kwd_filename, 'w') as f: - f.write_attr('/', 'kwik_version', 2) - traces = artificial_traces(n_samples_traces, n_channels) - # TODO: int16 traces - f.write('/recordings/0/data', - traces[:recording_offset, ...].astype(np.float32)) - f.write('/recordings/1/data', - traces[recording_offset:, ...].astype(np.float32)) - - return filename diff --git a/phy/io/kwik/model.py b/phy/io/kwik/model.py deleted file mode 100644 index c8724cd8a..000000000 --- a/phy/io/kwik/model.py +++ /dev/null @@ -1,1243 +0,0 @@ -# -*- coding: utf-8 -*- - -"""The KwikModel class manages in-memory structures and Kwik file open/save.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from random import randint -import os - -import numpy as np - -from .creator import KwikCreator -import six -from ..base import BaseModel, ClusterMetadata -from ..h5 import open_h5, File -from ..traces import _dat_to_traces -from ...traces.waveform import WaveformLoader, SpikeLoader -from ...traces.filter import bandpass_filter, apply_filter -from ...electrode.mea import MEA -from ...utils.logging import debug, warn -from ...utils.array import (PartialArray, - _concatenate_virtual_arrays, - _spikes_per_cluster, - _unique, - ) -from ...utils.settings import _load_default_settings -from ...utils._types import _is_integer, _as_array - - -#------------------------------------------------------------------------------ -# Kwik utility functions -#------------------------------------------------------------------------------ - -def _to_int_list(l): - """Convert int strings to ints.""" - return [int(_) for _ in l] - - -def _list_int_children(group): - """Return the list of int children of a HDF5 group.""" - return sorted(_to_int_list(group.keys())) - - -# TODO: refactor the functions below with h5.File.children(). - -def _list_channel_groups(kwik): - """Return the list of channel groups in a kwik file.""" - if 'channel_groups' in kwik: - return _list_int_children(kwik['/channel_groups']) - else: - return [] - - -def _list_recordings(kwik): - """Return the list of recordings in a kwik file.""" - if '/recordings' in kwik: - recordings = _list_int_children(kwik['/recordings']) - else: - recordings = [] - # TODO: return a dictionary of recordings instead of a list of recording - # ids. - # return {rec: Bunch({ - # 'start': kwik['/recordings/{0}'.format(rec)].attrs['start_sample'] - # }) for rec in recordings} - return recordings - - -def _list_channels(kwik, channel_group=None): - """Return the list of channels in a kwik file.""" - assert isinstance(channel_group, six.integer_types) - path = '/channel_groups/{0:d}/channels'.format(channel_group) - if path in kwik: - channels = _list_int_children(kwik[path]) - return channels - else: - return [] - - -def _list_clusterings(kwik, channel_group=None): - """Return the list of clusterings in a kwik file.""" - if channel_group is None: - raise RuntimeError("channel_group must be specified when listing " - "the clusterings.") - assert isinstance(channel_group, six.integer_types) - path = '/channel_groups/{0:d}/clusters'.format(channel_group) - if path not in kwik: - return [] - clusterings = sorted(kwik[path].keys()) - # Ensure 'main' is the first if it exists. - if 'main' in clusterings: - clusterings.remove('main') - clusterings = ['main'] + clusterings - return clusterings - - -def _concatenate_spikes(spikes, recs, offsets): - """Concatenate spike samples belonging to consecutive recordings.""" - assert offsets is not None - spikes = _as_array(spikes) - offsets = _as_array(offsets) - recs = _as_array(recs) - return (spikes + offsets[recs]).astype(np.uint64) - - -def _create_cluster_group(f, group_id, name, - clustering=None, - channel_group=None, - write_color=True, - ): - cg_path = ('/channel_groups/{0:d}/' - 'cluster_groups/{1:s}/{2:d}').format(channel_group, - clustering, - group_id, - ) - kv_path = cg_path + '/application_data/klustaviewa' - f.write_attr(cg_path, 'name', name) - if write_color: - f.write_attr(kv_path, 'color', randint(2, 10)) - - -def _create_clustering(f, name, - channel_group=None, - spike_clusters=None, - cluster_groups=None, - ): - if cluster_groups is None: - cluster_groups = {} - assert isinstance(f, File) - path = '/channel_groups/{0:d}/spikes/clusters/{1:s}'.format(channel_group, - name) - assert not f.exists(path) - - # Save spike_clusters. - f.write(path, spike_clusters.astype(np.int32)) - - cluster_ids = _unique(spike_clusters) - - # Create cluster metadata. - for cluster in cluster_ids: - cluster_path = '/channel_groups/{0:d}/clusters/{1:s}/{2:d}'.format( - channel_group, name, cluster) - kv_path = cluster_path + '/application_data/klustaviewa' - - # Default group: unsorted. - cluster_group = cluster_groups.get(cluster, 3) - f.write_attr(cluster_path, 'cluster_group', cluster_group) - f.write_attr(kv_path, 'color', randint(2, 10)) - - # Create cluster group metadata. - for group_id, cg_name in _DEFAULT_GROUPS: - _create_cluster_group(f, group_id, cg_name, - clustering=name, - channel_group=channel_group, - ) - - -def list_kwik(folders): - """Return the list of Kwik files found in a list of folders.""" - ret = [] - for d in folders: - for root, dirs, files in os.walk(os.path.expanduser(d)): - for f in files: - if f.endswith(".kwik"): - ret.append(os.path.join(root, f)) - return ret - - -def _open_h5_if_exists(kwik_path, file_type, mode=None): - basename, ext = op.splitext(kwik_path) - path = '{basename}.{ext}'.format(basename=basename, ext=file_type) - return open_h5(path, mode=mode) if op.exists(path) else None - - -def _read_traces(kwik, dtype=None, n_channels=None): - kwd_path = None - dat_path = None - if '/recordings' not in kwik: - return (None, None) - recordings = kwik.children('/recordings') - traces = [] - opened_files = [] - for recording in recordings: - # Is there a path specified to a .raw.kwd file which exists in - # [KWIK]/recordings/[X]/raw? If so, open it. - path = '/recordings/{}/raw'.format(recording) - if kwik.has_attr(path, 'hdf5_path'): - kwd_path = kwik.read_attr(path, 'hdf5_path')[:-8] - kwd = _open_h5_if_exists(kwd_path, 'raw.kwd') - if kwd is None: - debug("{} not found, trying same basename in KWIK dir" - .format(kwd_path)) - else: - debug("Loading traces: {}" - .format(kwd_path)) - traces.append(kwd.read('/recordings/{}/data' - .format(recording))) - opened_files.append(kwd) - continue - # Is there a path specified to a .dat file which exists? - elif kwik.has_attr(path, 'dat_path'): - assert dtype is not None - assert n_channels - dat_path = kwik.read_attr(path, 'dat_path') - if not op.exists(dat_path): - debug("{} not found, trying same basename in KWIK dir" - .format(dat_path)) - else: - debug("Loading traces: {}" - .format(dat_path)) - dat = _dat_to_traces(dat_path, dtype=dtype, - n_channels=n_channels) - traces.append(dat) - opened_files.append(dat) - continue - - # Does a file exist with the same name in the current directory? - # If so, open it. - if kwd_path is not None: - rel_path = str(op.basename(kwd_path)) - rel_path = op.join(op.dirname(op.realpath(kwik.filename)), - rel_path) - kwd = _open_h5_if_exists(rel_path, 'raw.kwd') - if kwd is None: - debug("{} not found, trying experiment basename in KWIK dir" - .format(rel_path)) - else: - debug("Loading traces: {}" - .format(rel_path)) - traces.append(kwd.read('/recordings/{}/data' - .format(recording))) - opened_files.append(kwd) - continue - elif dat_path is not None: - rel_path = op.basename(dat_path) - rel_path = op.join(op.dirname(op.realpath(kwik.filename)), - rel_path) - if not op.exists(rel_path): - debug("{} not found, trying experiment basename in KWIK dir" - .format(rel_path)) - else: - debug("Loading traces: {}" - .format(rel_path)) - dat = _dat_to_traces(rel_path, dtype=dtype, - n_channels=n_channels) - traces.append(dat) - opened_files.append(dat) - continue - - # Finally, is there a `raw.kwd` with the experiment basename in the - # current directory? If so, open it. - kwd = _open_h5_if_exists(kwik.filename, 'raw.kwd') - if kwd is None: - warn("Could not find any data source for traces (raw.kwd or " - ".dat or .bin.) Waveforms and traces will not be available.") - else: - debug("Successfully loaded basename.raw.kwd in same directory") - traces.append(kwd.read('/recordings/{}/data'.format(recording))) - opened_files.append(kwd) - - return traces, opened_files - - -_DEFAULT_GROUPS = [(0, 'Noise'), - (1, 'MUA'), - (2, 'Good'), - (3, 'Unsorted'), - ] - - -"""Metadata fields that must be provided when creating the Kwik file.""" -_mandatory_metadata_fields = ('dtype', - 'n_channels', - 'prb_file', - 'raw_data_files', - ) - - -def cluster_group_id(name_or_id): - """Return the id of a cluster group from its name.""" - if isinstance(name_or_id, six.string_types): - d = {group.lower(): id for id, group in _DEFAULT_GROUPS} - return d[name_or_id.lower()] - else: - assert _is_integer(name_or_id) - return name_or_id - - -#------------------------------------------------------------------------------ -# KwikModel class -#------------------------------------------------------------------------------ - -class KwikModel(BaseModel): - """Holds data contained in a kwik file.""" - - """Names of the default cluster groups.""" - default_cluster_groups = dict(_DEFAULT_GROUPS) - - def __init__(self, kwik_path=None, - channel_group=None, - clustering=None, - waveform_filter=True, - ): - super(KwikModel, self).__init__() - - # Initialize fields. - self._spike_samples = None - self._spike_clusters = None - self._spikes_per_cluster = None - self._metadata = None - self._clustering = clustering or 'main' - self._probe = None - self._channels = [] - self._channel_order = None - self._features = None - self._features_masks = None - self._masks = None - self._waveforms = None - self._cluster_metadata = None - self._clustering_metadata = {} - self._traces = None - self._recording_offsets = None - self._waveform_loader = None - self._waveform_filter = waveform_filter - self._opened_files = [] - - # Open the experiment. - self.kwik_path = kwik_path - self.open(kwik_path, - channel_group=channel_group, - clustering=clustering) - - @property - def path(self): - return self.kwik_path - - # Internal properties and methods - # ------------------------------------------------------------------------- - - def _check_kwik_version(self): - # This class only works with kwik version 2 for now. - kwik_version = self._kwik.read_attr('/', 'kwik_version') - if kwik_version != 2: - raise IOError("The kwik version is {v} != 2.".format(kwik_version)) - - @property - def _channel_groups_path(self): - return '/channel_groups/{0:d}'.format(self._channel_group) - - @property - def _spikes_path(self): - return '{0:s}/spikes'.format(self._channel_groups_path) - - @property - def _channels_path(self): - return '{0:s}/channels'.format(self._channel_groups_path) - - @property - def _clusters_path(self): - return '{0:s}/clusters'.format(self._channel_groups_path) - - def _cluster_path(self, cluster): - return '{0:s}/{1:d}'.format(self._clustering_path, cluster) - - @property - def _spike_clusters_path(self): - return '{0:s}/clusters/{1:s}'.format(self._spikes_path, - self._clustering) - - @property - def _clustering_path(self): - return '{0:s}/{1:s}'.format(self._clusters_path, self._clustering) - - # Loading and saving - # ------------------------------------------------------------------------- - - def _open_kwik_if_needed(self, mode=None): - if not self._kwik.is_open(): - self._kwik.open(mode=mode) - return True - else: - if mode is not None: - self._kwik.mode = mode - return False - - @property - def n_samples_waveforms(self): - return (self._metadata['extract_s_before'] + - self._metadata['extract_s_after']) - - def _create_waveform_loader(self): - """Create a waveform loader.""" - n_samples = (self._metadata['extract_s_before'], - self._metadata['extract_s_after']) - order = self._metadata['filter_butter_order'] - rate = self._metadata['sample_rate'] - low = self._metadata['filter_low'] - high = self._metadata['filter_high_factor'] * rate - b_filter = bandpass_filter(rate=rate, - low=low, - high=high, - order=order) - - if self._metadata.get('waveform_filter', True): - debug("Enable waveform filter.") - - def filter(x): - return apply_filter(x, b_filter) - - filter_margin = order * 3 - else: - debug("Disable waveform filter.") - filter = None - filter_margin = 0 - - dc_offset = self._metadata.get('waveform_dc_offset', None) - scale_factor = self._metadata.get('waveform_scale_factor', None) - self._waveform_loader = WaveformLoader(n_samples=n_samples, - filter=filter, - filter_margin=filter_margin, - dc_offset=dc_offset, - scale_factor=scale_factor, - ) - - def _update_waveform_loader(self): - if self._traces is not None: - self._waveform_loader.traces = self._traces - else: - self._waveform_loader.traces = np.zeros((0, self.n_channels), - dtype=np.float32) - - # Update the list of channels for the waveform loader. - self._waveform_loader.channels = self._channel_order - - def _create_cluster_metadata(self): - self._cluster_metadata = ClusterMetadata() - - @self._cluster_metadata.default - def group(cluster): - # Default group is unsorted. - return 3 - - def _load_meta(self): - """Load metadata from kwik file.""" - # Automatically load all metadata from spikedetekt group. - path = '/application_data/spikedetekt/' - params = {} - for attr in self._kwik.attrs(path): - params[attr] = self._kwik.read_attr(path, attr) - # Make sure all params are there. - default_params = {} - settings = _load_default_settings() - default_params.update(settings['traces']) - default_params.update(settings['spikedetekt']) - default_params.update(settings['klustakwik2']) - for name, default_value in default_params.items(): - if name not in params: - params[name] = default_value - self._metadata = params - - def _load_probe(self): - # Re-create the probe from the Kwik file. - channel_groups = {} - for group in self._channel_groups: - cg_p = '/channel_groups/{:d}'.format(group) - c_p = cg_p + '/channels' - channels = self._kwik.read_attr(cg_p, 'channel_order') - graph = self._kwik.read_attr(cg_p, 'adjacency_graph') - positions = { - channel: self._kwik.read_attr(c_p + '/' + str(channel), - 'position') - for channel in channels - } - channel_groups[group] = { - 'channels': channels, - 'graph': graph, - 'geometry': positions, - } - probe = {'channel_groups': channel_groups} - self._probe = MEA(probe=probe) - - def _load_recordings(self): - # Load recordings. - self._recordings = _list_recordings(self._kwik.h5py_file) - # This will be updated later if a KWD file is present. - self._recording_offsets = [0] * (len(self._recordings) + 1) - - def _load_channels(self): - self._channels = np.array(_list_channels(self._kwik.h5py_file, - self._channel_group)) - self._channel_order = self._probe.channels - assert set(self._channel_order) <= set(self._channels) - - def _load_channel_groups(self, channel_group=None): - self._channel_groups = _list_channel_groups(self._kwik.h5py_file) - if channel_group is None and self._channel_groups: - # Choose the default channel group if not specified. - channel_group = self._channel_groups[0] - # Load the channel group. - self._channel_group = channel_group - - def _load_features_masks(self): - - # Load features masks. - path = '{0:s}/features_masks'.format(self._channel_groups_path) - - nfpc = self._metadata['n_features_per_channel'] - nc = len(self.channel_order) - - if self._kwx is not None: - self._kwx = _open_h5_if_exists(self.kwik_path, 'kwx') - if path not in self._kwx: - debug("There are no features and masks in the `.kwx` file.") - # No need to keep the file open if it is empty. - self._kwx.close() - return - fm = self._kwx.read(path) - self._features_masks = fm - self._features = PartialArray(fm, 0) - - # This partial array simulates a (n_spikes, n_channels) array. - self._masks = PartialArray(fm, (slice(0, nfpc * nc, nfpc), 1)) - assert self._masks.shape == (self.n_spikes, nc) - - def _load_spikes(self): - # Load spike samples. - path = '{0:s}/time_samples'.format(self._spikes_path) - - # Concatenate the spike samples from consecutive recordings. - if path not in self._kwik: - debug("There are no spikes in the dataset.") - return - _spikes = self._kwik.read(path)[:] - self._spike_recordings = self._kwik.read( - '{0:s}/recording'.format(self._spikes_path))[:] - self._spike_samples = _concatenate_spikes(_spikes, - self._spike_recordings, - self._recording_offsets) - - def _load_spike_clusters(self): - self._spike_clusters = self._kwik.read(self._spike_clusters_path)[:] - - def _save_spike_clusters(self, spike_clusters): - assert spike_clusters.shape == self._spike_clusters.shape - assert spike_clusters.dtype == self._spike_clusters.dtype - self._spike_clusters = spike_clusters - sc = self._kwik.read(self._spike_clusters_path) - sc[:] = spike_clusters - - def _load_clusterings(self, clustering=None): - # Once the channel group is loaded, list the clusterings. - self._clusterings = _list_clusterings(self._kwik.h5py_file, - self.channel_group) - # Choose the first clustering (should always be 'main'). - if clustering is None and self.clusterings: - clustering = self.clusterings[0] - # Load the specified clustering. - self._clustering = clustering - - def _load_cluster_groups(self): - clusters = self._kwik.groups(self._clustering_path) - clusters = [int(cluster) for cluster in clusters] - for cluster in clusters: - path = self._cluster_path(cluster) - group = self._kwik.read_attr(path, 'cluster_group') - self._cluster_metadata.set_group([cluster], group) - - def _save_cluster_groups(self, cluster_groups): - assert isinstance(cluster_groups, dict) - for cluster, group in cluster_groups.items(): - path = self._cluster_path(cluster) - self._kwik.write_attr(path, 'cluster_group', group) - self._cluster_metadata.set_group([cluster], group) - - def _load_clustering_metadata(self): - attrs = self._kwik.attrs(self._clustering_path) - metadata = {} - for attr in attrs: - try: - metadata[attr] = self._kwik.read_attr(self._clustering_path, - attr) - except OSError: - debug("Error when reading `{}:{}`.".format( - self._clustering_path, attr)) - self._clustering_metadata = metadata - - def _save_clustering_metadata(self, metadata): - if not metadata: - return - assert isinstance(metadata, dict) - for name, value in metadata.items(): - path = self._clustering_path - self._kwik.write_attr(path, name, value) - self._clustering_metadata.update(metadata) - - def _load_traces(self): - n_channels = self._metadata.get('n_channels', None) - dtype = self._metadata.get('dtype', None) - dtype = np.dtype(dtype) if dtype else None - traces, opened_files = _read_traces(self._kwik, - dtype=dtype, - n_channels=n_channels) - - if not traces: - return - - # Update the list of opened files for cleanup. - self._opened_files.extend(opened_files) - - # Set the recordings offsets (no delay between consecutive recordings). - i = 0 - self._recording_offsets = [] - for trace in traces: - self._recording_offsets.append(i) - i += trace.shape[0] - self._traces = _concatenate_virtual_arrays(traces) - - def has_kwx(self): - """Returns whether the `.kwx` file is present. - - If not, the features and masks won't be available. - - """ - return self._kwx is not None - - def open(self, kwik_path, channel_group=None, clustering=None): - """Open a Kwik dataset. - - The `.kwik` and `.kwx` must be in the same folder with the - same basename. - - The files containing the traces (`.raw.kwd` or `.dat` / `.bin`) are - determined according to the following logic: - - - Is there a path specified to a file which exists in - [KWIK]/recordings/[X]/raw? If so, open it. - - If this file does not exist, does a file exist with the same name - in the current directory? If so, open it. - - If such a file does not exist, or no filename is specified in - the [KWIK], then is there a `raw.kwd` with the experiment basename - in the current directory? If so, open it. - - If not, return with a warning. - - Notes - ----- - - The `.kwik` file is opened in read-only mode, and is automatically - closed when this function returns. It is temporarily reopened when - the channel group or clustering changes. - - The `.kwik` file is temporarily opened in append mode when saving. - - The `.kwx` and `.raw.kwd` or `.dat` / `.bin` files stay open in - read-only mode as long as `model.close()` is not called. This is - because there might be read accesses to `features_masks` (`.kwx`) - and waveforms (`.raw.kwd` or `.dat` / `.bin`) while the dataset is - opened. - - Parameters - ---------- - - kwik_path : str - Path to a `.kwik` file. - channel_group : int or None (default is None) - The channel group (shank) index to use. This can be changed - later after the file has been opened. By default, the first - channel group is used. - clustering : str or None (default is None) - The clustering to use. This can be changed later after the file - has been opened. By default, the `main` clustering is used. An - error is raised if the `main` clustering doesn't exist. - - """ - - if kwik_path is None: - raise ValueError("No kwik_path specified.") - - if not kwik_path.endswith('.kwik'): - raise ValueError("File does not end in .kwik") - - # Open the file. - kwik_path = op.realpath(kwik_path) - self.kwik_path = kwik_path - self.name = op.splitext(op.basename(kwik_path))[0] - - # Open the KWIK file. - self._kwik = _open_h5_if_exists(kwik_path, 'kwik') - if self._kwik is None: - raise IOError("File `{0}` doesn't exist.".format(kwik_path)) - if not self._kwik.is_open(): - raise IOError("File `{0}` failed to open.".format(kwik_path)) - self._check_kwik_version() - - # Open the KWX and KWD files. - self._kwx = _open_h5_if_exists(kwik_path, 'kwx') - if self._kwx is None: - warn("The `.kwx` file hasn't been found. " - "Features and masks won't be available.") - - # KwikCreator instance. - self.creator = KwikCreator(kwik_path=kwik_path) - - # Load the data. - self._load_meta() - - # This needs metadata. - self._create_waveform_loader() - - self._load_recordings() - - # This generates the recording offset. - self._load_traces() - - self._load_channel_groups(channel_group) - - # Load the probe. - self._load_probe() - - # This needs channel groups. - self._load_clusterings(clustering) - - # This needs the recording offsets. - # This loads channels, channel_order, spikes, probe. - self._channel_group_changed(self._channel_group) - - # This loads spike clusters and cluster groups. - self._clustering_changed(clustering) - - # This needs channels, channel_order, and waveform loader. - self._update_waveform_loader() - - # No need to keep the kwik file open. - self._kwik.close() - - def save(self, spike_clusters, cluster_groups, clustering_metadata=None): - """Save the spike clusters and cluster groups in the Kwik file.""" - - # REFACTOR: with() to open/close the file if needed - to_close = self._open_kwik_if_needed(mode='a') - - self._save_spike_clusters(spike_clusters) - self._save_cluster_groups(cluster_groups) - self._save_clustering_metadata(clustering_metadata) - - if to_close: - self._kwik.close() - - def describe(self): - """Display information about the dataset.""" - def _print(name, value): - print("{0: <24}{1}".format(name, value)) - _print("Kwik file", self.kwik_path) - _print("Recordings", self.n_recordings) - - # List of channel groups. - cg = ['{:d}'.format(g) + ('*' if g == self.channel_group else '') - for g in self.channel_groups] - _print("List of shanks", ', '.join(cg)) - - # List of clusterings. - cl = ['{:s}'.format(c) + ('*' if c == self.clustering else '') - for c in self.clusterings] - _print("Clusterings", ', '.join(cl)) - - _print("Channels", self.n_channels) - _print("Spikes", self.n_spikes) - _print("Clusters", self.n_clusters) - _print("Duration", "{:.0f}s".format(self.duration)) - - # Changing channel group and clustering - # ------------------------------------------------------------------------- - - def _channel_group_changed(self, value): - """Called when the channel group changes.""" - if value not in self.channel_groups: - raise ValueError("The channel group {0} is invalid.".format(value)) - self._channel_group = value - - # Load data. - _to_close = self._open_kwik_if_needed() - - if self._kwik.h5py_file: - clusterings = _list_clusterings(self._kwik.h5py_file, - self._channel_group) - else: - warn(".kwik filepath doesn't exist.") - clusterings = None - - if clusterings: - if 'main' in clusterings: - self._load_clusterings('main') - self.clustering = 'main' - else: - self._load_clusterings(clusterings[0]) - self.clustering = clusterings[0] - - self._probe.change_channel_group(value) - self._load_channels() - self._load_spikes() - self._load_features_masks() - - # Recalculate spikes_per_cluster manually - self._spikes_per_cluster = \ - _spikes_per_cluster(self.spike_ids, self._spike_clusters) - - if _to_close: - self._kwik.close() - - # Update the list of channels for the waveform loader. - self._waveform_loader.channels = self._channel_order - - def _clustering_changed(self, value): - """Called when the clustering changes.""" - if value is None: - return - if value not in self.clusterings: - raise ValueError("The clustering {0} is invalid.".format(value)) - self._clustering = value - - # Load data. - _to_close = self._open_kwik_if_needed() - self._create_cluster_metadata() - self._load_spike_clusters() - self._load_cluster_groups() - self._load_clustering_metadata() - if _to_close: - self._kwik.close() - - # Managing cluster groups - # ------------------------------------------------------------------------- - - def _write_cluster_group(self, group_id, name, write_color=True): - if group_id <= 3: - raise ValueError("Default groups cannot be changed.") - - _to_close = self._open_kwik_if_needed(mode='a') - - _create_cluster_group(self._kwik, group_id, name, - clustering=self._clustering, - channel_group=self._channel_group, - write_color=write_color, - ) - - if _to_close: - self._kwik.close() - - def add_cluster_group(self, group_id, name): - """Add a new cluster group.""" - self._write_cluster_group(group_id, name, write_color=True) - - def rename_cluster_group(self, group_id, name): - """Rename an existing cluster group.""" - self._write_cluster_group(group_id, name, write_color=False) - - def delete_cluster_group(self, group_id): - if group_id <= 3: - raise ValueError("Default groups cannot be deleted.") - - path = ('/channel_groups/{0:d}/' - 'cluster_groups/{1:s}/{2:d}').format(self._channel_group, - self._clustering, - group_id, - ) - - _to_close = self._open_kwik_if_needed(mode='a') - - self._kwik.delete(path) - - if _to_close: - self._kwik.close() - - # Managing clusterings - # ------------------------------------------------------------------------- - - def add_clustering(self, name, spike_clusters): - """Save a new clustering to the file.""" - if name in self._clusterings: - raise ValueError("The clustering '{0}' ".format(name) + - "already exists.") - assert len(spike_clusters) == self.n_spikes - - _to_close = self._open_kwik_if_needed(mode='a') - - _create_clustering(self._kwik, - name, - channel_group=self._channel_group, - spike_clusters=spike_clusters, - ) - - # Update the list of clusterings. - self._load_clusterings(self._clustering) - - if _to_close: - self._kwik.close() - - def _move_clustering(self, old_name, new_name, copy=None): - if not copy and old_name == self._clustering: - raise ValueError("You cannot move the current clustering.") - if new_name in self._clusterings: - raise ValueError("The clustering '{0}' ".format(new_name) + - "already exists.") - - _to_close = self._open_kwik_if_needed(mode='a') - - if copy: - func = self._kwik.copy - else: - func = self._kwik.move - - # /channel_groups/x/spikes/clusters/ - p = self._spikes_path + '/clusters/' - func(p + old_name, p + new_name) - - # /channel_groups/x/clusters/ - p = self._clusters_path + '/' - func(p + old_name, p + new_name) - - # /channel_groups/x/cluster_groups/ - p = self._channel_groups_path + '/cluster_groups/' - func(p + old_name, p + new_name) - - # Update the list of clusterings. - self._load_clusterings(self._clustering) - - if _to_close: - self._kwik.close() - - def rename_clustering(self, old_name, new_name): - """Rename a clustering in the `.kwik` file.""" - self._move_clustering(old_name, new_name, copy=False) - - def copy_clustering(self, name, new_name): - """Copy a clustering in the `.kwik` file.""" - self._move_clustering(name, new_name, copy=True) - - def delete_clustering(self, name): - """Delete a clustering.""" - if name == self._clustering: - raise ValueError("You cannot delete the current clustering.") - if name not in self._clusterings: - raise ValueError(("The clustering {0} " - "doesn't exist.").format(name)) - - _to_close = self._open_kwik_if_needed(mode='a') - - # /channel_groups/x/spikes/clusters/ - parent = self._kwik.read(self._spikes_path + '/clusters/') - del parent[name] - - # /channel_groups/x/clusters/ - parent = self._kwik.read(self._clusters_path) - del parent[name] - - # /channel_groups/x/cluster_groups/ - parent = self._kwik.read(self._channel_groups_path + - '/cluster_groups/') - del parent[name] - - # Update the list of clusterings. - self._load_clusterings(self._clustering) - - if _to_close: - self._kwik.close() - - # Data - # ------------------------------------------------------------------------- - - @property - def duration(self): - """Duration of the experiment (in seconds).""" - if self._traces is None: - return 0. - return float(self.traces.shape[0]) / self.sample_rate - - @property - def channel_groups(self): - """List of channel groups found in the Kwik file.""" - return self._channel_groups - - @property - def n_features_per_channel(self): - """Number of features per channel (generally 3).""" - return self._metadata['n_features_per_channel'] - - @property - def channels(self): - """List of all channels in the current channel group. - - This list comes from the /channel_groups HDF5 group in the Kwik file. - - """ - # TODO: rename to channel_ids? - return self._channels - - @property - def channel_order(self): - """List of kept channels in the current channel group. - - If you want the channels used in the features, masks, and waveforms, - this is the property you want to use, and not `model.channels`. - - The channel order is the same than the one from the PRB file. - This order was used when generating the features and masks - in SpikeDetekt2. The same order is used in phy when loading the - waveforms from the traces file(s). - - """ - return self._channel_order - - @property - def n_channels(self): - """Number of all channels in the current channel group.""" - return len(self._channels) - - @property - def recordings(self): - """List of recording indices found in the Kwik file.""" - return self._recordings - - @property - def n_recordings(self): - """Number of recordings found in the Kwik file.""" - return len(self._recordings) - - @property - def clusterings(self): - """List of clusterings found in the Kwik file. - - The first one is always `main`. - - """ - return self._clusterings - - @property - def clustering(self): - """The currently-active clustering. - - Default is `main`. - - """ - return self._clustering - - @clustering.setter - def clustering(self, value): - """Change the currently-active clustering.""" - self._clustering_changed(value) - - @property - def clustering_metadata(self): - """A dictionary of key-value metadata specific to the current - clustering.""" - return self._clustering_metadata - - @property - def metadata(self): - """A dictionary holding metadata about the experiment. - - This information comes from the PRM file. It was used by - SpikeDetekt2 and KlustaKwik during automatic clustering. - - """ - return self._metadata - - @property - def probe(self): - """A `Probe` instance representing the probe used for the recording. - - This object contains information about the adjacency graph and - the channel positions. - - """ - return self._probe - - @property - def traces(self): - """Raw traces as found in the traces file(s). - - This object is memory-mapped to the HDF5 file, or `.dat` / `.bin` file, - or both. - - """ - return self._traces - - @property - def spike_samples(self): - """Spike samples from the current channel group. - - This is a NumPy array containing `uint64` values (number of samples - in unit of the sample rate). - - The spike times of all recordings are concatenated. There is no gap - between consecutive recordings, currently. - - """ - return self._spike_samples - - @property - def sample_rate(self): - """Sample rate of the recording. - - This value is found in the metadata coming from the PRM file. - - """ - return float(self._metadata['sample_rate']) - - @property - def spike_recordings(self): - """The recording index for each spike. - - This is a NumPy array of integers with `n_spikes` elements. - - """ - return self._spike_recordings - - @property - def n_spikes(self): - """Number of spikes in the current channel group.""" - return (len(self._spike_samples) - if self._spike_samples is not None else 0) - - @property - def features(self): - """Features from the current channel group. - - This is memory-mapped to the `.kwx` file. - - Note: in general, it is better to use the cluster store to access - the features and masks of some clusters. - - """ - return self._features - - @property - def masks(self): - """Masks from the current channel group. - - This is memory-mapped to the `.kwx` file. - - Note: in general, it is better to use the cluster store to access - the features and masks of some clusters. - - """ - return self._masks - - @property - def features_masks(self): - """Features-masks from the current channel group. - - This is memory-mapped to the `.kwx` file. - - Note: in general, it is better to use the cluster store to access - the features and masks of some clusters. - - """ - return self._features_masks - - @property - def waveforms(self): - """High-passed filtered waveforms from the current channel group. - - This is a virtual array mapped to the traces file(s). Filtering is - done on the fly. - - The shape is `(n_spikes, n_samples, n_channels)`. - - """ - return SpikeLoader(self._waveform_loader, self.spike_samples) - - @property - def spike_clusters(self): - """Spike clusters from the current channel group and clustering. - - Every element is the cluster identifier of a spike. - - The shape is `(n_spikes,)`. - - """ - return self._spike_clusters - - @property - def spikes_per_cluster(self): - """Spikes per cluster from the current channel group and clustering.""" - if self._spikes_per_cluster is None: - if self._spike_clusters is None: - self._spikes_per_cluster = {0: self.spike_ids} - else: - self._spikes_per_cluster = \ - _spikes_per_cluster(self.spike_ids, self._spike_clusters) - return self._spikes_per_cluster - - def update_spikes_per_cluster(self, spc): - self._spikes_per_cluster = spc - - @property - def cluster_metadata(self): - """Metadata about the clusters in the current channel group and - clustering. - - `cluster_metadata.group(cluster_id)` returns the group of a given - cluster. The default group is 3 (unsorted). - - """ - return self._cluster_metadata - - @property - def cluster_ids(self): - """List of cluster ids from the current channel group and clustering. - - This is a sorted list of unique cluster ids as found in the current - `spike_clusters` array. - - """ - return _unique(self._spike_clusters) - - @property - def spike_ids(self): - """List of spike ids.""" - return np.arange(self.n_spikes, dtype=np.int32) - - @property - def n_clusters(self): - """Number of clusters in the current channel group and clustering.""" - return len(self.cluster_ids) - - # Close - # ------------------------------------------------------------------------- - - def close(self): - """Close the `.kwik` and `.kwx` files if they are open, and cleanup - handles to all raw data files""" - - debug("Closing files") - if self._kwx is not None: - self._kwx.close() - for f in self._opened_files: - # upside-down if statement to avoid false positive lint error - if not (isinstance(f, np.ndarray)): - f.close() - else: - del f - self._kwik.close() diff --git a/phy/io/kwik/sparse_kk2.py b/phy/io/kwik/sparse_kk2.py deleted file mode 100644 index a28257b6f..000000000 --- a/phy/io/kwik/sparse_kk2.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Convert data to sparse structures for KK2.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ...utils.array import chunk_bounds -import six - - -#------------------------------------------------------------------------------ -# KlustaKwik2 functions -#------------------------------------------------------------------------------ - -def sparsify_features_masks(features, masks, chunk_size=10000): - from klustakwik2 import RawSparseData - - assert features.ndim == 2 - assert masks.ndim == 2 - assert features.shape == masks.shape - - n_spikes, num_features = features.shape - - # Stage 1: read min/max of fet values for normalization - # and count total number of unmasked features. - vmin = np.ones(num_features) * np.inf - vmax = np.ones(num_features) * (-np.inf) - total_unmasked_features = 0 - for _, _, i, j in chunk_bounds(n_spikes, chunk_size): - f, m = features[i:j], masks[i:j] - inds = m > 0 - # Replace the masked values by NaN. - vmin = np.minimum(np.min(f, axis=0), vmin) - vmax = np.maximum(np.max(f, axis=0), vmax) - total_unmasked_features += inds.sum() - # Stage 2: read data line by line, normalising - vdiff = vmax - vmin - vdiff[vdiff == 0] = 1 - fetsum = np.zeros(num_features) - fet2sum = np.zeros(num_features) - nsum = np.zeros(num_features) - all_features = np.zeros(total_unmasked_features) - all_fmasks = np.zeros(total_unmasked_features) - all_unmasked = np.zeros(total_unmasked_features, dtype=int) - offsets = np.zeros(n_spikes + 1, dtype=int) - curoff = 0 - for i in six.moves.range(n_spikes): - fetvals, fmaskvals = (features[i] - vmin) / vdiff, masks[i] - inds = (fmaskvals > 0).nonzero()[0] - masked_inds = (fmaskvals == 0).nonzero()[0] - all_features[curoff:curoff + len(inds)] = fetvals[inds] - all_fmasks[curoff:curoff + len(inds)] = fmaskvals[inds] - all_unmasked[curoff:curoff + len(inds)] = inds - offsets[i] = curoff - curoff += len(inds) - fetsum[masked_inds] += fetvals[masked_inds] - fet2sum[masked_inds] += fetvals[masked_inds] ** 2 - nsum[masked_inds] += 1 - offsets[-1] = curoff - - nsum[nsum == 0] = 1 - noise_mean = fetsum / nsum - noise_variance = fet2sum / nsum - noise_mean ** 2 - - return RawSparseData(noise_mean, - noise_variance, - all_features, - all_fmasks, - all_unmasked, - offsets, - ) diff --git a/phy/io/kwik/store_items.py b/phy/io/kwik/store_items.py deleted file mode 100644 index f8d8f8f5f..000000000 --- a/phy/io/kwik/store_items.py +++ /dev/null @@ -1,711 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""Store items for Kwik.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np - -from ...utils.selector import Selector -from ...utils.array import (_index_of, - _spikes_per_cluster, - _concatenate_per_cluster_arrays, - ) -from ..store import ClusterStore, FixedSizeItem, VariableSizeItem - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _default_array(shape, value=0, n_spikes=0, dtype=np.float32): - shape = (n_spikes,) + shape[1:] - out = np.empty(shape, dtype=dtype) - out.fill(value) - return out - - -def _atleast_nd(arr, ndim): - if arr.ndim == ndim: - return arr - assert arr.ndim < ndim - if ndim - arr.ndim == 1: - return arr[None, ...] - elif ndim - arr.ndim == 2: - return arr[None, None, ...] - - -def _mean(arr, shape): - if arr is not None: - # Make sure (possibly lazy) memmapped arrays are fully loaded. - arr = arr[...] - assert isinstance(arr, np.ndarray) - if arr.shape[0]: - return arr.mean(axis=0) - return np.zeros(shape, dtype=np.float32) - - -#------------------------------------------------------------------------------ -# Store items -#------------------------------------------------------------------------------ - -class FeatureMasks(VariableSizeItem): - """Store all features and masks of all clusters.""" - name = 'features and masks' - fields = ['features', 'masks'] - - def __init__(self, *args, **kwargs): - # Size of the chunk used when reading features and masks from the HDF5 - # .kwx file. - self.chunk_size = kwargs.pop('chunk_size') - - super(FeatureMasks, self).__init__(*args, **kwargs) - - self.n_features = self.model.n_features_per_channel - self.n_channels = len(self.model.channel_order) - self.n_spikes = self.model.n_spikes - self.n_chunks = self.n_spikes // self.chunk_size + 1 - self._shapes['masks'] = (-1, self.n_channels) - self._shapes['features'] = (-1, self.n_channels, self.n_features) - - def _store(self, - cluster, - chunk_spikes, - chunk_spikes_per_cluster, - chunk_features_masks, - ): - - nc = self.n_channels - nf = self.n_features - - # Number of spikes in the cluster and in the current - # chunk. - ns = len(chunk_spikes_per_cluster[cluster]) - - # Find the indices of the spikes in that cluster - # relative to the chunk. - idx = _index_of(chunk_spikes_per_cluster[cluster], chunk_spikes) - - # Extract features and masks for that cluster, in the - # current chunk. - tmp = chunk_features_masks[idx, :] - - # NOTE: channel order has already been taken into account - # by SpikeDetekt2 when saving the features and masks. - # All we need to know here is the number of channels - # in channel_order, there is no need to reorder. - - # Features. - f = tmp[:, :nc * nf, 0] - assert f.shape == (ns, nc * nf) - f = f.ravel().astype(np.float32) - - # Masks. - m = tmp[:, :nc * nf, 1][:, ::nf] - assert m.shape == (ns, nc) - m = m.ravel().astype(np.float32) - - # Save the data to disk. - self.disk_store.store(cluster, - features=f, - masks=m, - append=True, - ) - - def mean_masks(self, cluster): - masks = self.cluster_store.masks(cluster) - mean_masks = _mean(masks, (self.n_channels,)) - return mean_masks - - def mean_features(self, cluster): - features = self.cluster_store.features(cluster) - mean_features = _mean(features, (self.n_channels,)) - return mean_features - - def _store_means(self, cluster): - if not self.disk_store: - return - self.disk_store.store(cluster, - mean_masks=self.mean_masks(cluster), - mean_features=self.mean_features(cluster), - ) - - def is_consistent(self, cluster, spikes): - """Return whether the filesizes of the two cluster store files - (`.features` and `.masks`) are correct.""" - cluster_size = len(spikes) - expected_file_sizes = [('masks', (cluster_size * - self.n_channels * - 4)), - ('features', (cluster_size * - self.n_channels * - self.n_features * - 4))] - for name, expected_file_size in expected_file_sizes: - path = self.disk_store._cluster_path(cluster, name) - if not op.exists(path): - return False - actual_file_size = os.stat(path).st_size - if expected_file_size != actual_file_size: - return False - return True - - def store_all(self, mode=None): - """Store the features, masks, and their means of the clusters that - need it. - - Parameters - ---------- - - mode : str or None - How to choose whether cluster files need to be re-generated. - Can be one of the following options: - - * `None` or `default`: only regenerate the missing or inconsistent - clusters - * `force`: fully regenerate all clusters - * `read-only`: just load the existing files, do not write anything - - """ - - # No need to regenerate the cluster store if it exists and is valid. - clusters_to_generate = self.to_generate(mode=mode) - need_generate = len(clusters_to_generate) > 0 - - if need_generate: - - self._pr.value_max = self.n_chunks + 1 - - fm = self.model.features_masks - assert fm.shape[0] == self.n_spikes - - for i in range(self.n_chunks): - a, b = i * self.chunk_size, (i + 1) * self.chunk_size - - # Load a chunk from HDF5. - chunk_features_masks = fm[a:b] - assert isinstance(chunk_features_masks, np.ndarray) - if chunk_features_masks.shape[0] == 0: - break - - chunk_spike_clusters = self.model.spike_clusters[a:b] - chunk_spikes = np.arange(a, b) - - # Split the spikes. - chunk_spc = _spikes_per_cluster(chunk_spikes, - chunk_spike_clusters) - - # Go through the clusters appearing in the chunk and that - # need to be re-generated. - clusters = (set(chunk_spc.keys()). - intersection(set(clusters_to_generate))) - for cluster in sorted(clusters): - self._store(cluster, - chunk_spikes, - chunk_spc, - chunk_features_masks, - ) - - # Update the progress reporter. - self._pr.value += 1 - - # Store mean features and masks on disk. - self._pr.value = 0 - self._pr.value_max = len(clusters_to_generate) - for cluster in clusters_to_generate: - self._store_means(cluster) - self._pr.value += 1 - - self._pr.set_complete() - - def load(self, cluster, name): - """Load features or masks for a cluster. - - This uses the cluster store if possible, otherwise it falls back - to the model (much slower). - - """ - assert name in ('features', 'masks') - dtype = np.float32 - shape = self._shapes[name] - if self.disk_store: - data = self.disk_store.load(cluster, name, dtype, shape) - if data is not None: - return data - # Fallback to load_spikes if the data could not be obtained from - # the store. - spikes = self.spikes_per_cluster[cluster] - return self.load_spikes(spikes, name) - - def load_spikes(self, spikes, name): - """Load features or masks for an array of spikes.""" - assert name in ('features', 'masks') - shape = self._shapes[name] - data = getattr(self.model, name) - if data is not None and (isinstance(spikes, slice) or len(spikes)): - out = data[spikes] - # Default masks and features. - else: - out = _default_array(shape, - value=0. if name == 'features' else 1., - n_spikes=len(spikes), - ) - return out.reshape(shape) - - def on_merge(self, up): - """Create the cluster store files of the merged cluster - from the files of the old clusters. - - This is basically a concatenation of arrays, but the spike order - needs to be taken into account. - - """ - if not self.disk_store: - return - clusters = up.deleted - spc = up.old_spikes_per_cluster - # We load all masks and features of the merged clusters. - for name, shape in [('features', - (-1, self.n_channels, self.n_features)), - ('masks', - (-1, self.n_channels)), - ]: - arrays = {cluster: self.disk_store.load(cluster, - name, - dtype=np.float32, - shape=shape) - for cluster in clusters} - # Then, we concatenate them using the right insertion order - # as defined by the spikes. - - # OPTIM: this could be made a bit faster by passing - # both arrays at once. - concat = _concatenate_per_cluster_arrays(spc, arrays) - - # Finally, we store the result into the new cluster. - self.disk_store.store(up.added[0], **{name: concat}) - - def on_assign(self, up): - """Create the cluster store files of the new clusters - from the files of the old clusters. - - The files of all old clusters are loaded, re-split and concatenated - to form the new cluster files. - - """ - if not self.disk_store: - return - for name, shape in [('features', - (-1, self.n_channels, self.n_features)), - ('masks', - (-1, self.n_channels)), - ]: - # Load all data from the old clusters. - old_arrays = {cluster: self.disk_store.load(cluster, - name, - dtype=np.float32, - shape=shape) - for cluster in up.deleted} - # Create the new arrays. - for new in up.added: - # Find the old clusters which are parents of the current - # new cluster. - old_clusters = [o - for (o, n) in up.descendants - if n == new] - # Spikes per old cluster, used to create - # the concatenated array. - spc = {} - old_arrays_sub = {} - # Find the relative spike indices of every old cluster - # for the current new cluster. - for old in old_clusters: - # Find the spike indices in the old and new cluster. - old_spikes = up.old_spikes_per_cluster[old] - new_spikes = up.new_spikes_per_cluster[new] - old_in_new = np.in1d(old_spikes, new_spikes) - old_spikes_subset = old_spikes[old_in_new] - spc[old] = old_spikes_subset - # Extract the data from the old cluster to - # be moved to the new cluster. - old_spikes_rel = _index_of(old_spikes_subset, - old_spikes) - old_arrays_sub[old] = old_arrays[old][old_spikes_rel] - # Construct the array of the new cluster. - concat = _concatenate_per_cluster_arrays(spc, - old_arrays_sub) - # Save it in the cluster store. - self.disk_store.store(new, **{name: concat}) - - def on_cluster(self, up=None): - super(FeatureMasks, self).on_cluster(up) - # No need to change anything in the store if this is an undo or - # a redo. - if up is None or up.history is not None: - return - # Store the means of the new clusters. - for cluster in up.added: - self._store_means(cluster) - - -class Waveforms(VariableSizeItem): - """A cluster store item that manages the waveforms of all clusters.""" - name = 'waveforms' - fields = ['waveforms'] - - def __init__(self, *args, **kwargs): - self.n_spikes_max = kwargs.pop('n_spikes_max') - self.excerpt_size = kwargs.pop('excerpt_size') - # Size of the chunk used when reading waveforms from the raw data. - self.chunk_size = kwargs.pop('chunk_size') - - super(Waveforms, self).__init__(*args, **kwargs) - - self.n_channels = len(self.model.channel_order) - self.n_samples = self.model.n_samples_waveforms - self.n_spikes = self.model.n_spikes - - self._shapes['waveforms'] = (-1, self.n_samples, self.n_channels) - self._selector = Selector(self.model.spike_clusters, - n_spikes_max=self.n_spikes_max, - excerpt_size=self.excerpt_size, - ) - - # Get or create the subset spikes per cluster dictionary. - spc = (self.disk_store.load_file('waveforms_spikes') - if self.disk_store else None) - if spc is None: - spc = self._selector.subset_spikes_clusters(self.model.cluster_ids) - if self.disk_store: - self.disk_store.save_file('waveforms_spikes', spc) - self._spikes_per_cluster = spc - - def _subset_spikes_cluster(self, cluster, force=False): - if force or cluster not in self._spikes_per_cluster: - spikes = self._selector.subset_spikes_clusters([cluster])[cluster] - # Persist the new _spikes_per_cluster array on disk. - self._spikes_per_cluster[cluster] = spikes - if self.disk_store: - self.disk_store.save_file('waveforms_spikes', - self._spikes_per_cluster) - return self._spikes_per_cluster[cluster] - - @property - def spikes_per_cluster(self): - """Spikes per cluster.""" - # WARNING: this is read-only, because this item is responsible - # for the spike subselection. - return self._spikes_per_cluster - - def _store_mean(self, cluster): - if not self.disk_store: - return - self.disk_store.store(cluster, - mean_waveforms=self.mean_waveforms(cluster), - ) - - def waveforms_and_mean(self, cluster): - spikes = self._subset_spikes_cluster(cluster, force=True) - waveforms = self.model.waveforms[spikes].astype(np.float32) - mean_waveforms = _mean(waveforms, (self.n_samples, self.n_channels)) - return waveforms, mean_waveforms - - def mean_waveforms(self, cluster): - return self.waveforms_and_mean(cluster)[1] - - def store(self, cluster): - """Store waveforms and mean waveforms.""" - if not self.disk_store: - return - # NOTE: make sure to erase old spikes for that cluster. - # Typical case merge, undo, different merge. - waveforms, mean_waveforms = self.waveforms_and_mean(cluster) - self.disk_store.store(cluster, - waveforms=waveforms, - mean_waveforms=mean_waveforms, - ) - - def store_all(self, mode=None): - if not self.disk_store: - return - - clusters_to_generate = self.to_generate(mode=mode) - need_generate = len(clusters_to_generate) > 0 - - if need_generate: - spc = {cluster: self._subset_spikes_cluster(cluster, force=True) - for cluster in clusters_to_generate} - - # All spikes to fetch and save in the store. - spike_ids = _concatenate_per_cluster_arrays(spc, spc) - - # Load waveforms chunk by chunk for I/O contiguity. - n_chunks = len(spike_ids) // self.chunk_size + 1 - self._pr.value_max = n_chunks + 1 - - for i in range(n_chunks): - a, b = i * self.chunk_size, (i + 1) * self.chunk_size - spk = spike_ids[a:b] - - # Load a chunk of waveforms. - chunk_waveforms = self.model.waveforms[spk] - assert isinstance(chunk_waveforms, np.ndarray) - if chunk_waveforms.shape[0] == 0: - break - - chunk_spike_clusters = self.model.spike_clusters[spk] - - # Split the spikes. - chunk_spc = _spikes_per_cluster(spk, chunk_spike_clusters) - - # Go through the clusters appearing in the chunk and that - # need to be re-generated. - clusters = (set(chunk_spc.keys()). - intersection(set(clusters_to_generate))) - for cluster in sorted(clusters): - i = _index_of(chunk_spc[cluster], spk) - w = chunk_waveforms[i].astype(np.float32) - self.disk_store.store(cluster, - waveforms=w, - append=True - ) - - self._pr.increment() - - # Store mean waveforms on disk. - self._pr.value = 0 - self._pr.value_max = len(clusters_to_generate) - for cluster in clusters_to_generate: - self._store_mean(cluster) - self._pr.value += 1 - - def is_consistent(self, cluster, spikes): - """Return whether the waveforms and spikes match.""" - path_w = self.disk_store._cluster_path(cluster, 'waveforms') - if not op.exists(path_w): - return False - file_size_w = os.stat(path_w).st_size - n_spikes_w = file_size_w // (self.n_channels * self.n_samples * 4) - if n_spikes_w != len(self.spikes_per_cluster[cluster]): - return False - return True - - def load(self, cluster, name='waveforms'): - """Load waveforms for a cluster. - - This uses the cluster store if possible, otherwise it falls back - to the model (much slower). - - """ - assert name == 'waveforms' - dtype = np.float32 - if self.disk_store: - data = self.disk_store.load(cluster, - name, - dtype, - self._shapes[name], - ) - if data is not None: - return data - # Fallback to load_spikes if the data could not be obtained from - # the store. - spikes = self._subset_spikes_cluster(cluster) - return self.load_spikes(spikes, name) - - def load_spikes(self, spikes, name): - """Load waveforms for an array of spikes.""" - assert name == 'waveforms' - data = getattr(self.model, name) - shape = self._shapes[name] - if data is not None and (isinstance(spikes, slice) or len(spikes)): - return data[spikes] - # Default waveforms. - return _default_array(shape, value=0., n_spikes=len(spikes)) - - -class ClusterStatistics(FixedSizeItem): - """Manage cluster statistics.""" - name = 'statistics' - - def __init__(self, *args, **kwargs): - super(ClusterStatistics, self).__init__(*args, **kwargs) - self._funcs = {} - self.n_channels = len(self.model.channel_order) - self.n_samples_waveforms = self.model.n_samples_waveforms - self.n_features = self.model.n_features_per_channel - self._shapes = { - 'mean_masks': (-1, self.n_channels), - 'mean_features': (-1, self.n_channels, self.n_features), - 'mean_waveforms': (-1, self.n_samples_waveforms, self.n_channels), - 'mean_probe_position': (-1, 2), - 'main_channels': (-1, self.n_channels), - 'n_unmasked_channels': (-1,), - 'n_spikes': (-1,), - } - self.fields = list(self._shapes.keys()) - - def add(self, name, func, shape): - """Add a new statistics.""" - self.fields.append(name) - self._funcs[name] = func - self._shapes[name] = shape - - def remove(self, name): - """Remove a statistics.""" - self.fields.remove(name) - del self.funcs[name] - - def _load_mean(self, cluster, name): - if self.disk_store: - # Load from the disk store if possible. - return self.disk_store.load(cluster, name, np.float32, - self._shapes[name][1:]) - else: - # Otherwise compute the mean directly from the model. - item_name = ('waveforms' - if name == 'mean_waveforms' - else 'features and masks') - item = self.cluster_store.items[item_name] - return getattr(item, name)(cluster) - - def mean_masks(self, cluster): - return self._load_mean(cluster, 'mean_masks') - - def mean_features(self, cluster): - return self._load_mean(cluster, 'mean_features') - - def mean_waveforms(self, cluster): - return self._load_mean(cluster, 'mean_waveforms') - - def unmasked_channels(self, cluster): - mean_masks = self.load(cluster, 'mean_masks') - return np.nonzero(mean_masks > .1)[0] - - def n_unmasked_channels(self, cluster): - unmasked_channels = self.load(cluster, 'unmasked_channels') - return len(unmasked_channels) - - def mean_probe_position(self, cluster): - mean_masks = self.load(cluster, 'mean_masks') - if mean_masks is not None and mean_masks.shape[0]: - mean_cluster_position = (np.sum(self.model.probe.positions * - mean_masks[:, np.newaxis], axis=0) / - max(1, np.sum(mean_masks))) - else: - mean_cluster_position = np.zeros((2,), dtype=np.float32) - - return mean_cluster_position - - def n_spikes(self, cluster): - return len(self._spikes_per_cluster[cluster]) - - def main_channels(self, cluster): - mean_masks = self.load(cluster, 'mean_masks') - unmasked_channels = self.load(cluster, 'unmasked_channels') - # Weighted mean of the channels, weighted by the mean masks. - main_channels = np.argsort(mean_masks)[::-1] - main_channels = np.array([c for c in main_channels - if c in unmasked_channels]) - return main_channels - - def store_default(self, cluster): - """Compute the built-in statistics for one cluster. - - The mean masks, features, and waveforms are loaded from disk. - - """ - self.memory_store.store(cluster, - mean_masks=self.mean_masks(cluster), - mean_features=self.mean_features(cluster), - mean_waveforms=self.mean_waveforms(cluster), - ) - - n_spikes = self.n_spikes(cluster) - - # Note: some of the default statistics below rely on other statistics - # computed previously and stored in the memory store. - unmasked_channels = self.unmasked_channels(cluster) - self.memory_store.store(cluster, - unmasked_channels=unmasked_channels) - - n_unmasked_channels = self.n_unmasked_channels(cluster) - self.memory_store.store(cluster, - n_unmasked_channels=n_unmasked_channels) - - mean_probe_position = self.mean_probe_position(cluster) - main_channels = self.main_channels(cluster) - n_unmasked_channels = self.n_unmasked_channels(cluster) - self.memory_store.store(cluster, - mean_probe_position=mean_probe_position, - main_channels=main_channels, - n_unmasked_channels=n_unmasked_channels, - n_spikes=n_spikes, - ) - - def store(self, cluster, name=None): - """Compute all statistics for one cluster.""" - if name is None: - self.store_default(cluster) - for func in self._funcs.values(): - func(cluster) - else: - assert name in self._funcs - self._funcs[name](cluster) - - def load(self, cluster, name): - """Return a cluster statistic.""" - if cluster in self.memory_store: - return self.memory_store.load(cluster, name) - else: - # If the item hadn't been stored, compute it here by calling - # the corresponding method. - if hasattr(self, name): - return getattr(self, name)(cluster) - # Custom statistic. - else: - return self._funcs[name](cluster) - - def is_consistent(self, cluster, spikes): - """Return whether a cluster is consistent.""" - return cluster in self.memory_store - - -#------------------------------------------------------------------------------ -# Store creation -#------------------------------------------------------------------------------ - -def create_store(model, - spikes_per_cluster=None, - path=None, - features_masks_chunk_size=100000, - waveforms_n_spikes_max=None, - waveforms_excerpt_size=None, - waveforms_chunk_size=1000, - ): - """Create a cluster store for a model.""" - assert spikes_per_cluster is not None - cluster_store = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=path, - ) - - # Create the FeatureMasks store item. - # chunk_size is the number of spikes to load at once from - # the features_masks array. - cluster_store.register_item(FeatureMasks, - chunk_size=features_masks_chunk_size, - ) - cluster_store.register_item(Waveforms, - n_spikes_max=waveforms_n_spikes_max, - excerpt_size=waveforms_excerpt_size, - chunk_size=waveforms_chunk_size, - ) - cluster_store.register_item(ClusterStatistics) - return cluster_store diff --git a/phy/io/kwik/tests/__init__.py b/phy/io/kwik/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/io/kwik/tests/test_creator.py b/phy/io/kwik/tests/test_creator.py deleted file mode 100644 index 4ca7b7683..000000000 --- a/phy/io/kwik/tests/test_creator.py +++ /dev/null @@ -1,217 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of Kwik file creator.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac -from pytest import raises - -from ...h5 import open_h5 -from ..creator import KwikCreator, _write_by_chunk, create_kwik -from ..mock import (artificial_spike_samples, - artificial_features, - artificial_masks, - ) - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_write_by_chunk(tempdir): - n = 5 - arrs = [np.random.rand(i + 1, 3).astype(np.float32) for i in range(n)] - n_tot = sum(_.shape[0] for _ in arrs) - - path = op.join(tempdir, 'test.h5') - with open_h5(path, 'w') as f: - ds = f.write('/test', shape=(n_tot, 3), dtype=np.float32) - _write_by_chunk(ds, arrs) - with open_h5(path, 'r') as f: - ds = f.read('/test')[...] - offset = 0 - for i, arr in enumerate(arrs): - size = arr.shape[0] - assert size == (i + 1) - ac(ds[offset:offset + size, ...], arr) - offset += size - - -def test_creator_simple(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - - # Test create empty files. - creator.create_empty() - assert op.exists(basename + '.kwik') - assert op.exists(basename + '.kwx') - - # Test metadata. - creator.set_metadata('/application_data/spikedetekt', - a=1, b=2., c=[0, 1]) - - with open_h5(creator.kwik_path, 'r') as f: - assert f.read_attr('/application_data/spikedetekt', 'a') == 1 - assert f.read_attr('/application_data/spikedetekt', 'b') == 2. - ae(f.read_attr('/application_data/spikedetekt', 'c'), [0, 1]) - - # Test add spikes in one block. - n_spikes = 100 - n_channels = 8 - n_features = 3 - - spike_samples = artificial_spike_samples(n_spikes) - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - - creator.add_spikes(group=0, - spike_samples=spike_samples, - features=features.astype(np.float32), - masks=masks.astype(np.float32), - n_channels=n_channels, - n_features=n_features, - ) - - # Test the spike samples. - with open_h5(creator.kwik_path, 'r') as f: - s = f.read('/channel_groups/0/spikes/time_samples')[...] - assert s.dtype == np.uint64 - ac(s, spike_samples) - - # Test the features and masks. - with open_h5(creator.kwx_path, 'r') as f: - fm = f.read('/channel_groups/0/features_masks')[...] - assert fm.dtype == np.float32 - ac(fm[:, :, 0], features.reshape((-1, n_channels * n_features))) - ac(fm[:, ::n_features, 1], masks) - - # Spikes can only been added once. - with raises(RuntimeError): - creator.add_spikes(group=0, - spike_samples=spike_samples, - n_channels=n_channels, - n_features=n_features) - - -def test_creator_chunks(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - creator.create_empty() - - # Test add spikes in one block. - n_spikes = 100 - n_channels = 8 - n_features = 3 - - spike_samples = artificial_spike_samples(n_spikes) - features = artificial_features(n_spikes, n_channels, - n_features).astype(np.float32) - masks = artificial_masks(n_spikes, n_channels).astype(np.float32) - - def _split(arr): - n = n_spikes // 10 - return [arr[k:k + n, ...] for k in range(0, n_spikes, n)] - - creator.add_spikes(group=0, - spike_samples=spike_samples, - features=_split(features), - masks=_split(masks), - n_channels=n_channels, - n_features=n_features, - ) - - # Test the spike samples. - with open_h5(creator.kwik_path, 'r') as f: - s = f.read('/channel_groups/0/spikes/time_samples')[...] - assert s.dtype == np.uint64 - ac(s, spike_samples) - - # Test the features and masks. - with open_h5(creator.kwx_path, 'r') as f: - fm = f.read('/channel_groups/0/features_masks')[...] - assert fm.dtype == np.float32 - ac(fm[:, :, 0], features.reshape((-1, n_channels * n_features))) - ac(fm[:, ::n_features, 1], masks) - - -def test_creator_add_no_spikes(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - creator.create_empty() - - creator.add_spikes(group=0, n_channels=4, n_features=2) - - -def test_creator_metadata(tempdir): - basename = op.join(tempdir, 'my_file') - - creator = KwikCreator(basename) - - # Add recording. - creator.add_recording(0, start_sample=20000, sample_rate=10000) - - with open_h5(creator.kwik_path, 'r') as f: - assert f.read_attr('/recordings/0', 'start_sample') == 20000 - assert f.read_attr('/recordings/0', 'start_time') == 2. - assert f.read_attr('/recordings/0', 'sample_rate') == 10000 - - # Add probe. - channels = [0, 3, 1] - graph = [[0, 3], [1, 0]] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - 'geometry': {0: (10, 10)}, - }}} - creator.set_probe(probe) - - with open_h5(creator.kwik_path, 'r') as f: - ae(f.read_attr('/channel_groups/0', 'channel_order'), channels) - ae(f.read_attr('/channel_groups/0', 'adjacency_graph'), graph) - for channel in channels: - path = '/channel_groups/0/channels/{:d}'.format(channel) - position = (10, 10) if channel == 0 else (0, channel) - ae(f.read_attr(path, 'position'), position) - - # Add clustering. - creator.add_clustering(group=0, - name='main', - spike_clusters=np.arange(10), - cluster_groups={3: 1}, - ) - - with open_h5(creator.kwik_path, 'r') as f: - sc = f.read('/channel_groups/0/spikes/clusters/main') - ae(sc, np.arange(10)) - - for cluster in range(10): - path = '/channel_groups/0/clusters/main/{:d}'.format(cluster) - cg = f.read_attr(path, 'cluster_group') - assert cg == 3 if cluster != 3 else 1 - - -def test_create_kwik(tempdir): - - channels = [0, 3, 1] - graph = [[0, 3], [1, 0]] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - 'geometry': {0: (10, 10)}, - }}} - - kwik_path = op.join(tempdir, 'test.kwik') - create_kwik(kwik_path=kwik_path, - probe=probe, - sample_rate=20000, - ) diff --git a/phy/io/kwik/tests/test_mock.py b/phy/io/kwik/tests/test_mock.py deleted file mode 100644 index 49563e559..000000000 --- a/phy/io/kwik/tests/test_mock.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of mock Kwik file creation.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ...h5 import open_h5 -from ..mock import create_mock_kwik - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_create_kwik(tempdir): - - n_clusters = 10 - n_spikes = 50 - n_channels = 28 - n_fets = 2 - n_samples_traces = 3000 - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=n_clusters, - n_spikes=n_spikes, - n_channels=n_channels, - n_features_per_channel=n_fets, - n_samples_traces=n_samples_traces) - - with open_h5(filename) as f: - assert f diff --git a/phy/io/kwik/tests/test_model.py b/phy/io/kwik/tests/test_model.py deleted file mode 100644 index b51ab3933..000000000 --- a/phy/io/kwik/tests/test_model.py +++ /dev/null @@ -1,384 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of Kwik file opening routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ....electrode.mea import MEA, staggered_positions -from ....utils.logging import StringLogger, register, unregister -from ..model import (KwikModel, - _list_channel_groups, - _list_channels, - _list_recordings, - _list_clusterings, - _concatenate_spikes, - ) -from ..mock import create_mock_kwik -from ..creator import create_kwik - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -_N_CLUSTERS = 10 -_N_SPIKES = 100 -_N_CHANNELS = 28 -_N_FETS = 2 -_N_SAMPLES_TRACES = 10000 - - -def test_kwik_utility(tempdir): - - channels = list(range(_N_CHANNELS)) - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - model = KwikModel(filename) - - model._kwik.open() - assert _list_channel_groups(model._kwik.h5py_file) == [1] - assert _list_recordings(model._kwik.h5py_file) == [0, 1] - assert _list_clusterings(model._kwik.h5py_file, 1) == ['main', - 'original', - ] - assert _list_channels(model._kwik.h5py_file, 1) == channels - - -def test_concatenate_spikes(): - spikes = [2, 3, 5, 0, 11, 1] - recs = [0, 0, 0, 1, 1, 2] - offsets = [0, 7, 100] - concat = _concatenate_spikes(spikes, recs, offsets) - ae(concat, [2, 3, 5, 7, 18, 101]) - - -def test_kwik_empty(tempdir): - - channels = [0, 3, 1] - graph = [[0, 3], [1, 0]] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - 'geometry': {0: (10, 10)}, - }}} - sample_rate = 20000 - - kwik_path = op.join(tempdir, 'test.kwik') - create_kwik(kwik_path=kwik_path, probe=probe, sample_rate=sample_rate) - - model = KwikModel(kwik_path) - ae(model.channels, sorted(channels)) - ae(model.channel_order, channels) - - assert model.sample_rate == sample_rate - assert model.n_channels == 3 - assert model.spike_samples is None - assert model.has_kwx() - assert model.n_spikes == 0 - assert model.n_clusters == 0 - model.describe() - - -def test_kwik_open_full(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - with raises(ValueError): - KwikModel() - - # NOTE: n_channels - 2 because we use a special channel order. - nc = _N_CHANNELS - 2 - - # Test implicit open() method. - kwik = KwikModel(filename) - kwik.describe() - - kwik.metadata - ae(kwik.channels, np.arange(_N_CHANNELS)) - assert kwik.n_channels == _N_CHANNELS - assert kwik.n_spikes == _N_SPIKES - ae(kwik.channel_order, np.arange(1, _N_CHANNELS - 1)[::-1]) - - assert kwik.spike_samples.shape == (_N_SPIKES,) - assert kwik.spike_samples.dtype == np.uint64 - - # Make sure the spike samples are increasing, even with multiple - # recordings. - # WARNING: need to cast to int64, otherwise negative values will - # overflow and be positive, making the test pass while the - # spike samples are *not* increasing! - assert np.all(np.diff(kwik.spike_samples.astype(np.int64)) >= 0) - - assert kwik.spike_times.shape == (_N_SPIKES,) - assert kwik.spike_times.dtype == np.float64 - - assert kwik.spike_recordings.shape == (_N_SPIKES,) - assert kwik.spike_recordings.dtype == np.uint16 - - assert kwik.spike_clusters.shape == (_N_SPIKES,) - assert kwik.spike_clusters.min() in (0, 1, 2) - assert kwik.spike_clusters.max() in(_N_CLUSTERS - 2, _N_CLUSTERS - 1) - - assert kwik.features.shape == (_N_SPIKES, nc * _N_FETS) - kwik.features[0, ...] - - assert kwik.masks.shape == (_N_SPIKES, nc) - - assert kwik.traces.shape == (_N_SAMPLES_TRACES, _N_CHANNELS) - - assert kwik.waveforms[0].shape == (1, 40, nc) - assert kwik.waveforms[-1].shape == (1, 40, nc) - assert kwik.waveforms[-10].shape == (1, 40, nc) - assert kwik.waveforms[10].shape == (1, 40, nc) - assert kwik.waveforms[[10, 20]].shape == (2, 40, nc) - with raises(IndexError): - kwik.waveforms[_N_SPIKES + 10] - - with raises(ValueError): - kwik.clustering = 'foo' - with raises(ValueError): - kwik.channel_group = 42 - assert kwik.n_recordings == 2 - - # Test cluster groups. - for cluster in range(_N_CLUSTERS): - print(cluster) - assert kwik.cluster_metadata.group(cluster) == min(cluster, 3) - for cluster, group in kwik.cluster_groups.items(): - assert group == min(cluster, 3) - - # Test probe. - assert isinstance(kwik.probe, MEA) - assert kwik.probe.positions.shape == (nc, 2) - ae(kwik.probe.positions, staggered_positions(_N_CHANNELS)[1:-1][::-1]) - - kwik.close() - - -def test_kwik_open_no_kwx(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES, - with_kwx=False) - - # Test implicit open() method. - kwik = KwikModel(filename) - kwik.close() - - -def test_kwik_open_no_kwd(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES, - with_kwd=False) - - # Test implicit open() method. - kwik = KwikModel(filename) - l = StringLogger(level='debug') - register(l) - kwik.waveforms[:] - # Enusure that there is no error message. - assert not str(l).strip() - kwik.close() - unregister(l) - - -def test_kwik_save(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - - cluster_groups = {cluster: kwik.cluster_metadata.group(cluster) - for cluster in range(_N_CLUSTERS)} - sc_0 = kwik.spike_clusters.copy() - sc_1 = sc_0.copy() - new_cluster = _N_CLUSTERS + 10 - sc_1[_N_SPIKES // 2:] = new_cluster - cluster_groups[new_cluster] = 7 - ae(kwik.spike_clusters, sc_0) - - assert kwik.cluster_metadata.group(new_cluster) == 3 - kwik.save(sc_1, cluster_groups, {'test': (1, 2.)}) - ae(kwik.spike_clusters, sc_1) - assert kwik.cluster_metadata.group(new_cluster) == 7 - - kwik.close() - - kwik = KwikModel(filename) - ae(kwik.spike_clusters, sc_1) - assert kwik.cluster_metadata.group(new_cluster) == 7 - ae(kwik.clustering_metadata['test'], [1, 2]) - - -def test_kwik_clusterings(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - assert kwik.clusterings == ['main', 'original'] - - # The default clustering is 'main'. - assert kwik.n_spikes == _N_SPIKES - assert kwik.n_clusters == _N_CLUSTERS - assert kwik.cluster_groups[_N_CLUSTERS - 1] == 3 - ae(kwik.cluster_ids, np.arange(_N_CLUSTERS)) - - # Change clustering. - kwik.clustering = 'original' - n_clu = kwik.n_clusters - assert kwik.n_spikes == _N_SPIKES - # Some clusters may be empty with a small number of spikes like here - assert _N_CLUSTERS * 2 - 4 <= n_clu <= _N_CLUSTERS * 2 - assert kwik.cluster_groups[n_clu - 1] == 3 - assert len(kwik.cluster_ids) == n_clu - - -def test_kwik_manage_clusterings(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - spike_clusters = kwik.spike_clusters - assert kwik.clusterings == ['main', 'original'] - - # Test renaming. - kwik.clustering = 'original' - with raises(ValueError): - kwik.rename_clustering('a', 'b') - with raises(ValueError): - kwik.rename_clustering('original', 'b') - with raises(ValueError): - kwik.rename_clustering('main', 'original') - - kwik.clustering = 'main' - kwik.rename_clustering('original', 'original_2') - assert kwik.clusterings == ['main', 'original_2'] - with raises(ValueError): - kwik.clustering = 'original' - kwik.clustering = 'original_2' - n_clu = kwik.n_clusters - if (n_clu - 1) in kwik.cluster_groups: - assert kwik.cluster_groups[n_clu - 1] == 3 - assert len(kwik.cluster_ids) == n_clu - - # Test copy. - with raises(ValueError): - kwik.copy_clustering('a', 'b') - with raises(ValueError): - kwik.copy_clustering('original', 'b') - with raises(ValueError): - kwik.copy_clustering('main', 'original_2') - - # You cannot move the current clustering, but you can copy it. - with raises(ValueError): - kwik.rename_clustering('original_2', 'original_2_copy') - kwik.copy_clustering('original_2', 'original_2_copy') - kwik.delete_clustering('original_2_copy') - - kwik.clustering = 'main' - kwik.copy_clustering('original_2', 'original') - assert kwik.clusterings == ['main', 'original', 'original_2'] - - kwik.clustering = 'original' - cg = kwik.cluster_groups - ci = kwik.cluster_ids - - kwik.clustering = 'original_2' - assert kwik.cluster_groups == cg - ae(kwik.cluster_ids, ci) - - # Test delete. - with raises(ValueError): - kwik.delete_clustering('a') - kwik.delete_clustering('original') - kwik.clustering = 'main' - kwik.delete_clustering('original_2') - assert kwik.clusterings == ['main', 'original'] - - # Test add. - sc = np.ones(_N_SPIKES, dtype=np.int32) - sc[1] = sc[-2] = 3 - kwik.add_clustering('new', sc) - ae(kwik.spike_clusters, spike_clusters) - kwik.clustering = 'new' - ae(kwik.spike_clusters, sc) - assert kwik.n_clusters == 2 - ae(kwik.cluster_ids, [1, 3]) - assert kwik.cluster_groups == {1: 3, - 3: 3} - - -def test_kwik_manage_cluster_groups(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - kwik = KwikModel(filename) - - with raises(ValueError): - kwik.delete_cluster_group(2) - with raises(ValueError): - kwik.add_cluster_group(1, 'new') - with raises(ValueError): - kwik.rename_cluster_group(1, 'renamed') - - kwik.add_cluster_group(4, 'new') - kwik.rename_cluster_group(4, 'renamed') - - kwik.delete_cluster_group(4) - with raises(ValueError): - kwik.delete_cluster_group(4) diff --git a/phy/io/kwik/tests/test_sparse_kk2.py b/phy/io/kwik/tests/test_sparse_kk2.py deleted file mode 100644 index 6766719a1..000000000 --- a/phy/io/kwik/tests/test_sparse_kk2.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of sparse KK2 routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_array_almost_equal as aee - -from ..sparse_kk2 import sparsify_features_masks - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_sparsify_features_masks(): - # Data as arrays - fet = np.array([[1, 3, 5, 7, 11], - [6, 7, 8, 9, 10], - [11, 12, 13, 14, 15], - [16, 17, 18, 19, 20]], dtype=float) - - fmask = np.array([[1, 0.5, 0, 0, 0], - [0, 1, 1, 0, 0], - [0, 1, 1, 0, 0], - [0, 0, 0, 1, 1]], dtype=float) - - # Normalisation to [0, 1] - fet_original = fet.copy() - fet = (fet - np.amin(fet, axis=0)) / \ - (np.amax(fet, axis=0) - np.amin(fet, axis=0)) - - nanmasked_fet = fet.copy() - nanmasked_fet[fmask > 0] = np.nan - - # Correct computation of the corrected data and correction term - x = fet - w = fmask - nu = np.nanmean(nanmasked_fet, axis=0)[np.newaxis, :] - sigma2 = np.nanvar(nanmasked_fet, axis=0)[np.newaxis, :] - y = w * x + (1 - w) * nu - z = w * x * x + (1 - w) * (nu * nu + sigma2) - correction_terms = z - y * y - features = y - - data = sparsify_features_masks(fet_original, fmask) - - aee(data.noise_mean, np.nanmean(nanmasked_fet, axis=0)) - aee(data.noise_variance, np.nanvar(nanmasked_fet, axis=0)) - assert np.amin(data.features) == 0 - assert np.amax(data.features) == 1 - assert len(data.offsets) == 5 - - for i in range(4): - data_u = data.unmasked[data.offsets[i]:data.offsets[i + 1]] - true_u, = fmask[i, :].nonzero() - ae(data_u, true_u) - data_f = data.features[data.offsets[i]:data.offsets[i + 1]] - true_f = fet[i, data_u] - ae(data_f, true_f) - data_m = data.masks[data.offsets[i]:data.offsets[i + 1]] - true_m = fmask[i, data_u] - ae(data_m, true_m) - - # PART 2: Check that converting to SparseData is correct - # compute unique masks and apply correction terms to data - data = data.to_sparse_data() - - assert data.num_spikes == 4 - assert data.num_features == 5 - assert data.num_masks == 3 - - for i in range(4): - data_u = data.unmasked[data.unmasked_start[i]:data.unmasked_end[i]] - true_u, = fmask[i, :].nonzero() - ae(data_u, true_u) - data_f = data.features[data.values_start[i]:data.values_end[i]] - true_f = features[i, data_u] - aee(data_f, true_f) - data_c = data.correction_terms[data.values_start[i]:data.values_end[i]] - true_c = correction_terms[i, data_u] - aee(data_c, true_c) - data_m = data.masks[data.values_start[i]:data.values_end[i]] - true_m = fmask[i, data_u] - aee(data_m, true_m) diff --git a/phy/io/kwik/tests/test_store_items.py b/phy/io/kwik/tests/test_store_items.py deleted file mode 100644 index 026afa293..000000000 --- a/phy/io/kwik/tests/test_store_items.py +++ /dev/null @@ -1,167 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of Kwik store items.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae - -from ....utils.array import _spikes_per_cluster, _spikes_in_clusters -from ..model import (KwikModel, - ) -from ..mock import create_mock_kwik -from ..store_items import create_store - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -_N_CLUSTERS = 5 -_N_SPIKES = 100 -_N_CHANNELS = 28 -_N_FETS = 2 -_N_SAMPLES_TRACES = 10000 - - -def test_kwik_store(tempdir): - - # Create the test HDF5 file in the temporary directory. - filename = create_mock_kwik(tempdir, - n_clusters=_N_CLUSTERS, - n_spikes=_N_SPIKES, - n_channels=_N_CHANNELS, - n_features_per_channel=_N_FETS, - n_samples_traces=_N_SAMPLES_TRACES) - - nc = _N_CHANNELS - 2 - nf = _N_FETS - - model = KwikModel(filename) - ns = model.n_samples_waveforms - spc = _spikes_per_cluster(np.arange(_N_SPIKES), model.spike_clusters) - clusters = sorted(spc.keys()) - - # We initialize the ClusterStore. - cs = create_store(model, - path=tempdir, - spikes_per_cluster=spc, - features_masks_chunk_size=15, - waveforms_chunk_size=15, - waveforms_n_spikes_max=5, - waveforms_excerpt_size=2, - ) - - # We add a custom statistic function. - def mean_features_bis(cluster): - fet = cs.features(cluster) - cs.memory_store.store(cluster, mean_features_bis=fet.mean(axis=0)) - - cs.items['statistics'].add('mean_features_bis', - mean_features_bis, - (-1, nc), - ) - cs.register_field('mean_features_bis', 'statistics') - - waveforms_item = cs.items['waveforms'] - - # Now we generate the store. - cs.generate() - - # One cluster at a time. - for cluster in clusters: - # Check features. - fet_store = cs.features(cluster) - fet_expected = model.features[spc[cluster]].reshape((-1, nc, nf)) - ae(fet_store, fet_expected) - - # Check masks. - masks_store = cs.masks(cluster) - masks_expected = model.masks[spc[cluster]] - ae(masks_store, masks_expected) - - # Check waveforms. - waveforms_store = cs.waveforms(cluster) - # Find the spikes. - spikes = waveforms_item.spikes_per_cluster[cluster] - ae(waveforms_item.spikes_per_cluster[cluster], spikes) - waveforms_expected = model.waveforms[spikes] - ae(waveforms_store, waveforms_expected) - - # Check some statistics. - ae(cs.mean_features(cluster), fet_expected.mean(axis=0)) - ae(cs.mean_features_bis(cluster), fet_expected.mean(axis=0)) - ae(cs.mean_masks(cluster), masks_expected.mean(axis=0)) - ae(cs.mean_waveforms(cluster), waveforms_expected.mean(axis=0)) - - assert cs.n_unmasked_channels(cluster) >= 0 - assert cs.main_channels(cluster).shape == (nc,) - assert cs.mean_probe_position(cluster).shape == (2,) - - # Multiple clusters. - for clusters in (clusters[::2], [clusters[0]], []): - n_clusters = len(clusters) - spikes = _spikes_in_clusters(model.spike_clusters, clusters) - n_spikes = len(spikes) - - # Features. - if n_clusters: - fet_expected = model.features[spikes].reshape((n_spikes, - nc, nf)) - else: - fet_expected = np.zeros((0, nc, nf), dtype=np.float32) - ae(cs.load('features', clusters=clusters), fet_expected) - ae(cs.load('features', spikes=spikes), fet_expected) - - # Masks. - if n_clusters: - masks_expected = model.masks[spikes] - else: - masks_expected = np.ones((0, nc), dtype=np.float32) - ae(cs.load('masks', clusters=clusters), masks_expected) - ae(cs.load('masks', spikes=spikes), masks_expected) - - # Waveforms. - spc = waveforms_item.spikes_per_cluster - if n_clusters: - spikes = _spikes_in_clusters(spc, clusters) - waveforms_expected = model.waveforms[spikes] - else: - spikes = np.array([], dtype=np.int64) - waveforms_expected = np.zeros((0, ns, nc), dtype=np.float32) - ae(cs.load('waveforms', clusters=clusters), waveforms_expected) - ae(cs.load('waveforms', spikes=spikes), waveforms_expected) - - assert (cs.load('mean_features', clusters=clusters).shape == - (n_clusters, nc, nf)) - assert (cs.load('mean_features_bis', clusters=clusters).shape == - (n_clusters, nc, nf) if n_clusters else (0,)) - assert (cs.load('mean_masks', clusters=clusters).shape == - (n_clusters, nc)) - assert (cs.load('mean_waveforms', clusters=clusters).shape == - (n_clusters, ns, nc)) - - assert (cs.load('n_unmasked_channels', clusters=clusters).shape == - (n_clusters,)) - assert (cs.load('main_channels', clusters=clusters).shape == - (n_clusters, nc)) - assert (cs.load('mean_probe_position', clusters=clusters).shape == - (n_clusters, 2)) - - # Slice spikes. - spikes = slice(None, None, 3) - - # Features. - fet_expected = model.features[spikes].reshape((-1, nc, nf)) - ae(cs.load('features', spikes=spikes), fet_expected) - - # Masks. - masks_expected = model.masks[spikes] - ae(cs.load('masks', spikes=spikes), masks_expected) - - # Waveforms. - waveforms_expected = model.waveforms[spikes] - ae(cs.load('waveforms', spikes=spikes), waveforms_expected) diff --git a/phy/io/mock.py b/phy/io/mock.py index c24260fbd..ecfcbea60 100644 --- a/phy/io/mock.py +++ b/phy/io/mock.py @@ -9,11 +9,6 @@ import numpy as np import numpy.random as nr -from ..utils._color import _random_color -from ..utils.array import _unique, _spikes_per_cluster -from .base import BaseModel, ClusterMetadata -from ..electrode.mea import MEA, staggered_positions - #------------------------------------------------------------------------------ # Artificial data @@ -49,126 +44,3 @@ def artificial_spike_samples(n_spikes, max_isi=50): def artificial_correlograms(n_clusters, n_samples): return nr.uniform(size=(n_clusters, n_clusters, n_samples)) - - -#------------------------------------------------------------------------------ -# Artificial Model -#------------------------------------------------------------------------------ - -class MockModel(BaseModel): - n_channels = 28 - n_features_per_channel = 2 - n_features = 28 * n_features_per_channel - n_spikes = 1000 - n_samples_traces = 20000 - n_samples_waveforms = 40 - n_clusters = 10 - sample_rate = 20000. - - def __init__(self, n_spikes=None, n_clusters=None): - super(MockModel, self).__init__() - if n_spikes is not None: - self.n_spikes = n_spikes - if n_clusters is not None: - self.n_clusters = n_clusters - self.name = 'mock' - self._clustering = 'main' - nfpc = self.n_features_per_channel - self._metadata = {'description': 'A mock model.', - 'n_features_per_channel': nfpc} - self._cluster_metadata = ClusterMetadata() - - @self._cluster_metadata.default - def group(cluster): - if cluster <= 2: - return cluster - # Default group is unsorted. - return 3 - - @self._cluster_metadata.default - def color(cluster): - return _random_color() - - positions = staggered_positions(self.n_channels) - self._probe = MEA(channels=self.channels, positions=positions) - self._traces = artificial_traces(self.n_samples_traces, - self.n_channels) - self._spike_clusters = artificial_spike_clusters(self.n_spikes, - self.n_clusters) - self._spike_ids = np.arange(self.n_spikes).astype(np.int64) - self._spikes_per_cluster = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) - self._spike_samples = artificial_spike_samples(self.n_spikes, 30) - assert self._spike_samples[-1] < self.n_samples_traces - self._features = artificial_features(self.n_spikes, self.n_features) - self._masks = artificial_masks(self.n_spikes, self.n_channels) - self._features_masks = np.dstack((self._features, - np.repeat(self._masks, - nfpc, - axis=1))) - self._waveforms = artificial_waveforms(self.n_spikes, - self.n_samples_waveforms, - self.n_channels) - - @property - def channels(self): - return np.arange(self.n_channels) - - @property - def channel_order(self): - return self.channels - - @property - def metadata(self): - return self._metadata - - @property - def traces(self): - return self._traces - - @property - def spike_samples(self): - return self._spike_samples - - @property - def spike_clusters(self): - return self._spike_clusters - - @property - def spikes_per_cluster(self): - return self._spikes_per_cluster - - def update_spikes_per_cluster(self, spc): - self._spikes_per_cluster = spc - - @property - def spike_ids(self): - return self._spike_ids - - @property - def cluster_ids(self): - return _unique(self._spike_clusters) - - @property - def cluster_metadata(self): - return self._cluster_metadata - - @property - def features(self): - return self._features - - @property - def masks(self): - return self._masks - - @property - def features_masks(self): - return self._features_masks - - @property - def waveforms(self): - return self._waveforms - - @property - def probe(self): - return self._probe diff --git a/phy/io/sparse.py b/phy/io/sparse.py deleted file mode 100644 index 01fe470f7..000000000 --- a/phy/io/sparse.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Sparse matrix structures.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..utils._types import _as_array - - -#------------------------------------------------------------------------------ -# Sparse CSR -#------------------------------------------------------------------------------ - -def _csr_from_dense(dense): - """Create a CSR structure from a dense NumPy array.""" - raise NotImplementedError(("Creating CSR from dense matrix is not " - "implemented yet.")) - - -def _check_sparse_components(shape=None, data=None, - channels=None, spikes_ptr=None): - """Ensure the components of a sparse matrix are consistent.""" - if not isinstance(shape, (list, tuple, np.ndarray)): - raise ValueError("The shape is required and should be an " - "array-like list ({0} was given).".format(shape)) - if len(shape) != data.ndim + 1: - raise ValueError("'shape' {shape} and {ndim}D-array 'data' are " - "not consistent.".format(shape=shape, - ndim=data.ndim)) - if channels.ndim != 1: - raise ValueError("'channels' should be a 1D array.") - if spikes_ptr.ndim != 1: - raise ValueError("'spikes_ptr' should be a 1D array.") - nitems = data.shape[-1] - if nitems > np.prod(shape): - raise ValueError("'data' is too large (n={0:d}) ".format(nitems) + - " for the specified shape " - "{shape}.".format(shape=shape)) - if len(channels) != (shape[1]): - raise ValueError(("'channels' should have " - "{nexp} elements, " - "not {nact}.").format(nexp=(shape[1]), - nact=len(channels))) - if len(spikes_ptr) != (shape[0] + 1): - raise ValueError(("'spikes_ptr' should have " - "{nexp} elements, " - "not {nact}.").format(nexp=(shape[0] + 1), - nact=len(spikes_ptr))) - if len(data) != len(channels): - raise ValueError("'data' (n={0:d}) and ".format(len(data)) + - "'channels' (n={0:d}) ".format(len(channels)) + - "should have the same length") - return True - - -class SparseCSR(object): - """Sparse CSR matrix data structure.""" - def __init__(self, shape=None, data=None, channels=None, spikes_ptr=None): - # Ensure the arguments are all arrays. - data = _as_array(data) - channels = _as_array(channels) - spikes_ptr = _as_array(spikes_ptr) - # Ensure the arguments are consistent. - assert _check_sparse_components(shape=shape, - data=data, - channels=channels, - spikes_ptr=spikes_ptr) - nitems = data.shape[-1] - # Structure info. - self._nitems = nitems - # Create the structure. - self._shape = shape - self._data = data - self._channels = channels - self._spikes_ptr = spikes_ptr - - @property - def shape(self): - """Shape of the array.""" - return self._shape - - def __eq__(self, other): - return np.all((self._data == other._data) & - (self._channels == other._channels) & - (self._spikes_ptr == other._spikes_ptr)) - - # I/O methods - # ------------------------------------------------------------------------- - - def save_h5(self, f, path): - """Save the array in an HDF5 file.""" - f.write_attr(path, 'sparse_type', 'csr') - f.write_attr(path, 'shape', self._shape) - f.write(path + '/data', self._data) - f.write(path + '/channels', self._channels) - f.write(path + '/spikes_ptr', self._spikes_ptr) - - @staticmethod - def load_h5(f, path): - """Load a SparseCSR array from an HDF5 file.""" - f.read_attr(path, 'sparse_type') == 'csr' - shape = f.read_attr(path, 'shape') - data = f.read(path + '/data')[...] - channels = f.read(path + '/channels')[...] - spikes_ptr = f.read(path + '/spikes_ptr')[...] - return SparseCSR(shape=shape, - data=data, - channels=channels, - spikes_ptr=spikes_ptr) - - -def csr_matrix(dense=None, shape=None, - data=None, channels=None, spikes_ptr=None): - """Create a CSR matrix from a dense matrix, or from sparse data.""" - if dense is not None: - # Ensure 'dense' is a ndarray. - dense = _as_array(dense) - return _csr_from_dense(dense) - if data is None or channels is None or spikes_ptr is None: - raise ValueError("data, channels, and spikes_ptr must be specified.") - return SparseCSR(shape=shape, - data=data, channels=channels, spikes_ptr=spikes_ptr) - - -#------------------------------------------------------------------------------ -# HDF5 functions -#------------------------------------------------------------------------------ - -def save_h5(f, path, arr, overwrite=False): - """Save a sparse array into an HDF5 file.""" - if isinstance(arr, SparseCSR): - arr.save_h5(f, path) - elif isinstance(arr, np.ndarray): - f.write(path, arr, overwrite=overwrite) - else: - raise ValueError("The array should be a SparseCSR or " - "dense NumPy array.") - - -def load_h5(f, path): - """Load a sparse array from an HDF5 file.""" - # Sparse array. - if f.has_attr(path, 'sparse_type'): - if f.read_attr(path, 'sparse_type') == 'csr': - return SparseCSR.load_h5(f, path) - else: - raise NotImplementedError("Only SparseCSR arrays are implemented " - "currently.") - # Regular dense dataset. - else: - return f.read(path)[...] diff --git a/phy/io/store.py b/phy/io/store.py deleted file mode 100644 index 1dd8ab799..000000000 --- a/phy/io/store.py +++ /dev/null @@ -1,752 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Cluster store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from collections import OrderedDict -import os -import os.path as op -import re - -import numpy as np -from six import string_types - -from ..utils._types import _as_int, _is_integer, _is_array_like -from ..utils._misc import _load_json, _save_json -from ..utils.array import (PerClusterData, _spikes_in_clusters, - _subset_spc, _load_ndarray, - _save_arrays, _load_arrays, - ) -from ..utils.event import ProgressReporter -from ..utils.logging import debug, info, warn -from ..utils.settings import _ensure_dir_exists - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _directory_size(path): - """Return the total size in bytes of a directory.""" - total_size = 0 - for dirpath, dirnames, filenames in os.walk(path): - for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) - return total_size - - -def _file_cluster_id(path): - return int(op.splitext(op.basename(path))[0]) - - -def _default_array(shape, value=0, n_spikes=0, dtype=np.float32): - shape = (n_spikes,) + shape[1:] - out = np.empty(shape, dtype=dtype) - out.fill(value) - return out - - -def _assert_per_cluster_data_compatible(d_0, d_1): - n_0 = {k: len(v) for (k, v) in d_0.items()} - n_1 = {k: len(v) for (k, v) in d_1.items()} - if n_0 != n_1: - raise IOError("Inconsistency in the cluster store: please remove " - "`./.phy/cluster_store/`.") - - -#------------------------------------------------------------------------------ -# Base store -#------------------------------------------------------------------------------ - -class BaseStore(object): - def __init__(self, root_dir): - self._root_dir = op.realpath(root_dir) - _ensure_dir_exists(self._root_dir) - - def _location(self, filename): - return op.join(self._root_dir, filename) - - def _offsets_path(self, path): - assert path.endswith('.npy') - return op.splitext(path)[0] + '.offsets.npy' - - def _contains_multiple_arrays(self, path): - return op.exists(path) and op.exists(self._offsets_path(path)) - - def _save(self, filename, data): - """Save an array or list of arrays.""" - path = self._location(filename) - dir_path = op.dirname(path) - if not op.exists(dir_path): - os.makedirs(dir_path) - if isinstance(data, list): - if not data: - return - _save_arrays(path, data) - elif isinstance(data, np.ndarray): - dtype = data.dtype - if not data.size: - return - assert dtype != np.object - np.save(path, data) - - def _open(self, filename): - path = self._location(filename) - if not op.exists(path): - debug("File `{}` doesn't exist.".format(path)) - return - # Multiple arrays: - if self._contains_multiple_arrays(path): - return _load_arrays(path) - else: - return np.load(path) - - def _delete_file(self, filename): - path = self._location(filename) - if op.exists(path): - os.remove(path) - offsets_path = self._offsets_path(path) - if op.exists(offsets_path): - os.remove(offsets_path) - - def describe(self): - raise NotImplementedError() - - -#------------------------------------------------------------------------------ -# Memory store -#------------------------------------------------------------------------------ - -class MemoryStore(object): - """Store cluster-related data in memory.""" - def __init__(self): - self._ds = {} - - def store(self, cluster, **data): - """Store cluster-related data.""" - if cluster not in self._ds: - self._ds[cluster] = {} - self._ds[cluster].update(data) - - def load(self, cluster, keys=None): - """Load cluster-related data.""" - if keys is None: - return self._ds.get(cluster, {}) - else: - if isinstance(keys, string_types): - return self._ds.get(cluster, {}).get(keys, None) - assert isinstance(keys, (list, tuple)) - return {key: self._ds.get(cluster, {}).get(key, None) - for key in keys} - - @property - def cluster_ids(self): - """List of cluster ids in the store.""" - return sorted(self._ds.keys()) - - def erase(self, clusters): - """Delete some clusters from the store.""" - assert isinstance(clusters, list) - for cluster in clusters: - if cluster in self._ds: - del self._ds[cluster] - - def clear(self): - """Clear the store completely by deleting all clusters.""" - self.erase(self.cluster_ids) - - def __contains__(self, item): - return item in self._ds - - -#------------------------------------------------------------------------------ -# Disk store -#------------------------------------------------------------------------------ - -class DiskStore(object): - """Store cluster-related data in HDF5 files.""" - def __init__(self, directory): - assert directory is not None - # White list of extensions, to be sure we don't erase - # the wrong files. - self._allowed_extensions = set() - self._directory = op.realpath(op.expanduser(directory)) - - @property - def path(self): - return self._directory - - # Internal methods - # ------------------------------------------------------------------------- - - def _check_extension(self, file): - """Check that a file extension belongs to the white list of - allowed extensions. This is for safety.""" - _, extension = op.splitext(file) - extension = extension[1:] - if extension not in self._allowed_extensions: - raise RuntimeError("The extension '{0}' ".format(extension) + - "hasn't been registered.") - - def _cluster_path(self, cluster, key): - """Return the absolute path of a cluster in the disk store.""" - # TODO: subfolders - # Example of filename: `123.mykey`. - cluster = _as_int(cluster) - filename = '{0:d}.{1:s}'.format(cluster, key) - return op.realpath(op.join(self._directory, filename)) - - def _cluster_file_exists(self, cluster, key): - """Return whether a cluster file exists.""" - cluster = _as_int(cluster) - return op.exists(self._cluster_path(cluster, key)) - - def _is_cluster_file(self, path): - """Return whether a filename is of the form `xxx.yyy` where xxx is a - numbe and yyy belongs to the set of allowed extensions.""" - filename = op.basename(path) - extensions = '({0})'.format('|'.join(sorted(self._allowed_extensions))) - regex = r'^[0-9]+\.' + extensions + '$' - return re.match(regex, filename) is not None - - # Public methods - # ------------------------------------------------------------------------- - - def register_file_extensions(self, extensions): - """Register file extensions explicitely. This is a security - to make sure that we don't accidentally delete the wrong files.""" - if isinstance(extensions, string_types): - extensions = [extensions] - assert isinstance(extensions, list) - for extension in extensions: - self._allowed_extensions.add(extension) - - def store(self, cluster, append=False, **data): - """Store a NumPy array to disk.""" - # Do not create the file if there's nothing to write. - if not data: - return - mode = 'wb' if not append else 'ab' - for key, value in data.items(): - assert isinstance(value, np.ndarray) - path = self._cluster_path(cluster, key) - self._check_extension(path) - assert self._is_cluster_file(path) - with open(path, mode) as f: - value.tofile(f) - - def _get(self, cluster, key, dtype=None, shape=None): - # The cluster doesn't exist: return None for all keys. - if not self._cluster_file_exists(cluster, key): - return None - else: - return _load_ndarray(self._cluster_path(cluster, key), - dtype=dtype, shape=shape, lazy=False) - - def load(self, cluster, keys, dtype=None, shape=None): - """Load cluster-related data. Return a file handle, to be used - with np.fromfile() once the dtype and shape are known.""" - assert keys is not None - if isinstance(keys, string_types): - return self._get(cluster, keys, dtype=dtype, shape=shape) - assert isinstance(keys, list) - out = {} - for key in keys: - out[key] = self._get(cluster, key, dtype=dtype, shape=shape) - return out - - def save_file(self, filename, data): - path = op.realpath(op.join(self._directory, filename)) - _save_json(path, data) - - def load_file(self, filename): - path = op.realpath(op.join(self._directory, filename)) - if not op.exists(path): - return None - try: - return _load_json(path) - except ValueError as e: - warn("Error when loading `{}`: {}.".format(path, e)) - return None - - @property - def files(self): - """List of files present in the directory.""" - if not op.exists(self._directory): - return [] - return sorted(filter(self._is_cluster_file, - os.listdir(self._directory))) - - @property - def cluster_ids(self): - """List of cluster ids in the store.""" - clusters = set([_file_cluster_id(file) for file in self.files]) - return sorted(clusters) - - def erase(self, clusters): - """Delete some clusters from the store.""" - for cluster in clusters: - for key in self._allowed_extensions: - path = self._cluster_path(cluster, key) - if not op.exists(path): - continue - # Safety first: http://bit.ly/1ITJyF6 - self._check_extension(path) - if self._is_cluster_file(path): - os.remove(path) - else: - raise RuntimeError("The file {0} was about ".format(path) + - "to be removed, but it doesn't appear " - "to be a valid cluster file.") - - def clear(self): - """Clear the store completely by deleting all clusters.""" - self.erase(self.cluster_ids) - - -#------------------------------------------------------------------------------ -# Store item -#------------------------------------------------------------------------------ - -class StoreItem(object): - """A class describing information stored in the cluster store. - - Parameters - ---------- - - name : str - Name of the item. - fields : list - A list of field names. - model : Model - A `Model` instance for the current dataset. - memory_store : MemoryStore - The `MemoryStore` instance for the current dataset. - disk_store : DiskStore - The DiskStore instance for the current dataset. - - """ - name = 'item' - fields = None # list of names - - def __init__(self, cluster_store=None): - self.cluster_store = cluster_store - self.model = cluster_store.model - self.memory_store = cluster_store.memory_store - self.disk_store = cluster_store.disk_store - self._spikes_per_cluster = cluster_store.spikes_per_cluster - self._pr = ProgressReporter() - self._pr.set_progress_message('Initializing ' + self.name + - ': {progress:.1f}%.') - self._pr.set_complete_message(self.name.capitalize() + ' initialized.') - self._shapes = {} - - def empty_values(self, name): - """Return a null array of the right shape for a given field.""" - return _default_array(self._shapes.get(name, (-1,)), value=0.) - - @property - def progress_reporter(self): - """Progress reporter instance.""" - return self._pr - - @property - def spikes_per_cluster(self): - """Spikes per cluster.""" - return self._spikes_per_cluster - - @spikes_per_cluster.setter - def spikes_per_cluster(self, value): - self._spikes_per_cluster = value - - def spikes_in_clusters(self, clusters): - """Return the spikes belonging to clusters.""" - return _spikes_in_clusters(self._spikes_per_cluster, clusters) - - @property - def cluster_ids(self): - """Array of cluster ids.""" - return sorted(self._spikes_per_cluster) - - def is_consistent(self, cluster, spikes): - """Return whether the stored item is consistent. - - To be overriden.""" - return False - - def to_generate(self, mode=None): - """Return the list of clusters that need to be regenerated.""" - if mode in (None, 'default'): - return [cluster for cluster in self.cluster_ids - if not self.is_consistent(cluster, - self.spikes_per_cluster[cluster], - )] - elif mode == 'force': - return self.cluster_ids - elif mode == 'read-only': - return [] - else: - raise ValueError("`mode` should be None, `default`, `force`, " - "or `read-only`.") - - def store(self, cluster): - """Store data for a cluster from the model to the store. - - May be overridden. - - """ - pass - - def store_all(self, mode=None, **kwargs): - """Copy all data for that item from the model to the cluster store.""" - clusters = self.to_generate(mode) - self._pr.value_max = len(clusters) - for cluster in clusters: - self.store(cluster, **kwargs) - self._pr.value += 1 - self._pr.set_complete() - - def load(self, cluster, name): - """Load data for one cluster.""" - raise NotImplementedError() - - def load_multi(self, clusters, name): - """Load data for several clusters.""" - raise NotImplementedError() - - def load_spikes(self, spikes, name): - """Load data from an array of spikes.""" - raise NotImplementedError() - - def on_merge(self, up): - """Called when a new merge occurs. - - May be overriden if there's an efficient way to update the data - after a merge. - - """ - self.on_assign(up) - - def on_assign(self, up): - """Called when a new split occurs. - - May be overriden. - - """ - for cluster in up.added: - self.store(cluster) - - def on_cluster(self, up=None): - """Called when the clusters change. - - Old data is kept on disk and in memory, which is useful for - undo and redo. The `cluster_store.clean()` method can be called to - delete the old files. - - Nothing happens during undo and redo (the data is already there). - - """ - # No need to change anything in the store if this is an undo or - # a redo. - if up is None or up.history is not None: - return - if up.description == 'merge': - self.on_merge(up) - elif up.description == 'assign': - self.on_assign(up) - - -class FixedSizeItem(StoreItem): - """Store data which size doesn't depend on the cluster size.""" - def load_multi(self, clusters, name): - """Load data for several clusters.""" - if not len(clusters): - return self.empty_values(name) - return np.array([self.load(cluster, name) - for cluster in clusters]) - - -class VariableSizeItem(StoreItem): - """Store data which size does depend on the cluster size.""" - def load_multi(self, clusters, name, spikes=None): - """Load data for several clusters. - - A subset of spikes can also be specified. - - """ - if not len(clusters) or (spikes is not None and not len(spikes)): - return self.empty_values(name) - arrays = {cluster: self.load(cluster, name) - for cluster in clusters} - spc = _subset_spc(self._spikes_per_cluster, clusters) - _assert_per_cluster_data_compatible(spc, arrays) - pcd = PerClusterData(spc=spc, - arrays=arrays, - ) - if spikes is not None: - pcd = pcd.subset(spike_ids=spikes) - assert pcd.array.shape[0] == len(spikes) - return pcd.array - - -#------------------------------------------------------------------------------ -# Cluster store -#------------------------------------------------------------------------------ - -class ClusterStore(object): - """Hold per-cluster information on disk and in memory. - - Note - ---- - - Currently, this is used to accelerate access to per-cluster data - and statistics. All data is dynamically updated when clustering - changes occur. - - """ - def __init__(self, - model=None, - spikes_per_cluster=None, - path=None, - ): - self._model = model - self._spikes_per_cluster = spikes_per_cluster - self._memory = MemoryStore() - self._disk = DiskStore(path) if path is not None else None - self._items = OrderedDict() - self._item_per_field = {} - - # Core methods - #-------------------------------------------------------------------------- - - @property - def model(self): - """Model.""" - return self._model - - @property - def memory_store(self): - """Hold some cluster statistics.""" - return self._memory - - @property - def disk_store(self): - """Manage the cache of per-cluster voluminous data.""" - return self._disk - - @property - def spikes_per_cluster(self): - """Dictionary `{cluster_id: spike_ids}`.""" - return self._spikes_per_cluster - - def update_spikes_per_cluster(self, spikes_per_cluster): - self._spikes_per_cluster = spikes_per_cluster - for item in self._items.values(): - try: - item.spikes_per_cluster = spikes_per_cluster - except AttributeError: - debug("Skipping set spikes_per_cluster on " - "store item {}.".format(item.name)) - - @property - def cluster_ids(self): - """All cluster ids appearing in the `spikes_per_cluster` dictionary.""" - return sorted(self._spikes_per_cluster) - - # Store items - #-------------------------------------------------------------------------- - - @property - def items(self): - """Dictionary of registered store items.""" - return self._items - - def register_field(self, name, item_name=None): - """Register a new piece of data to store on memory or on disk. - - Parameters - ---------- - - name : str - The name of the field. - item_name : str - The name of the item. - - - """ - self._item_per_field[name] = self._items[item_name] - if self._disk: - self._disk.register_file_extensions(name) - - # Create the load function. - def _make_func(name): - def load(*args, **kwargs): - return self.load(name, *args, **kwargs) - return load - - load = _make_func(name) - - # We create the `self.()` method for loading. - # We need to ensure that the method name isn't already attributed. - assert not hasattr(self, name) - setattr(self, name, load) - - def register_item(self, item_cls, **kwargs): - """Register a `StoreItem` class in the store. - - A `StoreItem` class is responsible for storing some data to disk - and memory. It must register one or several pieces of data. - - """ - # Instantiate the item. - item = item_cls(cluster_store=self, **kwargs) - assert item.fields is not None - - # Register the StoreItem instance. - self._items[item.name] = item - - # Register all fields declared by the store item. - for field in item.fields: - self.register_field(field, item_name=item.name) - - return item - - # Files - #-------------------------------------------------------------------------- - - @property - def path(self): - """Path to the disk store cache.""" - return self.disk_store.path - - @property - def old_clusters(self): - """Clusters in the disk store that are no longer in the clustering.""" - return sorted(set(self.disk_store.cluster_ids) - - set(self.cluster_ids)) - - @property - def files(self): - """List of files present in the disk store.""" - return self.disk_store.files - - # Status - #-------------------------------------------------------------------------- - - @property - def total_size(self): - """Total size of the disk store.""" - return _directory_size(self.path) - - def is_consistent(self): - """Return whether the cluster store is probably consistent. - - Return true if all cluster stores files exist and have the expected - file size. - - """ - valid = set(self.cluster_ids) - # All store items should be consistent on all valid clusters. - consistent = all(all(item.is_consistent(clu, - self.spikes_per_cluster.get(clu, [])) - for clu in valid) - for item in self._items.values()) - return consistent - - @property - def status(self): - """Return the current status of the cluster store.""" - in_store = set(self.disk_store.cluster_ids) - valid = set(self.cluster_ids) - invalid = in_store - valid - - n_store = len(in_store) - n_old = len(invalid) - size = self.total_size / (1024. ** 2) - consistent = str(self.is_consistent()).rjust(5) - - status = '' - header = "Cluster store status ({0})".format(self.path) - status += header + '\n' - status += '-' * len(header) + '\n' - status += "Number of clusters in the store {0: 4d}\n".format(n_store) - status += "Number of old clusters {0: 4d}\n".format(n_old) - status += "Total size (MB) {0: 7.0f}\n".format(size) - status += "Consistent {0}\n".format(consistent) - return status - - def display_status(self): - """Display the current status of the cluster store.""" - print(self.status) - - # Store management - #-------------------------------------------------------------------------- - - def clear(self): - """Erase all files in the store.""" - self.memory_store.clear() - self.disk_store.clear() - info("Cluster store cleared.") - - def clean(self): - """Erase all old files in the store.""" - to_delete = self.old_clusters - self.memory_store.erase(to_delete) - self.disk_store.erase(to_delete) - n = len(to_delete) - info("{0} clusters deleted from the cluster store.".format(n)) - - def generate(self, mode=None): - """Generate the cluster store. - - Parameters - ---------- - - mode : str (default is None) - How the cluster store should be generated. Options are: - - * None or `default`: only regenerate the missing or inconsistent - clusters - * `force`: fully regenerate the cluster - * `read-only`: just load the existing files, do not write anything - - """ - assert isinstance(self._spikes_per_cluster, dict) - if hasattr(self._model, 'name'): - name = self._model.name - else: - name = 'the current model' - debug("Initializing the cluster store for {0:s}.".format(name)) - for item in self._items.values(): - item.store_all(mode) - - # Load - #-------------------------------------------------------------------------- - - def load(self, name, clusters=None, spikes=None): - """Load some data for a number of clusters and spikes.""" - item = self._item_per_field[name] - # Clusters requested. - if clusters is not None: - if _is_integer(clusters): - # Single cluster case. - return item.load(clusters, name) - clusters = np.unique(clusters) - if spikes is None: - return item.load_multi(clusters, name) - else: - return item.load_multi(clusters, name, spikes=spikes) - # Spikes requested. - elif spikes is not None: - assert clusters is None - if _is_array_like(spikes): - spikes = np.unique(spikes) - out = item.load_spikes(spikes, name) - assert isinstance(out, np.ndarray) - if _is_array_like(spikes): - assert out.shape[0] == len(spikes) - return out diff --git a/phy/io/tests/test_array.py b/phy/io/tests/test_array.py new file mode 100644 index 000000000..66c767996 --- /dev/null +++ b/phy/io/tests/test_array.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- + +"""Tests of array utility functions.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +import numpy as np +from pytest import raises + +from ..array import (_unique, + _normalize, + _index_of, + _in_polygon, + _spikes_in_clusters, + _spikes_per_cluster, + _flatten_per_cluster, + _get_data_lim, + select_spikes, + Selector, + chunk_bounds, + regular_subset, + excerpts, + data_chunk, + grouped_mean, + get_excerpts, + _concatenate_virtual_arrays, + _range_from_slice, + _pad, + _get_padded, + read_array, + write_array, + ) +from phy.utils._types import _as_array +from phy.utils.testing import _assert_equal as ae +from ..mock import artificial_spike_clusters + + +#------------------------------------------------------------------------------ +# Test utility functions +#------------------------------------------------------------------------------ + +def test_range_from_slice(): + """Test '_range_from_slice'.""" + + class _SliceTest(object): + """Utility class to make it more convenient to test slice objects.""" + def __init__(self, **kwargs): + self._kwargs = kwargs + + def __getitem__(self, item): + if isinstance(item, slice): + return _range_from_slice(item, **self._kwargs) + + with raises(ValueError): + _SliceTest()[:] + with raises(ValueError): + _SliceTest()[1:] + ae(_SliceTest()[:5], [0, 1, 2, 3, 4]) + ae(_SliceTest()[1:5], [1, 2, 3, 4]) + + with raises(ValueError): + _SliceTest()[::2] + with raises(ValueError): + _SliceTest()[1::2] + ae(_SliceTest()[1:5:2], [1, 3]) + + with raises(ValueError): + _SliceTest(start=0)[:] + with raises(ValueError): + _SliceTest(start=1)[:] + with raises(ValueError): + _SliceTest(step=2)[:] + + ae(_SliceTest(stop=5)[:], [0, 1, 2, 3, 4]) + ae(_SliceTest(start=1, stop=5)[:], [1, 2, 3, 4]) + ae(_SliceTest(stop=5)[1:], [1, 2, 3, 4]) + ae(_SliceTest(start=1)[:5], [1, 2, 3, 4]) + ae(_SliceTest(start=1, step=2)[:5], [1, 3]) + ae(_SliceTest(start=1)[:5:2], [1, 3]) + + ae(_SliceTest(length=5)[:], [0, 1, 2, 3, 4]) + with raises(ValueError): + _SliceTest(length=5)[:3] + ae(_SliceTest(length=5)[:10], [0, 1, 2, 3, 4]) + ae(_SliceTest(length=5)[:5], [0, 1, 2, 3, 4]) + ae(_SliceTest(start=1, length=5)[:], [1, 2, 3, 4, 5]) + ae(_SliceTest(start=1, length=5)[:6], [1, 2, 3, 4, 5]) + with raises(ValueError): + _SliceTest(start=1, length=5)[:4] + ae(_SliceTest(start=1, step=2, stop=5)[:], [1, 3]) + ae(_SliceTest(start=1, stop=5)[::2], [1, 3]) + ae(_SliceTest(stop=5)[1::2], [1, 3]) + + +def test_pad(): + arr = np.random.rand(10, 3) + + ae(_pad(arr, 0, 'right'), arr[:0, :]) + ae(_pad(arr, 3, 'right'), arr[:3, :]) + ae(_pad(arr, 9), arr[:9, :]) + ae(_pad(arr, 10), arr) + + ae(_pad(arr, 12, 'right')[:10, :], arr) + ae(_pad(arr, 12)[10:, :], np.zeros((2, 3))) + + ae(_pad(arr, 0, 'left'), arr[:0, :]) + ae(_pad(arr, 3, 'left'), arr[7:, :]) + ae(_pad(arr, 9, 'left'), arr[1:, :]) + ae(_pad(arr, 10, 'left'), arr) + + ae(_pad(arr, 12, 'left')[2:, :], arr) + ae(_pad(arr, 12, 'left')[:2, :], np.zeros((2, 3))) + + with raises(ValueError): + _pad(arr, -1) + + +def test_get_padded(): + arr = np.array([1, 2, 3])[:, np.newaxis] + + with raises(RuntimeError): + ae(_get_padded(arr, -2, 5).ravel(), [1, 2, 3, 0, 0]) + ae(_get_padded(arr, 1, 2).ravel(), [2]) + ae(_get_padded(arr, 0, 5).ravel(), [1, 2, 3, 0, 0]) + ae(_get_padded(arr, -2, 3).ravel(), [0, 0, 1, 2, 3]) + + +def test_get_data_lim(): + arr = np.random.rand(10, 5) + assert 0 < _get_data_lim(arr) < 1 + assert 0 < _get_data_lim(arr, 2) < 1 + + +def test_unique(): + """Test _unique() function""" + _unique([]) + + n_spikes = 300 + n_clusters = 3 + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + ae(_unique(spike_clusters), np.arange(n_clusters)) + + +def test_normalize(): + """Test _normalize() function.""" + + n_channels = 10 + positions = 1 + 2 * np.random.randn(n_channels, 2) + + # Keep ration is False. + positions_n = _normalize(positions) + + x_min, y_min = positions_n.min(axis=0) + x_max, y_max = positions_n.max(axis=0) + + np.allclose(x_min, 0.) + np.allclose(x_max, 1.) + np.allclose(y_min, 0.) + np.allclose(y_max, 1.) + + # Keep ratio is True. + positions_n = _normalize(positions, keep_ratio=True) + + x_min, y_min = positions_n.min(axis=0) + x_max, y_max = positions_n.max(axis=0) + + np.allclose(min(x_min, y_min), 0.) + np.allclose(max(x_max, y_max), 1.) + np.allclose(x_min + x_max, 1) + np.allclose(y_min + y_max, 1) + + +def test_index_of(): + """Test _index_of.""" + arr = [36, 42, 42, 36, 36, 2, 42] + lookup = _unique(arr) + ae(_index_of(arr, lookup), [1, 2, 2, 1, 1, 0, 2]) + + +def test_as_array(): + ae(_as_array(3), [3]) + ae(_as_array([3]), [3]) + ae(_as_array(3.), [3.]) + ae(_as_array([3.]), [3.]) + + with raises(ValueError): + _as_array(map) + + +def test_in_polygon(): + polygon = [[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]] + points = np.random.uniform(size=(100, 2), low=-1, high=1) + idx_expected = np.nonzero((points[:, 0] > 0) & + (points[:, 1] > 0) & + (points[:, 0] < 1) & + (points[:, 1] < 1))[0] + idx = np.nonzero(_in_polygon(points, polygon))[0] + ae(idx, idx_expected) + + +#------------------------------------------------------------------------------ +# Test read/save +#------------------------------------------------------------------------------ + +def test_read_write(tempdir): + arr = np.arange(10).astype(np.float32) + + path = op.join(tempdir, 'test.npy') + + write_array(path, arr) + ae(read_array(path), arr) + ae(read_array(path, mmap_mode='r'), arr) + + +#------------------------------------------------------------------------------ +# Test virtual concatenation +#------------------------------------------------------------------------------ + +def test_concatenate_virtual_arrays_1(): + arrs = [np.arange(5), np.arange(10, 12), np.array([0])] + c = _concatenate_virtual_arrays(arrs) + assert c.shape == (8,) + assert c._get_recording(3) == 0 + assert c._get_recording(5) == 1 + + ae(c[:], [0, 1, 2, 3, 4, 10, 11, 0]) + ae(c[0], [0]) + ae(c[4], [4]) + ae(c[5], [10]) + ae(c[6], [11]) + + ae(c[4:6], [4, 10]) + + ae(c[:6], [0, 1, 2, 3, 4, 10]) + ae(c[4:], [4, 10, 11, 0]) + ae(c[4:-1], [4, 10, 11]) + + +def test_concatenate_virtual_arrays_2(): + arrs = [np.zeros((2, 2)), np.ones((3, 2))] + c = _concatenate_virtual_arrays(arrs) + assert c.shape == (5, 2) + ae(c[:, :], np.vstack((np.zeros((2, 2)), np.ones((3, 2))))) + ae(c[0:4, 0], [0, 0, 1, 1]) + + +#------------------------------------------------------------------------------ +# Test chunking +#------------------------------------------------------------------------------ + +def test_chunk_bounds(): + chunks = chunk_bounds(200, 100, overlap=20) + + assert next(chunks) == (0, 100, 0, 90) + assert next(chunks) == (80, 180, 90, 170) + assert next(chunks) == (160, 200, 170, 200) + + +def test_chunk(): + data = np.random.randn(200, 4) + chunks = chunk_bounds(data.shape[0], 100, overlap=20) + + with raises(ValueError): + data_chunk(data, (0, 0, 0)) + + assert data_chunk(data, (0, 0)).shape == (0, 4) + + # Chunk 1. + ch = next(chunks) + d = data_chunk(data, ch) + d_o = data_chunk(data, ch, with_overlap=True) + + ae(d_o, data[0:100]) + ae(d, data[0:90]) + + # Chunk 2. + ch = next(chunks) + d = data_chunk(data, ch) + d_o = data_chunk(data, ch, with_overlap=True) + + ae(d_o, data[80:180]) + ae(d, data[90:170]) + + +def test_excerpts_1(): + bounds = [(start, end) for (start, end) in excerpts(100, + n_excerpts=3, + excerpt_size=10)] + assert bounds == [(0, 10), (45, 55), (90, 100)] + + +def test_excerpts_2(): + bounds = [(start, end) for (start, end) in excerpts(10, + n_excerpts=3, + excerpt_size=10)] + assert bounds == [(0, 10)] + + +def test_get_excerpts(): + data = np.random.rand(100, 2) + subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) + assert subdata.shape == (50, 2) + ae(subdata[:5, :], data[:5, :]) + ae(subdata[-5:, :], data[-10:-5, :]) + + data = np.random.rand(10, 2) + subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) + ae(subdata, data) + + data = np.random.rand(10, 2) + subdata = get_excerpts(data, n_excerpts=1, excerpt_size=10) + ae(subdata, data) + + assert len(get_excerpts(data, n_excerpts=0, excerpt_size=10)) == 0 + + +def test_regular_subset(): + spikes = [2, 3, 5, 7, 11, 13, 17] + ae(regular_subset(spikes), spikes) + ae(regular_subset(spikes, 100), spikes) + ae(regular_subset(spikes, 100, offset=2), spikes) + ae(regular_subset(spikes, 3), [2, 7, 17]) + ae(regular_subset(spikes, 3, offset=1), [3, 11]) + + +#------------------------------------------------------------------------------ +# Test spike clusters functions +#------------------------------------------------------------------------------ + +def test_spikes_in_clusters(): + """Test _spikes_in_clusters().""" + + n_spikes = 100 + n_clusters = 5 + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + + ae(_spikes_in_clusters(spike_clusters, []), []) + + for i in range(n_clusters): + assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, + [i])] == i) + + clusters = [1, 2, 3] + assert np.all(np.in1d(spike_clusters[_spikes_in_clusters(spike_clusters, + clusters)], + clusters)) + + +def test_spikes_per_cluster(): + """Test _spikes_per_cluster().""" + + n_spikes = 100 + n_clusters = 3 + spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) + + assert not _spikes_per_cluster([]) + + spikes_per_cluster = _spikes_per_cluster(spike_clusters) + assert list(spikes_per_cluster.keys()) == list(range(n_clusters)) + + for i in range(n_clusters): + ae(spikes_per_cluster[i], np.sort(spikes_per_cluster[i])) + assert np.all(spike_clusters[spikes_per_cluster[i]] == i) + + +def test_flatten_per_cluster(): + spc = {2: [2, 7, 11], 3: [3, 5], 5: []} + arr = _flatten_per_cluster(spc) + ae(arr, [2, 3, 5, 7, 11]) + + +def test_grouped_mean(): + spike_clusters = np.array([2, 3, 2, 2, 5]) + arr = spike_clusters * 10 + ae(grouped_mean(arr, spike_clusters), [20, 30, 50]) + + +def test_select_spikes(): + with raises(AssertionError): + select_spikes() + spikes = [2, 3, 5, 7, 11] + spc = lambda c: {2: [2, 7, 11], 3: [3, 5], 5: []}.get(c, None) + ae(select_spikes([], spikes_per_cluster=spc), []) + ae(select_spikes([2, 3, 5], spikes_per_cluster=spc), spikes) + ae(select_spikes([2, 5], spikes_per_cluster=spc), spc(2)) + + ae(select_spikes([2, 3, 5], 0, spikes_per_cluster=spc), spikes) + ae(select_spikes([2, 3, 5], None, spikes_per_cluster=spc), spikes) + ae(select_spikes([2, 3, 5], 1, spikes_per_cluster=spc), [2, 3]) + ae(select_spikes([2, 5], 2, spikes_per_cluster=spc), [2]) + + sel = Selector(spc) + assert sel.select_spikes() is None + ae(sel.select_spikes([2, 5]), spc(2)) + ae(sel.select_spikes([2, 5], 2), [2]) diff --git a/phy/io/tests/test_base.py b/phy/io/tests/test_base.py deleted file mode 100644 index c400c0433..000000000 --- a/phy/io/tests/test_base.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of the BaseModel class.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import raises - -from ..base import BaseModel, ClusterMetadata - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_base_cluster_metadata(): - meta = ClusterMetadata() - - @meta.default - def group(cluster): - return 3 - - @meta.default - def color(cluster): - return 0 - - assert meta.group(0) is not None - assert meta.group(2) == 3 - assert meta.group(10) == 3 - - meta.set_color(10, 5) - assert meta.color(10) == 5 - - # Alternative __setitem__ syntax. - meta.set_color([10, 11], 5) - assert meta.color(10) == 5 - assert meta.color(11) == 5 - - meta.set_color([10, 11], 6) - assert meta.color(10) == 6 - assert meta.color(11) == 6 - assert meta.color([10, 11]) == [6, 6] - - meta.set_color(10, 20) - assert meta.color(10) == 20 - - -def test_base(): - model = BaseModel() - - assert model.channel_group is None - - model.channel_group = 1 - assert model.channel_group == 1 - - assert model.channel_groups == [] - assert model.clusterings == [] - - model.clustering = 'original' - assert model.clustering == 'original' - - with raises(NotImplementedError): - model.metadata - with raises(NotImplementedError): - model.traces - with raises(NotImplementedError): - model.spike_samples - with raises(NotImplementedError): - model.spike_clusters - with raises(NotImplementedError): - model.cluster_metadata - with raises(NotImplementedError): - model.features - with raises(NotImplementedError): - model.masks - with raises(NotImplementedError): - model.waveforms - with raises(NotImplementedError): - model.probe - with raises(NotImplementedError): - model.save() - - model.close() diff --git a/phy/io/tests/test_context.py b/phy/io/tests/test_context.py new file mode 100644 index 000000000..2d54e2be2 --- /dev/null +++ b/phy/io/tests/test_context.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- + +"""Test context.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +import numpy as np +from numpy.testing import assert_array_equal as ae +from pytest import yield_fixture +from six.moves import cPickle + +from ..array import write_array, read_array +from ..context import Context, _fullname + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture(scope='function') +def context(tempdir): + ctx = Context('{}/cache/'.format(tempdir), verbose=1) + yield ctx + + +@yield_fixture +def temp_phy_config_dir(tempdir): + """Use a temporary phy user directory.""" + import phy.io.context + f = phy.io.context.phy_config_dir + phy.io.context.phy_config_dir = lambda: tempdir + yield + phy.io.context.phy_config_dir = f + + +#------------------------------------------------------------------------------ +# Test utils and cache +#------------------------------------------------------------------------------ + +def test_fullname(): + def myfunction(x): + return x + + assert _fullname(myfunction) == 'phy.io.tests.test_context.myfunction' + + +def test_read_write(tempdir): + x = np.arange(10) + write_array(op.join(tempdir, 'test.npy'), x) + ae(read_array(op.join(tempdir, 'test.npy')), x) + + +def test_context_load_save(tempdir, context, temp_phy_config_dir): + assert not context.load('unexisting') + + context.save('a/hello', {'text': 'world'}) + assert context.load('a/hello')['text'] == 'world' + + context.save('a/hello', {'text': 'world!'}, location='global') + assert context.load('a/hello', location='global')['text'] == 'world!' + + +def test_context_cache(context): + + _res = [] + + def f(x): + _res.append(x) + return x ** 2 + + x = np.arange(5) + x2 = x * x + + ae(f(x), x2) + assert len(_res) == 1 + + f = context.cache(f) + + # Run it a first time. + ae(f(x), x2) + assert len(_res) == 2 + + # The second time, the cache is used. + ae(f(x), x2) + assert len(_res) == 2 + + +def test_context_memcache(tempdir, context): + + _res = [] + + @context.memcache + def f(x): + _res.append(x) + return x ** 2 + + # Compute the function a first time. + x = np.arange(10) + ae(f(x), x ** 2) + assert len(_res) == 1 + + # The second time, the memory cache is used. + ae(f(x), x ** 2) + assert len(_res) == 1 + + # We artificially clear the memory cache. + context.save_memcache() + del context._memcache[_fullname(f)] + context.load_memcache(_fullname(f)) + + # This time, the result is loaded from disk. + ae(f(x), x ** 2) + assert len(_res) == 1 + + +def test_pickle_cache(tempdir, context): + """Make sure the Context is picklable.""" + with open(op.join(tempdir, 'test.pkl'), 'wb') as f: + cPickle.dump(context, f) + with open(op.join(tempdir, 'test.pkl'), 'rb') as f: + ctx = cPickle.load(f) + assert isinstance(ctx, Context) + assert ctx.cache_dir == context.cache_dir diff --git a/phy/utils/tests/test_datasets.py b/phy/io/tests/test_datasets.py similarity index 84% rename from phy/utils/tests/test_datasets.py rename to phy/io/tests/test_datasets.py index 77f356743..aafc4d102 100644 --- a/phy/utils/tests/test_datasets.py +++ b/phy/io/tests/test_datasets.py @@ -6,6 +6,7 @@ # Imports #------------------------------------------------------------------------------ +import logging import os.path as op from itertools import product @@ -19,18 +20,17 @@ download_sample_data, _check_md5_of_url, _BASE_URL, + _validate_output_dir, ) -from ..logging import register, StringLogger, set_level +from phy.utils.testing import captured_logging + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ # Fixtures #------------------------------------------------------------------------------ -def setup(): - set_level('debug') - - # Test URL and data _URL = 'http://test/data' _DATA = np.linspace(0., 1., 100000).astype(np.float32) @@ -51,7 +51,7 @@ def _add_mock_response(url, body, file_type='binary'): def mock_url(): _add_mock_response(_URL, _DATA.tostring()) _add_mock_response(_URL + '.md5', _CHECKSUM + ' ' + op.basename(_URL)) - yield + yield _URL responses.reset() @@ -82,6 +82,7 @@ def mock_urls(request): def _dl(path): + assert path download_file(_URL, path) with open(path, 'rb') as f: data = f.read() @@ -96,6 +97,12 @@ def _check(data): # Test utility functions #------------------------------------------------------------------------------ +def test_validate_output_dir(chdir_tempdir): + _validate_output_dir(None) + _validate_output_dir(op.join(chdir_tempdir, 'a/b/c')) + assert op.exists(op.join(chdir_tempdir, 'a/b/c/')) + + @responses.activate def test_check_md5_of_url(tempdir, mock_url): output_path = op.join(tempdir, 'data') @@ -116,25 +123,23 @@ def test_download_not_found(tempdir): @responses.activate def test_download_already_exists_invalid(tempdir, mock_url): - logger = StringLogger(level='debug') - register(logger) - path = op.join(tempdir, 'test') - # Create empty file. - open(path, 'a').close() - _check(_dl(path)) - assert 'redownload' in str(logger) + with captured_logging() as buf: + path = op.join(tempdir, 'test') + # Create empty file. + open(path, 'a').close() + _check(_dl(path)) + assert 'redownload' in buf.getvalue() @responses.activate def test_download_already_exists_valid(tempdir, mock_url): - logger = StringLogger(level='debug') - register(logger) - path = op.join(tempdir, 'test') - # Create valid file. - with open(path, 'ab') as f: - f.write(_DATA.tostring()) - _check(_dl(path)) - assert 'skip' in str(logger) + with captured_logging() as buf: + path = op.join(tempdir, 'test') + # Create valid file. + with open(path, 'ab') as f: + f.write(_DATA.tostring()) + _check(_dl(path)) + assert 'skip' in buf.getvalue() @responses.activate @@ -173,6 +178,9 @@ def test_download_sample_data(tempdir): data = f.read() ae(np.fromstring(data, np.float32), _DATA) + # Warning. + download_sample_data(name + '_', tempdir) + responses.reset() @@ -189,4 +197,6 @@ def test_dat_file(tempdir): arr = np.fromfile(f, dtype=np.int16).reshape((-1, 4)) assert arr.shape == (20000, 4) + assert download_test_data(fn, tempdir) == path + responses.reset() diff --git a/phy/io/tests/test_h5.py b/phy/io/tests/test_h5.py deleted file mode 100644 index 72038dd0e..000000000 --- a/phy/io/tests/test_h5.py +++ /dev/null @@ -1,282 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of HDF5 routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ...utils.testing import captured_output -from ..h5 import open_h5, _split_hdf5_path - - -#------------------------------------------------------------------------------ -# Utility test routines -#------------------------------------------------------------------------------ - -def _create_test_file(dirpath): - filename = op.join(dirpath, '_test.h5') - with open_h5(filename, 'w') as tempfile: - # Create a random dataset using h5py directly. - h5file = tempfile.h5py_file - h5file.create_dataset('ds1', (10,), dtype=np.float32) - group = h5file.create_group('/mygroup') - h5file.create_dataset('/mygroup/ds2', (10,), dtype=np.int8) - group.attrs['myattr'] = 123 - return tempfile.filename - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_split_hdf5_path(): - # The path should always start with a leading '/'. - with raises(ValueError): - _split_hdf5_path('') - with raises(ValueError): - _split_hdf5_path('path') - - h, t = _split_hdf5_path('/') - assert (h == '/') and (t == '') - - h, t = _split_hdf5_path('/path') - assert (h == '/') and (t == 'path') - - h, t = _split_hdf5_path('/path/') - assert (h == '/path') and (t == '') - - h, t = _split_hdf5_path('/path/to') - assert (h == '/path') and (t == 'to') - - h, t = _split_hdf5_path('/path/to/') - assert (h == '/path/to') and (t == '') - - # Check that invalid paths raise errors. - with raises(ValueError): - _split_hdf5_path('path/') - with raises(ValueError): - _split_hdf5_path('/path//') - with raises(ValueError): - _split_hdf5_path('/path//to') - - -def test_h5_read(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - # Test close() method. - f = open_h5(filename) - assert f.is_open() - f.close() - assert not f.is_open() - with raises(IOError): - f.describe() - - # Open the test HDF5 file. - with open_h5(filename) as f: - assert f.is_open() - - assert f.children() == ['ds1', 'mygroup'] - assert f.groups() == ['mygroup'] - assert f.datasets() == ['ds1'] - assert f.attrs('/mygroup') == ['myattr'] - assert f.attrs('/mygroup_nonexisting') == [] - - # Check dataset ds1. - ds1 = f.read('/ds1')[:] - assert isinstance(ds1, np.ndarray) - assert ds1.shape == (10,) - assert ds1.dtype == np.float32 - - # Check dataset ds2. - ds2 = f.read('/mygroup/ds2')[:] - assert isinstance(ds2, np.ndarray) - assert ds2.shape == (10,) - assert ds2.dtype == np.int8 - - # Check HDF5 group attribute. - assert f.has_attr('/mygroup', 'myattr') - assert not f.has_attr('/mygroup', 'myattr_bis') - assert not f.has_attr('/mygroup_bis', 'myattr_bis') - value = f.read_attr('/mygroup', 'myattr') - assert value == 123 - - # Check that errors are raised when the paths are invalid. - with raises(Exception): - f.read('//path') - with raises(Exception): - f.read('/path//') - with raises(ValueError): - f.read('/nonexistinggroup') - with raises(ValueError): - f.read('/nonexistinggroup/ds34') - - assert not f.is_open() - - -def test_h5_append(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - f.write('/ds_empty', dtype=np.float32, shape=(10, 2)) - arr = f.read('/ds_empty') - arr[:5, 0] = 1 - - with open_h5(filename, 'r') as f: - arr = f.read('/ds_empty')[...] - assert np.all(arr[:5, 0] == 1) - - -def test_h5_write(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - # Create some array. - temp_array = np.zeros(10, dtype=np.float32) - - # Open the test HDF5 file in read-only mode (the default) and - # try to write in it. This should raise an exception. - with open_h5(filename) as f: - with raises(Exception): - f.write('/ds1', temp_array) - - # Open the test HDF5 file in read/write mode and - # try to write in an existing dataset. - with open_h5(filename, 'a') as f: - # This raises an exception because the file already exists, - # and by default this is forbidden. - with raises(ValueError): - f.write('/ds1', temp_array) - - # This works, though, because we force overwriting the dataset. - f.write('/ds1', temp_array, overwrite=True) - ae(f.read('/ds1'), temp_array) - - # Write a new array. - f.write('/ds2', temp_array) - ae(f.read('/ds2'), temp_array) - - # Write a new array in a nonexistent group. - f.write('/ds3/ds4/ds5', temp_array) - ae(f.read('/ds3/ds4/ds5'), temp_array) - - # Write an existing attribute. - f.write_attr('/ds1', 'myattr', 456) - - with raises(KeyError): - f.read_attr('/ds1', 'nonexistingattr') - - assert f.read_attr('/ds1', 'myattr') == 456 - - # Write a new attribute in a dataset. - f.write_attr('/ds1', 'mynewattr', 789) - assert f.read_attr('/ds1', 'mynewattr') == 789 - - # Write a new attribute in a group. - f.write_attr('/mygroup', 'mynewattr', 890) - assert f.read_attr('/mygroup', 'mynewattr') == 890 - - # Write a new attribute in a nonexisting group. - f.write_attr('/nonexistinggroup', 'mynewattr', 2) - assert f.read_attr('/nonexistinggroup', 'mynewattr') == 2 - - # Write a new attribute two levels into a nonexisting group. - f.write_attr('/nonexistinggroup2/group3', 'mynewattr', 2) - assert f.read_attr('/nonexistinggroup2/group3', 'mynewattr') == 2 - - -def test_h5_describe(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - # Open the test HDF5 file. - with open_h5(filename) as f: - with captured_output() as (out, err): - f.describe() - output = out.getvalue().strip() - output_lines = output.split('\n') - assert len(output_lines) == 3 - - -def test_h5_move(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - - # Test dataset move. - assert f.exists('ds1') - arr = f.read('ds1')[:] - assert len(arr) == 10 - f.move('ds1', 'ds1_new') - assert not f.exists('ds1') - assert f.exists('ds1_new') - arr_new = f.read('ds1_new')[:] - assert len(arr_new) == 10 - ae(arr, arr_new) - - # Test group move. - assert f.exists('mygroup/ds2') - arr = f.read('mygroup/ds2') - f.move('mygroup', 'g/mynewgroup') - assert not f.exists('mygroup') - assert f.exists('g/mynewgroup') - assert f.exists('g/mynewgroup/ds2') - arr_new = f.read('g/mynewgroup/ds2') - ae(arr, arr_new) - - -def test_h5_copy(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - - # Test dataset copy. - assert f.exists('ds1') - arr = f.read('ds1')[:] - assert len(arr) == 10 - f.copy('ds1', 'ds1_new') - assert f.exists('ds1') - assert f.exists('ds1_new') - arr_new = f.read('ds1_new')[:] - assert len(arr_new) == 10 - ae(arr, arr_new) - - # Test group copy. - assert f.exists('mygroup/ds2') - arr = f.read('mygroup/ds2') - f.copy('mygroup', 'g/mynewgroup') - assert f.exists('mygroup') - assert f.exists('g/mynewgroup') - assert f.exists('g/mynewgroup/ds2') - arr_new = f.read('g/mynewgroup/ds2') - ae(arr, arr_new) - - -def test_h5_delete(tempdir): - # Create the test HDF5 file in the temporary directory. - filename = _create_test_file(tempdir) - - with open_h5(filename, 'a') as f: - - # Test dataset delete. - assert f.exists('ds1') - with raises(ValueError): - f.delete('a') - f.delete('ds1') - assert not f.exists('ds1') - - # Test group delete. - assert f.exists('mygroup/ds2') - f.delete('mygroup') - assert not f.exists('mygroup') - assert not f.exists('mygroup/ds2') diff --git a/phy/io/tests/test_mock.py b/phy/io/tests/test_mock.py index 1097a93b9..e53e18063 100644 --- a/phy/io/tests/test_mock.py +++ b/phy/io/tests/test_mock.py @@ -8,15 +8,15 @@ import numpy as np from numpy.testing import assert_array_equal as ae -from pytest import raises -from ...electrode.mea import MEA from ..mock import (artificial_waveforms, artificial_traces, artificial_spike_clusters, artificial_features, artificial_masks, - MockModel) + artificial_spike_samples, + artificial_correlograms, + ) #------------------------------------------------------------------------------ @@ -57,23 +57,15 @@ def _test_artificial(n_spikes=None, n_clusters=None): masks = artificial_masks(n_spikes, n_channels) assert masks.shape == (n_spikes, n_channels) + # Spikes. + spikes = artificial_spike_samples(n_spikes) + assert spikes.shape == (n_spikes,) + + # CCG. + ccg = artificial_correlograms(n_clusters, 10) + assert ccg.shape == (n_clusters, n_clusters, 10) + def test_artificial(): _test_artificial(n_spikes=100, n_clusters=10) _test_artificial(n_spikes=0, n_clusters=0) - - -def test_mock_model(): - model = MockModel() - - assert model.metadata['description'] == 'A mock model.' - assert model.traces.ndim == 2 - assert model.spike_samples.ndim == 1 - assert model.spike_clusters.ndim == 1 - assert model.features.ndim == 2 - assert model.masks.ndim == 2 - assert model.waveforms.ndim == 3 - - assert isinstance(model.probe, MEA) - with raises(NotImplementedError): - model.save() diff --git a/phy/io/tests/test_sparse.py b/phy/io/tests/test_sparse.py deleted file mode 100644 index a92502692..000000000 --- a/phy/io/tests/test_sparse.py +++ /dev/null @@ -1,122 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of sparse matrix structures.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from pytest import raises - -from ..sparse import csr_matrix, SparseCSR, load_h5, save_h5 -from ..h5 import open_h5 - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def _dense_matrix_example(): - """Sparse matrix example: - - * 1 2 * * - 3 * * 4 * - * * * * * - * * 5 * * - - Return a dense array. - - """ - arr = np.zeros((4, 5)) - arr[0, 1] = 1 - arr[0, 2] = 2 - arr[1, 0] = 3 - arr[1, 3] = 4 - arr[3, 2] = 5 - return arr - - -def _sparse_matrix_example(): - """Return a sparse representation of the sparse matrix example.""" - shape = (4, 5) - data = np.array([1, 2, 3, 4, 5]) - channels = np.array([1, 2, 0, 3, 2]) - spikes_ptr = np.array([0, 2, 4, 4, 5]) - return shape, data, channels, spikes_ptr - - -def test_sparse_csr_check(): - """Test the checks performed when creating a sparse matrix.""" - dense = _dense_matrix_example() - shape, data, channels, spikes_ptr = _sparse_matrix_example() - - # Dense to sparse conversion not implemented yet. - with raises(NotImplementedError): - csr_matrix(dense) - - # Need the three sparse components and the shape. - with raises(ValueError): - csr_matrix(data=data, channels=channels) - with raises(ValueError): - csr_matrix(data=data, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(channels=channels, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(data=data, channels=channels, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, - data=data[:-1], channels=channels, spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, channels=[[0]]) - with raises(ValueError): - csr_matrix(shape=shape, spikes_ptr=[0]) - with raises(ValueError): - csr_matrix(shape=(4, 5, 6), data=data, channels=np.zeros((2, 2)), - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=np.zeros((2, 2)), - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=channels, - spikes_ptr=np.zeros((2, 2))) - with raises(ValueError): - csr_matrix(shape=shape, data=np.zeros((100)), channels=channels, - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=np.zeros(100), - spikes_ptr=spikes_ptr) - with raises(ValueError): - csr_matrix(shape=shape, data=data, channels=channels, - spikes_ptr=np.zeros(100)) - - # This one should pass. - sparse = csr_matrix(shape=shape, - data=data, channels=channels, spikes_ptr=spikes_ptr) - assert isinstance(sparse, SparseCSR) - ae(sparse.shape, shape) - ae(sparse._data, data) - ae(sparse._channels, channels) - ae(sparse._spikes_ptr, spikes_ptr) - - -def test_sparse_hdf5(tempdir): - """Test the checks performed when creating a sparse matrix.""" - shape, data, channels, spikes_ptr = _sparse_matrix_example() - sparse = csr_matrix(shape=shape, data=data, channels=channels, - spikes_ptr=spikes_ptr) - dense = _dense_matrix_example() - - with open_h5(op.join(tempdir, 'test.h5'), 'w') as f: - path_sparse = '/my_sparse_array' - save_h5(f, path_sparse, sparse) - sparse_bis = load_h5(f, path_sparse) - assert sparse == sparse_bis - - path_dense = '/my_dense_array' - save_h5(f, path_dense, dense) - dense_bis = load_h5(f, path_dense) - ae(dense, dense_bis) diff --git a/phy/io/tests/test_store.py b/phy/io/tests/test_store.py deleted file mode 100644 index 158d7b495..000000000 --- a/phy/io/tests/test_store.py +++ /dev/null @@ -1,423 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test cluster store.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac - -from ...utils._types import Bunch -from ...utils.array import _spikes_per_cluster -from ..store import (MemoryStore, - BaseStore, - DiskStore, - ClusterStore, - VariableSizeItem, - FixedSizeItem, - ) - - -#------------------------------------------------------------------------------ -# Test data stores -#------------------------------------------------------------------------------ - -def test_memory_store(): - ms = MemoryStore() - assert ms.load(2) == {} - - assert ms.load(3).get('key', None) is None - assert ms.load(3) == {} - assert ms.load(3, ['key']) == {'key': None} - assert ms.load(3) == {} - assert ms.cluster_ids == [] - - ms.store(3, key='a') - assert ms.load(3) == {'key': 'a'} - assert ms.load(3, ['key']) == {'key': 'a'} - assert ms.load(3, 'key') == 'a' - assert ms.cluster_ids == [3] - - ms.store(3, key_bis='b') - assert ms.load(3) == {'key': 'a', 'key_bis': 'b'} - assert ms.load(3, ['key']) == {'key': 'a'} - assert ms.load(3, ['key_bis']) == {'key_bis': 'b'} - assert ms.load(3, ['key', 'key_bis']) == {'key': 'a', 'key_bis': 'b'} - assert ms.load(3, 'key_bis') == 'b' - assert ms.cluster_ids == [3] - - ms.erase([2, 3]) - assert ms.load(3) == {} - assert ms.load(3, ['key']) == {'key': None} - assert ms.cluster_ids == [] - - -def test_base_store(tempdir): - store = BaseStore(tempdir) - - store._save('11.npy', np.arange(11)) - store._save('12.npy', np.arange(12)) - store._save('2.npy', np.arange(2)) - store._save('3.npy', None) - store._save('1.npy', [np.arange(3), np.arange(3, 8)]) - - ae(store._open('01.npy'), None) - ae(store._open('11.npy'), np.arange(11)) - ae(store._open('12.npy'), np.arange(12)) - ae(store._open('2.npy'), np.arange(2)) - ae(store._open('3.npy'), None) - ae(store._open('1.npy')[0], np.arange(3)) - ae(store._open('1.npy')[1], np.arange(3, 8)) - - store._delete_file('2.npy') - store._delete_file('12.npy') - ae(store._open('2.npy'), None) - - -def test_disk_store(tempdir): - - dtype = np.float32 - sha = (2, 4) - shb = (3, 5) - a = np.random.rand(*sha).astype(dtype) - b = np.random.rand(*shb).astype(dtype) - - def _assert_equal(d_0, d_1): - """Test the equality of two dictionaries containing NumPy arrays.""" - assert sorted(d_0.keys()) == sorted(d_1.keys()) - for key in d_0.keys(): - ac(d_0[key], d_1[key]) - - ds = DiskStore(tempdir) - - ds.register_file_extensions(['key', 'key_bis']) - assert ds.cluster_ids == [] - - ds.store(3, key=a) - _assert_equal(ds.load(3, - ['key'], - dtype=dtype, - shape=sha, - ), - {'key': a}) - loaded = ds.load(3, 'key', dtype=dtype, shape=sha) - ac(loaded, a) - - # Loading a non-existing key returns None. - assert ds.load(3, 'key_bis') is None - assert ds.cluster_ids == [3] - - ds.store(3, key_bis=b) - _assert_equal(ds.load(3, ['key'], dtype=dtype, shape=sha), {'key': a}) - _assert_equal(ds.load(3, ['key_bis'], - dtype=dtype, - shape=shb, - ), - {'key_bis': b}) - _assert_equal(ds.load(3, - ['key', 'key_bis'], - dtype=dtype, - ), - {'key': a.ravel(), 'key_bis': b.ravel()}) - ac(ds.load(3, 'key_bis', dtype=dtype, shape=shb), b) - assert ds.cluster_ids == [3] - - ds.erase([2, 3]) - assert ds.load(3, ['key']) == {'key': None} - assert ds.cluster_ids == [] - - # Test load/save file. - ds.save_file('test', {'a': a}) - ds = DiskStore(tempdir) - data = ds.load_file('test') - ae(data['a'], a) - assert ds.load_file('test2') is None - - -def test_cluster_store_1(tempdir): - - # We define some data and a model. - n_spikes = 100 - n_clusters = 10 - - spike_ids = np.arange(n_spikes) - spike_clusters = np.random.randint(size=n_spikes, - low=0, high=n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - - model = {'spike_clusters': spike_clusters} - - # We initialize the ClusterStore. - cs = ClusterStore(model=model, - path=tempdir, - spikes_per_cluster=spikes_per_cluster, - ) - - # We create a n_spikes item to be stored in memory, - # and we define how to generate it for a given cluster. - class MyItem(FixedSizeItem): - name = 'my item' - fields = ['n_spikes'] - - def store(self, cluster): - spikes = self.spikes_per_cluster[cluster] - self.memory_store.store(cluster, n_spikes=len(spikes)) - - def load(self, cluster, name): - return self.memory_store.load(cluster, name) - - def on_cluster(self, up): - if up.description == 'merge': - n = sum(len(up.old_spikes_per_cluster[cl]) - for cl in up.deleted) - self.memory_store.store(up.added[0], n_spikes=n) - else: - super(MyItem, self).on_cluster(up) - - item = cs.register_item(MyItem) - item.progress_reporter.set_progress_message("Progress {progress}.\n") - item.progress_reporter.set_complete_message("Finished.\n") - - # Now we generate the store. - cs.generate() - - # We check that the n_spikes field has successfully been created. - for cluster in sorted(spikes_per_cluster): - assert cs.n_spikes(cluster) == len(spikes_per_cluster[cluster]) - - # Merge. - spc = spikes_per_cluster.copy() - spikes = np.sort(np.concatenate([spc[0], spc[1]])) - spc[20] = spikes - del spc[0] - del spc[1] - up = Bunch(description='merge', - added=[20], - deleted=[0, 1], - spike_ids=spikes, - new_spikes_per_cluster=spc, - old_spikes_per_cluster=spikes_per_cluster,) - - cs.items['my item'].on_cluster(up) - - # Check the list of clusters in the store. - ae(cs.memory_store.cluster_ids, list(range(0, n_clusters)) + [20]) - ae(cs.disk_store.cluster_ids, []) - assert cs.n_spikes(20) == len(spikes) - - # Recreate the cluster store. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - cs.register_item(MyItem) - cs.generate() - ae(cs.memory_store.cluster_ids, list(range(n_clusters))) - ae(cs.disk_store.cluster_ids, []) - - -def test_cluster_store_multi(): - """This tests the cluster store when a store item has several fields.""" - - cs = ClusterStore(spikes_per_cluster={0: [0, 2], 1: [1, 3, 4]}) - - class MyItem(FixedSizeItem): - name = 'my item' - fields = ['d', 'm'] - - def store(self, cluster): - spikes = self.spikes_per_cluster[cluster] - self.memory_store.store(cluster, d=len(spikes), m=len(spikes) ** 2) - - def load(self, cluster, name): - return self.memory_store.load(cluster, name) - - cs.register_item(MyItem) - - cs.generate() - - assert cs.memory_store.load(0, ['d', 'm']) == {'d': 2, 'm': 4} - assert cs.d(0) == 2 - assert cs.m(0) == 4 - - assert cs.memory_store.load(1, ['d', 'm']) == {'d': 3, 'm': 9} - assert cs.d(1) == 3 - assert cs.m(1) == 9 - - -def test_cluster_store_load(tempdir): - - # We define some data and a model. - n_spikes = 100 - n_clusters = 10 - - spike_ids = np.arange(n_spikes) - spike_clusters = np.random.randint(size=n_spikes, - low=0, high=n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - model = {'spike_clusters': spike_clusters} - - # We initialize the ClusterStore. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - - # We create a n_spikes item to be stored in memory, - # and we define how to generate it for a given cluster. - class MyItem(VariableSizeItem): - name = 'my item' - fields = ['spikes_square'] - - def store(self, cluster): - spikes = spikes_per_cluster[cluster] - data = (spikes ** 2).astype(np.int32) - self.disk_store.store(cluster, spikes_square=data) - - def load(self, cluster, name): - return self.disk_store.load(cluster, name, np.int32) - - def load_spikes(self, spikes, name): - return (spikes ** 2).astype(np.int32) - - cs.register_item(MyItem) - cs.generate() - - # All spikes in cluster 1. - cluster = 1 - spikes = spikes_per_cluster[cluster] - ae(cs.load('spikes_square', clusters=[cluster]), spikes ** 2) - - # Some spikes in several clusters. - clusters = [2, 3, 5] - spikes = np.concatenate([spikes_per_cluster[cl][::3] - for cl in clusters]) - ae(cs.load('spikes_square', spikes=spikes), np.unique(spikes) ** 2) - - # Empty selection. - assert len(cs.load('spikes_square', clusters=[])) == 0 - assert len(cs.load('spikes_square', spikes=[])) == 0 - - -def test_cluster_store_management(tempdir): - - # We define some data and a model. - n_spikes = 100 - n_clusters = 10 - - spike_ids = np.arange(n_spikes) - spike_clusters = np.random.randint(size=n_spikes, - low=0, high=n_clusters) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - - model = Bunch({'spike_clusters': spike_clusters, - 'cluster_ids': np.arange(n_clusters), - }) - - # We initialize the ClusterStore. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - - # We create a n_spikes item to be stored in memory, - # and we define how to generate it for a given cluster. - class MyItem(VariableSizeItem): - name = 'my item' - fields = ['spikes_square'] - - def store(self, cluster): - spikes = self.spikes_per_cluster[cluster] - if not self.is_consistent(cluster, spikes): - data = (spikes ** 2).astype(np.int32) - self.disk_store.store(cluster, spikes_square=data) - - def is_consistent(self, cluster, spikes): - data = self.disk_store.load(cluster, - 'spikes_square', - dtype=np.int32, - ) - if data is None: - return False - if len(data) != len(spikes): - return False - expected = (spikes ** 2).astype(np.int32) - return np.all(data == expected) - - cs.register_item(MyItem) - cs.update_spikes_per_cluster(spikes_per_cluster) - - def _check_to_generate(cs, clusters): - item = cs.items['my item'] - ae(item.to_generate(), clusters) - ae(item.to_generate(None), clusters) - ae(item.to_generate('default'), clusters) - ae(item.to_generate('force'), np.arange(n_clusters)) - ae(item.to_generate('read-only'), []) - - # Check the list of clusters to generate. - _check_to_generate(cs, np.arange(n_clusters)) - - # Generate the store. - cs.generate() - - # Check the status. - assert 'True' in cs.status - - # We re-initialize the ClusterStore. - cs = ClusterStore(model=model, - spikes_per_cluster=spikes_per_cluster, - path=tempdir, - ) - cs.register_item(MyItem) - cs.update_spikes_per_cluster(spikes_per_cluster) - - # Check the list of clusters to generate. - _check_to_generate(cs, []) - cs.display_status() - - # We erase a file. - path = op.join(cs.path, '1.spikes_square') - os.remove(path) - - # Check the list of clusters to generate. - _check_to_generate(cs, [1]) - assert '9' in cs.status - assert 'False' in cs.status - - cs.generate() - - # Check the status. - assert 'True' in cs.status - - # Now, we make new assignements. - spike_clusters = np.random.randint(size=n_spikes, - low=n_clusters, high=n_clusters + 5) - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - cs.update_spikes_per_cluster(spikes_per_cluster) - - # All files are now old and should be removed by clean(). - assert not cs.is_consistent() - item = cs.items['my item'] - ae(item.to_generate(), np.arange(n_clusters, n_clusters + 5)) - - ae(cs.cluster_ids, np.arange(n_clusters, n_clusters + 5)) - ae(cs.old_clusters, np.arange(n_clusters)) - cs.clean() - - ae(cs.cluster_ids, np.arange(n_clusters, n_clusters + 5)) - ae(cs.old_clusters, []) - ae(item.to_generate(), np.arange(n_clusters, n_clusters + 5)) - assert not cs.is_consistent() - cs.generate() - - assert cs.is_consistent() - ae(cs.cluster_ids, np.arange(n_clusters, n_clusters + 5)) - ae(cs.old_clusters, []) - ae(item.to_generate(), []) diff --git a/phy/io/tests/test_traces.py b/phy/io/tests/test_traces.py deleted file mode 100644 index bbc635f28..000000000 --- a/phy/io/tests/test_traces.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of read traces functions.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac - -from ..h5 import open_h5 -from ..traces import read_dat, _dat_n_samples, read_kwd -from ..mock import artificial_traces - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_read_dat(tempdir): - n_samples = 100 - n_channels = 10 - - arr = artificial_traces(n_samples, n_channels) - - path = op.join(tempdir, 'test') - arr.tofile(path) - assert _dat_n_samples(path, dtype=np.float64, - n_channels=n_channels) == n_samples - data = read_dat(path, dtype=arr.dtype, shape=arr.shape) - ae(arr, data) - data = read_dat(path, dtype=arr.dtype, n_channels=n_channels) - ae(arr, data) - - -def test_read_kwd(tempdir): - n_samples = 100 - n_channels = 10 - - arr = artificial_traces(n_samples, n_channels) - - path = op.join(tempdir, 'test') - - with open_h5(path, 'w') as f: - f.write('/recordings/0/data', - arr[:n_samples // 2, ...].astype(np.float32)) - f.write('/recordings/1/data', - arr[n_samples // 2:, ...].astype(np.float32)) - - with open_h5(path, 'r') as f: - data = read_kwd(f)[:] - - ac(arr, data) diff --git a/phy/io/traces.py b/phy/io/traces.py deleted file mode 100644 index e61651c90..000000000 --- a/phy/io/traces.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Raw data readers.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -import numpy as np - -from ..utils.array import _concatenate_virtual_arrays - - -#------------------------------------------------------------------------------ -# Raw data readers -#------------------------------------------------------------------------------ - -def read_kwd(kwd_handle): - """Read all traces in a  `.kwd` file. - - The output is a memory-mapped file. - - """ - f = kwd_handle - if '/recordings' not in f: - return - recordings = f.children('/recordings') - traces = [] - for recording in recordings: - traces.append(f.read('/recordings/{}/data'.format(recording))) - return _concatenate_virtual_arrays(traces) - - -def read_dat(filename, dtype=None, shape=None, offset=0, n_channels=None): - """Read traces from a flat binary `.dat` file. - - The output is a memory-mapped file. - - Parameters - ---------- - - filename : str - The path to the `.dat` file. - dtype : dtype - The NumPy dtype. - offset : 0 - The header size. - n_channels : int - The number of channels in the data. - shape : tuple (optional) - The array shape. Typically `(n_samples, n_channels)`. The shape is - automatically computed from the file size if the number of channels - and dtype are specified. - - """ - if shape is None: - assert n_channels > 0 - n_samples = _dat_n_samples(filename, dtype=dtype, - n_channels=n_channels) - shape = (n_samples, n_channels) - return np.memmap(filename, dtype=dtype, shape=shape, - mode='r', offset=offset) - - -def _dat_to_traces(dat_path, n_channels, dtype): - assert dtype is not None - assert n_channels is not None - n_samples = _dat_n_samples(dat_path, - n_channels=n_channels, - dtype=dtype, - ) - return read_dat(dat_path, - dtype=dtype, - shape=(n_samples, n_channels)) - - -def _dat_n_samples(filename, dtype=None, n_channels=None): - assert dtype is not None - item_size = np.dtype(dtype).itemsize - n_samples = op.getsize(filename) // (item_size * n_channels) - assert n_samples >= 0 - return n_samples - - -def read_ns5(filename): - # TODO - raise NotImplementedError() diff --git a/phy/plot/__init__.py b/phy/plot/__init__.py index 03d945c18..e5d3c1f12 100644 --- a/phy/plot/__init__.py +++ b/phy/plot/__init__.py @@ -1,11 +1,28 @@ # -*- coding: utf-8 -*- # flake8: noqa -"""Interactive and static visualization of data.""" - -from ._panzoom import PanZoom, PanZoomGrid -from ._vispy_utils import BaseSpikeCanvas, BaseSpikeVisual -from .waveforms import WaveformView, WaveformVisual, plot_waveforms -from .features import FeatureView, FeatureVisual, plot_features -from .traces import TraceView, TraceView, plot_traces -from .ccg import CorrelogramView, CorrelogramView, plot_correlograms +"""VisPy plotting.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +from vispy import config + +from .plot import View # noqa +from .transform import Translate, Scale, Range, Subplot, NDC +from .panzoom import PanZoom +from .utils import _get_linear_x + + +#------------------------------------------------------------------------------ +# Add the `glsl/ path` for shader include +#------------------------------------------------------------------------------ + +curdir = op.dirname(op.realpath(__file__)) +glsl_path = op.join(curdir, 'glsl') +if not config['include_path']: + config['include_path'] = [glsl_path] diff --git a/phy/plot/_mpl_utils.py b/phy/plot/_mpl_utils.py deleted file mode 100644 index ff2b38a23..000000000 --- a/phy/plot/_mpl_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- - -"""matplotlib utilities.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - - -#------------------------------------------------------------------------------ -# matplotlib utilities -#------------------------------------------------------------------------------ - -def _bottom_left_frame(ax): - """Only keep the bottom and left ticks in a matplotlib Axes.""" - ax.spines['right'].set_visible(False) - ax.spines['top'].set_visible(False) - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') diff --git a/phy/plot/_panzoom.py b/phy/plot/_panzoom.py deleted file mode 100644 index 6dc56fa1b..000000000 --- a/phy/plot/_panzoom.py +++ /dev/null @@ -1,788 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Pan & zoom transform.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import math - -import numpy as np - -from ..utils._types import _as_array - - -#------------------------------------------------------------------------------ -# PanZoom class -#------------------------------------------------------------------------------ - -class PanZoom(object): - """Pan & zoom transform. - - The panzoom transform allow to translate and scale an object in the window - space coordinate (2D). This means that whatever point you grab on the - screen, it should remains under the mouse pointer. Zoom is realized using - the mouse scroll and is always centered on the mouse pointer. - - You can also control programmatically the transform using: - - * aspect: control the aspect ratio of the whole scene - * pan : translate the scene to the given 2D coordinates - * zoom : set the zoom level (centered at current pan coordinates) - * zmin : minimum zoom level - * zmax : maximum zoom level - - """ - - _default_zoom_coeff = 1.5 - _default_wheel_coeff = .1 - _arrows = ('Left', 'Right', 'Up', 'Down') - _pm = ('+', '-') - - def __init__(self, aspect=1.0, pan=(0.0, 0.0), zoom=(1.0, 1.0), - zmin=1e-5, zmax=1e5, - xmin=None, xmax=None, - ymin=None, ymax=None, - ): - """ - Initialize the transform. - - Parameters - ---------- - - aspect : float (default is None) - Indicate what is the aspect ratio of the object displayed. This is - necessary to convert pixel drag move in object space coordinates. - - pan : float, float (default is 0, 0) - Initial translation - - zoom : float, float (default is 1) - Initial zoom level - - zmin : float (default is 0.01) - Minimum zoom level - - zmax : float (default is 1000) - Maximum zoom level - """ - - self._aspect = aspect - self._zmin = zmin - self._zmax = zmax - self._xmin = xmin - self._xmax = xmax - self._ymin = ymin - self._ymax = ymax - - self._zoom_to_pointer = True - - # Canvas this transform is attached to - self._canvas = None - self._canvas_aspect = np.ones(2) - self._width = 1 - self._height = 1 - - self._create_pan_and_zoom(_as_array(pan), _as_array(zoom)) - - # Programs using this transform - self._programs = [] - - def _create_pan_and_zoom(self, pan, zoom): - self._pan = np.array(pan) - self._zoom = np.array(zoom) - self._zoom_coeff = self._default_zoom_coeff - self._wheel_coeff = self._default_wheel_coeff - - # Various properties - # ------------------------------------------------------------------------- - - @property - def zoom_to_pointer(self): - """Whether to zoom toward the pointer position.""" - return self._zoom_to_pointer - - @zoom_to_pointer.setter - def zoom_to_pointer(self, value): - self._zoom_to_pointer = value - - @property - def is_attached(self): - """Whether the transform is attached to a canvas.""" - return self._canvas is not None - - @property - def aspect(self): - """Aspect (width/height).""" - return self._aspect - - @aspect.setter - def aspect(self, value): - """Aspect (width/height).""" - self._aspect = value - - # xmin/xmax - # ------------------------------------------------------------------------- - - @property - def xmin(self): - """Minimum x allowed for pan.""" - return self._xmin - - @xmin.setter - def xmin(self, value): - if self._xmax is not None: - self._xmin = np.minimum(value, self._xmax) - else: - self._xmin = value - - @property - def xmax(self): - """Maximum x allowed for pan.""" - return self._xmax - - @xmax.setter - def xmax(self, value): - if self._xmin is not None: - self._xmax = np.maximum(value, self._xmin) - else: - self._xmax = value - - # ymin/ymax - # ------------------------------------------------------------------------- - - @property - def ymin(self): - """Minimum y allowed for pan.""" - return self._ymin - - @ymin.setter - def ymin(self, value): - if self._ymax is not None: - self._ymin = min(value, self._ymax) - else: - self._ymin = value - - @property - def ymax(self): - """Maximum y allowed for pan.""" - return self._ymax - - @ymax.setter - def ymax(self, value): - if self._ymin is not None: - self._ymax = max(value, self._ymin) - else: - self._ymax = value - - # zmin/zmax - # ------------------------------------------------------------------------- - - @property - def zmin(self): - """Minimum zoom level.""" - return self._zmin - - @zmin.setter - def zmin(self, value): - """Minimum zoom level.""" - self._zmin = min(value, self._zmax) - - @property - def zmax(self): - """Maximal zoom level.""" - return self._zmax - - @zmax.setter - def zmax(self, value): - """Maximal zoom level.""" - self._zmax = max(value, self._zmin) - - # Internal methods - # ------------------------------------------------------------------------- - - def _apply_pan_zoom(self): - zoom = self._zoom_aspect() - for program in self._programs: - program["u_pan"] = self._pan - program["u_zoom"] = zoom - - def _zoom_aspect(self, zoom=None): - if zoom is None: - zoom = self._zoom - zoom = _as_array(zoom) - if self._aspect is not None: - aspect = self._canvas_aspect * self._aspect - else: - aspect = 1. - return zoom * aspect - - def _normalize(self, x_y, restrict_to_box=True): - x_y = np.asarray(x_y, dtype=np.float32) - size = np.array([self._width, self._height], dtype=np.float32) - pos = x_y / (size / 2.) - 1 - return pos - - def _constrain_pan(self): - """Constrain bounding box.""" - if self.xmin is not None and self._xmax is not None: - p0 = self.xmin + 1. / self._zoom[0] - p1 = self.xmax - 1. / self._zoom[0] - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[0] = np.clip(self._pan[0], p0, p1) - - if self.ymin is not None and self._ymax is not None: - p0 = self.ymin + 1. / self._zoom[1] - p1 = self.ymax - 1. / self._zoom[1] - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[1] = np.clip(self._pan[1], p0, p1) - - def _constrain_zoom(self): - """Constrain bounding box.""" - if self.xmin is not None: - self._zoom[0] = max(self._zoom[0], - 1. / (self._pan[0] - self.xmin)) - if self.xmax is not None: - self._zoom[0] = max(self._zoom[0], - 1. / (self.xmax - self._pan[0])) - - if self.ymin is not None: - self._zoom[1] = max(self._zoom[1], - 1. / (self._pan[1] - self.ymin)) - if self.ymax is not None: - self._zoom[1] = max(self._zoom[1], - 1. / (self.ymax - self._pan[1])) - - # Pan and zoom - # ------------------------------------------------------------------------- - - @property - def pan(self): - """Pan translation.""" - return self._pan - - @pan.setter - def pan(self, value): - """Pan translation.""" - assert len(value) == 2 - self._pan[:] = value - self._constrain_pan() - self._apply_pan_zoom() - - @property - def zoom(self): - """Zoom level.""" - return self._zoom - - @zoom.setter - def zoom(self, value): - """Zoom level.""" - if isinstance(value, (int, float)): - value = (value, value) - assert len(value) == 2 - self._zoom = np.clip(value, self._zmin, self._zmax) - if not self.is_attached: - return - - # Constrain bounding box. - self._constrain_pan() - self._constrain_zoom() - - self._apply_pan_zoom() - - def _do_pan(self, d): - - dx, dy = d - - pan_x, pan_y = self.pan - zoom_x, zoom_y = self._zoom_aspect(self._zoom) - - self.pan = (pan_x + dx / zoom_x, - pan_y + dy / zoom_y) - - self._canvas.update() - - def _do_zoom(self, d, p, c=1.): - dx, dy = d - x0, y0 = p - - pan_x, pan_y = self._pan - zoom_x, zoom_y = self._zoom - zoom_x_new, zoom_y_new = (zoom_x * math.exp(c * self._zoom_coeff * dx), - zoom_y * math.exp(c * self._zoom_coeff * dy)) - - zoom_x_new = max(min(zoom_x_new, self._zmax), self._zmin) - zoom_y_new = max(min(zoom_y_new, self._zmax), self._zmin) - - self.zoom = zoom_x_new, zoom_y_new - - if self._zoom_to_pointer: - zoom_x, zoom_y = self._zoom_aspect((zoom_x, - zoom_y)) - zoom_x_new, zoom_y_new = self._zoom_aspect((zoom_x_new, - zoom_y_new)) - - self.pan = (pan_x - x0 * (1. / zoom_x - 1. / zoom_x_new), - pan_y + y0 * (1. / zoom_y - 1. / zoom_y_new)) - - self._canvas.update() - - # Event callbacks - # ------------------------------------------------------------------------- - - keyboard_shortcuts = { - 'pan': ('left click and drag', 'arrows'), - 'zoom': ('right click and drag', '+', '-'), - 'reset': 'r', - } - - def on_resize(self, event): - """Resize event.""" - - self._width = float(event.size[0]) - self._height = float(event.size[1]) - aspect = max(1., self._width / max(self._height, 1.)) - if aspect > 1.0: - self._canvas_aspect = np.array([1.0 / aspect, 1.0]) - else: - self._canvas_aspect = np.array([1.0, aspect / 1.0]) - - # Update zoom level - self.zoom = self._zoom - - def on_mouse_move(self, event): - """Pan and zoom with the mouse.""" - if event.modifiers: - return - if event.is_dragging: - x0, y0 = self._normalize(event.press_event.pos) - x1, y1 = self._normalize(event.last_event.pos, False) - x, y = self._normalize(event.pos, False) - dx, dy = x - x1, -(y - y1) - if event.button == 1: - self._do_pan((dx, dy)) - elif event.button == 2: - c = np.sqrt(self._width) * .03 - self._do_zoom((dx, dy), (x0, y0), c=c) - - def on_mouse_wheel(self, event): - """Zoom with the mouse wheel.""" - if event.modifiers: - return - dx = np.sign(event.delta[1]) * self._wheel_coeff - # Zoom toward the mouse pointer. - x0, y0 = self._normalize(event.pos) - self._do_zoom((dx, dx), (x0, y0)) - - def _zoom_keyboard(self, key): - k = .05 - if key == '-': - k = -k - self._do_zoom((k, k), (0, 0)) - - def _pan_keyboard(self, key): - k = .1 / self.zoom - if key == 'Left': - self.pan += (+k[0], +0) - elif key == 'Right': - self.pan += (-k[0], +0) - elif key == 'Down': - self.pan += (+0, +k[1]) - elif key == 'Up': - self.pan += (+0, -k[1]) - self._canvas.update() - - def _reset_keyboard(self): - self.pan = (0., 0.) - self.zoom = 1. - self._canvas.update() - - def on_key_press(self, event): - """Key press event.""" - - # Zooming with the keyboard. - key = event.key - if event.modifiers: - return - - # Pan. - if key in self._arrows: - self._pan_keyboard(key) - - # Zoom. - if key in self._pm: - self._zoom_keyboard(key) - - # Reset with 'R'. - if key == 'R': - self._reset_keyboard() - - # Canvas methods - # ------------------------------------------------------------------------- - - def add(self, programs): - """Add programs to this tranform.""" - - if not isinstance(programs, (list, tuple)): - programs = [programs] - - for program in programs: - self._programs.append(program) - - self._apply_pan_zoom() - - def attach(self, canvas): - """Attach this tranform to a canvas.""" - self._canvas = canvas - self._width = float(canvas.size[0]) - self._height = float(canvas.size[1]) - - aspect = self._width / max(1, self._height) - if aspect > 1.0: - self._canvas_aspect = np.array([1.0 / aspect, 1.0]) - else: - self._canvas_aspect = np.array([1.0, aspect / 1.0]) - - canvas.connect(self.on_resize) - canvas.connect(self.on_mouse_wheel) - canvas.connect(self.on_mouse_move) - canvas.connect(self.on_key_press) - - -class PanZoomGrid(PanZoom): - """Pan & zoom transform for a grid view. - - This is used in a grid view with independent per-subplot pan & zoom. - - The currently-active subplot depends on where the cursor was when - the mouse was clicked. - - """ - - def __init__(self, *args, **kwargs): - self._index = (0, 0) # current index of the box being pan/zoom-ed - self._n_rows = 1 - self._create_pan_and_zoom(pan=(0., 0.), zoom=(1., 1.)) - self._set_pan_zoom_coeffs() - super(PanZoomGrid, self).__init__(*args, **kwargs) - - # Grid properties - # ------------------------------------------------------------------------- - - @property - def n_rows(self): - """Number of rows.""" - assert self._n_rows is not None - return self._n_rows - - @n_rows.setter - def n_rows(self, value): - self._n_rows = int(value) - assert self._n_rows >= 1 - if self._n_rows > 16: - raise RuntimeError("There cannot be more than 16x16 subplots. " - "The limitation comes from the maximum " - "uniform array size used by panzoom. " - "But you can try to increase the number 256 in " - "'plot/glsl/grid.glsl'.") - self._set_pan_zoom_coeffs() - - # Pan and zoom - # ------------------------------------------------------------------------- - - def _create_pan_and_zoom(self, pan, zoom): - pan = _as_array(pan) - zoom = _as_array(zoom) - n = 16 # maximum number of rows - self._pan_matrix = np.empty((n, n, 2)) - self._pan_matrix[...] = pan[None, None, :] - - self._zoom_matrix = np.empty((n, n, 2)) - self._zoom_matrix[...] = zoom[None, None, :] - - def _set_pan_zoom_coeffs(self): - # The zoom coefficient for mouse zoom should be proportional - # to the subplot size. - c = 3. / np.sqrt(self._n_rows) - self._zoom_coeff = self._default_zoom_coeff * c - self._wheel_coeff = self._default_wheel_coeff - - def _set_current_box(self, pos): - self._index = self._get_box(pos) - - @property - def _box(self): - i, j = self._index - return int(i * self._n_rows + j) - - @property - def _pan(self): - i, j = self._index - return self._pan_matrix[i, j, :] - - @_pan.setter - def _pan(self, value): - i, j = self._index - self._pan_matrix[i, j, :] = value - - @property - def _zoom(self): - i, j = self._index - return self._zoom_matrix[i, j, :] - - @_zoom.setter - def _zoom(self, value): - i, j = self._index - self._zoom_matrix[i, j, :] = value - - @property - def zoom_matrix(self): - """Zoom in every subplot.""" - return self._zoom_matrix - - @property - def pan_matrix(self): - """Pan in every subplot.""" - return self._pan_matrix - - def _apply_pan_zoom(self): - pan = self._pan - zoom = self._zoom_aspect(self._zoom) - value = (pan[0], pan[1], zoom[0], zoom[1]) - for program in self._programs: - program["u_pan_zoom[{0:d}]".format(self._box)] = value - - def _map_box(self, position, inverse=False): - position = _as_array(position) - if position.ndim == 1: - position = position[None, :] - n_rows = self._n_rows - rc_x, rc_y = self._index - - rc_x += 0.5 - rc_y += 0.5 - - x = -1.0 + rc_y * (2.0 / n_rows) - y = +1.0 - rc_x * (2.0 / n_rows) - - width = 0.95 / (1.0 * n_rows) - height = 0.95 / (1.0 * n_rows) - - if not inverse: - return (x + width * position[:, 0], - y + height * position[:, 1]) - else: - return np.c_[((position[:, 0] - x) / width, - (position[:, 1] - y) / height)] - - def _map_pan_zoom(self, position, inverse=False): - position = _as_array(position) - if position.ndim == 1: - position = position[None, :] - n_rows = self._n_rows - pan = self._pan - zoom = self._zoom_aspect(self._zoom) - if not inverse: - return zoom * (position + n_rows * pan) - else: - return (position / zoom - n_rows * pan) - - # xmin/xmax - # ------------------------------------------------------------------------- - - @property - def xmin(self): - return self._local(self._xmin) - - @xmin.setter - def xmin(self, value): - if self._xmax is not None: - self._xmin = np.minimum(value, self._xmax) - else: - self._xmin = value - - @property - def xmax(self): - return self._local(self._xmax) - - @xmax.setter - def xmax(self, value): - if self._xmin is not None: - self._xmax = np.maximum(value, self._xmin) - else: - self._xmax = value - - # ymin/ymax - # ------------------------------------------------------------------------- - - @property - def ymin(self): - return self._local(self._ymin) - - @ymin.setter - def ymin(self, value): - if self._ymax is not None: - self._ymin = min(value, self._ymax) - else: - self._ymin = value - - @property - def ymax(self): - return self._local(self._ymax) - - @ymax.setter - def ymax(self, value): - if self._ymin is not None: - self._ymax = max(value, self._ymin) - else: - self._ymax = value - - # Internal methods - # ------------------------------------------------------------------------- - - def _local(self, value): - if isinstance(value, np.ndarray): - i, j = self._index - value = value[i, j] - if value == np.nan: - value = None - return value - - def _constrain_pan(self): - """Constrain bounding box.""" - if self.xmin is not None and self._xmax is not None: - p0 = (self.xmin + 1. / self._zoom[0]) / self._n_rows - p1 = (self.xmax - 1. / self._zoom[0]) / self._n_rows - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[0] = np.clip(self._pan[0], p0, p1) - - if self.ymin is not None and self._ymax is not None: - p0 = (self.ymin + 1. / self._zoom[1]) / self._n_rows - p1 = (self.ymax - 1. / self._zoom[1]) / self._n_rows - p0, p1 = min(p0, p1), max(p0, p1) - self._pan[1] = np.clip(self._pan[1], p0, p1) - - def _get_box(self, x_y): - x0, y0 = x_y - - x0 /= self._width - y0 /= self._height - - x0 *= self._n_rows - y0 *= self._n_rows - - return (int(math.floor(y0)), int(math.floor(x0))) - - def _normalize(self, x_y, restrict_to_box=True): - x0, y0 = x_y - - x0 /= self._width - y0 /= self._height - - x0 *= self._n_rows - y0 *= self._n_rows - - if restrict_to_box: - x0 = x0 % 1 - y0 = y0 % 1 - - x0 = -(1 - 2 * x0) - y0 = -(1 - 2 * y0) - - x0 /= self._n_rows - y0 /= self._n_rows - - return x0, y0 - - def _initialize_pan_zoom(self): - # Initialize and set the uniform array. - # NOTE: 256 is the maximum size used for this uniform array. - # This corresponds to a hard-limit of 16x16 subplots. - self._u_pan_zoom = np.zeros((256, 4), - dtype=np.float32) - self._u_pan_zoom[:, 2:] = 1. - for program in self._programs: - program["u_pan_zoom"] = self._u_pan_zoom - - def _reset(self): - pan = (0., 0.) - zoom = (1., 1.) - - # Keep the current box index. - index = self._index - - # Update pan and zoom of all subplots. - for i in range(self._n_rows): - for j in range(self._n_rows): - self._index = (i, j) - - # Find out which axes to update. - if pan is not None: - p = list(pan) - self.pan = p - - if zoom is not None: - z = list(zoom) - self.zoom = z - - # Set back the box index. - self._index = index - - # Event callbacks - # ------------------------------------------------------------------------- - - keyboard_shortcuts = { - 'subplot_pan': ('left click and drag', 'arrows'), - 'subplot_zoom': ('right click and drag', '+ or -'), - 'global_reset': 'r', - } - - def on_mouse_move(self, event): - """Mouse move event.""" - # Set box index as a function of the press position. - if event.is_dragging: - self._set_current_box(event.press_event.pos) - super(PanZoomGrid, self).on_mouse_move(event) - - def on_mouse_press(self, event): - """Mouse press event.""" - # Set box index as a function of the press position. - self._set_current_box(event.pos) - super(PanZoomGrid, self).on_mouse_press(event) - - def on_mouse_double_click(self, event): - """Double click event.""" - # Set box index as a function of the press position. - self._set_current_box(event.pos) - super(PanZoomGrid, self).on_mouse_double_click(event) - - def on_mouse_wheel(self, event): - """Mouse wheel event.""" - # Set box index as a function of the press position. - self._set_current_box(event.pos) - super(PanZoomGrid, self).on_mouse_wheel(event) - - def on_key_press(self, event): - """Key press event.""" - super(PanZoomGrid, self).on_key_press(event) - - key = event.key - - # Reset with 'R'. - if key == 'R': - self._reset() - self._canvas.update() - - # Canvas methods - # ------------------------------------------------------------------------- - - def add(self, programs): - """Add programs to this tranform.""" - if not isinstance(programs, (list, tuple)): - programs = [programs] - for program in programs: - self._programs.append(program) - self._initialize_pan_zoom() - self._apply_pan_zoom() diff --git a/phy/plot/_vispy_utils.py b/phy/plot/_vispy_utils.py deleted file mode 100644 index 19103f5df..000000000 --- a/phy/plot/_vispy_utils.py +++ /dev/null @@ -1,553 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting/VisPy utilities.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from functools import wraps - -import numpy as np - -from vispy import app, gloo, config -from vispy.util.event import Event - -from ..utils._types import _as_array, _as_list -from ..utils.array import _unique, _in_polygon -from ..utils.logging import debug -from ._panzoom import PanZoom - - -#------------------------------------------------------------------------------ -# Misc -#------------------------------------------------------------------------------ - -def _load_shader(filename): - """Load a shader file.""" - path = op.join(op.dirname(op.realpath(__file__)), 'glsl', filename) - with open(path, 'r') as f: - return f.read() - - -def _tesselate_histogram(hist): - assert hist.ndim == 1 - nsamples = len(hist) - dx = 2. / nsamples - - x0 = -1 + dx * np.arange(nsamples) - - x = np.zeros(5 * nsamples + 1) - y = -1.0 * np.ones(5 * nsamples + 1) - - x[0:-1:5] = x0 - x[1::5] = x0 - x[2::5] = x0 + dx - x[3::5] = x0 - x[4::5] = x0 + dx - x[-1] = 1 - - y[1::5] = y[2::5] = -1 + 2. * hist - - return np.c_[x, y] - - -def _enable_depth_mask(): - gloo.set_state(clear_color='black', - depth_test=True, - depth_range=(0., 1.), - # depth_mask='true', - depth_func='lequal', - blend=True, - blend_func=('src_alpha', 'one_minus_src_alpha')) - gloo.set_clear_depth(1.0) - - -def _wrap_vispy(f): - """Decorator for a function returning a VisPy canvas. - - Add `show=True` parameter. - - """ - @wraps(f) - def wrapped(*args, **kwargs): - show = kwargs.pop('show', True) - canvas = f(*args, **kwargs) - if show: - canvas.show() - return canvas - return wrapped - - -#------------------------------------------------------------------------------ -# Base spike visual -#------------------------------------------------------------------------------ - -class _BakeVisual(object): - _shader_name = '' - _gl_draw_mode = '' - - def __init__(self, **kwargs): - super(_BakeVisual, self).__init__(**kwargs) - self._to_bake = [] - self._empty = True - - curdir = op.dirname(op.realpath(__file__)) - config['include_path'] = [op.join(curdir, 'glsl')] - - vertex = _load_shader(self._shader_name + '.vert') - fragment = _load_shader(self._shader_name + '.frag') - self.program = gloo.Program(vertex, fragment) - - @property - def empty(self): - """Specify whether the visual is currently empty or not.""" - return self._empty - - @empty.setter - def empty(self, value): - """Specify whether the visual is currently empty or not.""" - self._empty = value - - def set_to_bake(self, *bakes): - """Mark data items to be prepared for GPU.""" - for bake in bakes: - if bake not in self._to_bake: - self._to_bake.append(bake) - - def _bake(self): - """Prepare and upload the data on the GPU. - - Return whether something has been baked or not. - - """ - if self._empty: - return - n_bake = len(self._to_bake) - # Bake what needs to be baked. - # WARNING: the bake functions are called in alphabetical order. - # Tweak the names if there are dependencies between the functions. - for bake in sorted(self._to_bake): - # Name of the private baking method. - name = '_bake_{0:s}'.format(bake) - if hasattr(self, name): - getattr(self, name)() - self._to_bake = [] - return n_bake > 0 - - def draw(self): - """Draw the waveforms.""" - # Bake what needs to be baked at this point. - self._bake() - if not self._empty: - self.program.draw(self._gl_draw_mode) - - def update(self): - self.draw() - - -class BaseSpikeVisual(_BakeVisual): - """Base class for a VisPy visual showing spike data. - - There is a notion of displayed spikes and displayed clusters. - - """ - - _transparency = True - - def __init__(self, **kwargs): - super(BaseSpikeVisual, self).__init__(**kwargs) - self.n_spikes = None - self._spike_clusters = None - self._spike_ids = None - self._cluster_ids = None - self._cluster_order = None - self._cluster_colors = None - self._update_clusters_automatically = True - - if self._transparency: - gloo.set_state(clear_color='black', blend=True, - blend_func=('src_alpha', 'one_minus_src_alpha')) - - # Data properties - # ------------------------------------------------------------------------- - - def _set_or_assert_n_spikes(self, arr): - """If n_spikes is None, set it using the array's shape. Otherwise, - check that the array has n_spikes rows.""" - # TODO: improve this - if self.n_spikes is None: - self.n_spikes = arr.shape[0] - assert arr.shape[0] == self.n_spikes - - def _update_clusters(self): - self._cluster_ids = _unique(self._spike_clusters) - - @property - def spike_clusters(self): - """The clusters assigned to the displayed spikes.""" - return self._spike_clusters - - @spike_clusters.setter - def spike_clusters(self, value): - """Set all spike clusters.""" - value = _as_array(value) - self._spike_clusters = value - if self._update_clusters_automatically: - self._update_clusters() - self.set_to_bake('spikes_clusters') - - @property - def cluster_order(self): - """List of selected clusters in display order.""" - if self._cluster_order is None: - return self._cluster_ids - else: - return self._cluster_order - - @cluster_order.setter - def cluster_order(self, value): - value = _as_array(value) - assert sorted(value.tolist()) == sorted(self._cluster_ids) - self._cluster_order = value - - @property - def masks(self): - """Masks of the displayed spikes.""" - return self._masks - - @masks.setter - def masks(self, value): - assert isinstance(value, np.ndarray) - value = _as_array(value) - if value.ndim == 1: - value = value[None, :] - self._set_or_assert_n_spikes(value) - # TODO: support sparse structures - assert value.ndim == 2 - assert value.shape == (self.n_spikes, self.n_channels) - self._masks = value - self.set_to_bake('spikes') - - @property - def spike_ids(self): - """Spike ids to display.""" - if self._spike_ids is None: - self._spike_ids = np.arange(self.n_spikes) - return self._spike_ids - - @spike_ids.setter - def spike_ids(self, value): - value = _as_array(value) - self._set_or_assert_n_spikes(value) - self._spike_ids = value - self.set_to_bake('spikes') - - @property - def cluster_ids(self): - """Cluster ids of the displayed spikes.""" - return self._cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - """Clusters of the displayed spikes.""" - self._cluster_ids = _as_array(value) - - @property - def n_clusters(self): - """Number of displayed clusters.""" - if self._cluster_ids is None: - return None - else: - return len(self._cluster_ids) - - @property - def cluster_colors(self): - """Colors of the displayed clusters. - - The first color is the color of the smallest cluster. - - """ - return self._cluster_colors - - @cluster_colors.setter - def cluster_colors(self, value): - self._cluster_colors = _as_array(value) - assert len(self._cluster_colors) >= self.n_clusters - self.set_to_bake('cluster_color') - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_cluster_color(self): - if self.n_clusters == 0: - u_cluster_color = np.zeros((0, 0, 3)) - else: - u_cluster_color = self.cluster_colors.reshape((1, - self.n_clusters, - 3)) - assert u_cluster_color.ndim == 3 - assert u_cluster_color.shape[2] == 3 - u_cluster_color = (u_cluster_color * 255).astype(np.uint8) - self.program['u_cluster_color'] = gloo.Texture2D(u_cluster_color) - - -#------------------------------------------------------------------------------ -# Axes and boxes visual -#------------------------------------------------------------------------------ - -class BoxVisual(_BakeVisual): - """Box frames in a square grid of subplots.""" - _shader_name = 'box' - _gl_draw_mode = 'lines' - - def __init__(self, **kwargs): - super(BoxVisual, self).__init__(**kwargs) - self._n_rows = None - - @property - def n_rows(self): - """Number of rows in the grid.""" - return self._n_rows - - @n_rows.setter - def n_rows(self, value): - assert value >= 0 - self._n_rows = value - self._empty = not(self._n_rows > 0) - self.set_to_bake('n_rows') - - @property - def n_boxes(self): - """Number of boxes in the grid.""" - return self._n_rows * self._n_rows - - def _bake_n_rows(self): - if not self._n_rows: - return - arr = np.array([[-1, -1], - [-1, +1], - [-1, +1], - [+1, +1], - [+1, +1], - [+1, -1], - [+1, -1], - [-1, -1]]) * .975 - arr = np.tile(arr, (self.n_boxes, 1)) - position = np.empty((8 * self.n_boxes, 3), dtype=np.float32) - position[:, :2] = arr - position[:, 2] = np.repeat(np.arange(self.n_boxes), 8) - self.program['a_position'] = position - self.program['n_rows'] = self._n_rows - - -class AxisVisual(BoxVisual): - """Subplot axes in a subplot grid.""" - _shader_name = 'ax' - - def __init__(self, **kwargs): - super(AxisVisual, self).__init__(**kwargs) - self._xs = [] - self._ys = [] - self.program['u_color'] = (.2, .2, .2, 1.) - - def _bake_n_rows(self): - self.program['n_rows'] = self._n_rows - - @property - def xs(self): - """A list of x coordinates.""" - return self._xs - - @xs.setter - def xs(self, value): - self._xs = _as_list(value) - self.set_to_bake('positions') - - @property - def ys(self): - """A list of y coordinates.""" - return self._ys - - @ys.setter - def ys(self, value): - self._ys = _as_list(value) - self.set_to_bake('positions') - - @property - def color(self): - return tuple(self.program['u_color']) - - @color.setter - def color(self, value): - self.program['u_color'] = tuple(value) - - def _bake_positions(self): - if not self._n_rows: - return - nx = len(self._xs) - ny = len(self._ys) - n = nx + ny - position = np.empty((2 * n * self.n_boxes, 4), dtype=np.float32) - c = 1. - arr = [[x, -c, x, +c] for x in self._xs] - arr += [[-c, y, +c, y] for y in self._ys] - arr = np.hstack(arr).astype(np.float32) - arr = arr.reshape((-1, 2)) - # Positions. - position[:, :2] = np.tile(arr, (self.n_boxes, 1)) - # Index. - position[:, 2] = np.repeat(np.arange(self.n_boxes), 2 * n) - # Axes. - position[:, 3] = np.tile(([0] * (2 * nx)) + ([1] * (2 * ny)), - self.n_boxes) - self.program['a_position'] = position - - -class LassoVisual(_BakeVisual): - """Lasso.""" - _shader_name = 'lasso' - _gl_draw_mode = 'line_loop' - - def __init__(self, **kwargs): - super(LassoVisual, self).__init__(**kwargs) - self._points = [] - self._n_rows = None - self.program['u_box'] = 0 - - @property - def n_rows(self): - """Number of rows in the grid.""" - return self._n_rows - - @n_rows.setter - def n_rows(self, value): - assert value >= 0 - self._n_rows = value - self.set_to_bake('n_rows') - - @property - def points(self): - """Control points.""" - return self._points - - def _update_points(self): - self._empty = len(self._points) <= 1 - self.set_to_bake('points') - - @points.setter - def points(self, value): - value = list(value) - self._points = value - self._update_points() - - def add(self, xy): - """Add a new point.""" - self._points.append((xy)) - self._update_points() - debug("Add lasso point.") - - def clear(self): - """Remove all points.""" - self._points = [] - self._update_points() - debug("Clear lasso.") - - def in_lasso(self, points): - """Find points within the lasso. - - Parameters - ---------- - points : array - A `(n_points, 2)` array with coordinates in `[-1, 1]`. - - """ - if self.n_points <= 1: - return - polygon = self._points - # Close the polygon. - polygon.append(polygon[0]) - return _in_polygon(points, polygon) - - @property - def n_points(self): - return len(self._points) - - @property - def box(self): - """The row and column where the lasso is to be shown.""" - u_box = int(self.program['u_box'][0]) - return (u_box // self._n_rows, u_box % self._n_rows) - - @box.setter - def box(self, value): - assert len(value) == 2 - i, j = value - assert 0 <= i < self._n_rows - assert 0 <= j < self._n_rows - u_box = i * self._n_rows + j - self.program['u_box'] = u_box - - @property - def n_boxes(self): - """Number of boxes in the grid.""" - return self._n_rows * self._n_rows - - def _bake_n_rows(self): - if not self._n_rows: - return - self.program['n_rows'] = self._n_rows - - def _bake_points(self): - if self.n_points <= 1: - return - self.program['a_position'] = np.array(self._points, dtype=np.float32) - - -#------------------------------------------------------------------------------ -# Base spike canvas -#------------------------------------------------------------------------------ - -class BaseSpikeCanvas(app.Canvas): - """Base class for a VisPy canvas with spike data. - - Display a main `BaseSpikeVisual` with pan zoom. - - """ - - _visual_class = None - _pz = None - _events = () - keyboard_shortcuts = {} - - def __init__(self, **kwargs): - super(BaseSpikeCanvas, self).__init__(**kwargs) - self._create_visuals() - self._create_pan_zoom() - self._add_events() - self.keyboard_shortcuts.update(self._pz.keyboard_shortcuts) - - def _create_visuals(self): - self.visual = self._visual_class() - - def _create_pan_zoom(self): - self._pz = PanZoom() - self._pz.add(self.visual.program) - self._pz.attach(self) - - def _add_events(self): - self.events.add(**{event: Event for event in self._events}) - - def emit(self, name, **kwargs): - return getattr(self.events, name)(**kwargs) - - def on_draw(self, event): - """Draw the main visual.""" - self.context.clear() - self.visual.draw() - - def on_resize(self, event): - """Resize the OpenGL context.""" - self.context.set_viewport(0, 0, event.size[0], event.size[1]) diff --git a/phy/plot/base.py b/phy/plot/base.py new file mode 100644 index 000000000..78e33e2f2 --- /dev/null +++ b/phy/plot/base.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- + +"""Base VisPy classes.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from collections import defaultdict +import logging +import re + +from vispy import gloo +from vispy.app import Canvas +from vispy.util.event import Event + +from .transform import TransformChain, Clip +from .utils import _load_shader, _enable_depth_mask + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def indent(text): + return '\n'.join(' ' + l.strip() for l in text.splitlines()) + + +#------------------------------------------------------------------------------ +# Base spike visual +#------------------------------------------------------------------------------ + +class BaseVisual(object): + """A Visual represents one object (or homogeneous set of objects). + + It is rendered with a single pass of a single gloo program with a single + type of GL primitive. + + """ + + """Data variables that can be lists of arrays.""" + allow_list = () + + def __init__(self): + self.gl_primitive_type = None + self.transforms = TransformChain() + self.inserter = GLSLInserter() + self.inserter.insert_vert('uniform vec2 u_window_size;', 'header') + # The program will be set by the canvas when the visual is + # added to the canvas. + self.program = None + self.set_canvas_transforms_filter(lambda t: t) + + # Visual definition + # ------------------------------------------------------------------------- + + def set_shader(self, name): + self.vertex_shader = _load_shader(name + '.vert') + self.fragment_shader = _load_shader(name + '.frag') + + def set_primitive_type(self, primitive_type): + self.gl_primitive_type = primitive_type + + def on_draw(self): + """Draw the visual.""" + # Skip the drawing if the program hasn't been built yet. + # The program is built by the interact. + if self.program: + # Draw the program. + self.program.draw(self.gl_primitive_type) + else: # pragma: no cover + logger.debug("Skipping drawing visual `%s` because the program " + "has not been built yet.", self) + + def on_resize(self, size): + # HACK: we check whether u_window_size is used in order to avoid + # the VisPy warning. We only update it if that uniform is active. + s = '\n'.join(self.program.shaders) + s = s.replace('uniform vec2 u_window_size;', '') + if 'u_window_size' in s: + self.program['u_window_size'] = size + + # To override + # ------------------------------------------------------------------------- + + @staticmethod + def validate(**kwargs): + """Make consistent the input data for the visual.""" + return kwargs # pragma: no cover + + @staticmethod + def vertex_count(**kwargs): + """Return the number of vertices as a function of the input data.""" + return 0 # pragma: no cover + + def set_data(self): + """Set data to the program. + + Must be called *after* attach(canvas), because the program is built + when the visual is attached to the canvas. + + """ + raise NotImplementedError() + + def set_canvas_transforms_filter(self, f): + """Set a function filtering the canvas' transforms.""" + self.canvas_transforms_filter = f + + +#------------------------------------------------------------------------------ +# Build program with interacts +#------------------------------------------------------------------------------ + +def _insert_glsl(vertex, fragment, to_insert): + """Insert snippets in a shader. + + to_insert is a dict `{(shader_type, location): snippet}`. + + Snippets can contain `{{ var }}` placeholders for the transformed variable + name. + + """ + # Find the place where to insert the GLSL snippet. + # This is "gl_Position = transform(data_var_name);" where + # data_var_name is typically an attribute. + vs_regex = re.compile(r'gl_Position = transform\(([\S]+)\);') + r = vs_regex.search(vertex) + if not r: + logger.debug("The vertex shader doesn't contain the transform " + "placeholder: skipping the transform chain " + "GLSL insertion.") + return vertex, fragment + assert r + logger.log(5, "Found transform placeholder in vertex code: `%s`", + r.group(0)) + + # Find the GLSL variable with the data (should be a `vec2`). + var = r.group(1) + assert var and var in vertex + + # Headers. + vertex = to_insert['vert', 'header'] + '\n\n' + vertex + fragment = to_insert['frag', 'header'] + '\n\n' + fragment + + # Get the pre and post transforms. + vs_insert = to_insert['vert', 'before_transforms'] + vs_insert += to_insert['vert', 'transforms'] + vs_insert += to_insert['vert', 'after_transforms'] + + # Insert the GLSL snippet in the vertex shader. + vertex = vs_regex.sub(indent(vs_insert), vertex) + + # Now, we make the replacements in the fragment shader. + fs_regex = re.compile(r'(void main\(\)\s*\{)') + # NOTE: we add the `void main(){` that was removed by the regex. + fs_insert = '\\1\n' + to_insert['frag', 'before_transforms'] + fragment = fs_regex.sub(indent(fs_insert), fragment) + + # Replace the transformed variable placeholder by its name. + vertex = vertex.replace('{{ var }}', var) + + return vertex, fragment + + +class GLSLInserter(object): + """Insert GLSL snippets into shader codes.""" + + def __init__(self): + self._to_insert = defaultdict(list) + self.insert_vert('vec2 temp_pos_tr = {{ var }};', + 'before_transforms') + self.insert_vert('gl_Position = vec4(temp_pos_tr, 0., 1.);', + 'after_transforms') + self.insert_vert('varying vec2 v_temp_pos_tr;\n', 'header') + self.insert_frag('varying vec2 v_temp_pos_tr;\n', 'header') + + def _insert(self, shader_type, glsl, location): + assert location in ( + 'header', + 'before_transforms', + 'transforms', + 'after_transforms', + ) + self._to_insert[shader_type, location].append(glsl) + + def insert_vert(self, glsl, location='transforms'): + """Insert a GLSL snippet into the vertex shader. + + The location can be: + + * `header`: declaration of GLSL variables + * `before_transforms`: just before the transforms in the vertex shader + * `transforms`: where the GPU transforms are applied in the vertex + shader + * `after_transforms`: just after the GPU transforms + + """ + self._insert('vert', glsl, location) + + def insert_frag(self, glsl, location=None): + """Insert a GLSL snippet into the fragment shader.""" + self._insert('frag', glsl, location) + + def add_transform_chain(self, tc): + """Insert the GLSL snippets of a transform chain.""" + # Generate the transforms snippet. + for t in tc.gpu_transforms: + if isinstance(t, Clip): + # Set the varying value in the vertex shader. + self.insert_vert('v_temp_pos_tr = temp_pos_tr;') + continue + self.insert_vert(t.glsl('temp_pos_tr')) + # Clipping. + clip = tc.get('Clip') + if clip: + self.insert_frag(clip.glsl('v_temp_pos_tr'), 'before_transforms') + + def insert_into_shaders(self, vertex, fragment): + """Apply the insertions to shader code.""" + to_insert = defaultdict(str) + to_insert.update({key: '\n'.join(self._to_insert[key]) + '\n' + for key in self._to_insert}) + return _insert_glsl(vertex, fragment, to_insert) + + def __add__(self, inserter): + """Concatenate two inserters.""" + for key, values in self._to_insert.items(): + values.extend([_ for _ in inserter._to_insert[key] + if _ not in values]) + return self + + +#------------------------------------------------------------------------------ +# Base canvas +#------------------------------------------------------------------------------ + +class VisualEvent(Event): + def __init__(self, type, visual=None): + super(VisualEvent, self).__init__(type) + self.visual = visual + + +class BaseCanvas(Canvas): + """A blank VisPy canvas with a custom event system that keeps the order.""" + def __init__(self, *args, **kwargs): + super(BaseCanvas, self).__init__(*args, **kwargs) + self.transforms = TransformChain() + self.inserter = GLSLInserter() + self.visuals = [] + self.events.add(visual_added=VisualEvent) + + # Enable transparency. + _enable_depth_mask() + + def add_visual(self, visual): + """Add a visual to the canvas, and build its program by the same + occasion. + + We can't build the visual's program before, because we need the canvas' + transforms first. + + """ + # Retrieve the visual's GLSL inserter. + inserter = visual.inserter + # Add the visual's transforms. + inserter.add_transform_chain(visual.transforms) + # Then, add the canvas' transforms. + canvas_transforms = visual.canvas_transforms_filter(self.transforms) + inserter.add_transform_chain(canvas_transforms) + # Also, add the canvas' inserter. + inserter += self.inserter + # Now, we insert the transforms GLSL into the shaders. + vs, fs = visual.vertex_shader, visual.fragment_shader + vs, fs = inserter.insert_into_shaders(vs, fs) + # Finally, we create the visual's program. + visual.program = gloo.Program(vs, fs) + logger.log(5, "Vertex shader: %s", vs) + logger.log(5, "Fragment shader: %s", fs) + # Initialize the size. + visual.on_resize(self.size) + # Register the visual in the list of visuals in the canvas. + self.visuals.append(visual) + self.events.visual_added(visual=visual) + + def on_resize(self, event): + """Resize the OpenGL context.""" + self.context.set_viewport(0, 0, event.size[0], event.size[1]) + for visual in self.visuals: + visual.on_resize(event.size) + self.update() + + def on_draw(self, e): + """Draw all visuals.""" + gloo.clear() + for visual in self.visuals: + visual.on_draw() + + +#------------------------------------------------------------------------------ +# Base interact +#------------------------------------------------------------------------------ + +class BaseInteract(object): + """Implement dynamic transforms on a canvas.""" + canvas = None + + def attach(self, canvas): + """Attach this interact to a canvas.""" + self.canvas = canvas + + @canvas.connect + def on_visual_added(e): + self.update_program(e.visual.program) + + def update_program(self, program): + """Override this method to update programs when `self.update()` + is called.""" + pass + + def update(self): + """Update all visuals in the attached canvas.""" + if not self.canvas: + return + for visual in self.canvas.visuals: + self.update_program(visual.program) + self.canvas.update() diff --git a/phy/plot/ccg.py b/phy/plot/ccg.py deleted file mode 100644 index 81b7bf93b..000000000 --- a/phy/plot/ccg.py +++ /dev/null @@ -1,271 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting CCGs.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from vispy import gloo - -from ._mpl_utils import _bottom_left_frame -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - BoxVisual, - AxisVisual, - _tesselate_histogram, - _wrap_vispy) -from ._panzoom import PanZoomGrid -from ..utils._types import _as_array, _as_list -from ..utils._color import _selected_clusters_colors - - -#------------------------------------------------------------------------------ -# CCG visual -#------------------------------------------------------------------------------ - -class CorrelogramVisual(BaseSpikeVisual): - """Display a grid of auto- and cross-correlograms.""" - - _shader_name = 'correlograms' - _gl_draw_mode = 'triangle_strip' - - def __init__(self, **kwargs): - super(CorrelogramVisual, self).__init__(**kwargs) - self._correlograms = None - self._cluster_ids = None - self.n_bins = None - - # Data properties - # ------------------------------------------------------------------------- - - @property - def correlograms(self): - """Displayed correlograms. - - This is a `(n_clusters, n_clusters, n_bins)` array. - - """ - return self._correlograms - - @correlograms.setter - def correlograms(self, value): - value = _as_array(value) - # WARNING: need to set cluster_ids first - assert value.ndim == 3 - if self._cluster_ids is None: - self._cluster_ids = np.arange(value.shape[0]) - assert value.shape[:2] == (self.n_clusters, self.n_clusters) - self.n_bins = value.shape[2] - self._correlograms = value - self._empty = self.n_clusters == 0 or self.n_bins == 0 - self.set_to_bake('correlograms', 'color') - - @property - def cluster_ids(self): - """Displayed cluster ids.""" - return self._cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - self._cluster_ids = np.asarray(value, dtype=np.int32) - - @property - def n_boxes(self): - """Number of boxes in the grid view.""" - return self.n_clusters * self.n_clusters - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_correlograms(self): - n_points = self.n_boxes * (5 * self.n_bins + 1) - - # index increases from top to bottom, left to right - # same as matrix indices (i, j) starting at 0 - positions = [] - boxes = [] - - for i in range(self.n_clusters): - for j in range(self.n_clusters): - index = self.n_clusters * i + j - - hist = self._correlograms[i, j, :] - pos = _tesselate_histogram(hist) - n_points_hist = pos.shape[0] - - positions.append(pos) - boxes.append(index * np.ones(n_points_hist, dtype=np.float32)) - - positions = np.vstack(positions).astype(np.float32) - boxes = np.hstack(boxes) - - assert positions.shape == (n_points, 2) - assert boxes.shape == (n_points,) - - self.program['a_position'] = positions.copy() - self.program['a_box'] = boxes - self.program['n_rows'] = self.n_clusters - - -class CorrelogramView(BaseSpikeCanvas): - """A VisPy canvas displaying correlograms.""" - - _visual_class = CorrelogramVisual - _lines = [] - - def _create_visuals(self): - super(CorrelogramView, self)._create_visuals() - self.boxes = BoxVisual() - self.axes = AxisVisual() - - def _create_pan_zoom(self): - self._pz = PanZoomGrid() - self._pz.add(self.visual.program) - self._pz.add(self.axes.program) - self._pz.attach(self) - self._pz.aspect = None - self._pz.zmin = 1. - self._pz.xmin = -1. - self._pz.xmax = +1. - self._pz.ymin = -1. - self._pz.ymax = +1. - self._pz.zoom_to_pointer = False - - def set_data(self, - correlograms=None, - colors=None, - lines=None): - - if correlograms is not None: - correlograms = np.asarray(correlograms) - else: - correlograms = self.visual.correlograms - assert correlograms.ndim == 3 - n_clusters = len(correlograms) - assert correlograms.shape[:2] == (n_clusters, n_clusters) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - self.cluster_ids = np.arange(n_clusters) - self.visual.correlograms = correlograms - - if len(colors): - self.visual.cluster_colors = colors - - if lines is not None: - self.lines = lines - - self.update() - - @property - def cluster_ids(self): - """Displayed cluster ids.""" - return self.visual.cluster_ids - - @cluster_ids.setter - def cluster_ids(self, value): - self.visual.cluster_ids = value - self.boxes.n_rows = self.visual.n_clusters - if self.visual.n_clusters >= 1: - self._pz.n_rows = self.visual.n_clusters - self.axes.n_rows = self.visual.n_clusters - if self._lines: - self.lines = self.lines - - @property - def correlograms(self): - return self.visual.correlograms - - @correlograms.setter - def correlograms(self, value): - self.visual.correlograms = value - # Update the lines which depend on the number of bins. - self.lines = self.lines - - @property - def lines(self): - """List of x coordinates where to put vertical lines. - - This is unit of samples. - - """ - return self._lines - - @lines.setter - def lines(self, value): - self._lines = _as_list(value) - c = 2. / (float(max(1, self.visual.n_bins or 0))) - self.axes.xs = np.array(self._lines) * c - self.axes.color = (.5, .5, .5, 1.) - - @property - def lines_color(self): - return self.axes.color - - @lines_color.setter - def lines_color(self, value): - self.axes.color = value - - def on_draw(self, event): - """Draw the correlograms visual.""" - gloo.clear() - self.visual.draw() - self.boxes.draw() - if self._lines: - self.axes.draw() - - -#------------------------------------------------------------------------------ -# CCG plotting -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_correlograms(correlograms, **kwargs): - """Plot an array of correlograms. - - Parameters - ---------- - - correlograms : array - A `(n_clusters, n_clusters, n_bins)` array. - lines : ndarray - Array of x coordinates where to put vertical lines (in number of - samples). - colors : array-like (optional) - A list of colors as RGB tuples. - - """ - c = CorrelogramView(keys='interactive') - c.set_data(correlograms, **kwargs) - return c - - -def _plot_ccg_mpl(ccg, baseline=None, bin=1., color=None, ax=None): - """Plot a CCG with matplotlib and return an Axes instance.""" - import matplotlib.pyplot as plt - if ax is None: - ax = plt.subplot(111) - assert ccg.ndim == 1 - n = ccg.shape[0] - assert n % 2 == 1 - bin = float(bin) - x_min = -n // 2 * bin - bin / 2 - x_max = (n // 2 - 1) * bin + bin / 2 - width = bin * 1.05 - left = np.linspace(x_min, x_max, n) - ax.bar(left, ccg, facecolor=color, width=width, linewidth=0) - if baseline is not None: - ax.axhline(baseline, color='k', linewidth=2, linestyle='-') - ax.axvline(color='k', linewidth=2, linestyle='--') - - ax.set_xlim(x_min, x_max + bin / 2) - ax.set_ylim(0) - - # Only keep the bottom and left ticks. - _bottom_left_frame(ax) - - return ax diff --git a/phy/plot/features.py b/phy/plot/features.py deleted file mode 100644 index 230d615fc..000000000 --- a/phy/plot/features.py +++ /dev/null @@ -1,659 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting features.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from six import string_types -from vispy import gloo - -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - BoxVisual, - AxisVisual, - LassoVisual, - _enable_depth_mask, - _wrap_vispy, - ) -from ._panzoom import PanZoomGrid -from ..utils._types import _as_array, _is_integer -from ..utils.array import _index_of, _unique -from ..utils._color import _selected_clusters_colors - - -#------------------------------------------------------------------------------ -# Features visual -#------------------------------------------------------------------------------ - -def _get_feature_dim(dim, features=None, extra_features=None): - if isinstance(dim, (tuple, list)): - channel, feature = dim - return features[:, channel, feature] - elif isinstance(dim, string_types) and dim in extra_features: - x, m, M = extra_features[dim] - m, M = (float(m), float(M)) - x = _as_array(x, np.float32) - # Normalize extra feature. - d = float(max(1., M - m)) - x = (-1. + 2 * (x - m) / d) * .8 - return x - - -class BaseFeatureVisual(BaseSpikeVisual): - """Display a grid of multidimensional features.""" - - _shader_name = None - _gl_draw_mode = 'points' - - def __init__(self, **kwargs): - super(BaseFeatureVisual, self).__init__(**kwargs) - - self._features = None - # Mapping {feature_name: array} where the array must have n_spikes - # element. - self._extra_features = {} - self._x_dim = np.empty((0, 0), dtype=object) - self._y_dim = np.empty((0, 0), dtype=object) - self.n_channels, self.n_features = None, None - self.n_rows = None - - _enable_depth_mask() - - def add_extra_feature(self, name, array, array_min, array_max): - assert isinstance(array, np.ndarray) - assert array.ndim == 1 - if self.n_spikes: - if array.shape != (self.n_spikes,): - msg = ("Unable to add the extra feature " - "`{}`: ".format(name) + - "there should be {} ".format(self.n_spikes) + - "elements in the specified vector, not " - "{}.".format(len(array))) - raise ValueError(msg) - - self._extra_features[name] = (array, array_min, array_max) - - @property - def extra_features(self): - return self._extra_features - - @property - def features(self): - """Displayed features. - - This is a `(n_spikes, n_features)` array. - - """ - return self._features - - @features.setter - def features(self, value): - self._set_features_to_bake(value) - - def _set_features_to_bake(self, value): - # WARNING: when setting new data, features need to be set first. - # n_spikes will be set as a function of features. - value = _as_array(value) - # TODO: support sparse structures - assert value.ndim == 3 - self.n_spikes, self.n_channels, self.n_features = value.shape - self._features = value - self._empty = self.n_spikes == 0 - self.set_to_bake('spikes',) - - def _check_dimension(self, dim): - if _is_integer(dim): - dim = (dim, 0) - if isinstance(dim, tuple): - assert len(dim) == 2 - channel, feature = dim - assert _is_integer(channel) - assert _is_integer(feature) - assert 0 <= channel < self.n_channels - assert 0 <= feature < self.n_features - elif isinstance(dim, string_types): - assert dim in self._extra_features - elif dim: - raise ValueError('{0} should be (channel, feature) '.format(dim) + - 'or one of the extra features.') - - def project(self, box, features=None, extra_features=None): - """Project data to a subplot's two-dimensional subspace. - - Parameters - ---------- - box : 2-tuple - The `(row, col)` of the box. - features : array - extra_features : dict - - Notes - ----- - - The coordinate system is always the world coordinate system, i.e. - `[-1, 1]`. - - """ - i, j = box - dim_x = self._x_dim[i, j] - dim_y = self._y_dim[i, j] - - fet_x = _get_feature_dim(dim_x, - features=features, - extra_features=extra_features, - ) - fet_y = _get_feature_dim(dim_y, - features=features, - extra_features=extra_features, - ) - return np.c_[fet_x, fet_y] - - @property - def x_dim(self): - """Dimensions in the x axis of all subplots. - - This is a matrix of items which can be: - - * tuple `(channel_id, feature_idx)` - * an extra feature name (string) - - """ - return self._x_dim - - @x_dim.setter - def x_dim(self, value): - self._x_dim = value - self._update_dimensions() - - @property - def y_dim(self): - """Dimensions in the y axis of all subplots. - - This is a matrix of items which can be: - - * tuple `(channel_id, feature_idx)` - * an extra feature name (string) - - """ - return self._y_dim - - @y_dim.setter - def y_dim(self, value): - self._y_dim = value - self._update_dimensions() - - def _update_dimensions(self): - """Update the GPU data afte the dimensions have changed.""" - self._check_dimension_matrix(self._x_dim) - self._check_dimension_matrix(self._y_dim) - self.set_to_bake('spikes',) - - def _check_dimension_matrix(self, value): - if not isinstance(value, np.ndarray): - value = np.array(value, dtype=object) - assert value.ndim == 2 - assert value.shape[0] == value.shape[1] - assert value.dtype == object - self.n_rows = len(value) - for dim in value.flat: - self._check_dimension(dim) - - def set_dimension(self, axis, box, dim): - matrix = self._x_dim if axis == 'x' else self._y_dim - matrix[box] = dim - self._update_dimensions() - - @property - def n_boxes(self): - """Number of boxes in the grid.""" - return self.n_rows * self.n_rows - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_spikes(self): - n_points = self.n_boxes * self.n_spikes - - # index increases from top to bottom, left to right - # same as matrix indices (i, j) starting at 0 - positions = [] - boxes = [] - - for i in range(self.n_rows): - for j in range(self.n_rows): - pos = self.project((i, j), - features=self._features, - extra_features=self._extra_features, - ) - positions.append(pos) - index = self.n_rows * i + j - boxes.append(index * np.ones(self.n_spikes, dtype=np.float32)) - - positions = np.vstack(positions).astype(np.float32) - boxes = np.hstack(boxes) - - assert positions.shape == (n_points, 2) - assert boxes.shape == (n_points,) - - self.program['a_position'] = positions.copy() - self.program['a_box'] = boxes - self.program['n_rows'] = self.n_rows - - -class BackgroundFeatureVisual(BaseFeatureVisual): - """Display a grid of multidimensional features in the background.""" - - _shader_name = 'features_bg' - _transparency = False - - -class FeatureVisual(BaseFeatureVisual): - """Display a grid of multidimensional features.""" - - _shader_name = 'features' - - def __init__(self, **kwargs): - super(FeatureVisual, self).__init__(**kwargs) - self.program['u_size'] = 3. - - # Data properties - # ------------------------------------------------------------------------- - - def _set_features_to_bake(self, value): - super(FeatureVisual, self)._set_features_to_bake(value) - self.set_to_bake('spikes', 'spikes_clusters', 'color') - - def _get_mask_dim(self, dim): - if isinstance(dim, (tuple, list)): - channel, feature = dim - return self._masks[:, channel] - else: - return np.ones(self.n_spikes) - - def _update_dimensions(self): - super(FeatureVisual, self)._update_dimensions() - self.set_to_bake('spikes_clusters', 'color') - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_spikes(self): - n_points = self.n_boxes * self.n_spikes - - # index increases from top to bottom, left to right - # same as matrix indices (i, j) starting at 0 - positions = [] - masks = [] - boxes = [] - - for i in range(self.n_rows): - for j in range(self.n_rows): - pos = self.project((i, j), - features=self._features, - extra_features=self._extra_features, - ) - positions.append(pos) - - # The mask depends on both the `x` and `y` coordinates. - mask = np.maximum(self._get_mask_dim(self._x_dim[i, j]), - self._get_mask_dim(self._y_dim[i, j])) - masks.append(mask.astype(np.float32)) - - index = self.n_rows * i + j - boxes.append(index * np.ones(self.n_spikes, dtype=np.float32)) - - positions = np.vstack(positions).astype(np.float32) - masks = np.hstack(masks) - boxes = np.hstack(boxes) - - assert positions.shape == (n_points, 2) - assert masks.shape == (n_points,) - assert boxes.shape == (n_points,) - - self.program['a_position'] = positions.copy() - self.program['a_mask'] = masks - self.program['a_box'] = boxes - - self.program['n_clusters'] = self.n_clusters - self.program['n_rows'] = self.n_rows - - def _bake_spikes_clusters(self): - # Get the spike cluster indices (between 0 and n_clusters-1). - spike_clusters_idx = self.spike_clusters - # We take the cluster order into account here. - spike_clusters_idx = _index_of(spike_clusters_idx, self.cluster_order) - a_cluster = np.tile(spike_clusters_idx, - self.n_boxes).astype(np.float32) - self.program['a_cluster'] = a_cluster - self.program['n_clusters'] = self.n_clusters - - @property - def marker_size(self): - """Marker size in pixels.""" - return float(self.program['u_size']) - - @marker_size.setter - def marker_size(self, value): - value = np.clip(value, .1, 100) - self.program['u_size'] = float(value) - self.update() - - -def _iter_dimensions(dimensions): - if isinstance(dimensions, dict): - for box, dim in dimensions.items(): - yield (box, dim) - elif isinstance(dimensions, np.ndarray): - n_rows, n_cols = dimensions.shape - for i in range(n_rows): - for j in range(n_rows): - yield (i, j), dimensions[i, j] - elif isinstance(dimensions, list): - for i in range(len(dimensions)): - l = dimensions[i] - for j in range(len(l)): - dim = l[j] - yield (i, j), dim - - -class FeatureView(BaseSpikeCanvas): - """A VisPy canvas displaying features.""" - _visual_class = FeatureVisual - _events = ('enlarge',) - - def _create_visuals(self): - self.boxes = BoxVisual() - self.axes = AxisVisual() - self.background = BackgroundFeatureVisual() - self.lasso = LassoVisual() - super(FeatureView, self)._create_visuals() - - def _create_pan_zoom(self): - self._pz = PanZoomGrid() - self._pz.add(self.visual.program) - self._pz.add(self.background.program) - self._pz.add(self.lasso.program) - self._pz.add(self.axes.program) - self._pz.aspect = None - self._pz.attach(self) - - def init_grid(self, n_rows): - """Initialize the view with a given number of rows. - - Note - ---- - - This function *must* be called before setting the attributes. - - """ - assert n_rows >= 0 - - x_dim = np.empty((n_rows, n_rows), dtype=object) - y_dim = np.empty((n_rows, n_rows), dtype=object) - x_dim.fill('time') - y_dim.fill((0, 0)) - - self.visual.n_rows = n_rows - # NOTE: update the private variable because we don't want dimension - # checking at this point nor do we want to prepare the GPU data. - self.visual._x_dim = x_dim - self.visual._y_dim = y_dim - - self.background.n_rows = n_rows - self.background._x_dim = x_dim - self.background._y_dim = y_dim - - self.boxes.n_rows = n_rows - self.lasso.n_rows = n_rows - self.axes.n_rows = n_rows - self.axes.xs = [0] - self.axes.ys = [0] - self._pz.n_rows = n_rows - - xmin = np.empty((n_rows, n_rows)) - xmax = np.empty((n_rows, n_rows)) - ymin = np.empty((n_rows, n_rows)) - ymax = np.empty((n_rows, n_rows)) - for arr in (xmin, xmax, ymin, ymax): - arr.fill(np.nan) - self._pz._xmin = xmin - self._pz._xmax = xmax - self._pz._ymin = ymin - self._pz._ymax = ymax - - for i in range(n_rows): - for j in range(n_rows): - self._pz._xmin[i, j] = -1. - self._pz._xmax[i, j] = +1. - - @property - def x_dim(self): - return self.visual.x_dim - - @property - def y_dim(self): - return self.visual.y_dim - - def _set_dimension(self, axis, box, dim): - self.background.set_dimension(axis, box, dim) - self.visual.set_dimension(axis, box, dim) - min = self._pz._xmin if axis == 'x' else self._pz._ymin - max = self._pz._xmax if axis == 'x' else self._pz._ymax - if dim == 'time': - # NOTE: the private variables are the matrices. - min[box] = -1. - max[box] = +1. - else: - min[box] = None - max[box] = None - - def set_dimensions(self, axis, dimensions): - for box, dim in _iter_dimensions(dimensions): - self._set_dimension(axis, box, dim) - self. _update_dimensions() - - def smart_dimension(self, - axis, - box, - dim, - ): - """Smartify a dimension selection by ensuring x != y.""" - if not isinstance(dim, tuple): - return dim - n_features = self.visual.n_features - # Find current dimensions. - mat = self.x_dim if axis == 'x' else self.y_dim - mat_other = self.x_dim if axis == 'y' else self.y_dim - prev_dim = mat[box] - prev_dim_other = mat_other[box] - # Select smart new dimension. - if not isinstance(prev_dim, string_types): - channel, feature = dim - prev_channel, prev_feature = prev_dim - # Scroll the feature if the channel is the same. - if prev_channel == channel: - feature = (prev_feature + 1) % n_features - # Scroll the feature if it is the same than in the other axis. - if (prev_dim_other != 'time' and - prev_dim_other == (channel, feature)): - feature = (feature + 1) % n_features - dim = (channel, feature) - return dim - - def _update_dimensions(self): - self.background._update_dimensions() - self.visual._update_dimensions() - - def set_data(self, - features=None, - n_rows=1, - x_dimensions=None, - y_dimensions=None, - masks=None, - spike_clusters=None, - extra_features=None, - background_features=None, - colors=None, - ): - if features is not None: - assert isinstance(features, np.ndarray) - if features.ndim == 2: - features = features[..., None] - assert features.ndim == 3 - else: - features = self.visual.features - n_spikes, n_channels, n_features = features.shape - - if spike_clusters is None: - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - cluster_ids = _unique(spike_clusters) - n_clusters = len(cluster_ids) - - if masks is None: - masks = np.ones(features.shape[:2], dtype=np.float32) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - self.visual.features = features - - if background_features is not None: - assert features.shape[1:] == background_features.shape[1:] - self.background.features = background_features.astype(np.float32) - else: - self.background.n_channels = self.visual.n_channels - self.background.n_features = self.visual.n_features - - if masks is not None: - self.visual.masks = masks - - self.visual.spike_clusters = spike_clusters - assert spike_clusters.shape == (n_spikes,) - - if len(colors): - self.visual.cluster_colors = colors - - # Dimensions. - self.init_grid(n_rows) - if not extra_features: - extra_features = {'time': (np.linspace(0., 1., n_spikes), 0., 1.)} - for name, (array, m, M) in (extra_features or {}).items(): - self.add_extra_feature(name, array, m, M) - self.set_dimensions('x', x_dimensions or {(0, 0): 'time'}) - self.set_dimensions('y', y_dimensions or {(0, 0): (0, 0)}) - - self.update() - - def add_extra_feature(self, name, array, array_min, array_max, - array_bg=None): - self.visual.add_extra_feature(name, array, array_min, array_max) - # Note: the background array has a different number of spikes. - if array_bg is None: - array_bg = array - self.background.add_extra_feature(name, array_bg, - array_min, array_max) - - @property - def marker_size(self): - """Marker size.""" - return self.visual.marker_size - - @marker_size.setter - def marker_size(self, value): - self.visual.marker_size = value - self.update() - - def on_draw(self, event): - """Draw the features in a grid view.""" - gloo.clear(color=True, depth=True) - self.axes.draw() - self.background.draw() - self.visual.draw() - self.lasso.draw() - self.boxes.draw() - - keyboard_shortcuts = { - 'marker_size_increase': 'alt+', - 'marker_size_decrease': 'alt-', - 'add_lasso_point': 'shift+left click', - 'clear_lasso': 'shift+right click', - } - - def on_mouse_press(self, e): - control = e.modifiers == ('Control',) - shift = e.modifiers == ('Shift',) - if shift: - if e.button == 1: - # Lasso. - n_rows = self.lasso.n_rows - - box = self._pz._get_box(e.pos) - self.lasso.box = box - - position = self._pz._normalize(e.pos) - x, y = position - x *= n_rows - y *= -n_rows - pos = (x, y) - # pos = self._pz._map_box((x, y), inverse=True) - pos = self._pz._map_pan_zoom(pos, inverse=True) - self.lasso.add(pos.ravel()) - elif e.button == 2: - self.lasso.clear() - self.update() - elif control: - # Enlarge. - box = self._pz._get_box(e.pos) - self.emit('enlarge', - box=box, - x_dim=self.x_dim[box], - y_dim=self.y_dim[box], - ) - - def on_key_press(self, event): - """Handle key press events.""" - coeff = .25 - if 'Alt' in event.modifiers: - if event.key == '+' or event.key == '=': - self.marker_size += coeff - if event.key == '-': - self.marker_size -= coeff - - -#------------------------------------------------------------------------------ -# Plotting functions -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_features(features, **kwargs): - """Plot features. - - Parameters - ---------- - - features : ndarray - The features to plot. A `(n_spikes, n_channels, n_features)` array. - spike_clusters : ndarray (optional) - A `(n_spikes,)` int array with the spike clusters. - masks : ndarray (optional) - A `(n_spikes, n_channels)` float array with the spike masks. - n_rows : int - Number of rows (= number of columns) in the grid view. - x_dimensions : list - List of dimensions for the x axis. - y_dimensions : list - List of dimensions for the yœ axis. - extra_features : dict - A dictionary `{feature_name: array}` where `array` has - `n_spikes` elements. - background_features : ndarray - The background features. A `(n_spikes, n_channels, n_features)` array. - - """ - c = FeatureView(keys='interactive') - c.set_data(features, **kwargs) - return c diff --git a/phy/plot/glsl/__init__.py b/phy/plot/glsl/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/plot/glsl/ax.frag b/phy/plot/glsl/ax.frag deleted file mode 100644 index b3d4f845e..000000000 --- a/phy/plot/glsl/ax.frag +++ /dev/null @@ -1,9 +0,0 @@ -#include "grid.glsl" -uniform vec4 u_color; - -void main() { - // Clipping. - if (grid_clip(v_position, .975)) discard; - - gl_FragColor = u_color; -} diff --git a/phy/plot/glsl/ax.vert b/phy/plot/glsl/ax.vert deleted file mode 100644 index b87b20bd9..000000000 --- a/phy/plot/glsl/ax.vert +++ /dev/null @@ -1,30 +0,0 @@ -// xy is the vertex position in NDC -// index is the box index -// ax is 0 for x, 1 for y -attribute vec4 a_position; // xy, index, ax - -#include "grid.glsl" - -vec2 pan_zoom(vec2 position, float index, float ax) -{ - vec4 pz = fetch_pan_zoom(index); - vec2 pan = pz.xy; - vec2 zoom = pz.zw; - - if (ax < 0.5) - return vec2(zoom.x * (position.x + n_rows * pan.x), position.y); - else - return vec2(position.x, zoom.y * (position.y + n_rows * pan.y)); -} - -void main() { - - vec2 pos = a_position.xy; - vec2 position = pan_zoom(pos, a_position.z, a_position.w); - - gl_Position = vec4(to_box(position, a_position.z), - 0.0, 1.0); - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/box.frag b/phy/plot/glsl/box.frag deleted file mode 100644 index 4f60c17ad..000000000 --- a/phy/plot/glsl/box.frag +++ /dev/null @@ -1,3 +0,0 @@ -void main() { - gl_FragColor = vec4(.35, .35, .35, 1.); -} diff --git a/phy/plot/glsl/box.vert b/phy/plot/glsl/box.vert deleted file mode 100644 index ea72ad645..000000000 --- a/phy/plot/glsl/box.vert +++ /dev/null @@ -1,8 +0,0 @@ -#include "grid.glsl" - -attribute vec3 a_position; // x, y, index - -void main() { - gl_Position = vec4(to_box(a_position.xy, a_position.z), - 0.0, 1.0); -} diff --git a/phy/plot/glsl/color.glsl b/phy/plot/glsl/color.glsl deleted file mode 100644 index 4d0005597..000000000 --- a/phy/plot/glsl/color.glsl +++ /dev/null @@ -1,14 +0,0 @@ - -vec3 get_color(float cluster, sampler2D texture, float n_clusters) { - if (cluster < 0) - return vec3(.5, .5, .5); - return texture2D(texture, vec2(cluster / (n_clusters - 1.), .5)).xyz; -} - -vec3 color_mask(vec3 color, float mask) { - vec3 hsv = rgb_to_hsv(color); - // Change the saturation and value as a function of the mask. - hsv.y *= mask; - hsv.z *= .5 * (1. + mask); - return hsv_to_rgb(hsv); -} diff --git a/phy/plot/glsl/correlograms.frag b/phy/plot/glsl/correlograms.frag deleted file mode 100644 index 4afa3c773..000000000 --- a/phy/plot/glsl/correlograms.frag +++ /dev/null @@ -1,13 +0,0 @@ -#include "grid.glsl" - -varying vec4 v_color; -varying float v_box; - -void main() -{ - // Clipping. - if (grid_clip(v_position)) discard; - if (fract(v_box) > 0.) discard; - - gl_FragColor = v_color; -} diff --git a/phy/plot/glsl/correlograms.vert b/phy/plot/glsl/correlograms.vert deleted file mode 100644 index d00a3f407..000000000 --- a/phy/plot/glsl/correlograms.vert +++ /dev/null @@ -1,34 +0,0 @@ -#include "colormaps/color-space.glsl" -#include "color.glsl" -#include "grid.glsl" - -attribute vec2 a_position; -attribute float a_box; // (from 0 to n_rows**2-1) - -uniform sampler2D u_cluster_color; - -varying vec4 v_color; -varying float v_box; - -void main (void) -{ - // ACG/CCG color. - vec2 rc = row_col(a_box, n_rows); - if (abs(rc.x - rc.y) < .1) { - v_color.rgb = get_color(rc.x, - u_cluster_color, - n_rows); - } - else { - v_color.rgb = vec3(1., 1., 1.); - } - v_color.a = 1.; - - vec2 position = pan_zoom_grid(a_position, a_box); - vec2 box_position = to_box(position, a_box); - gl_Position = vec4(box_position, 0., 1.); - - // Used for clipping. - v_position = position; - v_box = a_box; -} diff --git a/phy/plot/glsl/depth_mask.glsl b/phy/plot/glsl/depth_mask.glsl deleted file mode 100644 index 298a67ff2..000000000 --- a/phy/plot/glsl/depth_mask.glsl +++ /dev/null @@ -1,8 +0,0 @@ -float depth_mask(float cluster, float mask, float n_clusters) { - // Depth and mask. - float depth = 0.0; - if (mask > 0.25) { - depth = -.1 - (cluster + mask) / (n_clusters + 10.); - } - return depth; -} diff --git a/phy/plot/glsl/features.frag b/phy/plot/glsl/features.frag deleted file mode 100644 index beed3e6cb..000000000 --- a/phy/plot/glsl/features.frag +++ /dev/null @@ -1,17 +0,0 @@ -#include "markers/disc.glsl" -#include "filled_antialias.glsl" -#include "grid.glsl" - -varying float v_size; -varying vec4 v_color; - -void main() -{ - // Clipping. - if (grid_clip(v_position)) discard; - - vec2 P = gl_PointCoord.xy - vec2(0.5,0.5); - float point_size = v_size + 2. * (1.0 + 1.5*1.0); - float distance = marker_disc(P*point_size, v_size); - gl_FragColor = filled(distance, 1.0, 1.0, v_color); -} diff --git a/phy/plot/glsl/features.vert b/phy/plot/glsl/features.vert deleted file mode 100644 index 5e3891fc2..000000000 --- a/phy/plot/glsl/features.vert +++ /dev/null @@ -1,37 +0,0 @@ -#include "colormaps/color-space.glsl" -#include "color.glsl" -#include "grid.glsl" -#include "depth_mask.glsl" - -attribute vec2 a_position; -attribute float a_mask; -attribute float a_cluster; // cluster idx -attribute float a_box; // (from 0 to n_rows**2-1) - -uniform float u_size; -uniform float n_clusters; -uniform sampler2D u_cluster_color; - -varying vec4 v_color; -varying float v_size; - -void main (void) -{ - v_size = u_size; - - v_color.rgb = color_mask(get_color(a_cluster, u_cluster_color, n_clusters), - a_mask); - v_color.a = .5; - - vec2 position = pan_zoom_grid(a_position, a_box); - vec2 box_position = to_box(position, a_box); - - // Depth as a function of the mask and cluster index. - float depth = depth_mask(mod(a_box, n_rows), a_mask, n_clusters); - - gl_Position = vec4(box_position, depth, 1.); - gl_PointSize = u_size + 2.0 * (1.0 + 1.5 * 1.0); - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/features_bg.frag b/phy/plot/glsl/features_bg.frag deleted file mode 100644 index 602dcd437..000000000 --- a/phy/plot/glsl/features_bg.frag +++ /dev/null @@ -1,9 +0,0 @@ -#include "grid.glsl" - -void main() -{ - // Clipping. - if (grid_clip(v_position)) discard; - - gl_FragColor = vec4(.5, .5, .5, .25); -} diff --git a/phy/plot/glsl/features_bg.vert b/phy/plot/glsl/features_bg.vert deleted file mode 100644 index 5973df500..000000000 --- a/phy/plot/glsl/features_bg.vert +++ /dev/null @@ -1,16 +0,0 @@ -#include "grid.glsl" - -attribute vec2 a_position; -attribute float a_box; // (from 0 to n_rows**2-1) - -void main (void) -{ - vec2 position = pan_zoom_grid(a_position, a_box); - vec2 box_position = to_box(position, a_box); - - gl_Position = vec4(box_position, 0., 1.); - gl_PointSize = 3.0; - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/filled_antialias.glsl b/phy/plot/glsl/filled_antialias.glsl deleted file mode 100644 index 7e95e7ee8..000000000 --- a/phy/plot/glsl/filled_antialias.glsl +++ /dev/null @@ -1,28 +0,0 @@ -vec4 filled(float distance, float linewidth, float antialias, vec4 bg_color) -{ - vec4 frag_color; - float t = linewidth/2.0 - antialias; - float signed_distance = distance; - float border_distance = abs(signed_distance) - t; - float alpha = border_distance/antialias; - alpha = exp(-alpha*alpha); - - if (border_distance < 0.0) - frag_color = bg_color; - else if (signed_distance < 0.0) - frag_color = bg_color; - else { - if (abs(signed_distance) < (linewidth/2.0 + antialias)) { - frag_color = vec4(bg_color.rgb, alpha * bg_color.a); - } - else { - discard; - } - } - return frag_color; -} - -vec4 filled(float distance, float linewidth, float antialias, vec4 fg_color, vec4 bg_color) -{ - return filled(distance, linewidth, antialias, fg_color); -} diff --git a/phy/plot/glsl/grid.glsl b/phy/plot/glsl/grid.glsl deleted file mode 100644 index 1bc154b77..000000000 --- a/phy/plot/glsl/grid.glsl +++ /dev/null @@ -1,46 +0,0 @@ - -uniform float n_rows; -uniform vec4 u_pan_zoom[256]; // maximum grid size: 16 x 16 - -varying vec2 v_position; - -vec2 row_col(float index, float n_rows) { - float row = floor(index / n_rows); - float col = mod(index, n_rows); - return vec2(row, col); -} - -vec4 fetch_pan_zoom(float index) { - vec4 pz = u_pan_zoom[int(index)]; - vec2 pan = pz.xy; - vec2 zoom = pz.zw; - return vec4(pan, zoom); -} - -vec2 to_box(vec2 position, float index) { - vec2 rc = row_col(index, n_rows) + 0.5; - - float x = -1.0 + rc.y * (2.0 / n_rows); - float y = +1.0 - rc.x * (2.0 / n_rows); - - float width = 0.95 / (1.0 * n_rows); - float height = 0.95 / (1.0 * n_rows); - - return vec2(x + width * position.x, - y + height * position.y); -} - -bool grid_clip(vec2 position, float lim) { - return ((position.x < -lim) || (position.x > +lim) || - (position.y < -lim) || (position.y > +lim)); -} - -bool grid_clip(vec2 position) { - return grid_clip(position, .95); -} - -vec2 pan_zoom_grid(vec2 position, float index) -{ - vec4 pz = fetch_pan_zoom(index); // (pan_x, pan_y, zoom_x, zoom_y) - return pz.zw * (position + n_rows * pz.xy); -} diff --git a/phy/plot/glsl/histogram.frag b/phy/plot/glsl/histogram.frag new file mode 100644 index 000000000..7c31f83aa --- /dev/null +++ b/phy/plot/glsl/histogram.frag @@ -0,0 +1,5 @@ +varying vec4 v_color; + +void main() { + gl_FragColor = v_color; +} diff --git a/phy/plot/glsl/histogram.vert b/phy/plot/glsl/histogram.vert new file mode 100644 index 000000000..1889ca2f4 --- /dev/null +++ b/phy/plot/glsl/histogram.vert @@ -0,0 +1,17 @@ +#include "utils.glsl" + +attribute vec2 a_position; +attribute float a_hist_index; // 0..n_hists-1 + +uniform sampler2D u_color; +uniform float n_hists; + +varying vec4 v_color; +varying float v_hist_index; + +void main() { + gl_Position = transform(a_position); + + v_color = fetch_texture(a_hist_index, u_color, n_hists); + v_hist_index = a_hist_index; +} diff --git a/phy/plot/glsl/lasso.frag b/phy/plot/glsl/lasso.frag deleted file mode 100644 index a4ec5fa9e..000000000 --- a/phy/plot/glsl/lasso.frag +++ /dev/null @@ -1,8 +0,0 @@ -#include "grid.glsl" - -void main() { - // Clipping. - if (grid_clip(v_position, .975)) discard; - - gl_FragColor = vec4(.5, .5, .5, 1.); -} diff --git a/phy/plot/glsl/lasso.vert b/phy/plot/glsl/lasso.vert deleted file mode 100644 index 6b5595e52..000000000 --- a/phy/plot/glsl/lasso.vert +++ /dev/null @@ -1,16 +0,0 @@ -#include "grid.glsl" - -attribute vec2 a_position; - -uniform float u_box; - -void main() { - - vec2 pos = a_position; - vec2 position = pan_zoom_grid(pos, u_box); - - gl_Position = vec4(to_box(position, u_box), -1.0, 1.0); - - // Used for clipping. - v_position = position; -} diff --git a/phy/plot/glsl/line.frag b/phy/plot/glsl/line.frag new file mode 100644 index 000000000..7c31f83aa --- /dev/null +++ b/phy/plot/glsl/line.frag @@ -0,0 +1,5 @@ +varying vec4 v_color; + +void main() { + gl_FragColor = v_color; +} diff --git a/phy/plot/glsl/line.vert b/phy/plot/glsl/line.vert new file mode 100644 index 000000000..a5c0c5474 --- /dev/null +++ b/phy/plot/glsl/line.vert @@ -0,0 +1,8 @@ +attribute vec2 a_position; +attribute vec4 a_color; +varying vec4 v_color; + +void main() { + gl_Position = transform(a_position); + v_color = a_color; +} diff --git a/phy/plot/glsl/pan_zoom.glsl b/phy/plot/glsl/pan_zoom.glsl deleted file mode 100644 index 05dd679cf..000000000 --- a/phy/plot/glsl/pan_zoom.glsl +++ /dev/null @@ -1,7 +0,0 @@ -uniform vec2 u_zoom; -uniform vec2 u_pan; - -vec2 pan_zoom(vec2 position) -{ - return u_zoom * (position + u_pan); -} diff --git a/phy/plot/glsl/plot.frag b/phy/plot/glsl/plot.frag new file mode 100644 index 000000000..09faf3bde --- /dev/null +++ b/phy/plot/glsl/plot.frag @@ -0,0 +1,11 @@ +varying vec4 v_color; +varying float v_signal_index; + +void main() { + + // Discard pixels between signals. + if (fract(v_signal_index) > 0.) + discard; + + gl_FragColor = v_color; +} diff --git a/phy/plot/glsl/plot.vert b/phy/plot/glsl/plot.vert new file mode 100644 index 000000000..05583b8b1 --- /dev/null +++ b/phy/plot/glsl/plot.vert @@ -0,0 +1,17 @@ +#include "utils.glsl" + +attribute vec3 a_position; +attribute vec4 a_color; +attribute float a_signal_index; // 0..n_signals-1 + +varying vec4 v_color; +varying float v_signal_index; + +void main() { + vec2 xy = a_position.xy; + gl_Position = transform(xy); + gl_Position.z = a_position.z; + + v_color = a_color; + v_signal_index = a_signal_index; +} diff --git a/phy/plot/glsl/polygon.frag b/phy/plot/glsl/polygon.frag new file mode 100644 index 000000000..8df552c17 --- /dev/null +++ b/phy/plot/glsl/polygon.frag @@ -0,0 +1,5 @@ +uniform vec4 u_color; + +void main() { + gl_FragColor = u_color; +} diff --git a/phy/plot/glsl/polygon.vert b/phy/plot/glsl/polygon.vert new file mode 100644 index 000000000..a82e7671b --- /dev/null +++ b/phy/plot/glsl/polygon.vert @@ -0,0 +1,7 @@ +attribute vec2 a_position; +uniform vec4 u_color; + +void main() { + gl_Position = transform(a_position); + gl_Position.z = -.5; +} diff --git a/phy/plot/glsl/scatter.frag b/phy/plot/glsl/scatter.frag new file mode 100644 index 000000000..16097b612 --- /dev/null +++ b/phy/plot/glsl/scatter.frag @@ -0,0 +1,13 @@ +#include "markers/%MARKER.glsl" +#include "antialias/filled.glsl" + +varying vec4 v_color; +varying float v_size; + +void main() +{ + vec2 P = gl_PointCoord.xy - vec2(0.5, 0.5); + float point_size = v_size + 5.; + float distance = marker_%MARKER(P *point_size, v_size); + gl_FragColor = filled(distance, 1.0, 1.0, v_color); +} diff --git a/phy/plot/glsl/scatter.vert b/phy/plot/glsl/scatter.vert new file mode 100644 index 000000000..c241fca42 --- /dev/null +++ b/phy/plot/glsl/scatter.vert @@ -0,0 +1,19 @@ +attribute vec3 a_position; +attribute vec4 a_color; +attribute float a_size; + +varying vec4 v_color; +varying float v_size; + +void main() { + vec2 xy = a_position.xy; + gl_Position = transform(xy); + gl_Position.z = a_position.z; + + // Point size as a function of the marker size and antialiasing. + gl_PointSize = a_size + 5.0; + + // Set the varyings. + v_color = a_color; + v_size = a_size; +} diff --git a/phy/plot/glsl/simple.frag b/phy/plot/glsl/simple.frag new file mode 100644 index 000000000..8df552c17 --- /dev/null +++ b/phy/plot/glsl/simple.frag @@ -0,0 +1,5 @@ +uniform vec4 u_color; + +void main() { + gl_FragColor = u_color; +} diff --git a/phy/plot/glsl/simple.vert b/phy/plot/glsl/simple.vert new file mode 100644 index 000000000..2a59757e8 --- /dev/null +++ b/phy/plot/glsl/simple.vert @@ -0,0 +1,6 @@ +attribute vec2 a_position; +uniform vec4 u_color; + +void main() { + gl_Position = transform(a_position); +} diff --git a/phy/plot/glsl/text.frag b/phy/plot/glsl/text.frag new file mode 100644 index 000000000..ef8644e2d --- /dev/null +++ b/phy/plot/glsl/text.frag @@ -0,0 +1,8 @@ + +uniform sampler2D u_tex; + +varying vec2 v_tex_coords; + +void main() { + gl_FragColor = texture2D(u_tex, v_tex_coords); +} diff --git a/phy/plot/glsl/text.vert b/phy/plot/glsl/text.vert new file mode 100644 index 000000000..f5abc380d --- /dev/null +++ b/phy/plot/glsl/text.vert @@ -0,0 +1,49 @@ + +attribute vec2 a_position; // text position +attribute float a_glyph_index; // glyph index in the text +attribute float a_quad_index; // quad index in the glyph +attribute float a_char_index; // index of the glyph in the texture +attribute float a_lengths; +attribute vec2 a_anchor; + +uniform vec2 u_glyph_size; // (w, h) + +varying vec2 v_tex_coords; + +const float rows = 6; +const float cols = 16; + +void main() { + float w = u_glyph_size.x / u_window_size.x; + float h = u_glyph_size.y / u_window_size.y; + + float dx = mod(a_quad_index, 2.); + float dy = 0.; + if ((2. <= a_quad_index) && (a_quad_index <= 4.)) { + dy = 1.; + } + + // Position of the glyph. + gl_Position = transform(a_position); + gl_Position.xy = gl_Position.xy + vec2(a_glyph_index * w + dx * w, dy * h); + // Anchor: the part in [-1, 1] is relative to the text size. + gl_Position.xy += (a_anchor - 1.) * .5 * vec2(a_lengths * w, h); + // NOTE: The part beyond [-1, 1] is absolute, so that texts stay aligned. + gl_Position.xy += (a_anchor - clamp(a_anchor, -1., 1.)); + + // Index in the texture + float i = floor(a_char_index / cols); + float j = mod(a_char_index, cols); + + // uv position in the texture for the glyph. + vec2 uv = vec2(j, rows - 1. - i); + uv /= vec2(cols, rows); + + // Little margin to avoid edge effects between glyphs. + dx = .01 + .98 * dx; + dy = .01 + .98 * dy; + // Texture coordinates for the fragment shader. + vec2 duv = vec2(dx / cols, dy /rows); + + v_tex_coords = uv + duv; +} diff --git a/phy/plot/glsl/traces.frag b/phy/plot/glsl/traces.frag deleted file mode 100644 index b388fe09d..000000000 --- a/phy/plot/glsl/traces.frag +++ /dev/null @@ -1,20 +0,0 @@ - -varying vec3 v_index; // (channel, cluster, mask) -varying vec3 v_color_channel; -varying vec3 v_color_spike; - -void main() { - vec3 color; - - // Discard vertices between two channels. - if ((fract(v_index.x) > 0.)) - discard; - - // Avoid color interpolation at spike boundaries. - if ((v_index.y >= 0) && (fract(v_index.y) == 0.) && (v_index.z > 0.)) - color = v_color_spike; - else - color = v_color_channel; - - gl_FragColor = vec4(color, .85); -} diff --git a/phy/plot/glsl/traces.vert b/phy/plot/glsl/traces.vert deleted file mode 100644 index 24f8a1b11..000000000 --- a/phy/plot/glsl/traces.vert +++ /dev/null @@ -1,59 +0,0 @@ -#include "pan_zoom.glsl" -#include "colormaps/color-space.glsl" -#include "color.glsl" - -attribute float a_position; -attribute vec2 a_index; // (channel_idx, t) -attribute vec2 a_spike; // (cluster_idx, mask) - -uniform sampler2D u_channel_color; -uniform sampler2D u_cluster_color; - -uniform float n_channels; -uniform float n_clusters; -uniform float n_samples; -uniform float u_scale; - -varying vec3 v_index; // (channel, cluster, mask) -varying vec3 v_color_channel; -varying vec3 v_color_spike; - -float get_x(float x_index) { - // 'x_index' is between 0 and nsamples. - return -1. + 2. * x_index / (float(n_samples) - 1.); -} - -float get_y(float y_index, float sample) { - // 'y_index' is between 0 and n_channels. - float a = float(u_scale) / float(n_channels); - float b = -1. + 2. * (y_index + .5) / float(n_channels); - return a * sample + .9 * b; -} - -void main() { - float channel = a_index.x; - - float x = get_x(a_index.y); - float y = get_y(channel, a_position); - vec2 position = vec2(x, y); - - gl_Position = vec4(pan_zoom(position), 0.0, 1.0); - - // Spike color as a function of the cluster and mask. - v_color_spike = color_mask(get_color(a_spike.x, // cluster_id - u_cluster_color, - n_clusters), - a_spike.y // mask - ); - - // Channel color. - v_color_channel = get_color(channel, - u_channel_color, - n_channels); - - // The fragment shader needs to know: - // * the channel (to discard fragments between channels) - // * the cluster (for the color) - // * the mask (for the alpha channel) - v_index = vec3(channel, a_spike.x, a_spike.y); -} diff --git a/phy/plot/glsl/utils.glsl b/phy/plot/glsl/utils.glsl new file mode 100644 index 000000000..944fa9e86 --- /dev/null +++ b/phy/plot/glsl/utils.glsl @@ -0,0 +1,3 @@ +vec4 fetch_texture(float index, sampler2D texture, float size) { + return texture2D(texture, vec2(index / (size - 1.), .5)); +} diff --git a/phy/plot/glsl/waveforms.frag b/phy/plot/glsl/waveforms.frag deleted file mode 100644 index f10c808e7..000000000 --- a/phy/plot/glsl/waveforms.frag +++ /dev/null @@ -1,8 +0,0 @@ -varying vec4 v_color; -varying vec2 v_box; - -void main() { - if ((fract(v_box.x) > 0.) || (fract(v_box.y) > 0.)) - discard; - gl_FragColor = v_color; -} diff --git a/phy/plot/glsl/waveforms.vert b/phy/plot/glsl/waveforms.vert deleted file mode 100644 index 8d619ca3b..000000000 --- a/phy/plot/glsl/waveforms.vert +++ /dev/null @@ -1,53 +0,0 @@ -#include "colormaps/color-space.glsl" -#include "color.glsl" -#include "pan_zoom.glsl" -#include "depth_mask.glsl" - -attribute vec2 a_data; // position (-1..1), mask -attribute float a_time; // -1..1 -attribute vec2 a_box; // 0..(n_clusters-1, n_channels-1) - -uniform float n_clusters; -uniform float n_channels; -uniform vec2 u_data_scale; -uniform vec2 u_channel_scale; -uniform sampler2D u_channel_pos; -uniform sampler2D u_cluster_color; -uniform float u_overlap; -uniform float u_alpha; - -varying vec4 v_color; -varying vec2 v_box; - -vec2 get_box_pos(vec2 box) { // box = (cluster, channel) - vec2 box_pos = texture2D(u_channel_pos, - vec2(box.y / (n_channels - 1.), .5)).xy; - box_pos = 2. * box_pos - 1.; - box_pos = box_pos * u_channel_scale; - // Spacing between cluster boxes. - float h = 2.5 * u_data_scale.x; - if (u_overlap < 0.5) - box_pos.x += h * (box.x - .5 * (n_clusters - 1.)) / n_clusters; - return box_pos; -} - -void main() { - vec2 pos = u_data_scale * vec2(a_time, a_data.x); // -1..1 - vec2 box_pos = get_box_pos(a_box); - v_box = a_box; - - // Depth as a function of the mask and cluster index. - float depth = depth_mask(a_box.x, a_data.y, n_clusters); - - vec2 x_coeff = vec2(1. / max(n_clusters, 1.), 1.); - if (u_overlap > 0.5) - x_coeff.x = 1.; - // The z coordinate is the depth: it depends on the mask. - gl_Position = vec4(pan_zoom(x_coeff * pos + box_pos), depth, 1.); - - // Compute the waveform color as a function of the cluster color - // and the mask. - v_color.rgb = color_mask(get_color(a_box.x, u_cluster_color, n_clusters), - a_data.y); - v_color.a = u_alpha; -} diff --git a/phy/plot/interact.py b/phy/plot/interact.py new file mode 100644 index 000000000..064ecda6a --- /dev/null +++ b/phy/plot/interact.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- + +"""Common interacts.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from vispy.gloo import Texture2D + +from .base import BaseInteract +from .transform import Scale, Range, Subplot, Clip, NDC +from .utils import _get_texture, _get_boxes, _get_box_pos_size +from .visuals import LineVisual + + +#------------------------------------------------------------------------------ +# Grid interact +#------------------------------------------------------------------------------ + +class Grid(BaseInteract): + """Grid interact. + + NOTE: to be used in a grid, a visual must define `a_box_index` + (by default) or another GLSL variable specified in `box_var`. + + Parameters + ---------- + + shape : tuple or str + Number of rows, cols in the grid. + box_var : str + Name of the GLSL variable with the box index. + + """ + + margin = .075 + + def __init__(self, shape=(1, 1), shape_var='u_grid_shape', box_var=None): + # Name of the variable with the box index. + self.box_var = box_var or 'a_box_index' + self.shape_var = shape_var + self._shape = shape + ms = 1 - self.margin + mc = 1 - self.margin + self._transforms = [Scale((ms, ms)), + Clip([-mc, -mc, +mc, +mc]), + Subplot(self.shape_var, self.box_var), + ] + + def attach(self, canvas): + super(Grid, self).attach(canvas) + canvas.transforms.add_on_gpu(self._transforms) + canvas.inserter.insert_vert(""" + attribute vec2 {}; + uniform vec2 {}; + """.format(self.box_var, self.shape_var), + 'header') + + def map(self, arr, box=None): + assert box is not None + assert len(box) == 2 + arr = self._transforms[0].apply(arr) + arr = Subplot(self.shape, box).apply(arr) + return arr + + def imap(self, arr, box=None): + assert box is not None + arr = Subplot(self.shape, box).inverse().apply(arr) + arr = self._transforms[0].inverse().apply(arr) + return arr + + def add_boxes(self, canvas, shape=None): + shape = shape or self.shape + assert isinstance(shape, tuple) + n, m = shape + n_boxes = n * m + a = 1 + .05 + + pos = np.array([[-a, -a, +a, -a], + [+a, -a, +a, +a], + [+a, +a, -a, +a], + [-a, +a, -a, -a], + ]) + pos = np.tile(pos, (n_boxes, 1)) + + box_index = [] + for i in range(n): + for j in range(m): + box_index.append([i, j]) + box_index = np.vstack(box_index) + box_index = np.repeat(box_index, 8, axis=0) + + boxes = LineVisual() + + @boxes.set_canvas_transforms_filter + def _remove_clip(tc): + return tc.remove('Clip') + + canvas.add_visual(boxes) + boxes.set_data(pos=pos) + boxes.program['a_box_index'] = box_index.astype(np.float32) + + def get_closest_box(self, pos): + x, y = pos + rows, cols = self.shape + j = np.clip(int(cols * (1. + x) / 2.), 0, cols - 1) + i = np.clip(int(rows * (1. - y) / 2.), 0, rows - 1) + return i, j + + def update_program(self, program): + program[self.shape_var] = self._shape + # Only set the default box index if necessary. + try: + program[self.box_var] + except KeyError: + program[self.box_var] = (0, 0) + + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, value): + self._shape = value + self.update() + + +#------------------------------------------------------------------------------ +# Boxed interact +#------------------------------------------------------------------------------ + +class Boxed(BaseInteract): + """Boxed interact. + + NOTE: to be used in a boxed, a visual must define `a_box_index` + (by default) or another GLSL variable specified in `box_var`. + + Parameters + ---------- + + box_bounds : array-like + A (n, 4) array where each row contains the `(xmin, ymin, xmax, ymax)` + bounds of every box, in normalized device coordinates. + + NOTE: the box bounds need to be contained within [-1, 1] at all times, + otherwise an error will be raised. This is to prevent silent clipping + of the values when they are passed to a gloo Texture2D. + + box_var : str + Name of the GLSL variable with the box index. + + """ + def __init__(self, + box_bounds=None, + box_pos=None, + box_size=None, + box_var=None, + keep_aspect_ratio=True, + ): + self._key_pressed = None + self.keep_aspect_ratio = keep_aspect_ratio + + # Name of the variable with the box index. + self.box_var = box_var or 'a_box_index' + + # Find the box bounds if only the box positions are passed. + if box_bounds is None: + assert box_pos is not None + # This will find a good box size automatically if it is not + # specified. + box_bounds = _get_boxes(box_pos, size=box_size, + keep_aspect_ratio=self.keep_aspect_ratio) + + self._box_bounds = np.atleast_2d(box_bounds) + assert self._box_bounds.shape[1] == 4 + + self.n_boxes = len(self._box_bounds) + self._transforms = [Range(NDC, 'box_bounds')] + + def attach(self, canvas): + super(Boxed, self).attach(canvas) + canvas.transforms.add_on_gpu(self._transforms) + canvas.inserter.insert_vert(""" + #include "utils.glsl" + attribute float {}; + uniform sampler2D u_box_bounds; + uniform float n_boxes;""".format(self.box_var), 'header') + canvas.inserter.insert_vert(""" + // Fetch the box bounds for the current box (`box_var`). + vec4 box_bounds = fetch_texture({}, + u_box_bounds, + n_boxes); + box_bounds = (2 * box_bounds - 1); // See hack in Python. + """.format(self.box_var), 'before_transforms') + + def map(self, arr, box=None): + assert box is not None + assert 0 <= box < len(self.box_bounds) + return Range(NDC, self.box_bounds[box]).apply(arr) + + def imap(self, arr, box=None): + assert 0 <= box < len(self.box_bounds) + return Range(NDC, self.box_bounds[box]).inverse().apply(arr) + + def update_program(self, program): + # Signal bounds (positions). + box_bounds = _get_texture(self._box_bounds, NDC, self.n_boxes, [-1, 1]) + box_bounds = box_bounds.astype(np.float32) + # TODO OPTIM: set the texture at initialization and update the data + program['u_box_bounds'] = Texture2D(box_bounds, + internalformat='rgba32f') + program['n_boxes'] = self.n_boxes + + # Change the box bounds, positions, or size + #-------------------------------------------------------------------------- + + @property + def box_bounds(self): + return self._box_bounds + + @box_bounds.setter + def box_bounds(self, val): + assert val.shape == (self.n_boxes, 4) + self._box_bounds = val + self.update() + + @property + def box_pos(self): + box_pos, _ = _get_box_pos_size(self._box_bounds) + return box_pos + + @box_pos.setter + def box_pos(self, val): + assert val.shape == (self.n_boxes, 2) + self.box_bounds = _get_boxes(val, size=self.box_size, + keep_aspect_ratio=self.keep_aspect_ratio) + + @property + def box_size(self): + _, box_size = _get_box_pos_size(self._box_bounds) + return box_size + + @box_size.setter + def box_size(self, val): + assert len(val) == 2 + self.box_bounds = _get_boxes(self.box_pos, size=val, + keep_aspect_ratio=self.keep_aspect_ratio) + + def get_closest_box(self, pos): + """Get the box closest to some position.""" + pos = np.atleast_2d(pos) + d = np.sum((np.array(self.box_pos) - pos) ** 2, axis=1) + idx = np.argmin(d) + return idx + + def update_boxes(self, box_pos, box_size): + """Set the box bounds from specified box positions and sizes.""" + assert box_pos.shape == (self.n_boxes, 2) + assert len(box_size) == 2 + self.box_bounds = _get_boxes(box_pos, + size=box_size, + keep_aspect_ratio=self.keep_aspect_ratio, + ) + + +class Stacked(Boxed): + """Stacked interact. + + NOTE: to be used in a stacked, a visual must define `a_box_index` + (by default) or another GLSL variable specified in `box_var`. + + Parameters + ---------- + + n_boxes : int + Number of boxes to stack vertically. + margin : int (0 by default) + The margin between the stacked subplots. Can be negative. Must be + between -1 and 1. The unit is relative to each box's size. + box_var : str + Name of the GLSL variable with the box index. + + """ + def __init__(self, n_boxes, margin=0, box_var=None, origin=None): + + # The margin must be in [-1, 1] + margin = np.clip(margin, -1, 1) + # Normalize the margin. + margin = 2. * margin / float(n_boxes) + + # Signal bounds. + b = np.zeros((n_boxes, 4)) + b[:, 0] = -1 + b[:, 1] = np.linspace(-1, 1 - 2. / n_boxes + margin, n_boxes) + b[:, 2] = 1 + b[:, 3] = np.linspace(-1 + 2. / n_boxes - margin, 1., n_boxes) + if origin == 'upper': + b = b[::-1, :] + + super(Stacked, self).__init__(b, box_var=box_var, + keep_aspect_ratio=False, + ) diff --git a/phy/plot/panzoom.py b/phy/plot/panzoom.py new file mode 100644 index 000000000..252a12b86 --- /dev/null +++ b/phy/plot/panzoom.py @@ -0,0 +1,496 @@ +# -*- coding: utf-8 -*- + +"""Pan & zoom transform.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import math +import sys + +import numpy as np + +from .base import BaseInteract +from .transform import Translate, Scale, pixels_to_ndc +from phy.utils._types import _as_array + + +#------------------------------------------------------------------------------ +# PanZoom class +#------------------------------------------------------------------------------ + +class PanZoom(BaseInteract): + """Pan and zoom interact. + + To use it: + + ```python + # Create a visual. + visual = MyVisual(...) + visual.set_data(...) + + # Attach the visual to the canvas. + canvas = BaseCanvas() + visual.attach(canvas, 'PanZoom') + + # Create and attach the PanZoom interact. + pz = PanZoom() + pz.attach(canvas) + + canvas.show() + ``` + + """ + + _default_zoom_coeff = 1.5 + _default_wheel_coeff = .1 + _arrows = ('Left', 'Right', 'Up', 'Down') + _pm = ('+', '-') + + def __init__(self, + aspect=1.0, + pan=(0.0, 0.0), zoom=(1.0, 1.0), + zmin=1e-5, zmax=1e5, + xmin=None, xmax=None, + ymin=None, ymax=None, + constrain_bounds=None, + pan_var_name='u_pan', + zoom_var_name='u_zoom', + enable_mouse_wheel=None, + ): + if constrain_bounds: + assert xmin is None + assert ymin is None + assert xmax is None + assert ymax is None + xmin, ymin, xmax, ymax = constrain_bounds + + self.pan_var_name = pan_var_name + self.zoom_var_name = zoom_var_name + + self._aspect = aspect + + self._zmin = zmin + self._zmax = zmax + self._xmin = xmin + self._xmax = xmax + self._ymin = ymin + self._ymax = ymax + + self._pan = np.array(pan) + self._zoom = np.array(zoom) + + self._zoom_coeff = self._default_zoom_coeff + self._wheel_coeff = self._default_wheel_coeff + + self.enable_keyboard_pan = True + + # Touch-related variables. + self._pinch = None + self._last_pinch_scale = None + if enable_mouse_wheel is None: + enable_mouse_wheel = sys.platform != 'darwin' + self.enable_mouse_wheel = enable_mouse_wheel + + self._zoom_to_pointer = True + self._canvas_aspect = np.ones(2) + + # Will be set when attached to a canvas. + self.canvas = None + self._translate = Translate(self.pan_var_name) + self._scale = Scale(self.zoom_var_name) + + # Various properties + # ------------------------------------------------------------------------- + + @property + def aspect(self): + """Aspect (width/height).""" + return self._aspect + + @aspect.setter + def aspect(self, value): + self._aspect = value + + # xmin/xmax + # ------------------------------------------------------------------------- + + @property + def xmin(self): + """Minimum x allowed for pan.""" + return self._xmin + + @xmin.setter + def xmin(self, value): + self._xmin = (np.minimum(value, self._xmax) + if self._xmax is not None else value) + + @property + def xmax(self): + """Maximum x allowed for pan.""" + return self._xmax + + @xmax.setter + def xmax(self, value): + self._xmax = (np.maximum(value, self._xmin) + if self._xmin is not None else value) + + # ymin/ymax + # ------------------------------------------------------------------------- + + @property + def ymin(self): + """Minimum y allowed for pan.""" + return self._ymin + + @ymin.setter + def ymin(self, value): + self._ymin = (min(value, self._ymax) + if self._ymax is not None else value) + + @property + def ymax(self): + """Maximum y allowed for pan.""" + return self._ymax + + @ymax.setter + def ymax(self, value): + self._ymax = (max(value, self._ymin) + if self._ymin is not None else value) + + # zmin/zmax + # ------------------------------------------------------------------------- + + @property + def zmin(self): + """Minimum zoom level.""" + return self._zmin + + @zmin.setter + def zmin(self, value): + self._zmin = min(value, self._zmax) + + @property + def zmax(self): + """Maximal zoom level.""" + return self._zmax + + @zmax.setter + def zmax(self, value): + self._zmax = max(value, self._zmin) + + # Internal methods + # ------------------------------------------------------------------------- + + def _zoom_aspect(self, zoom=None): + zoom = zoom if zoom is not None else self._zoom + zoom = _as_array(zoom) + aspect = (self._canvas_aspect * self._aspect + if self._aspect is not None else 1.) + return zoom * aspect + + def _normalize(self, pos): + return pixels_to_ndc(pos, size=self.size) + + def _constrain_pan(self): + """Constrain bounding box.""" + if self.xmin is not None and self.xmax is not None: + p0 = self.xmin + 1. / self._zoom[0] + p1 = self.xmax - 1. / self._zoom[0] + p0, p1 = min(p0, p1), max(p0, p1) + self._pan[0] = np.clip(self._pan[0], p0, p1) + + if self.ymin is not None and self.ymax is not None: + p0 = self.ymin + 1. / self._zoom[1] + p1 = self.ymax - 1. / self._zoom[1] + p0, p1 = min(p0, p1), max(p0, p1) + self._pan[1] = np.clip(self._pan[1], p0, p1) + + def _constrain_zoom(self): + """Constrain bounding box.""" + if self.xmin is not None: + self._zoom[0] = max(self._zoom[0], + 1. / (self._pan[0] - self.xmin)) + if self.xmax is not None: + self._zoom[0] = max(self._zoom[0], + 1. / (self.xmax - self._pan[0])) + + if self.ymin is not None: + self._zoom[1] = max(self._zoom[1], + 1. / (self._pan[1] - self.ymin)) + if self.ymax is not None: + self._zoom[1] = max(self._zoom[1], + 1. / (self.ymax - self._pan[1])) + + def get_mouse_pos(self, pos): + """Return the mouse coordinates in NDC, taking panzoom into account.""" + position = np.asarray(self._normalize(pos)) + zoom = np.asarray(self._zoom_aspect()) + pan = np.asarray(self.pan) + mouse_pos = ((position / zoom) - pan) + return mouse_pos + + # Pan and zoom + # ------------------------------------------------------------------------- + + @property + def pan(self): + """Pan translation.""" + return list(self._pan) + + @pan.setter + def pan(self, value): + """Pan translation.""" + assert len(value) == 2 + self._pan[:] = value + self._constrain_pan() + self.update() + + @property + def zoom(self): + """Zoom level.""" + return list(self._zoom) + + @zoom.setter + def zoom(self, value): + """Zoom level.""" + if isinstance(value, (int, float)): + value = (value, value) + assert len(value) == 2 + self._zoom = np.clip(value, self._zmin, self._zmax) + + # Constrain bounding box. + self._constrain_pan() + self._constrain_zoom() + + self.update() + + def pan_delta(self, d): + """Pan the view by a given amount.""" + dx, dy = d + + pan_x, pan_y = self.pan + zoom_x, zoom_y = self._zoom_aspect(self._zoom) + + self.pan = (pan_x + dx / zoom_x, pan_y + dy / zoom_y) + self.update() + + def zoom_delta(self, d, p=(0., 0.), c=1.): + """Zoom the view by a given amount.""" + dx, dy = d + x0, y0 = p + + pan_x, pan_y = self._pan + zoom_x, zoom_y = self._zoom + zoom_x_new, zoom_y_new = (zoom_x * math.exp(c * self._zoom_coeff * dx), + zoom_y * math.exp(c * self._zoom_coeff * dy)) + + zoom_x_new = max(min(zoom_x_new, self._zmax), self._zmin) + zoom_y_new = max(min(zoom_y_new, self._zmax), self._zmin) + + self.zoom = zoom_x_new, zoom_y_new + + if self._zoom_to_pointer: + zoom_x, zoom_y = self._zoom_aspect((zoom_x, + zoom_y)) + zoom_x_new, zoom_y_new = self._zoom_aspect((zoom_x_new, + zoom_y_new)) + + self.pan = (pan_x - x0 * (1. / zoom_x - 1. / zoom_x_new), + pan_y - y0 * (1. / zoom_y - 1. / zoom_y_new)) + + self.update() + + def set_pan_zoom(self, pan=None, zoom=None): + self._pan = pan + self._zoom = np.clip(zoom, self._zmin, self._zmax) + + # Constrain bounding box. + self._constrain_pan() + self._constrain_zoom() + + self.update() + + def set_range(self, bounds, keep_aspect=False): + """Zoom to fit a box.""" + # a * (v0 + t) = -1 + # a * (v1 + t) = +1 + # => + # a * (v1 - v0) = 2 + bounds = np.asarray(bounds, dtype=np.float64) + v0 = bounds[:2] + v1 = bounds[2:] + pan = -.5 * (v0 + v1) + zoom = 2. / (v1 - v0) + if keep_aspect: + zoom = zoom.min() * np.ones(2) + self.set_pan_zoom(pan=pan, zoom=zoom) + + def get_range(self): + """Return the bounds currently visible.""" + p, z = np.asarray(self.pan), np.asarray(self.zoom) + x0, y0 = -1. / z - p + x1, y1 = +1. / z - p + return (x0, y0, x1, y1) + + # Event callbacks + # ------------------------------------------------------------------------- + + keyboard_shortcuts = { + 'pan': ('left click and drag', 'arrows'), + 'zoom': ('right click and drag', '+', '-'), + 'reset': 'r', + } + + def _set_canvas_aspect(self): + w, h = self.size + aspect = w / max(float(h), 1.) + if aspect > 1.0: + self._canvas_aspect = np.array([1.0 / aspect, 1.0]) + else: + self._canvas_aspect = np.array([1.0, aspect / 1.0]) + + def _zoom_keyboard(self, key): + k = .05 + if key == '-': + k = -k + self.zoom_delta((k, k), (0, 0)) + + def _pan_keyboard(self, key): + k = .1 / np.asarray(self.zoom) + if key == 'Left': + self.pan_delta((+k[0], +0)) + elif key == 'Right': + self.pan_delta((-k[0], +0)) + elif key == 'Down': + self.pan_delta((+0, +k[1])) + elif key == 'Up': + self.pan_delta((+0, -k[1])) + self.update() + + def reset(self): + """Reset the view.""" + self.pan = (0., 0.) + self.zoom = 1. + self.update() + + def on_resize(self, event): + """Resize event.""" + self._set_canvas_aspect() + # Update zoom level + self.zoom = self._zoom + + def on_mouse_move(self, event): + """Pan and zoom with the mouse.""" + if event.modifiers: + return + if event.is_dragging: + x0, y0 = self._normalize(event.press_event.pos) + x1, y1 = self._normalize(event.last_event.pos) + x, y = self._normalize(event.pos) + dx, dy = x - x1, y - y1 + if event.button == 1: + self.pan_delta((dx, dy)) + elif event.button == 2: + c = np.sqrt(self.size[0]) * .03 + self.zoom_delta((dx, dy), (x0, y0), c=c) + + def on_touch(self, event): + if event.type == 'end': + self._pinch = None + elif event.type == 'pinch': + if event.scale in (1., self._last_pinch_scale): + self._pinch = None + return + self._last_pinch_scale = event.scale + x0, y0 = self._normalize(event.pos) + s = math.log(event.scale / event.last_scale) + c = np.sqrt(self.size[0]) * .05 + self.zoom_delta((s, s), + (x0, y0), + c=c) + self._pinch = True + elif event.type == 'touch': + if self._pinch: + return + x0, y0 = self._normalize(np.array(event.pos).mean(axis=0)) + x1, y1 = self._normalize(np.array(event.last_pos).mean(axis=0)) + dx, dy = x0 - x1, y0 - y1 + c = 5 + self.pan_delta((c * dx, c * dy)) + + def on_mouse_wheel(self, event): + """Zoom with the mouse wheel.""" + # NOTE: not called on OS X because of touchpad + if event.modifiers: + return + dx = np.sign(event.delta[1]) * self._wheel_coeff + # Zoom toward the mouse pointer. + x0, y0 = self._normalize(event.pos) + self.zoom_delta((dx, dx), (x0, y0)) + + def on_key_press(self, event): + """Pan and zoom with the keyboard.""" + # Zooming with the keyboard. + key = event.key + if event.modifiers: + return + + # Pan. + if self.enable_keyboard_pan and key in self._arrows: + self._pan_keyboard(key) + + # Zoom. + if key in self._pm: + self._zoom_keyboard(key) + + # Reset with 'R'. + if key == 'R': + self.reset() + + # Canvas methods + # ------------------------------------------------------------------------- + + @property + def size(self): + if self.canvas: + return self.canvas.size + else: + return (1, 1) + + def attach(self, canvas): + """Attach this interact to a canvas.""" + super(PanZoom, self).attach(canvas) + canvas.panzoom = self + + canvas.transforms.add_on_gpu([self._translate, self._scale]) + # Add the variable declarations. + vs = ('uniform vec2 {};\n'.format(self.pan_var_name) + + 'uniform vec2 {};\n'.format(self.zoom_var_name)) + canvas.inserter.insert_vert(vs, 'header') + + canvas.connect(self.on_resize) + canvas.connect(self.on_mouse_move) + canvas.connect(self.on_touch) + canvas.connect(self.on_key_press) + + if self.enable_mouse_wheel: + canvas.connect(self.on_mouse_wheel) + + self._set_canvas_aspect() + + def map(self, arr): + arr = Translate(self.pan).apply(arr) + arr = Scale(self.zoom).apply(arr) + return arr + + def imap(self, arr): + arr = Scale(self.zoom).inverse().apply(arr) + arr = Translate(self.pan).inverse().apply(arr) + return arr + + def update_program(self, program): + program[self.pan_var_name] = self._pan + program[self.zoom_var_name] = self._zoom_aspect() diff --git a/phy/plot/plot.py b/phy/plot/plot.py new file mode 100644 index 000000000..48cbfaf4e --- /dev/null +++ b/phy/plot/plot.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- + +"""Plotting interface.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from collections import OrderedDict +from contextlib import contextmanager + +import numpy as np + +from phy.io.array import _accumulate, _in_polygon +from phy.utils._types import _as_tuple +from .base import BaseCanvas +from .interact import Grid, Boxed, Stacked +from .panzoom import PanZoom +from .utils import _get_array +from .visuals import (ScatterVisual, PlotVisual, HistogramVisual, + LineVisual, TextVisual, PolygonVisual) + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +# NOTE: we ensure that we only create every type *once*, so that +# View._items has only one key for any class. +_SCATTER_CLASSES = {} + + +def _make_scatter_class(marker): + """Return a temporary ScatterVisual class with a given marker.""" + name = 'ScatterVisual' + marker.title() + if name not in _SCATTER_CLASSES: + cls = type(name, (ScatterVisual,), {'_default_marker': marker}) + _SCATTER_CLASSES[name] = cls + return _SCATTER_CLASSES[name] + + +#------------------------------------------------------------------------------ +# Plotting interface +#------------------------------------------------------------------------------ + +class View(BaseCanvas): + """High-level plotting canvas.""" + _default_box_index = (0,) + + def __init__(self, layout=None, shape=None, n_plots=None, origin=None, + box_bounds=None, box_pos=None, box_size=None, + enable_lasso=False, + **kwargs): + if not kwargs.get('keys', None): + kwargs['keys'] = None + super(View, self).__init__(**kwargs) + self.layout = layout + + if layout == 'grid': + self._default_box_index = (0, 0) + self.grid = Grid(shape) + self.grid.attach(self) + self.interact = self.grid + + elif layout == 'boxed': + self.n_plots = (len(box_bounds) + if box_bounds is not None else len(box_pos)) + self.boxed = Boxed(box_bounds=box_bounds, + box_pos=box_pos, + box_size=box_size) + self.boxed.attach(self) + self.interact = self.boxed + + elif layout == 'stacked': + self.n_plots = n_plots + self.stacked = Stacked(n_plots, margin=.1, origin=origin) + self.stacked.attach(self) + self.interact = self.stacked + + else: + self.interact = None + + self.panzoom = PanZoom(aspect=None, + constrain_bounds=[-2, -2, +2, +2]) + self.panzoom.attach(self) + + if enable_lasso: + self.lasso = Lasso() + self.lasso.attach(self) + else: + self.lasso = None + + self.clear() + + def clear(self): + """Reset the view.""" + self._items = OrderedDict() + self.visuals = [] + self.update() + + def _add_item(self, cls, *args, **kwargs): + """Add a plot item.""" + box_index = kwargs.pop('box_index', self._default_box_index) + + data = cls.validate(*args, **kwargs) + n = cls.vertex_count(**data) + + if not isinstance(box_index, np.ndarray): + k = len(self._default_box_index) + box_index = _get_array(box_index, (n, k)) + data['box_index'] = box_index + + if cls not in self._items: + self._items[cls] = [] + self._items[cls].append(data) + return data + + def plot(self, *args, **kwargs): + """Add a line plot.""" + return self._add_item(PlotVisual, *args, **kwargs) + + def scatter(self, *args, **kwargs): + """Add a scatter plot.""" + cls = _make_scatter_class(kwargs.pop('marker', + ScatterVisual._default_marker)) + return self._add_item(cls, *args, **kwargs) + + def hist(self, *args, **kwargs): + """Add some histograms.""" + return self._add_item(HistogramVisual, *args, **kwargs) + + def text(self, *args, **kwargs): + """Add text.""" + return self._add_item(TextVisual, *args, **kwargs) + + def lines(self, *args, **kwargs): + """Add some lines.""" + return self._add_item(LineVisual, *args, **kwargs) + + def __getitem__(self, box_index): + self._default_box_index = _as_tuple(box_index) + return self + + def build(self): + """Build all added items. + + Visuals are created, added, and built. The `set_data()` methods can + be called afterwards. + + """ + for cls, data_list in self._items.items(): + # Some variables are not concatenated. They are specified + # in `allow_list`. + data = _accumulate(data_list, cls.allow_list) + box_index = data.pop('box_index') + visual = cls() + self.add_visual(visual) + visual.set_data(**data) + # NOTE: visual.program.__contains__ is implemented in vispy master + # so we can replace this with `if 'a_box_index' in visual.program` + # after the next VisPy release. + if 'a_box_index' in visual.program._code_variables: + visual.program['a_box_index'] = box_index.astype(np.float32) + # TODO: refactor this when there is the possibility to update existing + # visuals without recreating the whole scene. + if self.lasso: + self.lasso.create_visual() + self.update() + + def get_pos_from_mouse(self, pos, box): + # From window coordinates to NDC (pan & zoom taken into account). + pos = self.panzoom.get_mouse_pos(pos) + # From NDC to data coordinates. + pos = self.interact.imap(pos, box) if self.interact else pos + return pos + + @contextmanager + def building(self): + """Context manager to specify the plots.""" + self.clear() + yield + self.build() + + +#------------------------------------------------------------------------------ +# Interactive tools +#------------------------------------------------------------------------------ + +class Lasso(object): + def __init__(self): + self._points = [] + self.view = None + self.visual = None + self.box = None + + def add(self, pos): + self._points.append(pos) + self.update_visual() + + @property + def polygon(self): + l = self._points + # Close the polygon. + # l = l + l[0] if len(l) else l + out = np.array(l, dtype=np.float64) + out = np.reshape(out, (out.size // 2, 2)) + assert out.ndim == 2 + assert out.shape[1] == 2 + return out + + def clear(self): + self._points = [] + self.box = None + self.update_visual() + + @property + def count(self): + return len(self._points) + + def in_polygon(self, pos): + return _in_polygon(pos, self.polygon) + + def attach(self, view): + view.connect(self.on_mouse_press) + self.view = view + + def create_visual(self): + self.visual = PolygonVisual() + self.view.add_visual(self.visual) + self.update_visual() + + def update_visual(self): + if not self.visual: + return + # Update the polygon. + self.visual.set_data(pos=self.polygon) + # Set the box index for the polygon, depending on the box + # where the first point was clicked in. + box = (self.box if self.box is not None + else self.view._default_box_index) + k = len(self.view._default_box_index) + n = self.visual.vertex_count(pos=self.polygon) + box_index = _get_array(box, (n, k)).astype(np.float32) + self.visual.program['a_box_index'] = box_index + self.view.update() + + def on_mouse_press(self, e): + if 'Control' in e.modifiers: + if e.button == 1: + # Find the box. + ndc = self.view.panzoom.get_mouse_pos(e.pos) + # NOTE: we don't update the box after the second point. + # In other words, the first point determines the box for the + # lasso. + if self.box is None and self.view.interact: + self.box = self.view.interact.get_closest_box(ndc) + # Transform from window coordinates to NDC. + pos = self.view.get_pos_from_mouse(e.pos, self.box) + self.add(pos) + else: + self.clear() + self.box = None diff --git a/phy/plot/static/SourceCodePro-Regular-32.npy.gz b/phy/plot/static/SourceCodePro-Regular-32.npy.gz new file mode 100644 index 000000000..46e0019df Binary files /dev/null and b/phy/plot/static/SourceCodePro-Regular-32.npy.gz differ diff --git a/phy/plot/static/SourceCodePro-Regular-48.npy.gz b/phy/plot/static/SourceCodePro-Regular-48.npy.gz new file mode 100644 index 000000000..992f5cce3 Binary files /dev/null and b/phy/plot/static/SourceCodePro-Regular-48.npy.gz differ diff --git a/phy/plot/static/chars.txt b/phy/plot/static/chars.txt new file mode 100644 index 000000000..8c24215d1 --- /dev/null +++ b/phy/plot/static/chars.txt @@ -0,0 +1 @@ + !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ \ No newline at end of file diff --git a/phy/plot/tests/conftest.py b/phy/plot/tests/conftest.py new file mode 100644 index 000000000..522a356ef --- /dev/null +++ b/phy/plot/tests/conftest.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +"""Test VisPy.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from vispy.app import use_app +from pytest import yield_fixture + +from ..base import BaseCanvas +from ..panzoom import PanZoom + + +#------------------------------------------------------------------------------ +# Utilities and fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def canvas(qapp): + use_app('pyqt4') + c = BaseCanvas(keys='interactive') + yield c + c.close() + + +@yield_fixture +def canvas_pz(canvas): + PanZoom(enable_mouse_wheel=True).attach(canvas) + yield canvas diff --git a/phy/plot/tests/test_base.py b/phy/plot/tests/test_base.py new file mode 100644 index 000000000..8f2984ac0 --- /dev/null +++ b/phy/plot/tests/test_base.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- + +"""Test base.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from pytest import yield_fixture + +from ..base import BaseVisual, BaseInteract, GLSLInserter +from ..transform import (subplot_bounds, Translate, Scale, Range, + Clip, Subplot, TransformChain) + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def vertex_shader_nohook(): + yield """ + attribute vec2 a_position; + void main() { + gl_Position = vec4(a_position.xy, 0, 1); + } + """ + + +@yield_fixture +def vertex_shader(): + yield """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position.xy); + gl_PointSize = 2.0; + } + """ + + +@yield_fixture +def fragment_shader(): + yield """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + + +#------------------------------------------------------------------------------ +# Test base +#------------------------------------------------------------------------------ + +def test_glsl_inserter_nohook(vertex_shader_nohook, fragment_shader): + vertex_shader = vertex_shader_nohook + inserter = GLSLInserter() + inserter.insert_vert('uniform float boo;', 'header') + inserter.insert_frag('// In fragment shader.', 'before_transforms') + vs, fs = inserter.insert_into_shaders(vertex_shader, fragment_shader) + assert vs == vertex_shader + assert fs == fragment_shader + + +def test_glsl_inserter_hook(vertex_shader, fragment_shader): + inserter = GLSLInserter() + inserter.insert_vert('uniform float boo;', 'header') + inserter.insert_frag('// In fragment shader.', 'before_transforms') + tc = TransformChain() + tc.add_on_gpu([Scale(.5)]) + inserter.add_transform_chain(tc) + vs, fs = inserter.insert_into_shaders(vertex_shader, fragment_shader) + assert 'temp_pos_tr = temp_pos_tr * 0.5;' in vs + assert 'uniform float boo;' in vs + assert '// In fragment shader.' in fs + + +def test_visual_1(qtbot, canvas): + class TestVisual(BaseVisual): + def __init__(self): + super(TestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + + def set_data(self): + self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] + + v = TestVisual() + canvas.add_visual(v) + # Must be called *after* add_visual(). + v.set_data() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() + + +def test_visual_2(qtbot, canvas, vertex_shader, fragment_shader): + """Test a BaseVisual with multiple CPU and GPU transforms. + + There should be points filling the entire right upper (2, 3) subplot. + + """ + + class TestVisual(BaseVisual): + def __init__(self): + super(TestVisual, self).__init__() + self.vertex_shader = vertex_shader + self.fragment_shader = fragment_shader + self.set_primitive_type('points') + self.transforms.add_on_cpu(Scale((.1, .1))) + self.transforms.add_on_cpu(Translate((-1, -1))) + self.transforms.add_on_cpu(Range((-1, -1, 1, 1), + (-1.5, -1.5, 1.5, 1.5), + )) + s = 'gl_Position.y += (1 + 1e-8 * u_window_size.x);' + self.inserter.insert_vert(s, 'after_transforms') + + def set_data(self): + data = np.random.uniform(0, 20, (1000, 2)) + pos = self.transforms.apply(data).astype(np.float32) + self.program['a_position'] = pos + + bounds = subplot_bounds(shape=(2, 3), index=(1, 2)) + canvas.transforms.add_on_gpu([Subplot((2, 3), (1, 2)), + Clip(bounds), + ]) + + # We attach the visual to the canvas. By default, a BaseInteract is used. + v = TestVisual() + canvas.add_visual(v) + v.set_data() + + v = TestVisual() + canvas.add_visual(v) + v.set_data() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + # qtbot.stop() + + +def test_interact_1(qtbot, canvas): + interact = BaseInteract() + interact.update() + + class TestVisual(BaseVisual): + def __init__(self): + super(TestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + + def set_data(self): + self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] + + interact.attach(canvas) + v = TestVisual() + canvas.add_visual(v) + v.set_data() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + interact.update() diff --git a/phy/plot/tests/test_ccg.py b/phy/plot/tests/test_ccg.py deleted file mode 100644 index cfeaf2eb2..000000000 --- a/phy/plot/tests/test_ccg.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test CCG plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark - -import numpy as np - -from ..ccg import _plot_ccg_mpl, CorrelogramView, plot_correlograms -from ...utils._color import _random_color -from ...io.mock import artificial_correlograms -from ...utils.testing import show_test - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Tests matplotlib -#------------------------------------------------------------------------------ - -def test_plot_ccg(): - n_bins = 51 - ccg = np.random.randint(size=n_bins, low=10, high=50) - _plot_ccg_mpl(ccg, baseline=20, color='g') - - -def test_plot_correlograms(): - n_bins = 51 - ccg = np.random.uniform(size=(3, 3, n_bins)) - c = plot_correlograms(ccg, lines=[-10, 0, 20], show=False) - show_test(c) - - -#------------------------------------------------------------------------------ -# Tests VisPy -#------------------------------------------------------------------------------ - -def _test_correlograms(n_clusters=None): - n_samples = 51 - - correlograms = artificial_correlograms(n_clusters, n_samples) - - c = CorrelogramView(keys='interactive') - c.cluster_ids = np.arange(n_clusters) - c.visual.correlograms = correlograms - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - c.lines = [-5, 0, 5] - show_test(c) - - -def test_correlograms_empty(): - _test_correlograms(n_clusters=0) - - -def test_correlograms_full(): - _test_correlograms(n_clusters=3) diff --git a/phy/plot/tests/test_features.py b/phy/plot/tests/test_features.py deleted file mode 100644 index e5665a0ac..000000000 --- a/phy/plot/tests/test_features.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test feature plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark - -import numpy as np - -from ..features import FeatureView, plot_features -from ...utils._color import _random_color -from ...io.mock import (artificial_features, - artificial_masks, - artificial_spike_clusters, - artificial_spike_samples) -from ...utils.testing import show_test - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def _test_features(n_spikes=None, n_clusters=None): - n_channels = 32 - n_features = 3 - - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_samples = artificial_spike_samples(n_spikes).astype(np.float32) - - c = FeatureView(keys='interactive') - c.init_grid(2) - c.visual.features = features - c.background.features = features[::2] * 2. - # Useful to test depth. - # masks[n_spikes//2:, ...] = 0 - c.visual.masks = masks - M = spike_samples.max() if len(spike_samples) else 1 - c.add_extra_feature('time', spike_samples, 0, M, - array_bg=spike_samples[::2]) - c.add_extra_feature('test', - np.sin(np.linspace(-10., 10., n_spikes)), - -1., 1., - array_bg=spike_samples[::2]) - c.set_dimensions('x', [['time', (1, 0)], - [(2, 1), 'time']]) - c.set_dimensions('y', [[(0, 0), (1, 1)], - [(1, 0), 'test']]) - c.visual.spike_clusters = spike_clusters - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - - show_test(c) - - -def test_features_empty(): - _test_features(n_spikes=0, n_clusters=0) - - -def test_features_full(): - _test_features(n_spikes=100, n_clusters=3) - - -def test_plot_features(): - n_spikes = 1000 - n_channels = 32 - n_features = 1 - n_clusters = 2 - - features = artificial_features(n_spikes, n_channels, n_features) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - # Unclustered spikes. - spike_clusters[::3] = -1 - - c = plot_features(features[:, :1, :], - show=False, - x_dimensions=[['time']], - y_dimensions=[[(0, 0)]], - ) - show_test(c) - - c = plot_features(features, - show=False) - show_test(c) - - c = plot_features(features, show=False) - show_test(c) - - c = plot_features(features, masks=masks, show=False) - show_test(c) - - c = plot_features(features, spike_clusters=spike_clusters, show=False) - show_test(c) diff --git a/phy/plot/tests/test_interact.py b/phy/plot/tests/test_interact.py new file mode 100644 index 000000000..7255bd82a --- /dev/null +++ b/phy/plot/tests/test_interact.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- + +"""Test interact.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from itertools import product + +import numpy as np +from numpy.testing import assert_equal as ae +from numpy.testing import assert_allclose as ac + +from ..base import BaseVisual +from ..interact import Grid, Boxed, Stacked +from ..panzoom import PanZoom +from ..transform import NDC + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +class MyTestVisual(BaseVisual): + def __init__(self): + super(MyTestVisual, self).__init__() + self.vertex_shader = """ + attribute vec2 a_position; + void main() { + gl_Position = transform(a_position); + gl_PointSize = 2.; + } + """ + self.fragment_shader = """ + void main() { + gl_FragColor = vec4(1, 1, 1, 1); + } + """ + self.set_primitive_type('points') + + def set_data(self): + n = 1000 + + coeff = [(1 + i + j) for i, j in product(range(2), range(3))] + coeff = np.repeat(coeff, n) + coeff = coeff[:, None] + + position = .1 * coeff * np.random.randn(2 * 3 * n, 2) + + self.program['a_position'] = position.astype(np.float32) + + +def _create_visual(qtbot, canvas, interact, box_index): + c = canvas + + # Attach the interact *and* PanZoom. The order matters! + interact.attach(c) + PanZoom(aspect=None, constrain_bounds=NDC).attach(c) + + visual = MyTestVisual() + c.add_visual(visual) + visual.set_data() + + visual.program['a_box_index'] = box_index.astype(np.float32) + + c.show() + qtbot.waitForWindowShown(c.native) + + +#------------------------------------------------------------------------------ +# Test grid +#------------------------------------------------------------------------------ + +def test_grid_interact(): + grid = Grid((4, 8)) + ac(grid.map([0., 0.], (0, 0)), [[-0.875, 0.75]]) + ac(grid.map([0., 0.], (1, 3)), [[-0.125, 0.25]]) + ac(grid.map([0., 0.], (3, 7)), [[0.875, -0.75]]) + + ac(grid.imap([[0.875, -0.75]], (3, 7)), [[0., 0.]]) + + +def test_grid_closest_box(): + grid = Grid((3, 7)) + ac(grid.get_closest_box((0., 0.)), (1, 3)) + ac(grid.get_closest_box((-1., +1.)), (0, 0)) + ac(grid.get_closest_box((+1., -1.)), (2, 6)) + ac(grid.get_closest_box((-1., -1.)), (2, 0)) + ac(grid.get_closest_box((+1., +1.)), (0, 6)) + + +def test_grid_1(qtbot, canvas): + + n = 1000 + + box_index = [[i, j] for i, j in product(range(2), range(3))] + box_index = np.repeat(box_index, n, axis=0) + + grid = Grid((2, 3)) + _create_visual(qtbot, canvas, grid, box_index) + + grid.add_boxes(canvas) + + # qtbot.stop() + + +def test_grid_2(qtbot, canvas): + + n = 1000 + + box_index = [[i, j] for i, j in product(range(2), range(3))] + box_index = np.repeat(box_index, n, axis=0) + + grid = Grid() + _create_visual(qtbot, canvas, grid, box_index) + grid.shape = (3, 3) + assert grid.shape == (3, 3) + + # qtbot.stop() + + +#------------------------------------------------------------------------------ +# Test boxed +#------------------------------------------------------------------------------ + +def test_boxed_1(qtbot, canvas): + + n = 6 + b = np.zeros((n, 4)) + + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 1. / 3., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 1. / 3., 1., n) + + n = 1000 + box_index = np.repeat(np.arange(6), n, axis=0) + + boxed = Boxed(box_bounds=b) + _create_visual(qtbot, canvas, boxed, box_index) + + ae(boxed.box_bounds, b) + boxed.box_bounds = b + + boxed.update_boxes(boxed.box_pos, boxed.box_size) + ac(boxed.box_bounds, b) + + # qtbot.stop() + + +def test_boxed_2(qtbot, canvas): + """Test setting the box position and size dynamically.""" + + n = 1000 + pos = np.c_[np.zeros(6), np.linspace(-1., 1., 6)] + box_index = np.repeat(np.arange(6), n, axis=0) + + boxed = Boxed(box_pos=pos) + _create_visual(qtbot, canvas, boxed, box_index) + + boxed.box_pos *= .25 + boxed.box_size = [1, .1] + + idx = boxed.get_closest_box((.5, .25)) + assert idx == 4 + + # qtbot.stop() + + +def test_boxed_interact(): + + n = 8 + b = np.zeros((n, 4)) + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 1. / 4., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 1. / 4., 1., n) + + boxed = Boxed(box_bounds=b) + ac(boxed.map([0., 0.], 0), [[-.875, -.875]]) + ac(boxed.map([0., 0.], 7), [[.875, .875]]) + ac(boxed.imap([[.875, .875]], 7), [[0., 0.]]) + + +def test_boxed_closest_box(): + b = np.array([[-.5, -.5, 0., 0.], + [0., 0., +.5, +.5]]) + boxed = Boxed(box_bounds=b) + + ac(boxed.get_closest_box((-1, -1)), 0) + ac(boxed.get_closest_box((-0.001, 0)), 0) + ac(boxed.get_closest_box((+0.001, 0)), 1) + ac(boxed.get_closest_box((-1, +1)), 0) + + +#------------------------------------------------------------------------------ +# Test stacked +#------------------------------------------------------------------------------ + +def test_stacked_1(qtbot, canvas): + + n = 1000 + box_index = np.repeat(np.arange(6), n, axis=0) + + stacked = Stacked(n_boxes=6, margin=-10, origin='upper') + _create_visual(qtbot, canvas, stacked, box_index) + + # qtbot.stop() + + +def test_stacked_closest_box(): + stacked = Stacked(n_boxes=4, origin='upper') + ac(stacked.get_closest_box((-.5, .9)), 0) + ac(stacked.get_closest_box((+.5, -.9)), 3) + + stacked = Stacked(n_boxes=4, origin='lower') + ac(stacked.get_closest_box((-.5, .9)), 3) + ac(stacked.get_closest_box((+.5, -.9)), 0) diff --git a/phy/plot/tests/test_panzoom.py b/phy/plot/tests/test_panzoom.py new file mode 100644 index 000000000..94659c29d --- /dev/null +++ b/phy/plot/tests/test_panzoom.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- + +"""Test panzoom.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_allclose as ac +from pytest import yield_fixture +from vispy.app import MouseEvent +from vispy.util import keys + +from ..base import BaseVisual +from ..panzoom import PanZoom + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +class MyTestVisual(BaseVisual): + def __init__(self): + super(MyTestVisual, self).__init__() + self.set_shader('simple') + self.set_primitive_type('lines') + + def set_data(self): + self.program['a_position'] = [[-1, 0], [1, 0]] + self.program['u_color'] = [1, 1, 1, 1] + + +@yield_fixture +def panzoom(qtbot, canvas_pz): + c = canvas_pz + visual = MyTestVisual() + c.add_visual(visual) + visual.set_data() + + c.show() + qtbot.waitForWindowShown(c.native) + + yield c.panzoom + + +#------------------------------------------------------------------------------ +# Test panzoom +#------------------------------------------------------------------------------ + +def test_panzoom_basic_attrs(): + pz = PanZoom() + + # Aspect. + assert pz.aspect == 1. + pz.aspect = 2. + assert pz.aspect == 2. + + # Constraints. + for name in ('xmin', 'xmax', 'ymin', 'ymax'): + assert getattr(pz, name) is None + setattr(pz, name, 1.) + assert getattr(pz, name) == 1. + + for name, v in (('zmin', 1e-5), ('zmax', 1e5)): + assert getattr(pz, name) == v + setattr(pz, name, v * 2) + assert getattr(pz, name) == v * 2 + + +def test_panzoom_basic_constrain(): + pz = PanZoom(constrain_bounds=(-1, -1, 1, 1)) + + # Aspect. + assert pz.aspect == 1. + pz.aspect = 2. + assert pz.aspect == 2. + + # Constraints. + assert pz.xmin == pz.ymin == -1 + assert pz.xmax == pz.ymax == +1 + + +def test_panzoom_basic_pan_zoom(): + pz = PanZoom() + + # Pan. + assert pz.pan == [0., 0.] + pz.pan = (1., -1.) + assert pz.pan == [1., -1.] + + # Zoom. + assert pz.zoom == [1., 1.] + pz.zoom = (2., .5) + assert pz.zoom == [2., .5] + pz.zoom = (1., 1.) + + # Pan delta. + pz.pan_delta((-1., 1.)) + assert pz.pan == [0., 0.] + + # Zoom delta. + pz.zoom_delta((1., 1.)) + assert pz.zoom[0] > 2 + assert pz.zoom[0] == pz.zoom[1] + pz.zoom = (1., 1.) + + # Zoom delta. + pz.zoom_delta((2., 3.), (.5, .5)) + assert pz.zoom[0] > 2 + assert pz.zoom[1] > 3 * pz.zoom[0] + + +def test_panzoom_map(): + pz = PanZoom() + pz.pan = (1., -1.) + ac(pz.map([0., 0.]), [[1., -1.]]) + + pz.zoom = (2., .5) + ac(pz.map([0., 0.]), [[2., -.5]]) + + ac(pz.imap([2., -.5]), [[0., 0.]]) + + +def test_panzoom_constraints_x(): + pz = PanZoom() + pz.xmin, pz.xmax = -2, 2 + + # Pan beyond the bounds. + pz.pan_delta((-2, 2)) + assert pz.pan == [-1, 2] + pz.reset() + + # Zoom beyond the bounds. + pz.zoom_delta((-1, -2)) + assert pz.pan == [0, 0] + assert pz.zoom[0] == .5 + assert pz.zoom[1] < .5 + + +def test_panzoom_constraints_y(): + pz = PanZoom() + pz.ymin, pz.ymax = -2, 2 + + # Pan beyond the bounds. + pz.pan_delta((2, -2)) + assert pz.pan == [2, -1] + pz.reset() + + # Zoom beyond the bounds. + pz.zoom_delta((-2, -1)) + assert pz.pan == [0, 0] + assert pz.zoom[0] < .5 + assert pz.zoom[1] == .5 + + +def test_panzoom_constraints_z(): + pz = PanZoom() + pz.zmin, pz.zmax = .5, 2 + + # Zoom beyond the bounds. + pz.zoom_delta((-10, -10)) + assert pz.zoom == [.5, .5] + pz.reset() + + pz.zoom_delta((10, 10)) + assert pz.zoom == [2, 2] + + +def test_panzoom_set_range(): + pz = PanZoom() + + def _test_range(*bounds): + pz.set_range(bounds) + ac(pz.get_range(), bounds) + + _test_range(-1, -1, 1, 1) + ac(pz.zoom, (1, 1)) + + _test_range(-.5, -.5, .5, .5) + ac(pz.zoom, (2, 2)) + + _test_range(0, 0, 1, 1) + ac(pz.zoom, (2, 2)) + + _test_range(-1, 0, 1, 1) + ac(pz.zoom, (1, 2)) + + pz.set_range((-1, 0, 1, 1), keep_aspect=True) + ac(pz.zoom, (1, 1)) + + +def test_panzoom_mouse_pos(): + pz = PanZoom() + pz.zoom_delta((10, 10), (.5, .25)) + pos = pz.get_mouse_pos((.01, -.01)) + ac(pos, (.5, .25), atol=1e-3) + + +#------------------------------------------------------------------------------ +# Test panzoom on canvas +#------------------------------------------------------------------------------ + +def test_panzoom_pan_mouse(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Pan with mouse. + press = MouseEvent(type='mouse_press', pos=(0, 0)) + c.events.mouse_move(pos=(10., 0.), button=1, + last_event=press, press_event=press) + assert pz.pan[0] > 0 + assert pz.pan[1] == 0 + pz.pan = (0, 0) + + # Panning with a modifier should not pan. + press = MouseEvent(type='mouse_press', pos=(0, 0)) + c.events.mouse_move(pos=(10., 0.), button=1, + last_event=press, press_event=press, + modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + + # qtbot.stop() + + +def test_panzoom_touch(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Pan with mouse. + c.events.touch(type='pinch', pos=(0, 0), scale=1, last_scale=1) + c.events.touch(type='pinch', pos=(0, 0), scale=2, last_scale=1) + assert pz.zoom[0] >= 2 + c.events.touch(type='end') + + c.events.touch(type='touch', pos=(0.1, 0), last_pos=(0, 0)) + assert pz.pan[0] >= 1 + + +def test_panzoom_pan_keyboard(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Pan with keyboard. + c.events.key_press(key=keys.UP) + assert pz.pan[0] == 0 + assert pz.pan[1] < 0 + + # All panning movements with keys. + c.events.key_press(key=keys.LEFT) + c.events.key_press(key=keys.DOWN) + c.events.key_press(key=keys.RIGHT) + assert pz.pan == [0, 0] + + # Reset with R. + c.events.key_press(key=keys.RIGHT) + c.events.key_press(key=keys.Key('r')) + assert pz.pan == [0, 0] + + # Using modifiers should not pan. + c.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + + # Disable keyboard pan. + pz.enable_keyboard_pan = False + c.events.key_press(key=keys.UP, modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + + +def test_panzoom_zoom_mouse(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Zoom with mouse. + press = MouseEvent(type='mouse_press', pos=(50., 50.)) + c.events.mouse_move(pos=(0., 0.), button=2, + last_event=press, press_event=press) + assert pz.pan[0] < 0 + assert pz.pan[1] < 0 + assert pz.zoom[0] < 1 + assert pz.zoom[1] > 1 + pz.reset() + + # Zoom with mouse. + size = np.asarray(c.size) + c.events.mouse_wheel(pos=size / 2., delta=(0., 1.)) + assert pz.pan == [0, 0] + assert pz.zoom[0] > 1 + assert pz.zoom[1] > 1 + pz.reset() + + # Using modifiers with the wheel should not zoom. + c.events.mouse_wheel(pos=(0., 0.), delta=(0., 1.), + modifiers=(keys.CONTROL,)) + assert pz.pan == [0, 0] + assert pz.zoom == [1, 1] + pz.reset() + + +def test_panzoom_zoom_keyboard(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Zoom with keyboard. + c.events.key_press(key=keys.Key('+')) + assert pz.pan == [0, 0] + assert pz.zoom[0] > 1 + assert pz.zoom[1] > 1 + + # Unzoom with keyboard. + c.events.key_press(key=keys.Key('-')) + assert pz.pan == [0, 0] + assert pz.zoom == [1, 1] + + +def test_panzoom_resize(qtbot, canvas_pz, panzoom): + c = canvas_pz + pz = panzoom + + # Increase coverage with different aspect ratio. + c.native.resize(400, 600) + # qtbot.stop() + # c.events.resize(size=(100, 1000)) + assert list(pz._canvas_aspect) == [1., 2. / 3] diff --git a/phy/plot/tests/test_plot.py b/phy/plot/tests/test_plot.py new file mode 100644 index 000000000..d1e57584e --- /dev/null +++ b/phy/plot/tests/test_plot.py @@ -0,0 +1,252 @@ +# -*- coding: utf-8 -*- + +"""Test plotting interface.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_array_equal as ae +from vispy.util import keys + +from ..panzoom import PanZoom +from ..plot import View +from ..transform import NDC +from ..utils import _get_linear_x + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def _show(qtbot, view, stop=False): + view.build() + view.show() + qtbot.waitForWindowShown(view.native) + if stop: # pragma: no cover + qtbot.stop() + view.close() + + +#------------------------------------------------------------------------------ +# Test plotting interface +#------------------------------------------------------------------------------ + +def test_building(qtbot): + view = View(keys='interactive') + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + + with view.building(): + view.scatter(x, y) + + view.show() + qtbot.waitForWindowShown(view.native) + view.close() + + +def test_simple_view(qtbot): + view = View() + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + + view.scatter(x, y) + _show(qtbot, view) + + +#------------------------------------------------------------------------------ +# Test visuals in grid +#------------------------------------------------------------------------------ + +def test_grid_scatter(qtbot): + view = View(layout='grid', shape=(2, 3)) + n = 100 + + assert isinstance(view.panzoom, PanZoom) + + x = np.random.randn(n) + y = np.random.randn(n) + + view[0, 1].scatter(x, y) + view[0, 2].scatter(x, y, color=np.random.uniform(.5, .8, size=(n, 4))) + + view[1, 0].scatter(x, y, size=np.random.uniform(5, 20, size=n)) + view[1, 1] + + # Multiple scatters in the same subplot. + view[1, 2].scatter(x[2::6], y[2::6], marker='asterisk', + color=(0, 1, 0, .25), size=20) + view[1, 2].scatter(x[::5], y[::5], marker='heart', + color=(1, 0, 0, .35), size=50) + view[1, 2].scatter(x[1::3], y[1::3], marker='heart', + color=(1, 0, 1, .35), size=30) + + _show(qtbot, view) + + +def test_grid_plot(qtbot): + view = View(layout='grid', shape=(1, 2)) + n_plots, n_samples = 5, 50 + + x = _get_linear_x(n_plots, n_samples) + y = np.random.randn(n_plots, n_samples) + + view[0, 0].plot(x, y) + view[0, 1].plot(x, y, color=np.random.uniform(.5, .8, size=(n_plots, 4))) + + _show(qtbot, view) + + +def test_grid_hist(qtbot): + view = View(layout='grid', shape=(3, 3)) + + hist = np.random.rand(3, 3, 20) + + for i in range(3): + for j in range(3): + view[i, j].hist(hist[i, j, :], + color=np.random.uniform(.5, .8, size=4)) + + _show(qtbot, view) + + +def test_grid_lines(qtbot): + view = View(layout='grid', shape=(1, 2)) + + view[0, 0].lines(pos=[-1, -.5, +1, -.5]) + view[0, 1].lines(pos=[-1, +.5, +1, +.5]) + + _show(qtbot, view) + + +def test_grid_text(qtbot): + view = View(layout='grid', shape=(2, 1)) + + view[0, 0].text(pos=(0, 0), text='Hello world!', anchor=(0., 0.)) + view[1, 0].text(pos=[[-.5, 0], [+.5, 0]], text=['|', ':)']) + + _show(qtbot, view) + + +def test_grid_complete(qtbot): + view = View(layout='grid', shape=(2, 2)) + t = _get_linear_x(1, 1000).ravel() + + view[0, 0].scatter(*np.random.randn(2, 100)) + view[0, 1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + + view[1, 1].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + + _show(qtbot, view) + + +#------------------------------------------------------------------------------ +# Test other interact +#------------------------------------------------------------------------------ + +def test_stacked_complete(qtbot): + view = View(layout='stacked', n_plots=3) + + t = _get_linear_x(1, 1000).ravel() + view[0].scatter(*np.random.randn(2, 100)) + + # Different types of visuals in the same subplot. + view[1].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + + # TODO + # v = view[2].plot(t[::2], np.sin(20 * t[::2]), color=(1, 0, 0, 1)) + # v.update(color=(0, 1, 0, 1)) + + _show(qtbot, view) + + +def test_boxed_complete(qtbot): + n = 3 + b = np.zeros((n, 4)) + b[:, 0] = b[:, 1] = np.linspace(-1., 1. - 2. / 3., n) + b[:, 2] = b[:, 3] = np.linspace(-1. + 2. / 3., 1., n) + view = View(layout='boxed', box_bounds=b) + + t = _get_linear_x(1, 1000).ravel() + view[0].scatter(*np.random.randn(2, 100)) + view[1].plot(t, np.sin(20 * t), color=(1, 0, 0, 1)) + view[2].hist(np.random.rand(5, 10), + color=np.random.uniform(.4, .9, size=(5, 4))) + + _show(qtbot, view) + + +#------------------------------------------------------------------------------ +# Test lasso +#------------------------------------------------------------------------------ + +def test_lasso_simple(qtbot): + view = View(enable_lasso=True, keys='interactive') + n = 1000 + + x = np.random.randn(n) + y = np.random.randn(n) + + view.scatter(x, y) + + l = view.lasso + ev = view.events + ev.mouse_press(pos=(0, 0), button=1, modifiers=(keys.CONTROL,)) + l.add((+1, -1)) + l.add((+1, +1)) + l.add((-1, +1)) + assert l.count == 4 + assert l.polygon.shape == (4, 2) + b = [[-1, -1], [+1, -1], [+1, +1], [-1, +1]] + ae(l.in_polygon(b), [False, False, True, True]) + + ev.mouse_press(pos=(0, 0), button=2, modifiers=(keys.CONTROL,)) + assert l.count == 0 + + _show(qtbot, view) + + +def test_lasso_grid(qtbot): + view = View(layout='grid', shape=(1, 2), + enable_lasso=True, keys='interactive') + x, y = np.meshgrid(np.linspace(-1., 1., 20), np.linspace(-1., 1., 20)) + x, y = x.ravel(), y.ravel() + view[0, 1].scatter(x, y, data_bounds=NDC) + + l = view.lasso + ev = view.events + + # Square selection in the left panel. + ev.mouse_press(pos=(100, 100), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 0) + ev.mouse_press(pos=(200, 100), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(200, 200), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(100, 200), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 0) + + # Clear. + ev.mouse_press(pos=(100, 200), button=2, modifiers=(keys.CONTROL,)) + assert l.box is None + + # Square selection in the right panel. + ev.mouse_press(pos=(500, 100), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 1) + ev.mouse_press(pos=(700, 100), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(700, 300), button=1, modifiers=(keys.CONTROL,)) + ev.mouse_press(pos=(500, 300), button=1, modifiers=(keys.CONTROL,)) + assert l.box == (0, 1) + + ind = l.in_polygon(np.c_[x, y]) + view[0, 1].scatter(x[ind], y[ind], color=(1., 0., 0., 1.), + data_bounds=NDC) + + _show(qtbot, view) diff --git a/phy/plot/tests/test_traces.py b/phy/plot/tests/test_traces.py deleted file mode 100644 index 5525b1512..000000000 --- a/phy/plot/tests/test_traces.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test CCG plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark - -import numpy as np - -from ..traces import TraceView, plot_traces -from ...utils._color import _random_color -from ...io.mock import (artificial_traces, - artificial_masks, - artificial_spike_clusters, - ) -from ...utils.testing import show_test - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Tests VisPy -#------------------------------------------------------------------------------ - -def _test_traces(n_samples=None): - n_channels = 20 - n_spikes = 50 - n_clusters = 3 - - traces = artificial_traces(n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_samples = np.linspace(50, n_samples - 50, n_spikes).astype(np.uint64) - - c = TraceView(keys='interactive') - c.visual.traces = traces - c.visual.n_samples_per_spike = 20 - c.visual.spike_samples = spike_samples - c.visual.spike_clusters = spike_clusters - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - c.visual.masks = masks - c.visual.sample_rate = 20000. - c.visual.offset = 0 - - show_test(c) - - -def test_traces_empty(): - _test_traces(n_samples=0) - - -def test_traces_full(): - _test_traces(n_samples=2000) - - -def test_plot_traces(): - n_samples = 10000 - n_channels = 20 - n_spikes = 50 - n_clusters = 3 - - traces = artificial_traces(n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - spike_samples = np.linspace(50, n_samples - 50, n_spikes).astype(np.uint64) - - c = plot_traces(traces, show=False) - show_test(c) - - c = plot_traces(traces, - spike_samples=spike_samples, - masks=masks, show=False) - show_test(c) - - c = plot_traces(traces, - spike_samples=spike_samples, - spike_clusters=spike_clusters, show=False) - show_test(c) - - c = plot_traces(traces, - spike_samples=spike_samples, - spike_clusters=spike_clusters, - show=False) - show_test(c) diff --git a/phy/plot/tests/test_transform.py b/phy/plot/tests/test_transform.py new file mode 100644 index 000000000..67c507e83 --- /dev/null +++ b/phy/plot/tests/test_transform.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- + +"""Test transform.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from textwrap import dedent + +import numpy as np +from numpy.testing import assert_equal as ae +from pytest import yield_fixture + +from ..transform import (_glslify, pixels_to_ndc, + Translate, Scale, Range, Clip, Subplot, + TransformChain, + ) + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +def _check_forward(transform, array, expected): + transformed = transform.apply(array) + if array is None or not len(array): + assert transformed is None or not len(transformed) + return + array = np.atleast_2d(array) + if isinstance(array, np.ndarray): + assert transformed.shape[1] == array.shape[1] + if not len(transformed): + assert not len(expected) + else: + assert np.allclose(transformed, expected, atol=1e-7) + + +def _check(transform, array, expected): + array = np.array(array, dtype=np.float64) + expected = np.array(expected, dtype=np.float64) + _check_forward(transform, array, expected) + # Test the inverse transform if it is implemented. + inv = transform.inverse() + _check_forward(inv, expected, array) + + +#------------------------------------------------------------------------------ +# Test utils +#------------------------------------------------------------------------------ + +def test_glslify(): + assert _glslify('a') == 'a', 'b' + assert _glslify((1, 2, 3, 4)) == 'vec4(1, 2, 3, 4)' + assert _glslify((1., 2.)) == 'vec2(1.0, 2.0)' + + +def test_pixels_to_ndc(): + assert list(pixels_to_ndc((0, 0), size=(10, 10))) == [-1, 1] + + +#------------------------------------------------------------------------------ +# Test transform +#------------------------------------------------------------------------------ + +def test_types(): + _check(Translate([1, 2]), [], []) + + for ab in [[3, 4], [3., 4.]]: + for arr in [ab, [ab], np.array(ab), np.array([ab]), + np.array([ab, ab, ab])]: + _check(Translate([1, 2]), arr, [[4, 6]]) + + +def test_translate_cpu(): + _check(Translate([1, 2]), [3, 4], [[4, 6]]) + + +def test_scale_cpu(): + _check(Scale([-1, 2]), [3, 4], [[-3, 8]]) + + +def test_range_cpu(): + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [-1, -1], [[-3, -3]]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [0, 0], [[-1, -1]]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [0.5, 0.5], [[0, 0]]) + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), [1, 1], [[1, 1]]) + + _check(Range([0, 0, 1, 1], [-1, -1, 1, 1]), + [[0, .5], [1.5, -.5]], [[-1, 0], [2, -2]]) + + +def test_range_cpu_vectorized(): + arr = np.arange(6).reshape((3, 2)) * 1. + arr_tr = arr / 5. + arr_tr[2, :] /= 10 + + f = np.tile([0, 0, 5, 5], (3, 1)) + f[2, :] *= 10 + + t = np.tile([0, 0, 1, 1], (3, 1)) + + _check(Range(f, t), arr, arr_tr) + + +def test_clip_cpu(): + _check(Clip(), [0, 0], [0, 0]) # Default bounds. + + _check(Clip([0, 1, 2, 3]), [0, 1], [0, 1]) + _check(Clip([0, 1, 2, 3]), [1, 2], [1, 2]) + _check(Clip([0, 1, 2, 3]), [2, 3], [2, 3]) + + _check(Clip([0, 1, 2, 3]), [-1, -1], []) + _check(Clip([0, 1, 2, 3]), [3, 4], []) + _check(Clip([0, 1, 2, 3]), [[-1, 0], [3, 4]], []) + + +def test_subplot_cpu(): + shape = (2, 3) + + _check(Subplot(shape, (0, 0)), [-1, -1], [-1, +0]) + _check(Subplot(shape, (0, 0)), [+0, +0], [-2. / 3., .5]) + + _check(Subplot(shape, (1, 0)), [-1, -1], [-1, -1]) + _check(Subplot(shape, (1, 0)), [+1, +1], [-1. / 3, 0]) + + _check(Subplot(shape, (1, 1)), [0, 1], [0, 0]) + + +#------------------------------------------------------------------------------ +# Test GLSL transforms +#------------------------------------------------------------------------------ + +def test_translate_glsl(): + t = Translate('u_translate').glsl('x') + assert 'x = x + u_translate' in t + + +def test_scale_glsl(): + assert 'x = x * u_scale' in Scale('u_scale').glsl('x') + + +def test_range_glsl(): + + assert Range([-1, -1, 1, 1]).glsl('x') + + expected = ('u_to.xy + (u_to.zw - u_to.xy) * (x - u_from.xy) / ' + '(u_from.zw - u_from.xy)') + r = Range('u_from', 'u_to') + assert expected in r.glsl('x') + + +def test_clip_glsl(): + expected = dedent(""" + if ((x.x < b.x) || + (x.y < b.y) || + (x.x > b.z) || + (x.y > b.w)) { + discard; + } + """).strip() + assert expected in Clip('b').glsl('x') + + +def test_subplot_glsl(): + glsl = Subplot('u_shape', 'a_index').glsl('x') + assert 'x = ' in glsl + + +#------------------------------------------------------------------------------ +# Test transform chain +#------------------------------------------------------------------------------ + +@yield_fixture +def array(): + yield np.array([[-1., 0.], [1., 2.]]) + + +def test_transform_chain_empty(array): + t = TransformChain() + + assert t.cpu_transforms == [] + assert t.gpu_transforms == [] + + ae(t.apply(array), array) + + +def test_transform_chain_one(array): + translate = Translate([1, 2]) + t = TransformChain() + t.add_on_cpu([translate]) + + assert t.cpu_transforms == [translate] + assert t.gpu_transforms == [] + + ae(t.apply(array), [[0, 2], [2, 4]]) + + +def test_transform_chain_two(array): + translate = Translate([1, 2]) + scale = Scale([.5, .5]) + t = TransformChain() + t.add_on_cpu([translate, scale]) + + assert t.cpu_transforms == [translate, scale] + assert t.gpu_transforms == [] + + assert isinstance(t.get('Translate'), Translate) + assert t.get('Unknown') is None + + ae(t.apply(array), [[0, 1], [1, 2]]) + + +def test_transform_chain_complete(array): + t = TransformChain() + t.add_on_cpu([Scale(.5), Scale(2.)]) + t.add_on_cpu(Range([-3, -3, 1, 1])) + t.add_on_gpu(Clip()) + t.add_on_gpu([Subplot('u_shape', 'a_box_index')]) + + assert len(t.cpu_transforms) == 3 + assert len(t.gpu_transforms) == 2 + + ae(t.apply(array), [[0, .5], [1, 1.5]]) + + assert len(t.remove('Scale').cpu_transforms) == len(t.cpu_transforms) - 2 + + +def test_transform_chain_add(): + tc = TransformChain() + tc.add_on_cpu([Scale(.5)]) + + tc_2 = TransformChain() + tc_2.add_on_cpu([Scale(2.)]) + + ae((tc + tc_2).apply([3.]), [[3.]]) + + +def test_transform_chain_inverse(): + tc = TransformChain() + tc.add_on_cpu([Scale(.5), Translate((1, 0)), Scale(2)]) + tci = tc.inverse() + ae(tc.apply([[1., 0.]]), [[3., 0.]]) + ae(tci.apply([[3., 0.]]), [[1., 0.]]) diff --git a/phy/plot/tests/test_utils.py b/phy/plot/tests/test_utils.py index 2268c6f1b..8cef15ab7 100644 --- a/phy/plot/tests/test_utils.py +++ b/phy/plot/tests/test_utils.py @@ -1,80 +1,140 @@ # -*- coding: utf-8 -*- -"""Test utils plotting.""" +"""Test plotting/VisPy utilities.""" + #------------------------------------------------------------------------------ # Imports #------------------------------------------------------------------------------ -from pytest import mark - -from vispy import app - -from ...utils.testing import (show_test_start, - show_test_run, - show_test_stop, - ) -from .._vispy_utils import LassoVisual -from .._panzoom import PanZoom, PanZoomGrid +import os +import os.path as op +import numpy as np +from numpy.testing import assert_array_equal as ae +from numpy.testing import assert_allclose as ac +from vispy import config -# Skip these tests in "make test-quick". -pytestmark = mark.long() +from phy.electrode.mea import linear_positions, staggered_positions +from ..utils import (_load_shader, + _tesselate_histogram, + _enable_depth_mask, + _get_data_bounds, + _boxes_overlap, + _binary_search, + _get_boxes, + _get_box_pos_size, + ) #------------------------------------------------------------------------------ -# Tests VisPy +# Test utilities #------------------------------------------------------------------------------ -_N_FRAMES = 2 - - -class TestCanvas(app.Canvas): - _pz = None - - def __init__(self, visual, grid=False, **kwargs): - super(TestCanvas, self).__init__(keys='interactive', **kwargs) - self.visual = visual - self._grid = grid - self._create_pan_zoom() - - def _create_pan_zoom(self): - if self._grid: - self._pz = PanZoomGrid() - self._pz.n_rows = self.visual.n_rows - else: - self._pz = PanZoom() - self._pz.add(self.visual.program) - self._pz.attach(self) - - def on_draw(self, event): - """Draw the main visual.""" - self.context.clear() - self.visual.draw() - - def on_resize(self, event): - """Resize the OpenGL context.""" - self.context.set_viewport(0, 0, event.size[0], event.size[1]) - - -def _show_visual(visual, grid=False, stop=True): - view = TestCanvas(visual, grid=grid) - show_test_start(view) - show_test_run(view, _N_FRAMES) - if stop: - show_test_stop(view) - return view - - -def test_lasso(): - lasso = LassoVisual() - lasso.n_rows = 4 - lasso.box = (1, 3) - lasso.points = [[+.8, +.8], - [-.8, +.8], - [-.8, -.8], - ] - view = _show_visual(lasso, grid=True, stop=False) - view.visual.add([+.8, -.8]) - show_test_run(view, _N_FRAMES) - show_test_stop(view) +def test_load_shader(): + assert 'main()' in _load_shader('simple.vert') + assert config['include_path'] + assert op.exists(config['include_path'][0]) + assert op.isdir(config['include_path'][0]) + assert os.listdir(config['include_path'][0]) + + +def test_tesselate_histogram(): + n = 7 + hist = np.arange(n) + thist = _tesselate_histogram(hist) + assert thist.shape == (6 * n, 2) + ac(thist[0], [0, 0]) + ac(thist[-3], [n, n - 1]) + ac(thist[-1], [n, 0]) + + +def test_enable_depth_mask(qtbot, canvas): + + @canvas.connect + def on_draw(e): + _enable_depth_mask() + + canvas.show() + qtbot.waitForWindowShown(canvas.native) + + +def test_get_data_bounds(): + db0 = np.array([[0, 1, 4, 5], + [0, 1, 4, 5], + [0, 1, 4, 5]]) + arr = np.arange(6).reshape((3, 2)) + assert np.all(_get_data_bounds(None, arr) == [[0, 1, 4, 5]]) + + db = db0.copy() + assert np.all(_get_data_bounds(db, arr) == [[0, 1, 4, 5]]) + + db = db0.copy() + db[2, :] = [1, 1, 1, 1] + assert np.all(_get_data_bounds(db, arr)[:2, :] == [[0, 1, 4, 5]]) + assert np.all(_get_data_bounds(db, arr)[2, :] == [0, 0, 2, 2]) + + db = db0.copy() + db[:2, :] = [1, 1, 1, 1] + assert np.all(_get_data_bounds(db, arr)[:2, :] == [[0, 0, 2, 2]]) + assert np.all(_get_data_bounds(db, arr)[2, :] == [0, 1, 4, 5]) + + +def test_boxes_overlap(): + + def _get_args(boxes): + x0, y0, x1, y1 = np.array(boxes).T + x0 = x0[:, np.newaxis] + x1 = x1[:, np.newaxis] + y0 = y0[:, np.newaxis] + y1 = y1[:, np.newaxis] + return x0, y0, x1, y1 + + boxes = [[-1, -1, 0, 0], [0.01, 0.01, 1, 1]] + x0, y0, x1, y1 = _get_args(boxes) + assert not _boxes_overlap(x0, y0, x1, y1) + + boxes = [[-1, -1, 0.1, 0.1], [0, 0, 1, 1]] + x0, y0, x1, y1 = _get_args(boxes) + assert _boxes_overlap(x0, y0, x1, y1) + + +def test_binary_search(): + def f(x): + return x < .4 + ac(_binary_search(f, 0, 1), .4) + ac(_binary_search(f, 0, .3), .3) + ac(_binary_search(f, .5, 1), .5) + + +def test_get_boxes(): + positions = [[-1, 0], [1, 0]] + boxes = _get_boxes(positions) + ac(boxes, [[-1, -.25, 0, .25], + [+0, -.25, 1, .25]], atol=1e-4) + + positions = [[-1, 0], [1, 0]] + boxes = _get_boxes(positions, keep_aspect_ratio=False) + ac(boxes, [[-1, -1, 0, 1], + [0, -1, 1, 1]], atol=1e-4) + + positions = linear_positions(4) + boxes = _get_boxes(positions) + ac(boxes, [[-0.5, -1.0, +0.5, -0.5], + [-0.5, -0.5, +0.5, +0.0], + [-0.5, +0.0, +0.5, +0.5], + [-0.5, +0.5, +0.5, +1.0], + ], atol=1e-4) + + positions = staggered_positions(8) + boxes = _get_boxes(positions) + ac(boxes[:, 1], np.arange(.75, -1.1, -.25), atol=1e-6) + ac(boxes[:, 3], np.arange(1, -.76, -.25), atol=1e-7) + + +def test_get_box_pos_size(): + bounds = [[-1, -.25, 0, .25], + [+0, -.25, 1, .25]] + pos, size = _get_box_pos_size(bounds) + ae(pos, [[-.5, 0], [.5, 0]]) + assert size == (.5, .25) diff --git a/phy/plot/tests/test_visuals.py b/phy/plot/tests/test_visuals.py new file mode 100644 index 000000000..bee2dccb2 --- /dev/null +++ b/phy/plot/tests/test_visuals.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- + +"""Test visuals.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + +from ..transform import NDC +from ..visuals import (ScatterVisual, PlotVisual, HistogramVisual, + LineVisual, PolygonVisual, TextVisual, + ) + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +def _test_visual(qtbot, c, v, stop=False, **kwargs): + c.add_visual(v) + data = v.validate(**kwargs) + assert v.vertex_count(**data) >= 0 + v.set_data(**kwargs) + c.show() + qtbot.waitForWindowShown(c.native) + if stop: # pragma: no cover + qtbot.stop() + c.close() + + +#------------------------------------------------------------------------------ +# Test scatter visual +#------------------------------------------------------------------------------ + +def test_scatter_empty(qtbot, canvas): + _test_visual(qtbot, canvas, ScatterVisual(), x=np.zeros(0), y=np.zeros(0)) + + +def test_scatter_markers(qtbot, canvas_pz): + c = canvas_pz + + n = 100 + x = .2 * np.random.randn(n) + y = .2 * np.random.randn(n) + + v = ScatterVisual(marker='vbar') + c.add_visual(v) + v.set_data(x=x, y=y) + + c.show() + qtbot.waitForWindowShown(c.native) + + # qtbot.stop() + + +def test_scatter_custom(qtbot, canvas_pz): + + n = 100 + + # Random position. + pos = .2 * np.random.randn(n, 2) + + # Random colors. + c = np.random.uniform(.4, .7, size=(n, 4)) + c[:, -1] = .5 + + # Random sizes + s = 5 + 20 * np.random.rand(n) + + _test_visual(qtbot, canvas_pz, ScatterVisual(), + pos=pos, color=c, size=s) + + +#------------------------------------------------------------------------------ +# Test plot visual +#------------------------------------------------------------------------------ + +def test_plot_empty(qtbot, canvas): + y = np.zeros((1, 0)) + _test_visual(qtbot, canvas, PlotVisual(), + y=y) + + +def test_plot_0(qtbot, canvas_pz): + y = np.zeros((1, 10)) + _test_visual(qtbot, canvas_pz, PlotVisual(), + y=y) + + +def test_plot_1(qtbot, canvas_pz): + y = .2 * np.random.randn(10) + _test_visual(qtbot, canvas_pz, PlotVisual(), + y=y) + + +def test_plot_2(qtbot, canvas_pz): + + n_signals = 50 + n_samples = 10 + y = 20 * np.random.randn(n_signals, n_samples) + + # Signal colors. + c = np.random.uniform(.5, 1, size=(n_signals, 4)) + c[:, 3] = .5 + + # Depth. + depth = np.linspace(0., -1., n_signals) + + _test_visual(qtbot, canvas_pz, PlotVisual(), + y=y, depth=depth, + data_bounds=[-1, -50, 1, 50], + color=c) + + +def test_plot_list(qtbot, canvas_pz): + y = [np.random.randn(i) for i in (5, 20)] + + c = np.random.uniform(.5, 1, size=(2, 4)) + c[:, 3] = .5 + + _test_visual(qtbot, canvas_pz, PlotVisual(), + y=y, color=c) + + +#------------------------------------------------------------------------------ +# Test histogram visual +#------------------------------------------------------------------------------ + +def test_histogram_empty(qtbot, canvas): + hist = np.zeros((1, 0)) + _test_visual(qtbot, canvas, HistogramVisual(), + hist=hist) + + +def test_histogram_0(qtbot, canvas_pz): + hist = np.zeros((10,)) + _test_visual(qtbot, canvas_pz, HistogramVisual(), + hist=hist) + + +def test_histogram_1(qtbot, canvas_pz): + hist = np.random.rand(1, 10) + _test_visual(qtbot, canvas_pz, HistogramVisual(), + hist=hist) + + +def test_histogram_2(qtbot, canvas_pz): + + n_hists = 5 + hist = np.random.rand(n_hists, 21) + + # Histogram colors. + c = np.random.uniform(.3, .6, size=(n_hists, 4)) + c[:, 3] = 1 + + _test_visual(qtbot, canvas_pz, HistogramVisual(), + hist=hist, color=c, ylim=2 * np.ones(n_hists)) + + +#------------------------------------------------------------------------------ +# Test line visual +#------------------------------------------------------------------------------ + +def test_line_empty(qtbot, canvas): + pos = np.zeros((0, 4)) + _test_visual(qtbot, canvas, LineVisual(), pos=pos) + + +def test_line_0(qtbot, canvas_pz): + n = 10 + y = np.linspace(-.5, .5, 10) + pos = np.c_[-np.ones(n), y, np.ones(n), y] + color = np.random.uniform(.5, .9, (n, 4)) + _test_visual(qtbot, canvas_pz, LineVisual(), + pos=pos, color=color, data_bounds=[-1, -1, 1, 1]) + + +#------------------------------------------------------------------------------ +# Test polygon visual +#------------------------------------------------------------------------------ + +def test_polygon_empty(qtbot, canvas): + pos = np.zeros((0, 2)) + _test_visual(qtbot, canvas, PolygonVisual(), pos=pos) + + +def test_polygon_0(qtbot, canvas_pz): + n = 9 + x = .5 * np.cos(np.linspace(0., 2 * np.pi, n)) + y = .5 * np.sin(np.linspace(0., 2 * np.pi, n)) + pos = np.c_[x, y] + _test_visual(qtbot, canvas_pz, PolygonVisual(), pos=pos) + + +#------------------------------------------------------------------------------ +# Test text visual +#------------------------------------------------------------------------------ + +def test_text_empty(qtbot, canvas): + pos = np.zeros((0, 2)) + _test_visual(qtbot, canvas, TextVisual(), pos=pos, text=[]) + _test_visual(qtbot, canvas, TextVisual()) + + +def test_text_0(qtbot, canvas_pz): + text = '0123456789' + text = [text[:n] for n in range(1, 11)] + + pos = np.c_[np.linspace(-.5, .5, 10), np.linspace(-.5, .5, 10)] + + _test_visual(qtbot, canvas_pz, TextVisual(), + pos=pos, text=text) + + +def test_text_1(qtbot, canvas_pz): + c = canvas_pz + + text = ['--x--'] * 5 + pos = [[0, 0], [-.5, +.5], [+.5, +.5], [-.5, -.5], [+.5, -.5]] + anchor = [[0, 0], [-1, +1], [+1, +1], [-1, -1], [+1, -1]] + + v = TextVisual() + c.add_visual(v) + v.set_data(pos=pos, text=text, anchor=anchor, data_bounds=NDC) + + v = ScatterVisual() + c.add_visual(v) + v.set_data(pos=pos, data_bounds=NDC) + + c.show() + qtbot.waitForWindowShown(c.native) + + # qtbot.stop() + c.close() diff --git a/phy/plot/tests/test_waveforms.py b/phy/plot/tests/test_waveforms.py deleted file mode 100644 index ed2dcf2b9..000000000 --- a/phy/plot/tests/test_waveforms.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test waveform plotting.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from pytest import mark - -import numpy as np - -from ..waveforms import WaveformView, plot_waveforms -from ...utils._color import _random_color -from ...io.mock import (artificial_waveforms, artificial_masks, - artificial_spike_clusters) -from ...electrode.mea import staggered_positions -from ...utils.testing import show_test - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - - -def _test_waveforms(n_spikes=None, n_clusters=None): - n_channels = 32 - n_samples = 40 - - channel_positions = staggered_positions(n_channels) - - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - c = WaveformView(keys='interactive') - c.visual.waveforms = waveforms - # Test depth. - # masks[n_spikes//2:, ...] = 0 - # Test position of masks. - # masks[:, :n_channels // 2] = 0 - # masks[:, n_channels // 2:] = 1 - c.visual.masks = masks - c.visual.spike_clusters = spike_clusters - c.visual.cluster_colors = np.array([_random_color() - for _ in range(n_clusters)]) - c.visual.channel_positions = channel_positions - c.visual.channel_order = np.arange(1, n_channels + 1) - - @c.connect - def on_channel_click(e): - print(e.channel_id, e.key) - - show_test(c) - - -def test_waveforms_empty(): - _test_waveforms(n_spikes=0, n_clusters=0) - - -def test_waveforms_full(): - _test_waveforms(n_spikes=100, n_clusters=3) - - -def test_plot_waveforms(): - n_spikes = 100 - n_clusters = 2 - n_channels = 32 - n_samples = 40 - - channel_positions = staggered_positions(n_channels) - - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - c = plot_waveforms(waveforms, show=False) - show_test(c) - - c = plot_waveforms(waveforms, masks=masks, show=False) - show_test(c) - - c = plot_waveforms(waveforms, spike_clusters=spike_clusters, show=False) - show_test(c) - - c = plot_waveforms(waveforms, - spike_clusters=spike_clusters, - channel_positions=channel_positions, - show=False) - show_test(c) diff --git a/phy/plot/traces.py b/phy/plot/traces.py deleted file mode 100644 index 32c2d3d47..000000000 --- a/phy/plot/traces.py +++ /dev/null @@ -1,325 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting traces.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from vispy import gloo - -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - _wrap_vispy, - ) -from ..utils._color import _selected_clusters_colors -from ..utils._types import _as_array -from ..utils.array import _index_of, _unique - - -#------------------------------------------------------------------------------ -# CCG visual -#------------------------------------------------------------------------------ - -class TraceVisual(BaseSpikeVisual): - """Display multi-channel extracellular traces with spikes. - - The visual displays a small portion of the traces at once. There is an - optional offset. - - """ - - _shader_name = 'traces' - _gl_draw_mode = 'line_strip' - - def __init__(self, **kwargs): - super(TraceVisual, self).__init__(**kwargs) - self._traces = None - self._spike_samples = None - self._n_samples_per_spike = None - self._sample_rate = None - self._offset = None - - self.program['u_scale'] = 1. - - # Data properties - # ------------------------------------------------------------------------- - - @property - def traces(self): - """Displayed traces. - - This is a `(n_samples, n_channels)` array. - - """ - return self._traces - - @traces.setter - def traces(self, value): - value = _as_array(value) - assert value.ndim == 2 - self.n_samples, self.n_channels = value.shape - self._traces = value - self._empty = self.n_samples == 0 - self._channel_colors = .5 * np.ones((self.n_channels, 3), - dtype=np.float32) - self.set_to_bake('traces', 'channel_color') - - @property - def channel_colors(self): - """Colors of the displayed channels.""" - return self._channel_colors - - @channel_colors.setter - def channel_colors(self, value): - self._channel_colors = _as_array(value) - assert len(self._channel_colors) == self.n_channels - self.set_to_bake('channel_color') - - @property - def spike_samples(self): - """Time samples of the displayed spikes.""" - return self._spike_samples - - @spike_samples.setter - def spike_samples(self, value): - assert isinstance(value, np.ndarray) - self._set_or_assert_n_spikes(value) - self._spike_samples = value - self.set_to_bake('spikes') - - @property - def n_samples_per_spike(self): - """Number of time samples per displayed spikes.""" - return self._n_samples_per_spike - - @n_samples_per_spike.setter - def n_samples_per_spike(self, value): - self._n_samples_per_spike = int(value) - - @property - def sample_rate(self): - """Sample rate of the recording.""" - return self._sample_rate - - @sample_rate.setter - def sample_rate(self, value): - self._sample_rate = int(value) - - @property - def offset(self): - """Offset of the displayed traces (in time samples).""" - return self._offset - - @offset.setter - def offset(self, value): - self._offset = int(value) - - @property - def channel_scale(self): - """Vertical scaling of the traces.""" - return np.asscalar(self.program['u_scale']) - - @channel_scale.setter - def channel_scale(self, value): - self.program['u_scale'] = value - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_traces(self): - ns, nc = self.n_samples, self.n_channels - - a_index = np.empty((nc * ns, 2), dtype=np.float32) - a_index[:, 0] = np.repeat(np.arange(nc), ns) - a_index[:, 1] = np.tile(np.arange(ns), nc) - - self.program['a_position'] = self._traces.T.ravel().astype(np.float32) - self.program['a_index'] = a_index - self.program['n_channels'] = nc - self.program['n_samples'] = ns - - def _bake_channel_color(self): - u_channel_color = self._channel_colors.reshape((1, - self.n_channels, - -1)) - u_channel_color = (u_channel_color * 255).astype(np.uint8) - self.program['u_channel_color'] = gloo.Texture2D(u_channel_color) - - def _bake_spikes(self): - # Handle the case where there are no spikes. - if self.n_spikes == 0: - a_spike = np.zeros((self.n_channels * self.n_samples, 2), - dtype=np.float32) - a_spike[:, 0] = -1. - self.program['a_spike'] = a_spike - self.program['n_clusters'] = 0 - return - - spike_clusters_idx = self.spike_clusters - spike_clusters_idx = _index_of(spike_clusters_idx, self.cluster_ids) - assert spike_clusters_idx.shape == (self.n_spikes,) - - samples = self._spike_samples - assert samples.shape == (self.n_spikes,) - - # -1 = there's no spike at this vertex - a_clusters = np.empty((self.n_channels, self.n_samples), - dtype=np.float32) - a_clusters.fill(-1.) - a_masks = np.zeros((self.n_channels, self.n_samples), - dtype=np.float32) - masks = self._masks # (n_spikes, n_channels) - - # Add all spikes, one by one. - k = self._n_samples_per_spike // 2 - for i, s in enumerate(samples): - m = masks[i, :] # masks across all channels - channels = (m > 0.) - c = spike_clusters_idx[i] # cluster idx - i = max(s - k, 0) - j = min(s + k, self.n_samples) - a_clusters[channels, i:j] = c - a_masks[channels, i:j] = m[channels, None] - - a_spike = np.empty((self.n_channels * self.n_samples, 2), - dtype=np.float32) - a_spike[:, 0] = a_clusters.ravel() - a_spike[:, 1] = a_masks.ravel() - assert a_spike.dtype == np.float32 - self.program['a_spike'] = a_spike - self.program['n_clusters'] = self.n_clusters - - def _bake_spikes_clusters(self): - self._bake_spikes() - - -class TraceView(BaseSpikeCanvas): - """A VisPy canvas displaying traces.""" - _visual_class = TraceVisual - - def _create_pan_zoom(self): - super(TraceView, self)._create_pan_zoom() - self._pz.aspect = None - self._pz.zmin = .5 - self._pz.xmin = -1. - self._pz.xmax = +1. - self._pz.ymin = -2. - self._pz.ymax = +2. - - def set_data(self, - traces=None, - spike_samples=None, - spike_clusters=None, - n_samples_per_spike=50, - masks=None, - colors=None, - ): - if traces is not None: - assert isinstance(traces, np.ndarray) - assert traces.ndim == 2 - else: - traces = self.visual.traces - # Detrend the traces. - traces = traces - traces.mean(axis=0) - s = traces.std() - if s > 0: - traces /= s - n_samples, n_channels = traces.shape - - if spike_samples is not None: - n_spikes = len(spike_samples) - else: - n_spikes = 0 - - if spike_clusters is None: - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - cluster_ids = _unique(spike_clusters) - n_clusters = len(cluster_ids) - - if masks is None: - masks = np.ones((n_spikes, n_channels), dtype=np.float32) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - self.visual.traces = traces.astype(np.float32) - - if masks is not None: - self.visual.masks = masks - - if n_samples_per_spike is not None: - self.visual.n_samples_per_spike = n_samples_per_spike - - if spike_samples is not None: - assert spike_samples.shape == (n_spikes,) - self.visual.spike_samples = spike_samples - - if spike_clusters is not None: - assert spike_clusters.shape == (n_spikes,) - self.visual.spike_clusters = spike_clusters - - if len(colors): - self.visual.cluster_colors = colors - - self.update() - - @property - def channel_scale(self): - """Vertical scale of the traces.""" - return self.visual.channel_scale - - @channel_scale.setter - def channel_scale(self, value): - self.visual.channel_scale = value - self.update() - - keyboard_shortcuts = { - 'channel_scale_increase': 'ctrl+', - 'channel_scale_decrease': 'ctrl-', - } - - def on_key_press(self, event): - """Handle key press events.""" - key = event.key - ctrl = 'Control' in event.modifiers - - # Box scale. - if ctrl and key in ('+', '-', '='): - coeff = 1.1 - u = self.channel_scale - if key == '-': - self.channel_scale = u / coeff - elif key == '+' or key == '=': - self.channel_scale = u * coeff - - -#------------------------------------------------------------------------------ -# Plotting functions -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_traces(traces, **kwargs): - """Plot traces. - - Parameters - ---------- - - traces : ndarray - The traces to plot. A `(n_samples, n_channels)` array. - spike_samples : ndarray (optional) - A `(n_spikes,)` int array with the spike times in number of samples. - spike_clusters : ndarray (optional) - A `(n_spikes,)` int array with the spike clusters. - masks : ndarray (optional) - A `(n_spikes, n_channels)` float array with the spike masks. - n_samples_per_spike : int - Waveform size in number of samples. - - """ - c = TraceView(keys='interactive') - c.set_data(traces, **kwargs) - return c diff --git a/phy/plot/transform.py b/phy/plot/transform.py new file mode 100644 index 000000000..de6f90fec --- /dev/null +++ b/phy/plot/transform.py @@ -0,0 +1,306 @@ +# -*- coding: utf-8 -*- + +"""Transforms.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from textwrap import dedent + +import numpy as np +from six import string_types + +import logging + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def _wrap_apply(f): + def wrapped(arr, **kwargs): + if arr is None or not len(arr): + return arr + arr = np.atleast_2d(arr) + assert arr.ndim == 2 + assert arr.dtype == np.float64 + out = f(arr, **kwargs) + assert out.dtype == np.float64 + out = np.atleast_2d(out) + assert out.ndim == 2 + assert out.shape[1] == arr.shape[1] + return out + return wrapped + + +def _wrap_glsl(f): + def wrapped(var, **kwargs): + out = f(var, **kwargs) + out = dedent(out).strip() + return out + return wrapped + + +def _glslify(r): + """Transform a string or a n-tuple to a valid GLSL expression.""" + if isinstance(r, string_types): + return r + else: + assert 2 <= len(r) <= 4 + return 'vec{}({})'.format(len(r), ', '.join(map(str, r))) + + +def _minus(value): + if isinstance(value, np.ndarray): + return -value + else: + assert len(value) == 2 + return -value[0], -value[1] + + +def _inverse(value): + if isinstance(value, np.ndarray): + return 1. / value + elif hasattr(value, '__len__'): + assert len(value) == 2 + return 1. / value[0], 1. / value[1] + else: + return 1. / value + + +def subplot_bounds(shape=None, index=None): + i, j = index + n_rows, n_cols = shape + + assert 0 <= i <= n_rows - 1 + assert 0 <= j <= n_cols - 1 + + width = 2.0 / n_cols + height = 2.0 / n_rows + + x = -1.0 + j * width + y = +1.0 - (i + 1) * height + + return [x, y, x + width, y + height] + + +def subplot_bounds_glsl(shape=None, index=None): + x0 = '-1.0 + 2.0 * {i}.y / {s}.y'.format(s=shape, i=index) + y0 = '+1.0 - 2.0 * ({i}.x + 1) / {s}.x'.format(s=shape, i=index) + x1 = '-1.0 + 2.0 * ({i}.y + 1) / {s}.y'.format(s=shape, i=index) + y1 = '+1.0 - 2.0 * ({i}.x) / {s}.x'.format(s=shape, i=index) + return 'vec4({x0}, {y0}, {x1}, {y1})'.format(x0=x0, y0=y0, x1=x1, y1=y1) + + +def pixels_to_ndc(pos, size=None): + """Convert from pixels to normalized device coordinates (in [-1, 1]).""" + pos = np.asarray(pos, dtype=np.float64) + size = np.asarray(size, dtype=np.float64) + pos = pos / (size / 2.) - 1 + # Flip y, because the origin in pixels is at the top left corner of the + # window. + pos[1] = -pos[1] + return pos + + +"""Bounds in Normalized Device Coordinates (NDC).""" +NDC = (-1.0, -1.0, +1.0, +1.0) + + +#------------------------------------------------------------------------------ +# Transforms +#------------------------------------------------------------------------------ + +class BaseTransform(object): + def __init__(self, value=None): + self.value = value + self.apply = _wrap_apply(self.apply) + self.glsl = _wrap_glsl(self.glsl) + + def apply(self, arr): + raise NotImplementedError() + + def glsl(self, var): + raise NotImplementedError() + + def inverse(self): + raise NotImplementedError() + + +class Translate(BaseTransform): + def apply(self, arr, value=None): + assert isinstance(arr, np.ndarray) + value = value if value is not None else self.value + return arr + np.asarray(value) + + def glsl(self, var): + assert var + return """{var} = {var} + {translate};""".format(var=var, + translate=self.value) + + def inverse(self): + if isinstance(self.value, string_types): + return Translate('-' + self.value) + else: + return Translate(_minus(self.value)) + + +class Scale(BaseTransform): + def apply(self, arr, value=None): + value = value if value is not None else self.value + return arr * np.asarray(value) + + def glsl(self, var): + assert var + return """{var} = {var} * {scale};""".format(var=var, scale=self.value) + + def inverse(self): + if isinstance(self.value, string_types): + return Scale('1.0 / ' + self.value) + else: + return Scale(_inverse(self.value)) + + +class Range(BaseTransform): + def __init__(self, from_bounds=None, to_bounds=None): + super(Range, self).__init__() + self.from_bounds = from_bounds if from_bounds is not None else NDC + self.to_bounds = to_bounds if to_bounds is not None else NDC + + def apply(self, arr, from_bounds=None, to_bounds=None): + from_bounds = np.asarray(from_bounds if from_bounds is not None + else self.from_bounds) + to_bounds = np.asarray(to_bounds if to_bounds is not None + else self.to_bounds) + f0 = from_bounds[..., :2] + f1 = from_bounds[..., 2:] + t0 = to_bounds[..., :2] + t1 = to_bounds[..., 2:] + + return t0 + (t1 - t0) * (arr - f0) / (f1 - f0) + + def glsl(self, var): + assert var + + from_bounds = _glslify(self.from_bounds) + to_bounds = _glslify(self.to_bounds) + + return ("{var} = {t}.xy + ({t}.zw - {t}.xy) * " + "({var} - {f}.xy) / ({f}.zw - {f}.xy);" + "").format(var=var, f=from_bounds, t=to_bounds) + + def inverse(self): + return Range(from_bounds=self.to_bounds, + to_bounds=self.from_bounds) + + +class Clip(BaseTransform): + def __init__(self, bounds=None): + super(Clip, self).__init__() + self.bounds = bounds or NDC + + def apply(self, arr, bounds=None): + bounds = bounds if bounds is not None else self.bounds + index = ((arr[:, 0] >= bounds[0]) & + (arr[:, 1] >= bounds[1]) & + (arr[:, 0] <= bounds[2]) & + (arr[:, 1] <= bounds[3])) + return arr[index, ...] + + def glsl(self, var): + assert var + bounds = _glslify(self.bounds) + + return """ + if (({var}.x < {bounds}.x) || + ({var}.y < {bounds}.y) || + ({var}.x > {bounds}.z) || + ({var}.y > {bounds}.w)) {{ + discard; + }} + """.format(bounds=bounds, var=var) + + def inverse(self): + return self + + +class Subplot(Range): + """Assume that the from_bounds is [-1, -1, 1, 1].""" + + def __init__(self, shape, index=None): + super(Subplot, self).__init__() + self.shape = shape + self.index = index + self.from_bounds = NDC + if isinstance(self.shape, tuple) and isinstance(self.index, tuple): + self.to_bounds = subplot_bounds(shape=self.shape, index=self.index) + elif (isinstance(self.shape, string_types) and + isinstance(self.index, string_types)): + self.to_bounds = subplot_bounds_glsl(shape=self.shape, + index=self.index) + + +#------------------------------------------------------------------------------ +# Transform chains +#------------------------------------------------------------------------------ + +class TransformChain(object): + """A linear sequence of transforms that happen on the CPU and GPU.""" + def __init__(self): + self.transformed_var_name = None + self.cpu_transforms = [] + self.gpu_transforms = [] + + def add_on_cpu(self, transforms): + """Add some transforms.""" + if not isinstance(transforms, list): + transforms = [transforms] + self.cpu_transforms.extend(transforms or []) + return self + + def add_on_gpu(self, transforms): + """Add some transforms.""" + if not isinstance(transforms, list): + transforms = [transforms] + self.gpu_transforms.extend(transforms or []) + return self + + def get(self, class_name): + """Get a transform in the chain from its name.""" + for transform in self.cpu_transforms + self.gpu_transforms: + if transform.__class__.__name__ == class_name: + return transform + + def _remove_transform(self, transforms, name): + return [t for t in transforms if t.__class__.__name__ != name] + + def remove(self, name): + """Remove a transform in the chain.""" + cpu_transforms = self._remove_transform(self.cpu_transforms, name) + gpu_transforms = self._remove_transform(self.gpu_transforms, name) + return (TransformChain().add_on_cpu(cpu_transforms). + add_on_gpu(gpu_transforms)) + + def apply(self, arr): + """Apply all CPU transforms on an array.""" + for t in self.cpu_transforms: + arr = t.apply(arr) + return arr + + def inverse(self): + """Return the inverse chain of transforms.""" + transforms = self.cpu_transforms + self.gpu_transforms + inv_transforms = [transform.inverse() + for transform in transforms[::-1]] + return TransformChain().add_on_cpu(inv_transforms) + + def __add__(self, tc): + assert isinstance(tc, TransformChain) + assert tc.transformed_var_name == self.transformed_var_name + self.cpu_transforms.extend(tc.cpu_transforms) + self.gpu_transforms.extend(tc.gpu_transforms) + return self diff --git a/phy/plot/utils.py b/phy/plot/utils.py new file mode 100644 index 000000000..ff5b3e6c4 --- /dev/null +++ b/phy/plot/utils.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- + +"""Plotting/VisPy utilities.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging +import os.path as op + +import numpy as np +from vispy import gloo + +from .transform import Range, NDC + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Box positioning +#------------------------------------------------------------------------------ + +def _boxes_overlap(x0, y0, x1, y1): + n = len(x0) + overlap_matrix = ((x0 < x1.T) & (x1 > x0.T) & (y0 < y1.T) & (y1 > y0.T)) + overlap_matrix[np.arange(n), np.arange(n)] = False + return np.any(overlap_matrix.ravel()) + + +def _binary_search(f, xmin, xmax, eps=1e-9): + """Return the largest x such f(x) is True.""" + middle = (xmax + xmin) / 2. + while xmax - xmin > eps: + assert xmin < xmax + middle = (xmax + xmin) / 2. + if f(xmax): + return xmax + if not f(xmin): + return xmin + if f(middle): + xmin = middle + else: + xmax = middle + return middle + + +def _get_box_size(x, y, ar=.5, margin=0): + + # Deal with degenerate x case. + xmin, xmax = x.min(), x.max() + if xmin == xmax: + # If all positions are vertical, the width can be maximum. + wmax = 1. + else: + wmax = xmax - xmin + + def f1(w): + """Return true if the configuration with the current box size + is non-overlapping.""" + # NOTE: w|h are the *half* width|height. + h = w * ar # fixed aspect ratio + return not _boxes_overlap(x - w, y - h, x + w, y + h) + + # Find the largest box size leading to non-overlapping boxes. + w = _binary_search(f1, 0, wmax) + w = w * (1 - margin) # margin + # Clip the half-width. + h = w * ar # aspect ratio + + return w, h + + +def _get_boxes(pos, size=None, margin=0, keep_aspect_ratio=True): + """Generate non-overlapping boxes in NDC from a set of positions.""" + + # Get x, y. + pos = np.asarray(pos, dtype=np.float64) + x, y = pos.T + x = x[:, np.newaxis] + y = y[:, np.newaxis] + + w, h = size if size is not None else _get_box_size(x, y) + + x0, y0 = x - w, y - h + x1, y1 = x + w, y + h + + # Renormalize the whole thing by keeping the aspect ratio. + x0min, y0min, x1max, y1max = x0.min(), y0.min(), x1.max(), y1.max() + if not keep_aspect_ratio: + b = (x0min, y0min, x1max, y1max) + else: + dx = x1max - x0min + dy = y1max - y0min + if dx > dy: + b = (x0min, (y1max + y0min) / 2. - dx / 2., + x1max, (y1max + y0min) / 2. + dx / 2.) + else: + b = ((x1max + x0min) / 2. - dy / 2., y0min, + (x1max + x0min) / 2. + dy / 2., y1max) + r = Range(from_bounds=b, + to_bounds=(-1, -1, 1, 1)) + return np.c_[r.apply(np.c_[x0, y0]), r.apply(np.c_[x1, y1])] + + +def _get_box_pos_size(box_bounds): + box_bounds = np.asarray(box_bounds) + x0, y0, x1, y1 = box_bounds.T + w = (x1 - x0) * .5 + h = (y1 - y0) * .5 + x = (x0 + x1) * .5 + y = (y0 + y1) * .5 + return np.c_[x, y], (w.mean(), h.mean()) + + +#------------------------------------------------------------------------------ +# Data validation +#------------------------------------------------------------------------------ + +def _get_texture(arr, default, n_items, from_bounds): + """Prepare data to be uploaded as a texture. + + The from_bounds must be specified. + + """ + if not hasattr(default, '__len__'): # pragma: no cover + default = [default] + n_cols = len(default) + if arr is None: # pragma: no cover + arr = np.tile(default, (n_items, 1)) + assert arr.shape == (n_items, n_cols) + # Convert to 3D texture. + arr = arr[np.newaxis, ...].astype(np.float64) + assert arr.shape == (1, n_items, n_cols) + # NOTE: we need to cast the texture to [0., 1.] (float texture). + # This is easy as soon as we assume that the signal bounds are in + # [-1, 1]. + assert len(from_bounds) == 2 + m, M = map(float, from_bounds) + assert np.all(arr >= m) + assert np.all(arr <= M) + arr = (arr - m) / (M - m) + assert np.all(arr >= 0) + assert np.all(arr <= 1.) + return arr + + +def _get_array(val, shape, default=None): + """Ensure an object is an array with the specified shape.""" + assert val is not None or default is not None + if hasattr(val, '__len__') and len(val) == 0: # pragma: no cover + val = None + out = np.zeros(shape, dtype=np.float64) + # This solves `ValueError: could not broadcast input array from shape (n) + # into shape (n, 1)`. + if val is not None and isinstance(val, np.ndarray): + if val.size == out.size: + val = val.reshape(out.shape) + out[...] = val if val is not None else default + assert out.shape == shape + return out + + +def _check_data_bounds(data_bounds): + assert data_bounds.ndim == 2 + assert data_bounds.shape[1] == 4 + assert np.all(data_bounds[:, 0] < data_bounds[:, 2]) + assert np.all(data_bounds[:, 1] < data_bounds[:, 3]) + + +def _get_data_bounds(data_bounds, pos=None, length=None): + """"Prepare data bounds, possibly using min/max of the data.""" + if data_bounds is None: + if pos is not None and len(pos): + m, M = pos.min(axis=0), pos.max(axis=0) + data_bounds = [m[0], m[1], M[0], M[1]] + else: + data_bounds = NDC + data_bounds = np.atleast_2d(data_bounds) + + ind_x = data_bounds[:, 0] == data_bounds[:, 2] + ind_y = data_bounds[:, 1] == data_bounds[:, 3] + if np.sum(ind_x): + data_bounds[ind_x, 0] -= 1 + data_bounds[ind_x, 2] += 1 + if np.sum(ind_y): + data_bounds[ind_y, 1] -= 1 + data_bounds[ind_y, 3] += 1 + + # Extend the data_bounds if needed. + if length is None: + length = pos.shape[0] if pos is not None else 1 + if data_bounds.shape[0] == 1: + data_bounds = np.tile(data_bounds, (length, 1)) + + # Check the shape of data_bounds. + assert data_bounds.shape == (length, 4) + + _check_data_bounds(data_bounds) + return data_bounds + + +def _get_pos(x, y): + assert x is not None + assert y is not None + + x = np.asarray(x, dtype=np.float64) + y = np.asarray(y, dtype=np.float64) + + # Validate the position. + assert x.ndim == y.ndim == 1 + assert x.shape == y.shape + + return x, y + + +def _get_index(n_items, item_size, n): + """Prepare an index attribute for GPU uploading.""" + index = np.arange(n_items) + index = np.repeat(index, item_size) + index = index.astype(np.float64) + assert index.shape == (n,) + return index + + +def _get_linear_x(n_signals, n_samples): + return np.tile(np.linspace(-1., 1., n_samples), (n_signals, 1)) + + +#------------------------------------------------------------------------------ +# Misc +#------------------------------------------------------------------------------ + +def _load_shader(filename): + """Load a shader file.""" + curdir = op.dirname(op.realpath(__file__)) + glsl_path = op.join(curdir, 'glsl') + path = op.join(glsl_path, filename) + with open(path, 'r') as f: + return f.read() + + +def _tesselate_histogram(hist): + """ + + 2/4 3 + ____ + |\ | + | \ | + | \ | + |___\| + + 0 1/5 + + """ + assert hist.ndim == 1 + nsamples = len(hist) + + x0 = np.arange(nsamples) + + x = np.zeros(6 * nsamples) + y = np.zeros(6 * nsamples) + + x[0::2] = np.repeat(x0, 3) + x[1::2] = x[0::2] + 1 + + y[2::6] = y[3::6] = y[4::6] = hist + + return np.c_[x, y] + + +def _enable_depth_mask(): + gloo.set_state(clear_color='black', + depth_test=True, + depth_range=(0., 1.), + # depth_mask='true', + depth_func='lequal', + blend=True, + blend_func=('src_alpha', 'one_minus_src_alpha')) + gloo.set_clear_depth(1.0) diff --git a/phy/plot/visuals.py b/phy/plot/visuals.py new file mode 100644 index 000000000..57dd0f90b --- /dev/null +++ b/phy/plot/visuals.py @@ -0,0 +1,568 @@ +# -*- coding: utf-8 -*- + +"""Common visuals.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import gzip +import os.path as op + +import numpy as np +from six import string_types +from vispy.gloo import Texture2D + +from .base import BaseVisual +from .transform import Range, NDC +from .utils import (_tesselate_histogram, + _get_texture, + _get_array, + _get_data_bounds, + _get_pos, + _get_index, + ) +from phy.utils import Bunch + + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +DEFAULT_COLOR = (0.03, 0.57, 0.98, .75) + + +#------------------------------------------------------------------------------ +# Visuals +#------------------------------------------------------------------------------ + +class ScatterVisual(BaseVisual): + _default_marker_size = 10. + _default_marker = 'disc' + _default_color = DEFAULT_COLOR + _supported_markers = ( + 'arrow', + 'asterisk', + 'chevron', + 'clover', + 'club', + 'cross', + 'diamond', + 'disc', + 'ellipse', + 'hbar', + 'heart', + 'infinity', + 'pin', + 'ring', + 'spade', + 'square', + 'tag', + 'triangle', + 'vbar', + ) + + def __init__(self, marker=None): + super(ScatterVisual, self).__init__() + self.n_points = None + + # Set the marker type. + self.marker = marker or self._default_marker + assert self.marker in self._supported_markers + + self.set_shader('scatter') + self.fragment_shader = self.fragment_shader.replace('%MARKER', + self.marker) + self.set_primitive_type('points') + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def vertex_count(x=None, y=None, pos=None, **kwargs): + return y.size if y is not None else len(pos) + + @staticmethod + def validate(x=None, + y=None, + pos=None, + color=None, + size=None, + depth=None, + data_bounds=None, + ): + if pos is None: + x, y = _get_pos(x, y) + pos = np.c_[x, y] + pos = np.asarray(pos) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + n = pos.shape[0] + + # Validate the data. + color = _get_array(color, (n, 4), ScatterVisual._default_color) + size = _get_array(size, (n, 1), ScatterVisual._default_marker_size) + depth = _get_array(depth, (n, 1), 0) + data_bounds = _get_data_bounds(data_bounds, pos) + assert data_bounds.shape[0] == n + + return Bunch(pos=pos, color=color, size=size, + depth=depth, data_bounds=data_bounds) + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + self.data_range.from_bounds = data.data_bounds + pos_tr = self.transforms.apply(data.pos) + pos_tr = np.c_[pos_tr, data.depth] + self.program['a_position'] = pos_tr.astype(np.float32) + self.program['a_size'] = data.size.astype(np.float32) + self.program['a_color'] = data.color.astype(np.float32) + + +def _as_list(arr): + if isinstance(arr, np.ndarray): + if arr.ndim == 1: + return [arr] + elif arr.ndim == 2: + return list(arr) + assert isinstance(arr, list) + return arr + + +def _min(arr): + return arr.min() if len(arr) else 0 + + +def _max(arr): + return arr.max() if len(arr) else 1 + + +class PlotVisual(BaseVisual): + _default_color = DEFAULT_COLOR + allow_list = ('x', 'y') + + def __init__(self): + super(PlotVisual, self).__init__() + + self.set_shader('plot') + self.set_primitive_type('line_strip') + + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def validate(x=None, + y=None, + color=None, + depth=None, + data_bounds=None, + ): + + assert y is not None + y = _as_list(y) + + if x is None: + x = [np.linspace(-1., 1., len(_)) for _ in y] + x = _as_list(x) + + # Remove empty elements. + assert len(x) == len(y) + + assert [len(_) for _ in x] == [len(_) for _ in y] + + n_signals = len(x) + + if data_bounds is None: + xmin = [_min(_) for _ in x] + ymin = [_min(_) for _ in y] + xmax = [_max(_) for _ in x] + ymax = [_max(_) for _ in y] + data_bounds = np.c_[xmin, ymin, xmax, ymax] + + color = _get_array(color, (n_signals, 4), PlotVisual._default_color) + assert color.shape == (n_signals, 4) + + depth = _get_array(depth, (n_signals, 1), 0) + assert depth.shape == (n_signals, 1) + + data_bounds = _get_data_bounds(data_bounds, length=n_signals) + data_bounds = data_bounds.astype(np.float64) + assert data_bounds.shape == (n_signals, 4) + + return Bunch(x=x, y=y, + color=color, depth=depth, + data_bounds=data_bounds) + + @staticmethod + def vertex_count(y=None, **kwargs): + """Take the output of validate() as input.""" + return y.size if isinstance(y, np.ndarray) else sum(len(_) for _ in y) + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + + assert isinstance(data.y, list) + n_signals = len(data.y) + n_samples = [len(_) for _ in data.y] + n = sum(n_samples) + x = np.concatenate(data.x) if len(data.x) else np.array([]) + y = np.concatenate(data.y) if len(data.y) else np.array([]) + + # Generate the position array. + pos = np.empty((n, 2), dtype=np.float64) + pos[:, 0] = x.ravel() + pos[:, 1] = y.ravel() + assert pos.shape == (n, 2) + + # Generate the color attribute. + color = data.color + assert color.shape == (n_signals, 4) + color = np.repeat(color, n_samples, axis=0) + assert color.shape == (n, 4) + + # Generate signal index. + signal_index = np.repeat(np.arange(n_signals), n_samples) + signal_index = _get_array(signal_index, (n, 1)) + assert signal_index.shape == (n, 1) + + # Transform the positions. + data_bounds = np.repeat(data.data_bounds, n_samples, axis=0) + self.data_range.from_bounds = data_bounds + pos_tr = self.transforms.apply(pos) + + # Position and depth. + depth = np.repeat(data.depth, n_samples, axis=0) + self.program['a_position'] = np.c_[pos_tr, depth].astype(np.float32) + self.program['a_color'] = color.astype(np.float32) + self.program['a_signal_index'] = signal_index.astype(np.float32) + + +class HistogramVisual(BaseVisual): + _default_color = DEFAULT_COLOR + + def __init__(self): + super(HistogramVisual, self).__init__() + + self.set_shader('histogram') + self.set_primitive_type('triangles') + + self.data_range = Range([0, 0, 1, 1]) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def validate(hist=None, + color=None, + ylim=None): + assert hist is not None + hist = np.asarray(hist, np.float64) + if hist.ndim == 1: + hist = hist[None, :] + assert hist.ndim == 2 + n_hists, n_bins = hist.shape + + # Validate the data. + color = _get_array(color, (n_hists, 4), HistogramVisual._default_color) + + # Validate ylim. + if ylim is None: + ylim = hist.max() if hist.size > 0 else 1. + ylim = np.atleast_1d(ylim) + if len(ylim) == 1: + ylim = np.tile(ylim, n_hists) + if ylim.ndim == 1: + ylim = ylim[:, np.newaxis] + assert ylim.shape == (n_hists, 1) + + return Bunch(hist=hist, + ylim=ylim, + color=color, + ) + + @staticmethod + def vertex_count(hist, **kwargs): + hist = np.atleast_2d(hist) + n_hists, n_bins = hist.shape + return 6 * n_hists * n_bins + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + hist = data.hist + + n_hists, n_bins = hist.shape + n = self.vertex_count(hist) + + # NOTE: this must be set *before* `apply_cpu_transforms` such + # that the histogram is correctly normalized. + data_bounds = np.c_[np.zeros((n_hists, 2)), + n_bins * np.ones((n_hists, 1)), + data.ylim] + data_bounds = np.repeat(data_bounds, 6 * n_bins, axis=0) + self.data_range.from_bounds = data_bounds + + # Set the transformed position. + pos = np.vstack(_tesselate_histogram(row) for row in hist) + pos_tr = self.transforms.apply(pos) + assert pos_tr.shape == (n, 2) + self.program['a_position'] = pos_tr.astype(np.float32) + + # Generate the hist index. + hist_index = _get_index(n_hists, n_bins * 6, n) + self.program['a_hist_index'] = hist_index.astype(np.float32) + + # Hist colors. + tex = _get_texture(data.color, self._default_color, n_hists, [0, 1]) + self.program['u_color'] = tex.astype(np.float32) + self.program['n_hists'] = n_hists + + +class TextVisual(BaseVisual): + """Display strings at multiple locations. + + Currently, the color, font family, and font size is not customizable. + + """ + _default_color = (1., 1., 1., 1.) + + def __init__(self, color=None): + super(TextVisual, self).__init__() + self.set_shader('text') + self.set_primitive_type('triangles') + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + # Load the font. + curdir = op.realpath(op.dirname(__file__)) + font_name = 'SourceCodePro-Regular' + font_size = 32 + # The font texture is gzipped. + fn = '%s-%d.npy.gz' % (font_name, font_size) + with gzip.open(op.join(curdir, 'static', fn), 'rb') as f: + self._tex = np.load(f) + with open(op.join(curdir, 'static', 'chars.txt'), 'r') as f: + self._chars = f.read() + self.color = color if color is not None else self._default_color + assert len(self.color) == 4 + + def _get_glyph_indices(self, s): + return [self._chars.index(char) for char in s] + + @staticmethod + def validate(pos=None, text=None, anchor=None, + data_bounds=None): + + if text is None: + text = [] + if isinstance(text, string_types): + text = [text] + if pos is None: + pos = np.zeros((len(text), 2)) + + assert pos is not None + pos = np.atleast_2d(pos) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + n_text = pos.shape[0] + assert len(text) == n_text + + anchor = anchor if anchor is not None else (0., 0.) + anchor = np.atleast_2d(anchor) + if anchor.shape[0] == 1: + anchor = np.repeat(anchor, n_text, axis=0) + assert anchor.ndim == 2 + assert anchor.shape == (n_text, 2) + + # By default, we assume that the coordinates are in NDC. + if data_bounds is None: + data_bounds = NDC + data_bounds = _get_data_bounds(data_bounds, pos) + assert data_bounds.shape[0] == n_text + data_bounds = data_bounds.astype(np.float64) + assert data_bounds.shape == (n_text, 4) + + return Bunch(pos=pos, text=text, anchor=anchor, + data_bounds=data_bounds) + + @staticmethod + def vertex_count(pos=None, **kwargs): + """Take the output of validate() as input.""" + # Total number of glyphs * 6 (6 vertices per glyph). + return sum(map(len, kwargs['text'])) * 6 + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + pos = data.pos.astype(np.float64) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + assert pos.dtype == np.float64 + + # Concatenate all strings. + text = data.text + lengths = list(map(len, text)) + text = ''.join(text) + a_char_index = self._get_glyph_indices(text) + n_glyphs = len(a_char_index) + + tex = self._tex + glyph_height = tex.shape[0] // 6 + glyph_width = tex.shape[1] // 16 + glyph_size = (glyph_width, glyph_height) + + # Position of all glyphs. + a_position = np.repeat(pos, lengths, axis=0) + if not len(lengths): + a_glyph_index = np.zeros((0,)) + else: + a_glyph_index = np.concatenate([np.arange(n) for n in lengths]) + a_quad_index = np.arange(6) + + a_anchor = data.anchor + + a_position = np.repeat(a_position, 6, axis=0) + a_glyph_index = np.repeat(a_glyph_index, 6) + a_quad_index = np.tile(a_quad_index, n_glyphs) + a_char_index = np.repeat(a_char_index, 6) + + a_anchor = np.repeat(a_anchor, lengths, axis=0) + a_anchor = np.repeat(a_anchor, 6, axis=0) + + a_lengths = np.repeat(lengths, lengths) + a_lengths = np.repeat(a_lengths, 6) + + n_vertices = n_glyphs * 6 + assert a_position.shape == (n_vertices, 2) + assert a_glyph_index.shape == (n_vertices,) + assert a_quad_index.shape == (n_vertices,) + assert a_anchor.shape == (n_vertices, 2) + assert a_lengths.shape == (n_vertices,) + + # Transform the positions. + data_bounds = data.data_bounds + data_bounds = np.repeat(data_bounds, lengths, axis=0) + data_bounds = np.repeat(data_bounds, 6, axis=0) + assert data_bounds.shape == (n_vertices, 4) + self.data_range.from_bounds = data_bounds + pos_tr = self.transforms.apply(a_position) + assert pos_tr.shape == (n_vertices, 2) + + self.program['a_position'] = pos_tr.astype(np.float32) + self.program['a_glyph_index'] = a_glyph_index.astype(np.float32) + self.program['a_quad_index'] = a_quad_index.astype(np.float32) + self.program['a_char_index'] = a_char_index.astype(np.float32) + self.program['a_anchor'] = a_anchor.astype(np.float32) + self.program['a_lengths'] = a_lengths.astype(np.float32) + + self.program['u_glyph_size'] = glyph_size + # TODO: color + + self.program['u_tex'] = Texture2D(tex[::-1, :]) + + +class LineVisual(BaseVisual): + """Lines.""" + _default_color = (.3, .3, .3, 1.) + + def __init__(self, color=None): + super(LineVisual, self).__init__() + self.set_shader('line') + self.set_primitive_type('lines') + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def validate(pos=None, color=None, data_bounds=None): + assert pos is not None + pos = np.atleast_2d(pos) + assert pos.ndim == 2 + n_lines = pos.shape[0] + assert pos.shape[1] == 4 + + # Color. + color = _get_array(color, (n_lines, 4), LineVisual._default_color) + + # By default, we assume that the coordinates are in NDC. + if data_bounds is None: + data_bounds = NDC + data_bounds = _get_data_bounds(data_bounds, length=n_lines) + data_bounds = data_bounds.astype(np.float64) + assert data_bounds.shape == (n_lines, 4) + + return Bunch(pos=pos, color=color, data_bounds=data_bounds) + + @staticmethod + def vertex_count(pos=None, **kwargs): + """Take the output of validate() as input.""" + return pos.shape[0] * 2 + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + pos = data.pos + assert pos.ndim == 2 + assert pos.shape[1] == 4 + assert pos.dtype == np.float64 + n_lines = pos.shape[0] + n_vertices = 2 * n_lines + pos = pos.reshape((-1, 2)) + + # Transform the positions. + data_bounds = np.repeat(data.data_bounds, 2, axis=0) + self.data_range.from_bounds = data_bounds + pos_tr = self.transforms.apply(pos) + + # Position. + assert pos_tr.shape == (n_vertices, 2) + self.program['a_position'] = pos_tr.astype(np.float32) + + # Color. + color = np.repeat(data.color, 2, axis=0) + self.program['a_color'] = color.astype(np.float32) + + +class PolygonVisual(BaseVisual): + """Polygon.""" + _default_color = (.5, .5, .5, 1.) + + def __init__(self): + super(PolygonVisual, self).__init__() + self.set_shader('polygon') + self.set_primitive_type('line_loop') + self.data_range = Range(NDC) + self.transforms.add_on_cpu(self.data_range) + + @staticmethod + def validate(pos=None, data_bounds=None): + assert pos is not None + pos = np.atleast_2d(pos) + assert pos.ndim == 2 + assert pos.shape[1] == 2 + + # By default, we assume that the coordinates are in NDC. + if data_bounds is None: + data_bounds = NDC + data_bounds = _get_data_bounds(data_bounds) + data_bounds = data_bounds.astype(np.float64) + assert data_bounds.shape == (1, 4) + + return Bunch(pos=pos, data_bounds=data_bounds) + + @staticmethod + def vertex_count(pos=None, **kwargs): + """Take the output of validate() as input.""" + return pos.shape[0] + + def set_data(self, *args, **kwargs): + data = self.validate(*args, **kwargs) + pos = data.pos + assert pos.ndim == 2 + assert pos.shape[1] == 2 + assert pos.dtype == np.float64 + n_vertices = pos.shape[0] + + # Transform the positions. + self.data_range.from_bounds = data.data_bounds + pos_tr = self.transforms.apply(pos) + + # Position. + assert pos_tr.shape == (n_vertices, 2) + self.program['a_position'] = pos_tr.astype(np.float32) + + self.program['u_color'] = self._default_color diff --git a/phy/plot/waveforms.py b/phy/plot/waveforms.py deleted file mode 100644 index 1b937b402..000000000 --- a/phy/plot/waveforms.py +++ /dev/null @@ -1,512 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Plotting waveforms.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from vispy import gloo -from vispy.gloo import Texture2D - -from ._panzoom import PanZoom -from ._vispy_utils import (BaseSpikeVisual, - BaseSpikeCanvas, - _enable_depth_mask, - _wrap_vispy, - ) -from ..utils._types import _as_array -from ..utils._color import _selected_clusters_colors -from ..utils.array import _index_of, _normalize, _unique -from ..electrode.mea import linear_positions - - -#------------------------------------------------------------------------------ -# Waveform visual -#------------------------------------------------------------------------------ - -class WaveformVisual(BaseSpikeVisual): - """Display waveforms with probe geometry.""" - - _shader_name = 'waveforms' - _gl_draw_mode = 'line_strip' - - def __init__(self, **kwargs): - super(WaveformVisual, self).__init__(**kwargs) - - self._waveforms = None - self.n_channels, self.n_samples = None, None - self._channel_order = None - self._channel_positions = None - - self.program['u_data_scale'] = (.05, .05) - self.program['u_channel_scale'] = (1., 1.) - self.program['u_overlap'] = 0. - self.program['u_alpha'] = 0.5 - _enable_depth_mask() - - # Data properties - # ------------------------------------------------------------------------- - - @property - def waveforms(self): - """Displayed waveforms. - - This is a `(n_spikes, n_samples, n_channels)` array. - - """ - return self._waveforms - - @waveforms.setter - def waveforms(self, value): - # WARNING: when setting new data, waveforms need to be set first. - # n_spikes will be set as a function of waveforms. - value = _as_array(value) - # TODO: support sparse structures - assert value.ndim == 3 - self.n_spikes, self.n_samples, self.n_channels = value.shape - self._waveforms = value - self._empty = self.n_spikes == 0 - self.set_to_bake('spikes', 'spikes_clusters', 'color') - - @property - def channel_positions(self): - """Positions of the channels. - - This is a `(n_channels, 2)` array. - - """ - return self._channel_positions - - @channel_positions.setter - def channel_positions(self, value): - value = _as_array(value) - self._channel_positions = value - self.set_to_bake('channel_positions') - - @property - def channel_order(self): - return self._channel_order - - @channel_order.setter - def channel_order(self, value): - self._channel_order = value - - @property - def alpha(self): - """Alpha transparency (between 0 and 1).""" - return self.program['u_alpha'] - - @alpha.setter - def alpha(self, value): - self.program['u_alpha'] = value - - @property - def box_scale(self): - """Scale of the waveforms. - - This is a pair of scalars. - - """ - return tuple(self.program['u_data_scale']) - - @box_scale.setter - def box_scale(self, value): - assert len(value) == 2 - self.program['u_data_scale'] = value - - @property - def probe_scale(self): - """Scale of the probe. - - This is a pair of scalars. - - """ - return tuple(self.program['u_channel_scale']) - - @probe_scale.setter - def probe_scale(self, value): - assert len(value) == 2 - self.program['u_channel_scale'] = value - - @property - def overlap(self): - """Whether to overlap waveforms.""" - return True if self.program['u_overlap'][0] > .5 else False - - @overlap.setter - def overlap(self, value): - assert value in (True, False) - self.program['u_overlap'] = 1. if value else 0. - - def channel_hover(self, position): - """Return the channel id closest to the mouse pointer. - - Parameters - ---------- - - position : tuple - The normalized coordinates of the mouse pointer, in world - coordinates (in `[-1, 1]`). - - """ - mouse_pos = position / self.probe_scale - # Normalize channel positions. - positions = self.channel_positions.astype(np.float32) - positions = _normalize(positions, keep_ratio=True) - positions = .1 + .8 * positions - positions = 2 * positions - 1 - # Find closest channel. - d = np.sum((positions - mouse_pos[None, :]) ** 2, axis=1) - idx = np.argmin(d) - # if self.channel_order is not None: - # channel_id = self.channel_order[idx] - # WARNING: by convention this is the relative channel index. - return idx - - # Data baking - # ------------------------------------------------------------------------- - - def _bake_channel_positions(self): - # WARNING: channel_positions must be in [0,1] because we have a - # texture. - positions = self.channel_positions.astype(np.float32) - positions = _normalize(positions, keep_ratio=True) - positions = positions.reshape((1, self.n_channels, -1)) - # Rescale a bit and recenter. - positions = .1 + .8 * positions - u_channel_pos = np.dstack((positions, - np.zeros((1, self.n_channels, 1)))) - u_channel_pos = (u_channel_pos * 255).astype(np.uint8) - # TODO: more efficient to update the data from an existing texture - self.program['u_channel_pos'] = Texture2D(u_channel_pos, - wrapping='clamp_to_edge') - - def _bake_spikes(self): - - # Bake masks. - # WARNING: swap channel/time axes in the waveforms array. - waveforms = np.swapaxes(self._waveforms, 1, 2) - assert waveforms.shape == (self.n_spikes, - self.n_channels, - self.n_samples, - ) - masks = np.repeat(self._masks.ravel(), self.n_samples) - data = np.c_[waveforms.ravel(), masks.ravel()].astype(np.float32) - # TODO: more efficient to update the data from an existing VBO - self.program['a_data'] = data - - # TODO: SparseCSR, this should just be 'channel' - self._channels_per_spike = np.tile(np.arange(self.n_channels). - astype(np.float32), - self.n_spikes) - - # TODO: SparseCSR, this should be np.diff(spikes_ptr) - self._n_channels_per_spike = self.n_channels * np.ones(self.n_spikes, - dtype=np.int32) - - self._n_waveforms = np.sum(self._n_channels_per_spike) - - # TODO: precompute this with a maximum number of waveforms? - a_time = np.tile(np.linspace(-1., 1., self.n_samples), - self._n_waveforms).astype(np.float32) - - self.program['a_time'] = a_time - self.program['n_channels'] = self.n_channels - - def _bake_spikes_clusters(self): - # WARNING: needs to be called *after* _bake_spikes(). - if not hasattr(self, '_n_channels_per_spike'): - raise RuntimeError("'_bake_spikes()' needs to be called before " - "'bake_spikes_clusters().") - # Get the spike cluster indices (between 0 and n_clusters-1). - spike_clusters_idx = self.spike_clusters - # We take the cluster order into account here. - spike_clusters_idx = _index_of(spike_clusters_idx, self.cluster_order) - # Generate the box attribute. - assert len(spike_clusters_idx) == len(self._n_channels_per_spike) - a_cluster = np.repeat(spike_clusters_idx, - self._n_channels_per_spike * self.n_samples) - a_channel = np.repeat(self._channels_per_spike, self.n_samples) - a_box = np.c_[a_cluster, a_channel].astype(np.float32) - # TODO: more efficient to update the data from an existing VBO - self.program['a_box'] = a_box - self.program['n_clusters'] = self.n_clusters - - -class WaveformView(BaseSpikeCanvas): - """A VisPy canvas displaying waveforms.""" - _visual_class = WaveformVisual - _arrows = ('Left', 'Right', 'Up', 'Down') - _pm = ('+', '-', '=') - _events = ('channel_click',) - _key_pressed = None - _show_mean = False - - def _create_visuals(self): - super(WaveformView, self)._create_visuals() - self.mean = WaveformVisual() - self.mean.alpha = 1. - - def _create_pan_zoom(self): - self._pz = PanZoom() - self._pz.add(self.visual.program) - self._pz.add(self.mean.program) - self._pz.attach(self) - - def set_data(self, - waveforms=None, - masks=None, - spike_clusters=None, - channel_positions=None, - channel_order=None, - colors=None, - **kwargs - ): - - if waveforms is not None: - assert isinstance(waveforms, np.ndarray) - if waveforms.ndim == 2: - waveforms = waveforms[None, ...] - assert waveforms.ndim == 3 - else: - waveforms = self.visual.waveforms - n_spikes, n_samples, n_channels = waveforms.shape - - if spike_clusters is None: - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - spike_clusters = _as_array(spike_clusters) - cluster_ids = _unique(spike_clusters) - n_clusters = len(cluster_ids) - - if masks is None: - masks = np.ones((n_spikes, n_channels), dtype=np.float32) - - if colors is None: - colors = _selected_clusters_colors(n_clusters) - - if channel_order is None: - channel_order = self.visual.channel_order - if channel_order is None: - channel_order = np.arange(n_channels) - - if channel_positions is None: - channel_positions = self.visual.channel_positions - if channel_positions is None: - channel_positions = linear_positions(n_channels) - - self.visual.waveforms = waveforms.astype(np.float32) - - if masks is not None: - self.visual.masks = masks - - self.visual.spike_clusters = spike_clusters - assert spike_clusters.shape == (n_spikes,) - - if len(colors): - self.visual.cluster_colors = colors - - self.visual.channel_positions = channel_positions - self.visual.channel_order = channel_order - - # Extra parameters. - for k, v in kwargs.items(): - setattr(self, k, v) - - self.update() - - @property - def box_scale(self): - """Scale of the waveforms. - - This is a pair of scalars. - - """ - return self.visual.box_scale - - @box_scale.setter - def box_scale(self, value): - self.visual.box_scale = value - self.mean.box_scale = value - self.update() - - @property - def probe_scale(self): - """Scale of the probe. - - This is a pair of scalars. - - """ - return self.visual.probe_scale - - @probe_scale.setter - def probe_scale(self, value): - self.visual.probe_scale = value - self.mean.probe_scale = value - self.update() - - @property - def alpha(self): - """Opacity.""" - return self.visual.alpha - - @alpha.setter - def alpha(self, value): - self.visual.alpha = value - self.mean.alpha = value - self.update() - - @property - def overlap(self): - """Whether to overlap waveforms.""" - return self.visual.overlap - - @overlap.setter - def overlap(self, value): - self.visual.overlap = value - self.mean.overlap = value - self.update() - - @property - def show_mean(self): - """Whether to show_mean waveforms.""" - return self._show_mean - - @show_mean.setter - def show_mean(self, value): - self._show_mean = value - self.update() - - keyboard_shortcuts = { - 'waveform_scale_increase': ('ctrl+', - 'ctrl+up', - 'shift+wheel up', - ), - 'waveform_scale_decrease': ('ctrl-', - 'ctrl+down', - 'shift+wheel down', - ), - 'waveform_width_increase': ('ctrl+right', 'ctrl+wheel up'), - 'waveform_width_decrease': ('ctrl+left', 'ctrl+wheel down'), - 'probe_width_increase': ('shift+right', 'ctrl+alt+wheel up'), - 'probe_width_decrease': ('shift+left', 'ctrl+alt+wheel down'), - 'probe_height_increase': ('shift+up', 'shift+alt+wheel up'), - 'probe_height_decrease': ('shift+down', 'shift+alt+wheel down'), - 'select_channel': ('ctrl+left click', 'ctrl+right click', 'num+click'), - } - - def on_key_press(self, event): - """Handle key press events.""" - key = event.key - - self._key_pressed = key - - ctrl = 'Control' in event.modifiers - shift = 'Shift' in event.modifiers - - # Box scale. - if ctrl and key in self._arrows + self._pm: - coeff = 1.1 - u, v = self.box_scale - if key == 'Left': - self.box_scale = (u / coeff, v) - elif key == 'Right': - self.box_scale = (u * coeff, v) - elif key in ('Down', '-'): - self.box_scale = (u, v / coeff) - elif key in ('Up', '+', '='): - self.box_scale = (u, v * coeff) - - # Probe scale. - if shift and key in self._arrows: - coeff = 1.1 - u, v = self.probe_scale - if key == 'Left': - self.probe_scale = (u / coeff, v) - elif key == 'Right': - self.probe_scale = (u * coeff, v) - elif key == 'Down': - self.probe_scale = (u, v / coeff) - elif key == 'Up': - self.probe_scale = (u, v * coeff) - - def on_key_release(self, event): - self._key_pressed = None - - def on_mouse_wheel(self, event): - """Handle mouse wheel events.""" - ctrl = 'Control' in event.modifiers - shift = 'Shift' in event.modifiers - alt = 'Alt' in event.modifiers - coeff = 1. + .1 * event.delta[1] - - # Box scale. - if ctrl and not alt: - u, v = self.box_scale - self.box_scale = (u * coeff, v) - if shift and not alt: - u, v = self.box_scale - self.box_scale = (u, v * coeff) - - # Probe scale. - if ctrl and alt: - u, v = self.probe_scale - self.probe_scale = (u * coeff, v) - if shift and alt: - u, v = self.probe_scale - self.probe_scale = (u, v * coeff) - - def on_mouse_press(self, e): - key = self._key_pressed - if 'Control' in e.modifiers or key in map(str, range(10)): - box_idx = int(key.name) if key in map(str, range(10)) else None - # Normalise mouse position. - position = self._pz._normalize(e.pos) - position[1] = -position[1] - zoom = self._pz._zoom_aspect() - pan = self._pz.pan - mouse_pos = ((position / zoom) - pan) - # Find the channel id. - channel_idx = self.visual.channel_hover(mouse_pos) - self.emit("channel_click", - channel_idx=channel_idx, - ax='x' if e.button == 1 else 'y', - box_idx=box_idx, - ) - - def on_draw(self, event): - """Draw the visual.""" - gloo.clear(color=True, depth=True) - if self._show_mean: - self.mean.draw() - else: - self.visual.draw() - - -#------------------------------------------------------------------------------ -# Plotting functions -#------------------------------------------------------------------------------ - -@_wrap_vispy -def plot_waveforms(waveforms, **kwargs): - """Plot waveforms. - - Parameters - ---------- - - waveforms : ndarray - The waveforms to plot. A `(n_spikes, n_samples, n_channels)` array. - spike_clusters : ndarray (optional) - A `(n_spikes,)` int array with the spike clusters. - masks : ndarray (optional) - A `(n_spikes, n_channels)` float array with the spike masks. - channel_positions : ndarray - A `(n_channels, 2)` array with the channel positions. - - """ - c = WaveformView(keys='interactive') - c.set_data(waveforms, **kwargs) - return c diff --git a/phy/scripts/__init__.py b/phy/scripts/__init__.py deleted file mode 100644 index 1e350cdd4..000000000 --- a/phy/scripts/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Main CLI tool.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from .phy_script import main diff --git a/phy/scripts/phy_script.py b/phy/scripts/phy_script.py deleted file mode 100644 index 2cc6d7c0c..000000000 --- a/phy/scripts/phy_script.py +++ /dev/null @@ -1,479 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""phy main CLI tool. - -Usage: - - phy --help - -""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import sys -import os.path as op -import argparse -from textwrap import dedent - -import numpy as np -from six import exec_, string_types - - -#------------------------------------------------------------------------------ -# Parser utilities -#------------------------------------------------------------------------------ - -class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, - argparse.RawDescriptionHelpFormatter): - pass - - -class Parser(argparse.ArgumentParser): - def error(self, message): - sys.stderr.write(message + '\n\n') - self.print_help() - sys.exit(2) - - -_examples = dedent(""" - -examples: - phy -v display the version of phy - phy download hybrid_120sec.dat -o data/ - download a sample raw data file in `data/` - phy describe my_file.kwik - display information about a Kwik dataset - phy spikesort my_params.prm - run the whole suite (spike detection and clustering) - phy detect my_params.prm - run spike detection on a parameters file - phy cluster-auto my_file.kwik - run klustakwik on a dataset (after spike detection) - phy cluster-manual my_file.kwik - run the manual clustering GUI - -""") - - -#------------------------------------------------------------------------------ -# Parser creator -#------------------------------------------------------------------------------ - -class ParserCreator(object): - def __init__(self): - self.create_main() - self.create_download() - self.create_traces() - self.create_describe() - self.create_spikesort() - self.create_detect() - self.create_auto() - self.create_manual() - self.create_notebook() - - @property - def parser(self): - return self._parser - - def _add_sub_parser(self, name, desc): - p = self._subparsers.add_parser(name, help=desc, description=desc) - self._add_options(p) - return p - - def _add_options(self, parser): - parser.add_argument('--debug', '-d', - action='store_true', - help='activate debug logging mode') - - parser.add_argument('--hide-traceback', - action='store_true', - help='hide the traceback for cleaner error ' - 'messages') - - parser.add_argument('--profiler', '-p', - action='store_true', - help='activate the profiler') - - parser.add_argument('--line-profiler', '-lp', - dest='line_profiler', - action='store_true', - help='activate the line-profiler -- you ' - 'need to decorate the functions ' - 'to profile with `@profile` ' - 'in the code') - - parser.add_argument('--ipython', '-i', action='store_true', - help='launch the script in an interactive ' - 'IPython console') - - parser.add_argument('--pdb', action='store_true', - help='activate the Python debugger') - - def create_main(self): - import phy - - desc = sys.modules['phy'].__doc__ - self._parser = Parser(description=desc, - epilog=_examples, - formatter_class=CustomFormatter, - ) - self._parser.add_argument('--version', '-v', - action='version', - version=phy.__version_git__, - help='print the version of phy') - self._add_options(self._parser) - self._subparsers = self._parser.add_subparsers(dest='command', - title='subcommand', - ) - - def create_download(self): - desc = 'download a sample dataset' - p = self._add_sub_parser('download', desc) - p.add_argument('file', help='dataset filename') - p.add_argument('--output-dir', '-o', help='output directory') - p.add_argument('--base', - default='cortexlab', - choices=('cortexlab', 'github'), - help='data repository name: `cortexlab` or `github`', - ) - p.set_defaults(func=download) - - def create_describe(self): - desc = 'describe a `.kwik` file' - p = self._add_sub_parser('describe', desc) - p.add_argument('file', help='path to a `.kwik` file') - p.add_argument('--clustering', default='main', - help='name of the clustering to use') - p.set_defaults(func=describe) - - def create_traces(self): - desc = 'show the traces of a raw data file' - p = self._add_sub_parser('traces', desc) - p.add_argument('file', help='path to a `.kwd` or `.dat` file') - p.add_argument('--interval', - help='detection interval in seconds (e.g. `0,10`)') - p.add_argument('--n-channels', '-n', - help='number of channels in the recording ' - '(only required when using a flat binary file)') - p.add_argument('--dtype', - help='NumPy data type ' - '(only required when using a flat binary file)', - default='int16', - ) - p.add_argument('--sample-rate', '-s', - help='sample rate in Hz ' - '(only required when using a flat binary file)') - p.set_defaults(func=traces) - - def create_spikesort(self): - desc = 'launch the whole spike sorting pipeline on a `.prm` file' - p = self._add_sub_parser('spikesort', desc) - p.add_argument('file', help='path to a `.prm` file') - p.add_argument('--kwik-path', help='filename of the `.kwik` file ' - 'to create (by default, `"experiment_name".kwik`)') - p.add_argument('--overwrite', action='store_true', default=False, - help='overwrite the `.kwik` file ') - p.add_argument('--interval', - help='detection interval in seconds (e.g. `0,10`)') - p.set_defaults(func=spikesort) - - def create_detect(self): - desc = 'launch the spike detection algorithm on a `.prm` file' - p = self._add_sub_parser('detect', desc) - p.add_argument('file', help='path to a `.prm` file') - p.add_argument('--kwik-path', help='filename of the `.kwik` file ' - 'to create (by default, `"experiment_name".kwik`)') - p.add_argument('--overwrite', action='store_true', default=False, - help='overwrite the `.kwik` file ') - p.add_argument('--interval', - help='detection interval in seconds (e.g. `0,10`)') - p.set_defaults(func=detect) - - def create_auto(self): - desc = 'launch the automatic clustering algorithm on a `.kwik` file' - p = self._add_sub_parser('cluster-auto', desc) - p.add_argument('file', help='path to a `.kwik` file') - p.add_argument('--clustering', default='main', - help='name of the clustering to use') - p.add_argument('--channel-group', default=None, - help='channel group to cluster') - p.set_defaults(func=cluster_auto) - - def create_manual(self): - desc = 'launch the manual clustering GUI on a `.kwik` file' - p = self._add_sub_parser('cluster-manual', desc) - p.add_argument('file', help='path to a `.kwik` file') - p.add_argument('--clustering', default='main', - help='name of the clustering to use') - p.add_argument('--channel-group', default=None, - help='channel group to manually cluster') - p.add_argument('--cluster-ids', '-c', - help='list of clusters to select initially') - p.add_argument('--no-store', action='store_true', default=False, - help='do not create the store (faster loading time, ' - 'slower GUI)') - p.set_defaults(func=cluster_manual) - - def create_notebook(self): - # TODO - pass - - def parse(self, args): - try: - return self._parser.parse_args(args) - except SystemExit as e: - if e.code != 0: - raise e - - -#------------------------------------------------------------------------------ -# Subcommand functions -#------------------------------------------------------------------------------ - -def _get_kwik_path(args): - kwik_path = args.file - - if not op.exists(kwik_path): - raise IOError("The file `{}` doesn't exist.".format(kwik_path)) - - return kwik_path - - -def _create_session(args, **kwargs): - from phy.session import Session - kwik_path = _get_kwik_path(args) - session = Session(kwik_path, **kwargs) - return session - - -def describe(args): - from phy.io.kwik import KwikModel - path = _get_kwik_path(args) - model = KwikModel(path, clustering=args.clustering) - return 'model.describe()', dict(model=model) - - -def download(args): - from phy import download_sample_data - download_sample_data(args.file, - output_dir=args.output_dir, - base=args.base, - ) - - -def traces(args): - from vispy.app import run - from phy.plot.traces import TraceView - from phy.io.h5 import open_h5 - from phy.io.traces import read_kwd, read_dat - - path = args.file - if path.endswith('.kwd'): - f = open_h5(args.file) - traces = read_kwd(f) - elif path.endswith(('.dat', '.bin')): - if not args.n_channels: - raise ValueError("Please specify `--n-channels`.") - if not args.dtype: - raise ValueError("Please specify `--dtype`.") - if not args.sample_rate: - raise ValueError("Please specify `--sample-rate`.") - n_channels = int(args.n_channels) - dtype = np.dtype(args.dtype) - traces = read_dat(path, dtype=dtype, n_channels=n_channels) - - start, end = map(int, args.interval.split(',')) - sample_rate = float(args.sample_rate) - start = int(sample_rate * start) - end = int(sample_rate * end) - - c = TraceView(keys='interactive') - c.visual.traces = .01 * traces[start:end, ...] - c.show() - run() - - return None, None - - -def detect(args): - from phy.io import create_kwik - - assert args.file.endswith('.prm') - kwik_path = args.kwik_path - kwik_path = create_kwik(args.file, - overwrite=args.overwrite, - kwik_path=kwik_path) - - interval = args.interval - if interval is not None: - interval = list(map(float, interval.split(','))) - - # Create the session with the newly-created .kwik file. - args.file = kwik_path - session = _create_session(args, use_store=False) - return ('session.detect(interval=interval)', - dict(session=session, interval=interval)) - - -def cluster_auto(args): - from phy.utils._misc import _read_python - from phy.session import Session - - assert args.file.endswith('.prm') - - channel_group = (int(args.channel_group) - if args.channel_group is not None else None) - - params = _read_python(args.file) - kwik_path = params['experiment_name'] + '.kwik' - session = Session(kwik_path) - - ns = dict(session=session, - clustering=args.clustering, - channel_group=channel_group, - ) - cmd = ('session.cluster(' - 'clustering=clustering, ' - 'channel_group=channel_group)') - return (cmd, ns) - - -def spikesort(args): - from phy.io import create_kwik - - assert args.file.endswith('.prm') - kwik_path = args.kwik_path - kwik_path = create_kwik(args.file, - overwrite=args.overwrite, - kwik_path=kwik_path, - ) - # Create the session with the newly-created .kwik file. - args.file = kwik_path - session = _create_session(args, use_store=False) - - interval = args.interval - if interval is not None: - interval = list(map(float, interval.split(','))) - - ns = dict(session=session, - interval=interval, - n_s_clusters=100, # TODO: better handling of KK parameters - ) - cmd = ('session.detect(interval=interval); session.cluster();') - return (cmd, ns) - - -def cluster_manual(args): - channel_group = (int(args.channel_group) - if args.channel_group is not None else None) - session = _create_session(args, - clustering=args.clustering, - channel_group=channel_group, - use_store=not(args.no_store), - ) - cluster_ids = (list(map(int, args.cluster_ids.split(','))) - if args.cluster_ids else None) - - session.model.describe() - - from phy.gui import start_qt_app - start_qt_app() - - gui = session.show_gui(cluster_ids=cluster_ids, show=False) - print("\nPress `ctrl+h` to see the list of keyboard shortcuts.\n") - return 'gui.show()', dict(session=session, gui=gui, requires_qt=True) - - -#------------------------------------------------------------------------------ -# Main functions -#------------------------------------------------------------------------------ - -def main(args=None): - p = ParserCreator() - if args is None: - args = sys.argv[1:] - elif isinstance(args, string_types): - args = args.split(' ') - args = p.parse(args) - if args is None: - return - - if args.profiler or args.line_profiler: - from phy.utils.testing import _enable_profiler, _profile - prof = _enable_profiler(args.line_profiler) - else: - prof = None - - import phy - if args.debug: - phy.debug() - - # Hide the traceback. - if args.hide_traceback: - def exception_handler(exception_type, exception, traceback): - print("{}: {}".format(exception_type.__name__, exception)) - - sys.excepthook = exception_handler - - # Activate IPython debugger. - if args.pdb: - from IPython.core import ultratb - sys.excepthook = ultratb.FormattedTB(mode='Verbose', - color_scheme='Linux', - call_pdb=1, - ) - - func = getattr(args, 'func', None) - if func is None: - p.parser.print_help() - return - - out = func(args) - if not out: - return - cmd, ns = out - if not cmd: - return - requires_qt = ns.pop('requires_qt', False) - requires_vispy = ns.pop('requires_vispy', False) - - # Default variables in namespace. - ns.update(phy=phy, path=args.file) - if 'session' in ns: - ns['model'] = ns['session'].model - - # Interactive mode with IPython. - if args.ipython: - print("\nStarting IPython...") - from IPython import start_ipython - args_ipy = ["-i", "-c='{}'".format(cmd)] - if requires_qt or requires_vispy: - # Activate Qt event loop integration with Qt. - args_ipy += ["--gui=qt"] - start_ipython(args_ipy, user_ns=ns) - else: - if not prof: - exec_(cmd, {}, ns) - else: - _profile(prof, cmd, {}, ns) - - if requires_qt: - # Launch the Qt app. - from phy.gui import run_qt_app - run_qt_app() - elif requires_vispy: - # Launch the VisPy Qt app. - from vispy.app import use_app, run - use_app('pyqt4') - run() - - -#------------------------------------------------------------------------------ -# Entry point -#------------------------------------------------------------------------------ - -if __name__ == '__main__': - main() diff --git a/phy/scripts/tests/__init__.py b/phy/scripts/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/scripts/tests/test_phy_script.py b/phy/scripts/tests/test_phy_script.py deleted file mode 100644 index 763e6ec05..000000000 --- a/phy/scripts/tests/test_phy_script.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*-1 - -"""Tests of the script.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from ..phy_script import ParserCreator, main - - -#------------------------------------------------------------------------------ -# Script tests -#------------------------------------------------------------------------------ - -def test_parse_version(): - p = ParserCreator() - p.parse(['--version']) - - -def test_parse_cluster_manual(): - p = ParserCreator() - args = p.parse(['cluster-manual', 'test', '-i', '--debug']) - assert args.command == 'cluster-manual' - assert args.ipython - assert args.debug - assert not args.profiler - assert not args.line_profiler - - -def test_parse_cluster_auto(): - p = ParserCreator() - args = p.parse(['cluster-auto', 'test', '-lp']) - assert args.command == 'cluster-auto' - assert not args.ipython - assert not args.debug - assert not args.profiler - assert args.line_profiler - - -def test_download(chdir_tempdir): - main('download hybrid_10sec.prm') diff --git a/phy/session/__init__.py b/phy/session/__init__.py deleted file mode 100644 index 741969821..000000000 --- a/phy/session/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -"""Interactive session around a model.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from .session import Session diff --git a/phy/session/default_settings.py b/phy/session/default_settings.py deleted file mode 100644 index 7e6ddc6cb..000000000 --- a/phy/session/default_settings.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Default settings for the session.""" - - -# ----------------------------------------------------------------------------- -# Session settings -# ----------------------------------------------------------------------------- - -def on_open(session): - """You can update the session when a model is opened. - - For example, you can register custom statistics with - `session.register_statistic`. - - """ - pass - - -def on_gui_open(session, gui): - """You can customize a GUI when it is open.""" - pass - - -def on_view_open(gui, view): - """You can customize a view when it is open.""" - pass - - -# ----------------------------------------------------------------------------- -# Misc settings -# ----------------------------------------------------------------------------- - -# Logging level in the log file. -log_file_level = 'debug' diff --git a/phy/session/session.py b/phy/session/session.py deleted file mode 100644 index f3131e474..000000000 --- a/phy/session/session.py +++ /dev/null @@ -1,468 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function - -"""Session structure.""" - - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -import shutil - -import numpy as np - -from ..utils.array import _concatenate -from ..utils.logging import debug, info, FileLogger, unregister, register -from ..utils.settings import _ensure_dir_exists -from ..io.base import BaseSession -from ..io.kwik.model import KwikModel -from ..io.kwik.store_items import create_store -# HACK: avoid Qt import -try: - from ..cluster.manual.gui import ClusterManualGUI -except ImportError: - class ClusterManualGUI(object): - _vm_classes = {} -from ..cluster.algorithms.klustakwik import KlustaKwik -from ..detect.spikedetekt import SpikeDetekt - - -#------------------------------------------------------------------------------ -# Session class -#------------------------------------------------------------------------------ - -def _process_ups(ups): - """This function processes the UpdateInfo instances of the two - undo stacks (clustering and cluster metadata) and concatenates them - into a single UpdateInfo instance.""" - if len(ups) == 0: - return - elif len(ups) == 1: - return ups[0] - elif len(ups) == 2: - up = ups[0] - up.update(ups[1]) - return up - else: - raise NotImplementedError() - - -class Session(BaseSession): - """A manual clustering session. - - This is the main object used for manual clustering. It implements - all common actions: - - * Loading a dataset (`.kwik` file) - * Listing the clusters - * Changing the current channel group or current clustering - * Showing views (waveforms, features, correlograms, etc.) - * Clustering actions: merge, split, undo, redo - * Wizard: cluster quality, best clusters, most similar clusters - * Save back to .kwik - - """ - - _vm_classes = ClusterManualGUI._vm_classes - _gui_classes = {'cluster_manual': ClusterManualGUI} - - def __init__(self, - kwik_path=None, - clustering=None, - channel_group=None, - model=None, - use_store=True, - phy_user_dir=None, - waveform_filter=True, - ): - self._clustering = clustering - self._channel_group = channel_group - self._use_store = use_store - self._file_logger = None - self._waveform_filter = waveform_filter - if kwik_path: - kwik_path = op.realpath(kwik_path) - super(Session, self).__init__(model=model, - path=kwik_path, - phy_user_dir=phy_user_dir, - vm_classes=self._vm_classes, - gui_classes=self._gui_classes, - ) - - def _backup_kwik(self, kwik_path): - """Save a copy of the Kwik file before opening it.""" - if kwik_path is None: - return - backup_kwik_path = kwik_path + '.bak' - if not op.exists(backup_kwik_path): - info("Saving a backup of the Kwik file " - "in {0}.".format(backup_kwik_path)) - shutil.copyfile(kwik_path, backup_kwik_path) - - def _create_model(self, path): - model = KwikModel(path, - clustering=self._clustering, - channel_group=self._channel_group, - waveform_filter=self._waveform_filter, - ) - self._create_logger(path) - return model - - def _create_logger(self, path): - path = op.splitext(path)[0] + '.log' - level = self.settings['log_file_level'] - if not self._file_logger: - self._file_logger = FileLogger(filename=path, level=level) - register(self._file_logger) - - def _save_model(self): - """Save the spike clusters and cluster groups to the Kwik file.""" - groups = {cluster: self.model.cluster_metadata.group(cluster) - for cluster in self.cluster_ids} - self.model.save(self.model.spike_clusters, - groups, - clustering_metadata=self.model.clustering_metadata, - ) - info("Saved {0:s}.".format(self.model.kwik_path)) - - # File-related actions - # ------------------------------------------------------------------------- - - def open(self, kwik_path=None, model=None): - """Open a `.kwik` file.""" - self._backup_kwik(kwik_path) - return super(Session, self).open(model=model, path=kwik_path) - - @property - def kwik_path(self): - """Path to the `.kwik` file.""" - return self.model.path - - @property - def has_unsaved_changes(self): - """Whether there are unsaved changes in the model. - - If true, a prompt message for saving will be displayed when closing - the GUI. - - """ - # TODO - pass - - # Properties - # ------------------------------------------------------------------------- - - @property - def n_spikes(self): - """Number of spikes in the current channel group.""" - return self.model.n_spikes - - @property - def cluster_ids(self): - """Array of all cluster ids used in the current clustering.""" - return self.model.cluster_ids - - @property - def n_clusters(self): - """Number of clusters in the current clustering.""" - return self.model.n_clusters - - # Event callbacks - # ------------------------------------------------------------------------- - - def _create_cluster_store(self): - # Do not create the store if there is only one cluster. - if self.model.n_clusters <= 1 or not self._use_store: - # Just use a mock store. - self.store = create_store(self.model, - self.model.spikes_per_cluster, - ) - return - - # Kwik store in experiment_dir/name.phy/1/main/cluster_store. - store_path = op.join(self.settings.exp_settings_dir, - 'cluster_store', - str(self.model.channel_group), - self.model.clustering - ) - _ensure_dir_exists(store_path) - - # Instantiate the store. - spc = self.model.spikes_per_cluster - cs = self.settings['features_masks_chunk_size'] - wns = self.settings['waveforms_n_spikes_max'] - wes = self.settings['waveforms_excerpt_size'] - self.store = create_store(self.model, - path=store_path, - spikes_per_cluster=spc, - features_masks_chunk_size=cs, - waveforms_n_spikes_max=wns, - waveforms_excerpt_size=wes, - ) - - # Generate the cluster store if it doesn't exist or is invalid. - # If the cluster store already exists and is consistent - # with the data, it is not recreated. - self.store.generate() - - def change_channel_group(self, channel_group): - """Change the current channel group.""" - self._channel_group = channel_group - self.model.channel_group = channel_group - info("Switched to channel group {}.".format(channel_group)) - self.emit('open') - - def change_clustering(self, clustering): - """Change the current clustering.""" - self._clustering = clustering - self.model.clustering = clustering - info("Switched to `{}` clustering.".format(clustering)) - self.emit('open') - - def on_open(self): - self._create_cluster_store() - - def on_close(self): - if self._file_logger: - unregister(self._file_logger) - self._file_logger = None - - def register_statistic(self, func=None, shape=(-1,)): - """Decorator registering a custom cluster statistic. - - Parameters - ---------- - - func : function - A function that takes a cluster index as argument, and returns - some statistics (generally a NumPy array). - - Notes - ----- - - This function will be called on every cluster when a dataset is opened. - It is also automatically called on new clusters when clusters change. - You can access the data from the model and from the cluster store. - - """ - if func is not None: - return self.register_statistic()(func) - - def decorator(func): - - name = func.__name__ - - def _wrapper(cluster): - out = func(cluster) - self.store.memory_store.store(cluster, **{name: out}) - - # Add the statistics. - stats = self.store.items['statistics'] - stats.add(name, _wrapper, shape) - # Register it in the global cluster store. - self.store.register_field(name, 'statistics') - # Compute it on all existing clusters. - stats.store_all(name=name, mode='force') - info("Registered statistic `{}`.".format(name)) - - return decorator - - # Spike sorting - # ------------------------------------------------------------------------- - - def detect(self, traces=None, - interval=None, - algorithm='spikedetekt', - **kwargs): - """Detect spikes in traces. - - Parameters - ---------- - - traces : array - An `(n_samples, n_channels)` array. If unspecified, the Kwik - file's raw data is used. - interval : tuple (optional) - A tuple `(start, end)` (in seconds) where to detect spikes. - algorithm : str - The algorithm name. Only `spikedetekt` currently. - **kwargs : dictionary - Algorithm parameters. - - Returns - ------- - - result : dict - A `{channel_group: tuple}` mapping, where the tuple is: - - * `spike_times` : the spike times (in seconds). - * `masks`: the masks of the spikes `(n_spikes, n_channels)`. - - """ - assert algorithm == 'spikedetekt' - # Create `.phy/spikedetekt/` directory for temporary files. - sd_dir = op.join(self.settings.exp_settings_dir, 'spikedetekt') - _ensure_dir_exists(sd_dir) - # Default interval. - if interval is not None: - (start_sec, end_sec) = interval - sr = self.model.sample_rate - interval_samples = (int(start_sec * sr), - int(end_sec * sr)) - else: - interval_samples = None - # Find the raw traces. - traces = traces if traces is not None else self.model.traces - # Take the parameters in the Kwik file, coming from the PRM file. - params = self.model.metadata - params.update(kwargs) - # Probe parameters required by SpikeDetekt. - params['probe_channels'] = self.model.probe.channels_per_group - params['probe_adjacency_list'] = self.model.probe.adjacency - # Start the spike detection. - debug("Running SpikeDetekt with the following parameters: " - "{}.".format(params)) - sd = SpikeDetekt(tempdir=sd_dir, **params) - out = sd.run_serial(traces, interval_samples=interval_samples) - n_features = params['n_features_per_channel'] - - # Add the spikes in the `.kwik` and `.kwx` files. - for group in out.groups: - spike_samples = _concatenate(out.spike_samples[group]) - n_spikes = len(spike_samples) if spike_samples is not None else 0 - n_channels = sd._n_channels_per_group[group] - self.model.creator.add_spikes(group=group, - spike_samples=spike_samples, - spike_recordings=None, # TODO - masks=out.masks[group], - features=out.features[group], - n_channels=n_channels, - n_features=n_features, - ) - sc = np.zeros(n_spikes, dtype=np.int32) - self.model.creator.add_clustering(group=group, - name='main', - spike_clusters=sc) - self.emit('open') - - if out.groups: - self.change_channel_group(out.groups[0]) - - def cluster(self, - clustering=None, - algorithm='klustakwik', - spike_ids=None, - channel_group=None, - **kwargs): - """Run an automatic clustering algorithm on all or some of the spikes. - - Parameters - ---------- - - clustering : str - The name of the clustering in which to save the results. - algorithm : str - The algorithm name. Only `klustakwik` currently. - spike_ids : array-like - Array of spikes to cluster. - - Returns - ------- - - spike_clusters : array - The spike_clusters assignements returned by the algorithm. - - """ - if clustering is None: - clustering = 'main' - if channel_group is not None: - self.change_channel_group(channel_group) - - kk2_dir = op.join(self.settings.exp_settings_dir, 'klustakwik2') - _ensure_dir_exists(kk2_dir) - - # Take KK2's default parameters. - from klustakwik2.default_parameters import default_parameters - params = default_parameters.copy() - # Update the PRM ones, by filtering them. - params.update({k: v for k, v in self.model.metadata.items() - if k in default_parameters}) - # Update the ones passed to the function. - params.update(kwargs) - - # Original spike_clusters array. - if self.model.spike_clusters is None: - n_spikes = (len(spike_ids) if spike_ids is not None - else self.model.n_spikes) - spike_clusters_orig = np.zeros(n_spikes, dtype=np.int32) - else: - spike_clusters_orig = self.model.spike_clusters.copy() - - # HACK: there needs to be one clustering. - if 'empty' not in self.model.clusterings: - self.model.add_clustering('empty', spike_clusters_orig) - - # Instantiate the KlustaKwik instance. - kk = KlustaKwik(**params) - - # Save the current clustering in the Kwik file. - @kk.connect - def on_iter(sc): - # Update the original spike clusters. - spike_clusters = spike_clusters_orig.copy() - spike_clusters[spike_ids] = sc - # Save to a text file. - path = op.join(kk2_dir, 'spike_clusters.txt') - # Backup. - if op.exists(path): - shutil.copy(path, path + '~') - np.savetxt(path, spike_clusters, fmt='%d') - - info("Running {}...".format(algorithm)) - # Run KK. - sc = kk.cluster(model=self.model, spike_ids=spike_ids) - info("The automatic clustering process has finished.") - - # Save the results in the Kwik file. - spike_clusters = spike_clusters_orig.copy() - spike_clusters[spike_ids] = sc - - # Add a new clustering and switch to it. - if clustering in self.model.clusterings: - self.change_clustering('empty') - self.model.delete_clustering(clustering) - self.model.add_clustering(clustering, spike_clusters) - - # Copy the main clustering to original (only if this is the very - # first run of the clustering algorithm). - if clustering == 'main': - self.model.copy_clustering('main', 'original') - self.change_clustering(clustering) - - # Set the new clustering metadata. - params = kk.params - params['version'] = kk.version - metadata = {'{}_{}'.format(algorithm, name): value - for name, value in params.items()} - self.model.clustering_metadata.update(metadata) - self.save() - info("The clustering has been saved in the " - "`{}` clustering in the `.kwik` file.".format(clustering)) - self.model.delete_clustering('empty') - return sc - - # GUI - # ------------------------------------------------------------------------- - - def show_gui(self, **kwargs): - """Show a GUI.""" - gui = super(Session, self).show_gui(store=self.store, - **kwargs) - - @gui.connect - def on_request_save(): - self.save() - - return gui diff --git a/phy/session/tests/__init__.py b/phy/session/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/phy/session/tests/test_session.py b/phy/session/tests/test_session.py deleted file mode 100644 index db44052ad..000000000 --- a/phy/session/tests/test_session.py +++ /dev/null @@ -1,401 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of session structure.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -import numpy as np -from numpy.testing import assert_allclose as ac -from numpy.testing import assert_equal as ae -from pytest import raises, yield_fixture, mark - -from ..session import Session -from ...utils.array import _spikes_in_clusters -from ...utils.logging import set_level -from ...io.mock import MockModel, artificial_traces -from ...io.kwik.mock import create_mock_kwik -from ...io.kwik.creator import create_kwik - - -# Skip these tests in "make test-quick". -pytestmark = mark.long() - - -#------------------------------------------------------------------------------ -# Fixtures -#------------------------------------------------------------------------------ - -def setup(): - set_level('debug') - - -def teardown(): - set_level('info') - - -n_clusters = 5 -n_spikes = 50 -n_channels = 28 -n_fets = 2 -n_samples_traces = 3000 - - -def _start_manual_clustering(kwik_path=None, - model=None, - tempdir=None, - chunk_size=None, - ): - session = Session(phy_user_dir=tempdir) - session.open(kwik_path=kwik_path, model=model) - session.settings['waveforms_scale_factor'] = 1. - session.settings['features_scale_factor'] = 1. - session.settings['traces_scale_factor'] = 1. - session.settings['prompt_save_on_exit'] = False - if chunk_size is not None: - session.settings['features_masks_chunk_size'] = chunk_size - return session - - -@yield_fixture -def session(tempdir): - # Create the test HDF5 file in the temporary directory. - kwik_path = create_mock_kwik(tempdir, - n_clusters=n_clusters, - n_spikes=n_spikes, - n_channels=n_channels, - n_features_per_channel=n_fets, - n_samples_traces=n_samples_traces) - - session = _start_manual_clustering(kwik_path=kwik_path, - tempdir=tempdir) - session.tempdir = tempdir - - yield session - - session.close() - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_store_corruption(tempdir): - # Create the test HDF5 file in the temporary directory. - kwik_path = create_mock_kwik(tempdir, - n_clusters=n_clusters, - n_spikes=n_spikes, - n_channels=n_channels, - n_features_per_channel=n_fets, - n_samples_traces=n_samples_traces) - - session = Session(kwik_path, phy_user_dir=tempdir) - store_path = session.store.path - session.close() - - # Corrupt a file in the store. - fn = op.join(store_path, 'waveforms_spikes') - with open(fn, 'rb') as f: - contents = f.read() - with open(fn, 'wb') as f: - f.write(contents[1:-1]) - - session = Session(kwik_path, phy_user_dir=tempdir) - session.close() - - -def test_session_one_cluster(tempdir): - session = Session(phy_user_dir=tempdir) - # The disk store is not created if there is only one cluster. - session.open(model=MockModel(n_clusters=1)) - assert session.store.disk_store is None - - -def test_session_store_features(tempdir): - """Check that the cluster store works for features and masks.""" - - model = MockModel(n_spikes=50, n_clusters=3) - s0 = np.nonzero(model.spike_clusters == 0)[0] - s1 = np.nonzero(model.spike_clusters == 1)[0] - - session = _start_manual_clustering(model=model, - tempdir=tempdir, - chunk_size=4, - ) - - f = session.store.features(0) - m = session.store.masks(1) - w = session.store.waveforms(1) - - assert f.shape == (len(s0), 28, 2) - assert m.shape == (len(s1), 28,) - assert w.shape == (len(s1), model.n_samples_waveforms, 28,) - - ac(f, model.features[s0].reshape((f.shape[0], -1, 2)), 1e-3) - ac(m, model.masks[s1], 1e-3) - - -def test_session_gui_clustering(qtbot, session): - - cs = session.store - spike_clusters = session.model.spike_clusters.copy() - - f = session.model.features - m = session.model.masks - - def _check_arrays(cluster, clusters_for_sc=None, spikes=None): - """Check the features and masks in the cluster store - of a given custer.""" - if spikes is None: - if clusters_for_sc is None: - clusters_for_sc = [cluster] - spikes = _spikes_in_clusters(spike_clusters, clusters_for_sc) - shape = (len(spikes), - len(session.model.channel_order), - session.model.n_features_per_channel) - ac(cs.features(cluster), f[spikes, :].reshape(shape)) - ac(cs.masks(cluster), m[spikes]) - - _check_arrays(0) - _check_arrays(2) - - gui = session.show_gui() - qtbot.addWidget(gui.main_window) - - # Merge two clusters. - clusters = [0, 2] - gui.merge(clusters) # Create cluster 5. - _check_arrays(5, clusters) - - # Split some spikes. - spikes = [2, 3, 5, 7, 11, 13] - # clusters = np.unique(spike_clusters[spikes]) - gui.split(spikes) # Create cluster 6 and more. - _check_arrays(6, spikes=spikes) - - # Undo. - gui.undo() - _check_arrays(5, clusters) - - # Undo. - gui.undo() - _check_arrays(0) - _check_arrays(2) - - # Redo. - gui.redo() - _check_arrays(5, clusters) - - # Split some spikes. - spikes = [5, 7, 11, 13, 17, 19] - # clusters = np.unique(spike_clusters[spikes]) - gui.split(spikes) # Create cluster 6 and more. - _check_arrays(6, spikes=spikes) - - # Test merge-undo-different-merge combo. - spc = gui.clustering.spikes_per_cluster.copy() - clusters = gui.cluster_ids[:3] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - # Undo. - gui.undo() - for cluster in clusters: - _check_arrays(cluster, spikes=spc[cluster]) - # Another merge. - clusters = gui.cluster_ids[1:5] - up = gui.merge(clusters) - _check_arrays(up.added[0], spikes=up.spike_ids) - - # Move a cluster to a group. - cluster = gui.cluster_ids[0] - gui.move([cluster], 2) - assert len(gui.store.mean_probe_position(cluster)) == 2 - - # Save. - spike_clusters_new = gui.model.spike_clusters.copy() - # Check that the spike clusters have changed. - assert not np.all(spike_clusters_new == spike_clusters) - ac(session.model.spike_clusters, gui.clustering.spike_clusters) - session.save() - - # Re-open the file and check that the spike clusters and - # cluster groups have correctly been saved. - session = _start_manual_clustering(kwik_path=session.model.path, - tempdir=session.tempdir) - ac(session.model.spike_clusters, gui.clustering.spike_clusters) - ac(session.model.spike_clusters, spike_clusters_new) - #  Check the cluster groups. - clusters = gui.clustering.cluster_ids - groups = session.model.cluster_groups - assert groups[cluster] == 2 - - gui.close() - - -def test_session_gui_multiple_clusterings(qtbot, session): - - gui = session.show_gui() - qtbot.addWidget(gui.main_window) - - assert session.model.n_spikes == n_spikes - assert session.model.n_clusters == n_clusters - assert len(session.model.cluster_ids) == n_clusters - assert gui.clustering.n_clusters == n_clusters - assert session.model.cluster_metadata.group(1) == 1 - - # Change clustering. - with raises(ValueError): - session.change_clustering('automat') - session.change_clustering('original') - - n_clusters_2 = session.model.n_clusters - assert session.model.n_spikes == n_spikes - assert session.model.n_clusters == n_clusters_2 - assert len(session.model.cluster_ids) == n_clusters_2 - assert gui.clustering.n_clusters == n_clusters_2 - assert session.model.cluster_metadata.group(2) == 2 - - # Merge the clusters and save, for the current clustering. - gui.clustering.merge(gui.clustering.cluster_ids) - session.save() - - # Re-open the session. - session = _start_manual_clustering(kwik_path=session.model.path, - tempdir=session.tempdir) - - # The default clustering is the main one: nothing should have - # changed here. - assert session.model.n_clusters == n_clusters - - session.change_clustering('original') - assert session.model.n_spikes == n_spikes - assert session.model.n_clusters == 1 - assert session.model.cluster_ids == n_clusters_2 - - gui.close() - - -def test_session_kwik(session): - - # Check backup. - assert op.exists(op.join(session.tempdir, session.kwik_path + '.bak')) - - cs = session.store - nc = n_channels - 2 - - # Check the stored items. - for cluster in range(n_clusters): - n_spikes = len(session.model.spikes_per_cluster[cluster]) - n_unmasked_channels = cs.n_unmasked_channels(cluster) - - assert cs.features(cluster).shape == (n_spikes, nc, n_fets) - assert cs.masks(cluster).shape == (n_spikes, nc) - assert cs.mean_masks(cluster).shape == (nc,) - assert n_unmasked_channels <= nc - assert cs.mean_probe_position(cluster).shape == (2,) - assert cs.main_channels(cluster).shape == (n_unmasked_channels,) - - -def test_session_gui_statistics(qtbot, session): - """Test registration of new statistic.""" - - gui = session.show_gui() - qtbot.addWidget(gui.main_window) - - @session.register_statistic - def n_spikes_2(cluster): - return gui.clustering.cluster_counts.get(cluster, 0) ** 2 - - store = gui.store - stats = store.items['statistics'] - - def _check(): - for clu in gui.cluster_ids: - assert (store.n_spikes_2(clu) == - store.features(clu).shape[0] ** 2) - - assert 'n_spikes_2' in stats.fields - _check() - - # Merge the clusters and check that the statistics has been - # recomputed for the new cluster. - clusters = gui.cluster_ids - gui.merge(clusters) - _check() - assert gui.cluster_ids == [max(clusters) + 1] - - gui.undo() - _check() - - gui.merge(gui.cluster_ids[::2]) - _check() - - gui.close() - - -@mark.parametrize('spike_ids', [None, np.arange(20)]) -def test_session_auto(session, spike_ids): - set_level('info') - sc = session.cluster(num_starting_clusters=10, - spike_ids=spike_ids, - clustering='test', - ) - assert session.model.clustering == 'test' - assert session.model.clusterings == ['main', - 'original', 'test'] - - # Re-open the dataset and check that the clustering has been saved. - session = _start_manual_clustering(kwik_path=session.model.path, - tempdir=session.tempdir) - if spike_ids is None: - spike_ids = slice(None, None, None) - assert len(session.model.spike_clusters) == n_spikes - assert not np.all(session.model.spike_clusters[spike_ids] == sc) - - assert 'klustakwik_version' not in session.model.clustering_metadata - kk_dir = op.join(session.settings.exp_settings_dir, 'klustakwik2') - - # Check temporary files with latest clustering. - files = os.listdir(kk_dir) - assert 'spike_clusters.txt' in files - if len(files) >= 2: - assert 'spike_clusters.txt~' in files - sc_txt = np.loadtxt(op.join(kk_dir, 'spike_clusters.txt')) - ae(sc, sc_txt[spike_ids]) - - for clustering in ('test',): - session.change_clustering(clustering) - assert len(session.model.spike_clusters) == n_spikes - assert np.all(session.model.spike_clusters[spike_ids] == sc) - - # Test clustering metadata. - session.change_clustering('test') - metadata = session.model.clustering_metadata - assert 'klustakwik_version' in metadata - assert metadata['klustakwik_num_starting_clusters'] == 10 - - -def test_session_detect(tempdir): - channels = range(n_channels) - graph = [[i, i + 1] for i in range(n_channels - 1)] - probe = {'channel_groups': { - 0: {'channels': channels, - 'graph': graph, - }}} - sample_rate = 10000 - n_samples_traces = 10000 - traces = artificial_traces(n_samples_traces, n_channels) - assert traces is not None - - kwik_path = op.join(tempdir, 'test.kwik') - create_kwik(kwik_path=kwik_path, probe=probe, sample_rate=sample_rate) - session = Session(kwik_path, phy_user_dir=tempdir) - session.detect(traces=traces) - m = session.model - if m.n_spikes > 0: - shape = (m.n_spikes, n_channels * m.n_features_per_channel) - assert m.features.shape == shape diff --git a/phy/stats/__init__.py b/phy/stats/__init__.py index 53deb63f3..67522338b 100644 --- a/phy/stats/__init__.py +++ b/phy/stats/__init__.py @@ -3,4 +3,4 @@ """Statistics functions.""" -from .ccg import pairwise_correlograms +from .ccg import correlograms diff --git a/phy/stats/ccg.py b/phy/stats/ccg.py index 3121df132..c6a038c01 100644 --- a/phy/stats/ccg.py +++ b/phy/stats/ccg.py @@ -8,8 +8,8 @@ import numpy as np -from ..utils._types import _as_array -from ..utils.array import _index_of, _unique +from phy.utils._types import _as_array +from phy.io.array import _index_of, _unique #------------------------------------------------------------------------------ @@ -36,26 +36,49 @@ def _create_correlograms_array(n_clusters, winsize_bins): dtype=np.int32) -def correlograms(spike_samples, spike_clusters, - cluster_order=None, - binsize=None, winsize_bins=None): +def _symmetrize_correlograms(correlograms): + """Return the symmetrized version of the CCG arrays.""" + + n_clusters, _, n_bins = correlograms.shape + assert n_clusters == _ + + # We symmetrize c[i, j, 0]. + # This is necessary because the algorithm in correlograms() + # is sensitive to the order of identical spikes. + correlograms[..., 0] = np.maximum(correlograms[..., 0], + correlograms[..., 0].T) + + sym = correlograms[..., 1:][..., ::-1] + sym = np.transpose(sym, (1, 0, 2)) + + return np.dstack((sym, correlograms)) + + +def correlograms(spike_times, + spike_clusters, + cluster_ids=None, + sample_rate=1., + bin_size=None, + window_size=None, + symmetrize=True, + ): """Compute all pairwise cross-correlograms among the clusters appearing in `spike_clusters`. Parameters ---------- - spike_samples : array-like - Spike times in samples (integers). + spike_times : array-like + Spike times in seconds. spike_clusters : array-like Spike-cluster mapping. - cluster_order : array-like + cluster_ids : array-like The list of unique clusters, in any order. That order will be used in the output array. - binsize : int - Number of time samples in one bin. - winsize_bins : int (odd number) - Number of bins in the window. + bin_size : float + Size of the bin, in seconds. + window_size : float + Size of the window, in seconds. Returns ------- @@ -64,37 +87,37 @@ def correlograms(spike_samples, spike_clusters, A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise CCGs. - Notes - ----- - - If winsize_samples is the (odd) number of time samples in the window - then: - - winsize_bins = 2 * ((winsize_samples // 2) // binsize) + 1 - assert winsize_bins % 2 == 1 - - For performance reasons, it is recommended to compute the CCGs on a subset - with only a few thousands or tens of thousands of spikes. - """ + assert sample_rate > 0. + assert np.all(np.diff(spike_times) >= 0), ("The spike times must be " + "increasing.") - spike_clusters = _as_array(spike_clusters) - spike_samples = _as_array(spike_samples) - if spike_samples.dtype in (np.int32, np.int64): - spike_samples = spike_samples.astype(np.uint64) + # Get the spike samples. + spike_times = np.asarray(spike_times, dtype=np.float64) + spike_samples = (spike_times * sample_rate).astype(np.int64) - assert spike_samples.dtype == np.uint64 + spike_clusters = _as_array(spike_clusters) assert spike_samples.ndim == 1 assert spike_samples.shape == spike_clusters.shape + # Find `binsize`. + bin_size = np.clip(bin_size, 1e-5, 1e5) # in seconds + binsize = int(sample_rate * bin_size) # in samples + assert binsize >= 1 + + # Find `winsize_bins`. + window_size = np.clip(window_size, 1e-5, 1e5) # in seconds + winsize_bins = 2 * int(.5 * window_size / bin_size) + 1 + + assert winsize_bins >= 1 assert winsize_bins % 2 == 1 # Take the cluster oder into account. - if cluster_order is None: + if cluster_ids is None: clusters = _unique(spike_clusters) else: - clusters = _as_array(cluster_order) + clusters = _as_array(cluster_ids) n_clusters = len(clusters) # Like spike_clusters, but with 0..n_clusters-1 indices. @@ -148,45 +171,7 @@ def correlograms(spike_samples, spike_clusters, np.arange(n_clusters), 0] = 0 - return correlograms - - -#------------------------------------------------------------------------------ -# Helper functions for CCG data structures -#------------------------------------------------------------------------------ - -def _symmetrize_correlograms(correlograms): - """Return the symmetrized version of the CCG arrays.""" - - n_clusters, _, n_bins = correlograms.shape - assert n_clusters == _ - - # We symmetrize c[i, j, 0]. - # This is necessary because the algorithm in correlograms() - # is sensitive to the order of identical spikes. - correlograms[..., 0] = np.maximum(correlograms[..., 0], - correlograms[..., 0].T) - - sym = correlograms[..., 1:][..., ::-1] - sym = np.transpose(sym, (1, 0, 2)) - - return np.dstack((sym, correlograms)) - - -def pairwise_correlograms(spike_samples, - spike_clusters, - binsize=None, - winsize_bins=None, - ): - """Compute all pairwise correlograms in a set of neurons. - - TODO: improve interface and documentation. - - """ - ccgs = correlograms(spike_samples, - spike_clusters, - binsize=binsize, - winsize_bins=winsize_bins, - ) - ccgs = _symmetrize_correlograms(ccgs) - return ccgs + if symmetrize: + return _symmetrize_correlograms(correlograms) + else: + return correlograms diff --git a/phy/stats/clusters.py b/phy/stats/clusters.py new file mode 100644 index 000000000..72bf3d25e --- /dev/null +++ b/phy/stats/clusters.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +"""Cluster statistics.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np + + +#------------------------------------------------------------------------------ +# Cluster statistics +#------------------------------------------------------------------------------ + +def mean(x): + return x.mean(axis=0) + + +def get_unmasked_channels(mean_masks, min_mask=.25): + return np.nonzero(mean_masks > min_mask)[0] + + +def get_mean_probe_position(mean_masks, site_positions): + return (np.sum(site_positions * mean_masks[:, np.newaxis], axis=0) / + max(1, np.sum(mean_masks))) + + +def get_sorted_main_channels(mean_masks, unmasked_channels): + # Weighted mean of the channels, weighted by the mean masks. + main_channels = np.argsort(mean_masks)[::-1] + main_channels = np.array([c for c in main_channels + if c in unmasked_channels]) + return main_channels + + +#------------------------------------------------------------------------------ +# Wizard measures +#------------------------------------------------------------------------------ + +def get_waveform_amplitude(mean_masks, mean_waveforms): + """Return the amplitude of the waveforms on all channels.""" + + assert mean_waveforms.ndim == 2 + n_samples, n_channels = mean_waveforms.shape + + assert mean_masks.ndim == 1 + assert mean_masks.shape == (n_channels,) + + mean_waveforms = mean_waveforms * mean_masks + assert mean_waveforms.shape == (n_samples, n_channels) + + # Amplitudes. + m, M = mean_waveforms.min(axis=0), mean_waveforms.max(axis=0) + return M - m + + +def get_mean_masked_features_distance(mean_features_0, + mean_features_1, + mean_masks_0, + mean_masks_1, + n_features_per_channel=None, + ): + """Compute the distance between the mean masked features.""" + + assert n_features_per_channel > 0 + + mu_0 = mean_features_0.ravel() + mu_1 = mean_features_1.ravel() + + omeg_0 = mean_masks_0 + omeg_1 = mean_masks_1 + + omeg_0 = np.repeat(omeg_0, n_features_per_channel) + omeg_1 = np.repeat(omeg_1, n_features_per_channel) + + d_0 = mu_0 * omeg_0 + d_1 = mu_1 * omeg_1 + + return np.linalg.norm(d_0 - d_1) diff --git a/phy/stats/tests/test_ccg.py b/phy/stats/tests/test_ccg.py index 3f7eb2348..2c5affe7c 100644 --- a/phy/stats/tests/test_ccg.py +++ b/phy/stats/tests/test_ccg.py @@ -12,7 +12,6 @@ from ..ccg import (_increment, _diff_shifted, correlograms, - _symmetrize_correlograms, ) @@ -30,15 +29,7 @@ def _random_data(max_cluster): def _ccg_params(): - # window = 50 ms - winsize_samples = 2 * (25 * 20) + 1 - # bin = 1 ms - binsize = 1 * 20 - # 51 bins - winsize_bins = 2 * ((winsize_samples // 2) // binsize) + 1 - assert winsize_bins % 2 == 1 - - return binsize, winsize_bins + return .001, .05 def test_utils(): @@ -78,7 +69,8 @@ def test_ccg_0(): c_expected[0, 1, 0] = 0 # This is a peculiarity of the algorithm. c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + cluster_ids=[0, 1], symmetrize=False) ae(c, c_expected) @@ -94,7 +86,8 @@ def test_ccg_1(): c_expected[0, 0, 2] = 1 c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + symmetrize=False) ae(c, c_expected) @@ -105,7 +98,8 @@ def test_ccg_2(): binsize, winsize_bins = _ccg_params() c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) assert c.shape == (max_cluster, max_cluster, 26) @@ -117,14 +111,16 @@ def test_ccg_symmetry_time(): binsize, winsize_bins = _ccg_params() c0 = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) spike_samples_1 = np.cumsum(np.r_[np.arange(1), np.diff(spike_samples)[::-1]]) spike_samples_1 = spike_samples_1.astype(np.uint64) spike_clusters_1 = spike_clusters[::-1] c1 = correlograms(spike_samples_1, spike_clusters_1, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) # The ACGs are identical. ae(c0[0, 0], c1[0, 0]) @@ -142,11 +138,13 @@ def test_ccg_symmetry_clusters(): binsize, winsize_bins = _ccg_params() c0 = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) spike_clusters_1 = 1 - spike_clusters c1 = correlograms(spike_samples, spike_clusters_1, - binsize=binsize, winsize_bins=winsize_bins) + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000, symmetrize=False) # The ACGs are identical. ae(c0[0, 0], c1[1, 1]) @@ -161,10 +159,9 @@ def test_symmetrize_correlograms(): spike_samples, spike_clusters = _random_data(3) binsize, winsize_bins = _ccg_params() - c = correlograms(spike_samples, spike_clusters, - binsize=binsize, winsize_bins=winsize_bins) - - sym = _symmetrize_correlograms(c) + sym = correlograms(spike_samples, spike_clusters, + bin_size=binsize, window_size=winsize_bins, + sample_rate=20000) assert sym.shape == (3, 3, 51) # The ACG are reversed. diff --git a/phy/stats/tests/test_clusters.py b/phy/stats/tests/test_clusters.py new file mode 100644 index 000000000..643f9f80d --- /dev/null +++ b/phy/stats/tests/test_clusters.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- + +"""Tests of cluster statistics.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import numpy as np +from numpy.testing import assert_array_equal as ae +from numpy.testing import assert_allclose as ac +from pytest import yield_fixture + +from ..clusters import (mean, + get_unmasked_channels, + get_mean_probe_position, + get_sorted_main_channels, + get_mean_masked_features_distance, + get_waveform_amplitude, + ) +from phy.electrode.mea import staggered_positions +from phy.io.mock import (artificial_features, + artificial_masks, + artificial_waveforms, + ) + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def n_channels(): + yield 28 + + +@yield_fixture +def n_spikes(): + yield 50 + + +@yield_fixture +def n_samples(): + yield 40 + + +@yield_fixture +def n_features_per_channel(): + yield 4 + + +@yield_fixture +def features(n_spikes, n_channels, n_features_per_channel): + yield artificial_features(n_spikes, n_channels, n_features_per_channel) + + +@yield_fixture +def masks(n_spikes, n_channels): + yield artificial_masks(n_spikes, n_channels) + + +@yield_fixture +def waveforms(n_spikes, n_samples, n_channels): + yield artificial_waveforms(n_spikes, n_samples, n_channels) + + +@yield_fixture +def site_positions(n_channels): + yield staggered_positions(n_channels) + + +#------------------------------------------------------------------------------ +# Tests +#------------------------------------------------------------------------------ + +def test_mean(features, n_channels, n_features_per_channel): + mf = mean(features) + assert mf.shape == (n_channels, n_features_per_channel) + ae(mf, features.mean(axis=0)) + + +def test_unmasked_channels(masks, n_channels): + # Mask many values in the masks array. + threshold = .05 + masks[:, 1::2] *= threshold + # Compute the mean masks. + mean_masks = mean(masks) + # Find the unmasked channels. + channels = get_unmasked_channels(mean_masks, threshold) + # These are 0, 2, 4, etc. + ae(channels, np.arange(0, n_channels, 2)) + + +def test_mean_probe_position(masks, site_positions): + masks[:, ::2] *= .05 + mean_masks = mean(masks) + mean_pos = get_mean_probe_position(mean_masks, site_positions) + assert mean_pos.shape == (2,) + assert mean_pos[0] < 0 + assert mean_pos[1] > 0 + + +def test_sorted_main_channels(masks): + masks *= .05 + masks[:, [5, 7]] *= 20 + mean_masks = mean(masks) + channels = get_sorted_main_channels(mean_masks, + get_unmasked_channels(mean_masks)) + assert np.all(np.in1d(channels, [5, 7])) + + +def test_waveform_amplitude(masks, waveforms): + waveforms *= .1 + masks *= .1 + + waveforms[:, 10, :] *= 10 + masks[:, 10] *= 10 + + mean_waveforms = mean(waveforms) + mean_masks = mean(masks) + + amplitude = get_waveform_amplitude(mean_masks, mean_waveforms) + assert np.all(amplitude >= 0) + assert amplitude.shape == (mean_waveforms.shape[1],) + + +def test_mean_masked_features_distance(features, + n_channels, + n_features_per_channel, + ): + + # Shifted feature vectors. + shift = 10. + f0 = mean(features) + f1 = mean(features) + shift + + # Only one channel is unmasked. + m0 = m1 = np.zeros(n_channels) + m0[n_channels // 2] = 1 + + # Check the distance. + d_expected = np.sqrt(n_features_per_channel) * shift + d_computed = get_mean_masked_features_distance(f0, f1, m0, m1, + n_features_per_channel) + ac(d_expected, d_computed) diff --git a/phy/traces/__init__.py b/phy/traces/__init__.py index 039bc84aa..e51e25895 100644 --- a/phy/traces/__init__.py +++ b/phy/traces/__init__.py @@ -3,7 +3,5 @@ """Spike detection, waveform extraction.""" -from .detect import Thresholder, FloodFillDetector, compute_threshold from .filter import Filter, Whitening -from .pca import PCA from .waveform import WaveformLoader, WaveformExtractor, SpikeLoader diff --git a/phy/traces/detect.py b/phy/traces/detect.py deleted file mode 100644 index 3c90361f4..000000000 --- a/phy/traces/detect.py +++ /dev/null @@ -1,373 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Spike detection.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from six import string_types -from six.moves import range, zip - -from ..utils.array import _as_array - - -#------------------------------------------------------------------------------ -# Thresholder -#------------------------------------------------------------------------------ - -def compute_threshold(arr, single_threshold=True, std_factor=None): - """Compute the threshold(s) of filtered traces. - - Parameters - ---------- - - arr : ndarray - Filtered traces, shape `(n_samples, n_channels)`. - single_threshold : bool - Whether there should be a unique threshold for all channels, or - one threshold per channel. - std_factor : float or 2-tuple - The threshold in unit of signal std. Two values can be specified - for multiple thresholds (weak and strong). - - Returns - ------- - - thresholds : ndarray - A `(2,)` or `(2, n_channels)` array with the thresholds. - - """ - assert arr.ndim == 2 - ns, nc = arr.shape - - assert std_factor is not None - if isinstance(std_factor, (int, float)): - std_factor = (std_factor, std_factor) - assert isinstance(std_factor, (tuple, list)) - assert len(std_factor) == 2 - std_factor = np.array(std_factor) - - if not single_threshold: - std_factor = std_factor[:, None] - - # Get the median of all samples in all excerpts, on all channels. - if single_threshold: - median = np.median(np.abs(arr)) - # Or independently for each channel. - else: - median = np.median(np.abs(arr), axis=0) - - # Compute the threshold from the median. - std = median / .6745 - threshold = std_factor * std - assert isinstance(threshold, np.ndarray) - - if single_threshold: - assert threshold.ndim == 1 - assert len(threshold) == 2 - else: - assert threshold.ndim == 2 - assert threshold.shape == (2, nc) - return threshold - - -class Thresholder(object): - """Threshold traces to detect spikes. - - Parameters - ---------- - - mode : str - `'positive'`, `'negative'`, or `'both'`. - thresholds : dict - A `{str: float}` mapping for multiple thresholds (e.g. `weak` - and `strong`). - - Example - ------- - - ```python - thres = Thresholder('positive', thresholds=(.1, .2)) - crossings = thres(traces) - ``` - - """ - def __init__(self, - mode=None, - thresholds=None, - ): - assert mode in ('positive', 'negative', 'both') - if isinstance(thresholds, (float, int, np.ndarray)): - thresholds = {'default': thresholds} - if thresholds is None: - thresholds = {} - assert isinstance(thresholds, dict) - self._mode = mode - self._thresholds = thresholds - - def transform(self, data): - """Return `data`, `-data`, or `abs(data)` depending on the mode.""" - if self._mode == 'positive': - return data - elif self._mode == 'negative': - return -data - elif self._mode == 'both': - return np.abs(data) - - def detect(self, data_t, threshold=None): - """Perform the thresholding operation.""" - # Accept dictionary of thresholds. - if isinstance(threshold, (list, tuple)): - return {name: self(data_t, threshold=name) - for name in threshold} - # Use the only threshold by default (if there is only one). - if threshold is None: - assert len(self._thresholds) == 1 - threshold = list(self._thresholds.keys())[0] - # Fetch the threshold from its name. - if isinstance(threshold, string_types): - assert threshold in self._thresholds - threshold = self._thresholds[threshold] - # threshold = float(threshold) - # Threshold the data. - return data_t > threshold - - def __call__(self, data, threshold=None): - # Transform the data according to the mode. - data_t = self.transform(data) - return self.detect(data_t, threshold=threshold) - - -# ----------------------------------------------------------------------------- -# Connected components -# ----------------------------------------------------------------------------- - -def _to_tuples(x): - return ((i, j) for (i, j) in x) - - -def _to_list(x): - return [(i, j) for (i, j) in x] - - -def connected_components(weak_crossings=None, - strong_crossings=None, - probe_adjacency_list=None, - join_size=None, - channels=None): - """Find all connected components in binary arrays of threshold crossings. - - Parameters - ---------- - - weak_crossings : array - `(n_samples, n_channels)` array with weak threshold crossings - strong_crossings : array - `(n_samples, n_channels)` array with strong threshold crossings - probe_adjacency_list : dict - A dict `{channel: [neighbors]}` - channels : array - An (n_channels,) array with a list of all non-dead channels - join_size : int - The number of samples defining the tolerance in time for - finding connected components - - - Returns - ------- - - A list of lists of pairs `(samp, chan)` of the connected components in - the 2D array `weak_crossings`, where a pair is adjacent if the samples are - within `join_size` of each other, and the channels are adjacent in - `probe_adjacency_list`, the channel graph. - - Note - ---- - - The channel mapping assumes that column #i in the data array is channel #i - in the probe adjacency graph. - - """ - - if probe_adjacency_list is None: - probe_adjacency_list = {} - - if channels is None: - channels = [] - - # If the channels aren't referenced at all but exist in 'channels', add a - # trivial self-connection so temporal floodfill will work. If this channel - # is dead, it should be removed from 'channels'. - probe_adjacency_list.update({i: {i} for i in channels - if not probe_adjacency_list.get(i)}) - - # Make sure the values are sets. - probe_adjacency_list = {c: set(cs) - for c, cs in probe_adjacency_list.items()} - - if strong_crossings is None: - strong_crossings = weak_crossings - - assert weak_crossings.shape == strong_crossings.shape - - # Set of connected component labels which contain at least one strong - # node. - strong_nodes = set() - - n_s, n_ch = weak_crossings.shape - join_size = int(join_size or 0) - - # An array with the component label for each node in the array - label_buffer = np.zeros((n_s, n_ch), dtype=np.int32) - - # Component indices, a dictionary with keys the label of the component - # and values a list of pairs (sample, channel) belonging to that component - comp_inds = {} - - # mgraph is the channel graph, but with edge node connected to itself - # because we want to include ourself in the adjacency. Each key of the - # channel graph (a dictionary) is a node, and the value is a set of nodes - # which are connected to it by an edge - mgraph = {} - for source, targets in probe_adjacency_list.items(): - # we add self connections - mgraph[source] = targets.union([source]) - - # Label of the next component - c_label = 1 - - # For all pairs sample, channel which are nonzero (note that numpy .nonzero - # returns (all_i_s, all_i_ch), a pair of lists whose values at the - # corresponding place are the sample, channel pair which is nonzero. The - # lists are also returned in sorted order, so that i_s is always increasing - # and i_ch is always increasing for a given value of i_s. izip is an - # iterator version of the Python zip function, i.e. does the same as zip - # but quicker. zip(A,B) is a list of all pairs (a,b) with a in A and b in B - # in order (i.e. (A[0], B[0]), (A[1], B[1]), .... In conclusion, the next - # line loops through all the samples i_s, and for each sample it loops - # through all the channels. - for i_s, i_ch in zip(*weak_crossings.nonzero()): - # The next two lines iterate through all the neighbours of i_s, i_ch - # in the graph defined by graph in the case of edges, and - # j_s from i_s-join_size to i_s. - for j_s in range(i_s - join_size, i_s + 1): - # Allow us to leave out a channel from the graph to exclude bad - # channels - if i_ch not in mgraph: - continue - for j_ch in mgraph[i_ch]: - # Label of the adjacent element. - adjlabel = label_buffer[j_s, j_ch] - # If the adjacent element is nonzero we need to do something. - if adjlabel: - curlabel = label_buffer[i_s, i_ch] - if curlabel == 0: - # If current element is still zero, we just assign - # the label of the adjacent element to the current one. - label_buffer[i_s, i_ch] = adjlabel - # And add it to the list for the labelled component. - comp_inds[adjlabel].append((i_s, i_ch)) - - elif curlabel != adjlabel: - # If the current element is unequal to the adjacent - # one, we merge them by reassigning the elements of the - # adjacent component to the current one. - # samps_chans is an array of pairs sample, channel - # currently assigned to component adjlabel. - samps_chans = np.array(comp_inds[adjlabel], - dtype=np.int32) - - # samps_chans[:, 0] is the sample indices, so this - # gives only the samp,chan pairs that are within - # join_size of the current point. - # TODO: is this the right behaviour? If a component can - # have a width bigger than join_size I think it isn't! - samps_chans = samps_chans[i_s - samps_chans[:, 0] <= - join_size] - - # Relabel the adjacent samp,chan points with current - # label. - samps, chans = samps_chans[:, 0], samps_chans[:, 1] - label_buffer[samps, chans] = curlabel - - # Add them to the current label list, and remove the - # adjacent component entirely. - comp_inds[curlabel].extend(comp_inds.pop(adjlabel)) - - # Did not deal with merge condition, now fixed it - # seems... - # WARNING: might this "in" incur a performance hit - # here...? - if adjlabel in strong_nodes: - strong_nodes.add(curlabel) - strong_nodes.remove(adjlabel) - - # NEW: add the current component label to the set of all - # strong nodes, if the current node is strong. - if curlabel > 0 and strong_crossings[i_s, i_ch]: - strong_nodes.add(curlabel) - - if label_buffer[i_s, i_ch] == 0: - # If nothing is adjacent, we have the beginnings of a new - # component, # so we label it, create a new list for the new - # component which is given label c_label, - # then increase c_label for the next new component afterwards. - label_buffer[i_s, i_ch] = c_label - comp_inds[c_label] = [(i_s, i_ch)] - if strong_crossings[i_s, i_ch]: - strong_nodes.add(c_label) - c_label += 1 - - # Only return the values, because we don't actually need the labels. - comps = [comp_inds[key] for key in comp_inds.keys() if key in strong_nodes] - return comps - - -class FloodFillDetector(object): - """Detect spikes in weak and strong threshold crossings. - - Parameters - ---------- - - probe_adjacency_list : dict - A dict `{channel: [neighbors]}`. - join_size : int - The number of samples defining the tolerance in time for - finding connected components - - Example - ------- - - ```python - det = FloodFillDetector(probe_adjacency_list=..., - join_size=...) - components = det(weak_crossings, strong_crossings) - ``` - - `components` is a list of `(n, 2)` int arrays with the sample and channel - for every sample in the component. - - """ - def __init__(self, probe_adjacency_list=None, join_size=None, - channels_per_group=None): - self._adjacency_list = probe_adjacency_list - self._join_size = join_size - self._channels_per_group = channels_per_group - - def __call__(self, weak_crossings=None, strong_crossings=None): - weak_crossings = _as_array(weak_crossings, np.bool) - strong_crossings = _as_array(strong_crossings, np.bool) - all_channels = sorted([item for sublist - in self._channels_per_group.values() - for item in sublist]) - - cc = connected_components(weak_crossings=weak_crossings, - strong_crossings=strong_crossings, - probe_adjacency_list=self._adjacency_list, - channels=all_channels, - join_size=self._join_size, - ) - # cc is a list of list of pairs (sample, channel) - return [np.array(c) for c in cc] diff --git a/phy/traces/filter.py b/phy/traces/filter.py index 89b748be3..cba7a7881 100644 --- a/phy/traces/filter.py +++ b/phy/traces/filter.py @@ -25,13 +25,13 @@ def bandpass_filter(rate=None, low=None, high=None, order=None): 'pass') -def apply_filter(x, filter=None): +def apply_filter(x, filter=None, axis=0): """Apply a filter to an array.""" x = _as_array(x) - if x.shape[0] == 0: + if x.shape[axis] == 0: return x b, a = filter - return signal.filtfilt(b, a, x, axis=0) + return signal.filtfilt(b, a, x, axis=axis) class Filter(object): diff --git a/phy/traces/pca.py b/phy/traces/pca.py deleted file mode 100644 index e9ff85093..000000000 --- a/phy/traces/pca.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- - -"""PCA for features.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ..utils._types import _as_array - - -#------------------------------------------------------------------------------ -# PCA -#------------------------------------------------------------------------------ - -def _compute_pcs(x, n_pcs=None, masks=None): - """Compute the PCs of waveforms.""" - - assert x.ndim == 3 - x = _as_array(x, np.float64) - - n_spikes, n_samples, n_channels = x.shape - - if masks is not None: - assert isinstance(masks, np.ndarray) - assert masks.shape == (n_spikes, n_channels) - - # Compute regularization cov matrix. - cov_reg = np.eye(n_samples) - if masks is not None: - unmasked = masks > 0 - # The last dimension is now time. The second dimension is channel. - x_swapped = np.swapaxes(x, 1, 2) - # This is the list of all unmasked spikes on all channels. - # shape: (n_unmasked_spikes, n_samples) - unmasked_all = x_swapped[unmasked, :] - # Let's compute the regularization cov matrix of this beast. - # shape: (n_samples, n_samples) - cov_reg_ = np.cov(unmasked_all, rowvar=0) - # Make sure the covariance matrix is valid. - if cov_reg_.ndim == 2: - cov_reg = cov_reg - assert cov_reg.shape == (n_samples, n_samples) - - pcs_list = [] - # Loop over channels - for channel in range(n_channels): - x_channel = x[:, :, channel] - # Compute cov matrix for the channel - if masks is not None: - # Unmasked waveforms on that channel - # shape: (n_unmasked, n_samples) - x_channel = np.compress(masks[:, channel] > 0, - x_channel, axis=0) - assert x_channel.ndim == 2 - # Don't compute the cov matrix if there are no unmasked spikes - # on that channel. - alpha = 1. / n_spikes - if x_channel.shape[0] <= 1: - cov = alpha * cov_reg - else: - cov_channel = np.cov(x_channel, rowvar=0) - assert cov_channel.shape == (n_samples, n_samples) - cov = alpha * cov_reg + cov_channel - # Compute the eigenelements - vals, vecs = np.linalg.eigh(cov) - pcs = vecs.T.astype(np.float32)[np.argsort(vals)[::-1]] - # Take the first n_pcs components. - if n_pcs is not None: - pcs = pcs[:n_pcs, ...] - pcs_list.append(pcs[:n_pcs, ...]) - - pcs = np.dstack(pcs_list) - return pcs - - -def _project_pcs(x, pcs): - """Project data points onto principal components. - - Parameters - ---------- - - x : array - The waveforms - pcs : array - The PCs returned by `_compute_pcs()`. - """ - # pcs: (nf, ns, nc) - # x: (n, ns, nc) - # out: (n, nc, nf) - assert pcs.ndim == 3 - assert x.ndim == 3 - n, ns, nc = x.shape - nf, ns_, nc_ = pcs.shape - assert ns == ns_ - assert nc == nc_ - - x_proj = np.einsum('ijk,...jk->...ki', pcs, x) - assert x_proj.shape == (n, nc, nf) - return x_proj - - -class PCA(object): - """Apply PCA to waveforms.""" - def __init__(self, n_pcs=None): - self._n_pcs = n_pcs - self._pcs = None - - def fit(self, waveforms, masks=None): - """Compute the PCs of waveforms. - - Parameters - ---------- - - waveforms : ndarray - Shape: `(n_spikes, n_samples, n_channels)` - masks : ndarray - Shape: `(n_spikes, n_channels)` - - """ - self._pcs = _compute_pcs(waveforms, n_pcs=self._n_pcs, masks=masks) - return self._pcs - - def transform(self, waveforms, pcs=None): - """Project waveforms on the PCs. - - Parameters - ---------- - - waveforms : ndarray - Shape: `(n_spikes, n_samples, n_channels)` - - """ - if pcs is None: - pcs = self._pcs - # Need to call fit() if the pcs are None here. - if pcs is not None: - return _project_pcs(waveforms, pcs) diff --git a/phy/traces/tests/test_detect.py b/phy/traces/tests/test_detect.py deleted file mode 100644 index fc511d59f..000000000 --- a/phy/traces/tests/test_detect.py +++ /dev/null @@ -1,321 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of spike detection routines.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae -from numpy.testing import assert_allclose as ac - -from ..detect import (compute_threshold, - Thresholder, - connected_components, - FloodFillDetector, - ) -from ...io.mock import artificial_traces - - -#------------------------------------------------------------------------------ -# Test thresholder -#------------------------------------------------------------------------------ - -def test_compute_threshold(): - n_samples, n_channels = 100, 10 - data = artificial_traces(n_samples, n_channels) - - # Single threshold. - threshold = compute_threshold(data, std_factor=1.) - assert threshold.shape == (2,) - assert threshold[0] > 0 - assert threshold[0] == threshold[1] - - threshold = compute_threshold(data, std_factor=[1., 2.]) - assert threshold.shape == (2,) - assert threshold[1] == 2 * threshold[0] - - # Multiple threshold. - threshold = compute_threshold(data, single_threshold=False, std_factor=2.) - assert threshold.shape == (2, n_channels) - - threshold = compute_threshold(data, - single_threshold=False, - std_factor=(1., 2.)) - assert threshold.shape == (2, n_channels) - ac(threshold[1], 2 * threshold[0]) - - -def test_thresholder(): - n_samples, n_channels = 100, 12 - strong, weak = .1, .2 - - data = artificial_traces(n_samples, n_channels) - - # Positive and strong. - thresholder = Thresholder(mode='positive', - thresholds=strong) - ae(thresholder(data), data > strong) - - # Negative and weak. - thresholder = Thresholder(mode='negative', - thresholds={'weak': weak}) - ae(thresholder(data), data < -weak) - - # Both and strong+weak. - thresholder = Thresholder(mode='both', - thresholds={'weak': weak, - 'strong': strong, - }) - ae(thresholder(data, 'weak'), np.abs(data) > weak) - ae(thresholder(data, threshold='strong'), np.abs(data) > strong) - - # Multiple thresholds. - t = thresholder(data, ('weak', 'strong')) - ae(t['weak'], np.abs(data) > weak) - ae(t['strong'], np.abs(data) > strong) - - # Array threshold. - thre = np.linspace(weak - .05, strong + .05, n_channels) - thresholder = Thresholder(mode='positive', thresholds=thre) - t = thresholder(data) - assert t.shape == data.shape - ae(t, data > thre) - - -#------------------------------------------------------------------------------ -# Test connected components -#------------------------------------------------------------------------------ - -def _as_set(c): - if isinstance(c, np.ndarray): - c = c.tolist() - c = [tuple(_) for _ in c] - return set(c) - - -def _assert_components_equal(cc1, cc2): - assert len(cc1) == len(cc2) - for c1, c2 in zip(cc1, cc2): - assert _as_set(c1) == _as_set(c2) - - -def _test_components(chunk=None, components=None, **kwargs): - - def _clip(x, m, M): - return [_ for _ in x if m <= _ < M] - - n = 5 - probe_adjacency_list = {i: set(_clip([i - 1, i + 1], 0, n)) - for i in range(n)} - - if chunk is None: - chunk = [[0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 1, 1, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 1, 1], - ] - - if components is None: - components = [] - - if not isinstance(chunk, np.ndarray): - chunk = np.array(chunk) - strong_crossings = kwargs.pop('strong_crossings', None) - if (strong_crossings is not None and - not isinstance(strong_crossings, np.ndarray)): - strong_crossings = np.array(strong_crossings) - - comp = connected_components(chunk, - probe_adjacency_list=probe_adjacency_list, - strong_crossings=strong_crossings, - **kwargs) - _assert_components_equal(comp, components) - - -def test_components(): - # 1 time step, 1 element - _test_components([[0, 0, 0, 0, 0]], []) - - _test_components([[1, 0, 0, 0, 0]], [[(0, 0)]]) - - _test_components([[0, 1, 0, 0, 0]], [[(0, 1)]]) - - _test_components([[0, 0, 0, 1, 0]], [[(0, 3)]]) - - _test_components([[0, 0, 0, 0, 1]], [[(0, 4)]]) - - # 1 time step, 2 elements - _test_components([[1, 1, 0, 0, 0]], [[(0, 0), (0, 1)]]) - - _test_components([[1, 0, 1, 0, 0]], [[(0, 0)], [(0, 2)]]) - - _test_components([[1, 0, 0, 0, 1]], [[(0, 0)], [(0, 4)]]) - - _test_components([[0, 1, 0, 1, 0]], [[(0, 1)], [(0, 3)]]) - - # 1 time step, 3 elements - _test_components([[1, 1, 1, 0, 0]], [[(0, 0), (0, 1), (0, 2)]]) - - _test_components([[1, 1, 0, 1, 0]], [[(0, 0), (0, 1)], [(0, 3)]]) - - _test_components([[1, 0, 1, 1, 0]], [[(0, 0)], [(0, 2), (0, 3)]]) - - _test_components([[0, 1, 1, 1, 0]], [[(0, 1), (0, 2), (0, 3)]]) - - _test_components([[0, 1, 1, 0, 1]], [[(0, 1), (0, 2)], [(0, 4)]]) - - # 5 time steps, varying join_size - _test_components([ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 1, 1, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 1, 1], - ], [[(1, 2)], - [(2, 0)], - [(2, 2), (2, 3)], - [(3, 0)], - [(3, 3)], - [(4, 1)], - [(4, 3), (4, 4)], - ]) - - _test_components([ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [1, 0, 1, 1, 0], - [1, 0, 0, 1, 0], - [0, 1, 0, 1, 1], - ], [[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4)], - [(2, 0), (3, 0), (4, 1)]], join_size=1) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4), - (2, 0), (3, 0), (4, 1)]], join_size=2) - - # 5 time steps, strong != weak - _test_components(join_size=0, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]) - - _test_components(components=[[(1, 2)]], - join_size=0, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4)]], - join_size=1, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1]]) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4), - (2, 0), (3, 0), (4, 1)]], - join_size=2, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 1, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]) - - _test_components( - components=[[(1, 2), (2, 2), (2, 3), (3, 3), (4, 3), (4, 4), - (2, 0), (3, 0), (4, 1)]], - join_size=2, - strong_crossings=[ - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 1, 0, 0, 0]]) - - -def test_flood_fill(): - - graph = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: []} - - channels = {0: [1, 2, 3]} - - ff = FloodFillDetector(probe_adjacency_list=graph, - join_size=1, - channels_per_group=channels - ) - - weak = [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 1, 1], - [0, 0, 0, 0], - [0, 0, 0, 1], - [0, 0, 0, 0], - ] - - # Weak - weak - comps = [[(1, 1), (1, 2)], - [(3, 2), (4, 2)], - [(4, 3)], - [(6, 3)], - ] - cc = ff(weak, weak) - _assert_components_equal(cc, comps) - - # Weak and strong - strong = [[0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1], - [0, 0, 0, 0], - [0, 0, 0, 1], - [0, 0, 0, 0], - ] - - comps = [[(3, 2), (4, 2)], - [(4, 3)], - [(6, 3)], - ] - cc = ff(weak, strong) - _assert_components_equal(cc, comps) - - channels = {0: [1, 2, 3, 4]} - - ff = FloodFillDetector(probe_adjacency_list=graph, - join_size=2, - channels_per_group=channels - ) - - weak = [[0, 0, 0, 0, 0], - [0, 1, 0, 1, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1], - [0, 0, 0, 0, 0], - ] - - comps = [[(1, 1)], - [(1, 3)], - [(4, 4), (6, 4)], - ] - - cc = ff(weak, weak) - _assert_components_equal(cc, comps) diff --git a/phy/traces/tests/test_pca.py b/phy/traces/tests/test_pca.py deleted file mode 100644 index 54788c91d..000000000 --- a/phy/traces/tests/test_pca.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- - -"""PCA tests.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ...io.mock import artificial_waveforms, artificial_masks -from ..pca import PCA, _compute_pcs - - -#------------------------------------------------------------------------------ -# Test PCA -#------------------------------------------------------------------------------ - -def test_pca(): - n_spikes = 100 - n_samples = 40 - n_channels = 12 - waveforms = artificial_waveforms(n_spikes, n_samples, n_channels) - masks = artificial_masks(n_spikes, n_channels) - - pca = PCA(n_pcs=3) - pcs = pca.fit(waveforms, masks) - assert pcs.shape == (3, n_samples, n_channels) - fet = pca.transform(waveforms) - assert fet.shape == (n_spikes, n_channels, 3) - - -def test_compute_pcs(): - """Test PCA on a 2D array.""" - # Horizontal ellipsoid. - x = np.random.randn(20000, 2) * np.array([[10., 1.]]) - # Rotate the points by pi/4. - a = 1. / np.sqrt(2.) - rot = np.array([[a, -a], [a, a]]) - x = np.dot(x, rot) - # Compute the PCs. - pcs = _compute_pcs(x[..., None]) - assert pcs.ndim == 3 - assert (np.abs(pcs) - a).max() < 1e-2 - - -def test_compute_pcs_3d(): - """Test PCA on a 3D array.""" - x1 = np.random.randn(20000, 2) * np.array([[10., 1.]]) - x2 = np.random.randn(20000, 2) * np.array([[1., 10.]]) - x = np.dstack((x1, x2)) - # Compute the PCs. - pcs = _compute_pcs(x) - assert pcs.ndim == 3 - assert np.linalg.norm(pcs[0, :, 0] - np.array([-1., 0.])) < 1e-2 - assert np.linalg.norm(pcs[1, :, 0] - np.array([0., -1.])) < 1e-2 - assert np.linalg.norm(pcs[0, :, 1] - np.array([0, 1.])) < 1e-2 - assert np.linalg.norm(pcs[1, :, 1] - np.array([-1., 0.])) < 1e-2 diff --git a/phy/traces/tests/test_waveform.py b/phy/traces/tests/test_waveform.py index 7c2735046..836aa80bd 100644 --- a/phy/traces/tests/test_waveform.py +++ b/phy/traces/tests/test_waveform.py @@ -8,11 +8,15 @@ import numpy as np from numpy.testing import assert_array_equal as ae -import numpy.random as npr -from pytest import raises - -from ...io.mock import artificial_traces -from ..waveform import _slice, WaveformLoader, WaveformExtractor +from pytest import raises, yield_fixture + +from phy.io.mock import artificial_traces, artificial_spike_samples +from phy.utils import Bunch +from ..waveform import (_slice, + WaveformLoader, + WaveformExtractor, + SpikeLoader, + ) from ..filter import bandpass_filter, apply_filter @@ -25,9 +29,6 @@ def test_extract_simple(): strong = 2. nc = 4 ns = 20 - channels = list(range(nc)) - cpg = {0: channels} - # graph = {0: [1, 2], 1: [0, 2], 2: [0, 1], 3: []} data = np.random.uniform(size=(ns, nc), low=0., high=1.) @@ -49,17 +50,14 @@ def test_extract_simple(): we = WaveformExtractor(extract_before=3, extract_after=5, - thresholds={'weak': weak, - 'strong': strong}, - channels_per_group=cpg, ) + we.set_thresholds(weak=weak, strong=strong) # _component() comp = we._component(component, n_samples=ns) ae(comp.comp_s, [10, 10, 11, 11, 12, 12]) ae(comp.comp_ch, [0, 1, 0, 1, 0, 1]) assert (comp.s_min, comp.s_max) == (10 - 3, 12 + 4) - ae(comp.channels, range(nc)) # _normalize() assert we._normalize(weak) == 0 @@ -82,7 +80,7 @@ def test_extract_simple(): assert 11 <= s < 12 # extract() - wave_e = we.extract(data, s, channels=channels) + wave_e = we.extract(data, s) assert wave_e.shape[1] == wave.shape[1] ae(wave[3:6, :2], wave_e[3:6, :2]) @@ -91,9 +89,8 @@ def test_extract_simple(): assert wave_a.shape == (3 + 5, nc) # Test final call. - groups, s_f, wave_f, masks_f = we(component, data=data, data_t=data) + s_f, masks_f, wave_f = we(component, data=data, data_t=data) assert s_f == s - assert np.all(groups == 0) ae(masks_f, masks) ae(wave_f, wave_a) @@ -102,57 +99,92 @@ def test_extract_simple(): extract_after=5, thresholds={'weak': weak, 'strong': strong}, - channels_per_group={0: [1, 0, 3]}, ) - groups, s_f_o, wave_f_o, masks_f_o = we(component, data=data, data_t=data) - assert np.all(groups == 0) + s_f_o, masks_f_o, wave_f_o = we(component, data=data, data_t=data) assert s_f == s_f_o - assert np.allclose(wave_f[:, [1, 0, 3]], wave_f_o) - ae(masks_f_o, [1., 0.5, 0.]) + assert np.allclose(wave_f, wave_f_o) + ae(masks_f_o, [0.5, 1., 0., 0.]) #------------------------------------------------------------------------------ -# Tests loader +# Tests utility functions #------------------------------------------------------------------------------ def test_slice(): assert _slice(0, (20, 20)) == slice(0, 20, None) -def test_loader(): - n_samples_trace, n_channels = 10000, 100 - n_samples = 40 - n_spikes = n_samples_trace // (2 * n_samples) +#------------------------------------------------------------------------------ +# Tests loader +#------------------------------------------------------------------------------ + +@yield_fixture(params=[(None, None), (-1, 2)]) +def waveform_loader(request): + scale_factor, dc_offset = request.param + + n_samples_trace, n_channels = 1000, 5 + h = 10 + n_samples_waveforms = 2 * h + n_spikes = n_samples_trace // (2 * n_samples_waveforms) traces = artificial_traces(n_samples_trace, n_channels) - spike_samples = np.cumsum(npr.randint(low=0, high=2 * n_samples, - size=n_spikes)) + spike_samples = artificial_spike_samples(n_spikes, + max_isi=2 * n_samples_waveforms) with raises(ValueError): WaveformLoader(traces) - # Create a loader. - loader = WaveformLoader(traces, n_samples=n_samples) - assert id(loader.traces) == id(traces) - loader.traces = traces - - # Extract a waveform. - t = spike_samples[10] - waveform = loader._load_at(t) - assert waveform.shape == (n_samples, n_channels) - ae(waveform, traces[t - 20:t + 20, :]) + loader = WaveformLoader(traces=traces, + n_samples_waveforms=n_samples_waveforms, + scale_factor=scale_factor, + dc_offset=dc_offset, + ) + b = Bunch(loader=loader, + n_samples_waveforms=n_samples_waveforms, + n_spikes=n_spikes, + spike_samples=spike_samples, + ) + yield b + + +def test_loader_edge_case(): + wl = WaveformLoader(n_samples_waveforms=3) + wl.traces = np.random.rand(0, 2) + wl[0] + + +def test_loader_simple(waveform_loader): + b = waveform_loader + spike_samples = b.spike_samples + loader = b.loader + traces = loader.traces + dc_offset = loader.dc_offset or 0 + scale_factor = loader.scale_factor or 1 + n_samples_traces, n_channels = traces.shape + n_samples_waveforms = b.n_samples_waveforms + h = n_samples_waveforms // 2 + + assert loader.offset == 0 + assert loader.dc_offset in (dc_offset, None) + assert loader.scale_factor in (scale_factor, None) + + def _transform(arr): + return (arr - dc_offset) * scale_factor waveforms = loader[spike_samples[10:20]] - assert waveforms.shape == (10, n_samples, n_channels) + assert waveforms.shape == (10, n_samples_waveforms, n_channels) t = spike_samples[15] w1 = waveforms[5, ...] - w2 = traces[t - 20:t + 20, :] + w2 = _transform(traces[t - h:t + h, :]) assert np.allclose(w1, w2) + sl = SpikeLoader(loader, spike_samples) + assert np.allclose(sl[15], w2) + def test_edges(): - n_samples_trace, n_channels = 1000, 10 - n_samples = 40 + n_samples_trace, n_channels = 100, 10 + n_samples_waveforms = 20 traces = artificial_traces(n_samples_trace, n_channels) @@ -165,7 +197,7 @@ def test_edges(): # Create a loader. loader = WaveformLoader(traces, - n_samples=n_samples, + n_samples_waveforms=n_samples_waveforms, filter=lambda x: apply_filter(x, b_filter), filter_margin=filter_margin) @@ -173,57 +205,50 @@ def test_edges(): with raises(ValueError): loader._load_at(200000) - assert loader._load_at(0).shape == (n_samples, n_channels) - assert loader._load_at(5).shape == (n_samples, n_channels) - assert loader._load_at(n_samples_trace - 5).shape == (n_samples, - n_channels) - assert loader._load_at(n_samples_trace - 1).shape == (n_samples, - n_channels) + ns = n_samples_waveforms + filter_margin + assert loader._load_at(0).shape == (ns, n_channels) + assert loader._load_at(5).shape == (ns, n_channels) + assert loader._load_at(n_samples_trace - 5).shape == (ns, n_channels) + assert loader._load_at(n_samples_trace - 1).shape == (ns, n_channels) def test_loader_channels(): - n_samples_trace, n_channels = 1000, 50 - n_samples = 40 + n_samples_trace, n_channels = 1000, 10 + n_samples_waveforms = 20 traces = artificial_traces(n_samples_trace, n_channels) # Create a loader. - loader = WaveformLoader(traces, n_samples=n_samples) + loader = WaveformLoader(traces, n_samples_waveforms=n_samples_waveforms) loader.traces = traces - channels = [10, 20, 30] + channels = [2, 5, 7] loader.channels = channels assert loader.channels == channels - assert loader[500].shape == (1, n_samples, 3) - assert loader[[500, 501, 600, 300]].shape == (4, n_samples, 3) + assert loader[500].shape == (1, n_samples_waveforms, 3) + assert loader[[500, 501, 600, 300]].shape == (4, n_samples_waveforms, 3) # Test edge effects. - assert loader[3].shape == (1, n_samples, 3) - assert loader[995].shape == (1, n_samples, 3) + assert loader[3].shape == (1, n_samples_waveforms, 3) + assert loader[995].shape == (1, n_samples_waveforms, 3) with raises(NotImplementedError): loader[500:510] def test_loader_filter(): - n_samples_trace, n_channels = 1000, 100 - n_samples = 40 - n_spikes = n_samples_trace // (2 * n_samples) - - traces = artificial_traces(n_samples_trace, n_channels) - spike_samples = np.cumsum(npr.randint(low=0, high=2 * n_samples, - size=n_spikes)) + traces = np.c_[np.arange(20), np.arange(20, 40)].astype(np.int32) + n_samples_trace, n_channels = traces.shape + h = 3 - # With filter. - def my_filter(x): + def my_filter(x, axis=0): return x * x loader = WaveformLoader(traces, - n_samples=(n_samples // 2, n_samples // 2), + n_samples_waveforms=(h, h), filter=my_filter, - filter_margin=5) + filter_margin=2) - t = spike_samples[5] - waveform_filtered = loader._load_at(t) + t = 10 + waveform_filtered = loader[t] traces_filtered = my_filter(traces) - traces_filtered[t - 20:t + 20, :] - assert np.allclose(waveform_filtered, traces_filtered[t - 20:t + 20, :]) + assert np.allclose(waveform_filtered, traces_filtered[t - h:t + h, :]) diff --git a/phy/traces/waveform.py b/phy/traces/waveform.py index 4231bcee4..0ea3ad589 100644 --- a/phy/traces/waveform.py +++ b/phy/traces/waveform.py @@ -6,38 +6,21 @@ # Imports #------------------------------------------------------------------------------ +import logging + import numpy as np from scipy.interpolate import interp1d from ..utils._types import _as_array, Bunch -from ..utils.array import _pad -from ..utils.logging import warn +from phy.io.array import _pad, _get_padded + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ # Waveform extractor from a connected component #------------------------------------------------------------------------------ -def _get_padded(data, start, end): - """Return `data[start:end]` filling in with zeros outside array bounds - - Assumes that either `start<0` or `end>len(data)` but not both. - - """ - if start < 0 and end >= data.shape[0]: - raise RuntimeError() - if start < 0: - start_zeros = np.zeros((-start, data.shape[1]), - dtype=data.dtype) - return np.vstack((start_zeros, data[:end])) - elif end > data.shape[0]: - end_zeros = np.zeros((end - data.shape[0], data.shape[1]), - dtype=data.dtype) - return np.vstack((data[start:], end_zeros)) - else: - return data[start:end] - - class WaveformExtractor(object): """Extract waveforms after data filtering and spike detection.""" def __init__(self, @@ -45,31 +28,15 @@ def __init__(self, extract_after=None, weight_power=None, thresholds=None, - channels_per_group=None, ): self._extract_before = extract_before self._extract_after = extract_after self._weight_power = weight_power if weight_power is not None else 1. self._thresholds = thresholds or {} - self._channels_per_group = channels_per_group - # mapping channel => channels in the shank - self._dep_channels = {i: channels - for channels in channels_per_group.values() - for i in channels} - self._channel_groups = {i: g - for g, channels in channels_per_group.items() - for i in channels} def _component(self, component, data=None, n_samples=None): comp_s = component[:, 0] # shape: (component_size,) comp_ch = component[:, 1] # shape: (component_size,) - channel = comp_ch[0] - if channel not in self._dep_channels: - raise RuntimeError("Channel `{}` appears to be dead and should " - "have been excluded from the threshold " - "crossings.".format(channel)) - channels = self._dep_channels[channel] - group = self._channel_groups[comp_ch[0]] # Get the temporal window around the waveform. s_min, s_max = (comp_s.min() - 3), (comp_s.max() + 4) @@ -81,8 +48,6 @@ def _component(self, component, data=None, n_samples=None): comp_ch=comp_ch, s_min=s_min, s_max=s_max, - channels=channels, - group=group, ) def _normalize(self, x): @@ -103,7 +68,6 @@ def _comp_wave(self, data_t, comp): def masks(self, data_t, wave, comp): nc = data_t.shape[1] - channels = comp.channels comp_ch = comp.comp_ch s_min = comp.s_min @@ -119,7 +83,6 @@ def masks(self, data_t, wave, comp): # Compute the float masks. masks_float = self._normalize(peaks_values) # Keep shank channels. - masks_float = masks_float[channels] return masks_float def spike_sample_aligned(self, wave, comp): @@ -132,13 +95,13 @@ def spike_sample_aligned(self, wave, comp): s_aligned = np.sum(wave_n_p * u) / np.sum(wave_n_p) + s_min return s_aligned - def extract(self, data, s_aligned, channels=None): + def extract(self, data, s_aligned): s = int(s_aligned) # Get block of given size around peak sample. waveform = _get_padded(data, s - self._extract_before - 1, s + self._extract_after + 2) - return waveform[:, channels] # Keep shank channels. + return waveform def align(self, waveform, s_aligned): s = int(s_aligned) @@ -149,8 +112,8 @@ def align(self, waveform, s_aligned): try: f = interp1d(old_s, waveform, bounds_error=True, kind='cubic', axis=0) - except ValueError: - warn("Interpolation error at time {0:d}".format(s)) + except ValueError: # pragma: no cover + logger.warn("Interpolation error at time %d", s) return waveform return f(new_s) @@ -163,20 +126,19 @@ def __call__(self, component=None, data=None, data_t=None): data=data, n_samples=data_t.shape[0], ) - channels = comp.channels wave = self._comp_wave(data_t, comp) masks = self.masks(data_t, wave, comp) s_aligned = self.spike_sample_aligned(wave, comp) - waveform_unaligned = self.extract(data, s_aligned, channels=channels) + waveform_unaligned = self.extract(data, s_aligned) waveform_aligned = self.align(waveform_unaligned, s_aligned) assert waveform_aligned.ndim == 2 assert masks.ndim == 1 assert waveform_aligned.shape[1] == masks.shape[0] - return comp.group, s_aligned, waveform_aligned, masks + return s_aligned, masks, waveform_aligned #------------------------------------------------------------------------------ @@ -224,17 +186,17 @@ def __init__(self, offset=0, filter=None, filter_margin=0, - n_samples=None, + n_samples_waveforms=None, channels=None, scale_factor=None, dc_offset=None, + dtype=None, ): - # A (possibly memmapped) array-like structure with traces. if traces is not None: self.traces = traces else: self._traces = None - self.dtype = np.float32 + self.dtype = dtype or (traces.dtype if traces is not None else None) # Scale factor for the loaded waveforms. self._scale_factor = scale_factor self._dc_offset = dc_offset @@ -242,14 +204,14 @@ def __init__(self, self._offset = int(offset) # List of channels to use when loading the waveforms. self._channels = channels - # A filter function that takes a (n_samples, n_channels) array as - # input. + # A filter function that takes a (n_samples_waveforms, n_channels) + # array as input. self._filter = filter # Number of samples to return, can be an int or a # tuple (before, after). - if n_samples is None: - raise ValueError("'n_samples' must be specified.") - self.n_samples_before_after = _before_after(n_samples) + if n_samples_waveforms is None: + raise ValueError("'n_samples_waveforms' must be specified.") + self.n_samples_before_after = _before_after(n_samples_waveforms) self.n_samples_waveforms = sum(self.n_samples_before_after) # Number of additional samples to use for filtering. self._filter_margin = _before_after(filter_margin) @@ -257,6 +219,18 @@ def __init__(self, self._n_samples_extract = (self.n_samples_waveforms + sum(self._filter_margin)) + @property + def offset(self): + return self._offset + + @property + def dc_offset(self): + return self._dc_offset + + @property + def scale_factor(self): + return self._scale_factor + @property def traces(self): """Raw traces.""" @@ -303,30 +277,13 @@ def _load_at(self, time): elif slice_extract.stop >= ns - 1: extract = _pad(extract, self._n_samples_extract, 'right') - assert extract.shape[0] == self._n_samples_extract - - # Filter the waveforms. - # TODO: do the filtering in a vectorized way for higher performance. - if self._filter is not None: - waveforms = self._filter(extract) - else: - waveforms = extract - - # Remove the margin. - margin_before, margin_after = self._filter_margin - if margin_after > 0: - assert margin_before >= 0 - waveforms = waveforms[margin_before:-margin_after, :] - # Make a subselection with the specified channels. if self._channels is not None: - out = waveforms[..., self._channels] - else: - out = waveforms + extract = extract[..., self._channels] - assert out.shape == (self.n_samples_waveforms, - self.n_channels_waveforms) - return out + assert extract.shape == (self._n_samples_extract, + self.n_channels_waveforms) + return extract def __getitem__(self, item): """Load waveforms.""" @@ -335,27 +292,46 @@ def __getitem__(self, item): "implemented yet.") if not hasattr(item, '__len__'): item = [item] + # Ensure a list of time samples are being requested. spikes = _as_array(item) n_spikes = len(spikes) + # Initialize the array. # TODO: int16 - shape = (n_spikes, self.n_samples_waveforms, - self.n_channels_waveforms) + shape = (n_spikes, self._n_samples_extract, self.n_channels_waveforms) + # No traces: return null arrays. if self.n_samples_trace == 0: return np.zeros(shape, dtype=self.dtype) - waveforms = np.empty(shape, dtype=self.dtype) + waveforms = np.zeros(shape, dtype=self.dtype) + # Load all spikes. for i, time in enumerate(spikes): try: waveforms[i, ...] = self._load_at(time) - except ValueError as e: - warn("Error while loading waveform: {0}".format(str(e))) + except ValueError as e: # pragma: no cover + logger.warn("Error while loading waveform: %s", str(e)) + + # Filter the waveforms. + if self._filter is not None: + waveforms = self._filter(waveforms, axis=1) + + # Remove the margin. + margin_before, margin_after = self._filter_margin + if margin_after > 0: + assert margin_before >= 0 + waveforms = waveforms[:, margin_before:-margin_after, :] + + # Transform. if self._dc_offset: waveforms -= self._dc_offset if self._scale_factor: waveforms *= self._scale_factor + + assert waveforms.shape == (n_spikes, self.n_samples_waveforms, + self.n_channels_waveforms) + return waveforms @@ -370,6 +346,7 @@ def __init__(self, waveforms, spike_samples): self.shape = (len(spike_samples), waveforms.n_samples_waveforms, waveforms.n_channels_waveforms) + self.ndim = len(self.shape) def __getitem__(self, item): times = self._spike_samples[item] diff --git a/phy/utils/__init__.py b/phy/utils/__init__.py index 66243eac6..2a49c9a45 100644 --- a/phy/utils/__init__.py +++ b/phy/utils/__init__.py @@ -3,8 +3,14 @@ """Utilities.""" -from ._types import _is_array_like, _as_array, _as_tuple, _as_list, Bunch -from .datasets import download_file, download_sample_data +from ._misc import _load_json, _save_json, _fullname +from ._types import (_is_array_like, _as_array, _as_tuple, _as_list, + _as_scalar, _as_scalars, + Bunch, _is_list, _bunchify) from .event import EventEmitter, ProgressReporter -from .logging import debug, info, warn, register, unregister, set_level -from .settings import Settings, _ensure_dir_exists +from .plugin import IPlugin, get_plugin +from .config import( _ensure_dir_exists, + load_master_config, + phy_config_dir, + load_config, + ) diff --git a/phy/utils/_color.py b/phy/utils/_color.py index ae0fcb24b..974779b57 100644 --- a/phy/utils/_color.py +++ b/phy/utils/_color.py @@ -7,19 +7,18 @@ #------------------------------------------------------------------------------ import numpy as np - -from random import uniform -from colorsys import hsv_to_rgb +from numpy.random import uniform +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb #------------------------------------------------------------------------------ -# Colors +# Random colors #------------------------------------------------------------------------------ def _random_color(): """Generate a random RGB color.""" h, s, v = uniform(0., 1.), uniform(.5, 1.), uniform(.5, 1.) - r, g, b = hsv_to_rgb(h, s, v) + r, g, b = hsv_to_rgb((h, s, v)) return r, g, b @@ -39,7 +38,7 @@ def _random_bright_color(): #------------------------------------------------------------------------------ -# Default colormap +# Colormap #------------------------------------------------------------------------------ # Default color map for the selected clusters. @@ -49,6 +48,7 @@ def _random_bright_color(): [228, 31, 228], [2, 217, 2], [255, 147, 2], + [212, 150, 70], [205, 131, 201], [201, 172, 36], @@ -59,11 +59,56 @@ def _random_bright_color(): ]) -def _selected_clusters_colors(n_clusters=None): - if n_clusters is None: - n_clusters = _COLORMAP.shape[0] - if n_clusters > _COLORMAP.shape[0]: - colors = np.tile(_COLORMAP, (1 + n_clusters // _COLORMAP.shape[0], 1)) +def _apply_color_masks(color, masks=None, alpha=None): + alpha = alpha or .5 + hsv = rgb_to_hsv(color[:, :3]) + # Change the saturation and value as a function of the mask. + if masks is not None: + hsv[:, 1] *= masks + hsv[:, 2] *= .5 * (1. + masks) + color = hsv_to_rgb(hsv) + n = color.shape[0] + color = np.c_[color, alpha * np.ones((n, 1))] + return color + + +def _colormap(i): + n = len(_COLORMAP) + return _COLORMAP[i % n] / 255. + + +def _spike_colors(spike_clusters, masks=None, alpha=None): + n = len(_COLORMAP) + if spike_clusters is not None: + c = _COLORMAP[np.mod(spike_clusters, n), :] / 255. else: - colors = _COLORMAP - return colors[:n_clusters, ...] / 255. + c = np.ones((masks.shape[0], 3)) + c = _apply_color_masks(c, masks=masks, alpha=alpha) + return c + + +class ColorSelector(object): + """Return the color of a cluster. + + If the cluster belongs to the selection, returns the colormap color. + + Otherwise, return a random color and remember this color. + + """ + def __init__(self): + self._colors = {} + + def get(self, clu, cluster_ids=None, alpha=None): + alpha = alpha or .5 + if cluster_ids and clu in cluster_ids: + i = cluster_ids.index(clu) + color = _colormap(i) + color = tuple(color) + (alpha,) + else: + if clu in self._colors: + return self._colors[clu] + color = _random_color() + color = tuple(color) + (alpha,) + self._colors[clu] = color + assert len(color) == 4 + return color diff --git a/phy/utils/_misc.py b/phy/utils/_misc.py index db70effbf..d10848689 100644 --- a/phy/utils/_misc.py +++ b/phy/utils/_misc.py @@ -11,34 +11,15 @@ import json import os.path as op import os -import sys import subprocess -from inspect import getargspec +from textwrap import dedent import numpy as np -from six import string_types, exec_ -from six.moves import builtins, cPickle +from six import string_types, text_type, exec_ from ._types import _is_integer -#------------------------------------------------------------------------------ -# Pickle utility functions -#------------------------------------------------------------------------------ - -def _load_pickle(path): - path = op.realpath(op.expanduser(path)) - assert op.exists(path) - with open(path, 'rb') as f: - return cPickle.load(f) - - -def _save_pickle(path, data): - path = op.realpath(op.expanduser(path)) - with open(path, 'wb') as f: - cPickle.dump(data, f, protocol=2) - - #------------------------------------------------------------------------------ # JSON utility functions #------------------------------------------------------------------------------ @@ -50,14 +31,15 @@ def _encode_qbytearray(arr): def _decode_qbytearray(data_b64): - from phy.gui.qt import QtCore + from phy.gui.qt import QByteArray encoded = base64.b64decode(data_b64) - out = QtCore.QByteArray.fromBase64(encoded) + out = QByteArray.fromBase64(encoded) return out class _CustomEncoder(json.JSONEncoder): def default(self, obj): + from phy.gui.qt import QString if isinstance(obj, np.ndarray): obj_contiguous = np.ascontiguousarray(obj) data_b64 = base64.b64encode(obj_contiguous.data).decode('utf8') @@ -68,7 +50,9 @@ def default(self, obj): return {'__qbytearray__': _encode_qbytearray(obj)} elif isinstance(obj, np.generic): return np.asscalar(obj) - return super(_CustomEncoder, self).default(obj) + elif isinstance(obj, QString): # pragma: no cover + return text_type(obj) + return super(_CustomEncoder, self).default(obj) # pragma: no cover def _json_custom_hook(d): @@ -117,13 +101,18 @@ def _save_json(path, data): data = _stringify_keys(data) path = op.realpath(op.expanduser(path)) with open(path, 'w') as f: - json.dump(data, f, cls=_CustomEncoder, indent=2) + json.dump(data, f, cls=_CustomEncoder, indent=2, sort_keys=True) #------------------------------------------------------------------------------ # Various Python utility functions #------------------------------------------------------------------------------ +def _fullname(o): + """Return the fully-qualified name of a function.""" + return o.__module__ + "." + o.__name__ if o.__module__ else o.__name__ + + def _read_python(path): path = op.realpath(op.expanduser(path)) assert op.exists(path) @@ -135,57 +124,15 @@ def _read_python(path): return metadata -def _fun_arg_count(f): - """Return the number of arguments of a function. - - WARNING: with methods, only works if the first argument is named 'self'. - - """ - args = getargspec(f).args - if args and args[0] == 'self': - args = args[1:] - return len(args) - - -def _is_in_ipython(): - return '__IPYTHON__' in dir(builtins) - - -def _is_interactive(): - """Determine whether the user has requested interactive mode.""" - # The Python interpreter sets sys.flags correctly, so use them! - if sys.flags.interactive: - return True - - # IPython does not set sys.flags when -i is specified, so first - # check it if it is already imported. - if '__IPYTHON__' not in dir(builtins): - return False - - # Then we check the application singleton and determine based on - # a variable it sets. - try: - from IPython.config.application import Application as App - return App.initialized() and App.instance().interact - except (ImportError, AttributeError): - return False - - -def _show_shortcut(shortcut): - if isinstance(shortcut, string_types): - return shortcut - elif isinstance(shortcut, tuple): - return ', '.join(shortcut) - - -def _show_shortcuts(shortcuts, name=''): - print() - if name: - name = ' for ' + name - print('Keyboard shortcuts' + name) - for name in sorted(shortcuts): - print('{0:<40}: {1:s}'.format(name, _show_shortcut(shortcuts[name]))) - print() +def _write_text(path, contents): + contents = dedent(contents) + dir_path = op.dirname(path) + if not op.exists(dir_path): + os.mkdir(dir_path) + assert op.isdir(dir_path) + assert not op.exists(path) + with open(path, 'w') as f: + f.write(contents) def _git_version(): @@ -199,7 +146,7 @@ def _git_version(): '--always', '--tags'], stderr=fnull).strip().decode('ascii')) return version - except (OSError, subprocess.CalledProcessError): + except (OSError, subprocess.CalledProcessError): # pragma: no cover return "" finally: os.chdir(curdir) diff --git a/phy/utils/_types.py b/phy/utils/_types.py index 8b1f11fb1..72394744a 100644 --- a/phy/utils/_types.py +++ b/phy/utils/_types.py @@ -31,10 +31,31 @@ def copy(self): return Bunch(super(Bunch, self).copy()) +def _bunchify(b): + """Ensure all dict elements are Bunch.""" + assert isinstance(b, dict) + b = Bunch(b) + for k in b: + if isinstance(b[k], dict): + b[k] = Bunch(b[k]) + return b + + def _is_list(obj): return isinstance(obj, list) +def _as_scalar(obj): + if isinstance(obj, np.generic): + return np.asscalar(obj) + assert isinstance(obj, (int, float)) + return obj + + +def _as_scalars(arr): + return [_as_scalar(x) for x in arr] + + def _is_integer(x): return isinstance(x, integer_types + (np.generic,)) @@ -43,17 +64,14 @@ def _is_float(x): return isinstance(x, (float, np.float32, np.float64)) -def _as_int(x): - if isinstance(x, integer_types): - return x - x = np.asscalar(x) - return x - - def _as_list(obj): """Ensure an object is a list.""" - if isinstance(obj, string_types): + if obj is None: + return None + elif isinstance(obj, string_types): return [obj] + elif isinstance(obj, tuple): + return list(obj) elif not hasattr(obj, '__len__'): return [obj] else: @@ -70,6 +88,8 @@ def _as_array(arr, dtype=None): Avoid a copy if possible. """ + if arr is None: + return None if isinstance(arr, np.ndarray) and dtype is None: return arr if isinstance(arr, integer_types + (float,)): @@ -88,8 +108,6 @@ def _as_tuple(item): """Ensure an item is a tuple.""" if item is None: return None - # elif hasattr(item, '__len__'): - # return tuple(item) elif not isinstance(item, tuple): return (item,) else: diff --git a/phy/utils/array.py b/phy/utils/array.py deleted file mode 100644 index c0a8a5a64..000000000 --- a/phy/utils/array.py +++ /dev/null @@ -1,854 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Utility functions for NumPy arrays.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op -from math import floor -from operator import mul -from functools import reduce -import math - -import numpy as np -from six import integer_types, string_types - -from .logging import warn -from ._types import _as_tuple, _as_array - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _range_from_slice(myslice, start=None, stop=None, step=None, length=None): - """Convert a slice to an array of integers.""" - assert isinstance(myslice, slice) - # Find 'step'. - step = step if step is not None else myslice.step - if step is None: - step = 1 - # Find 'start'. - start = start if start is not None else myslice.start - if start is None: - start = 0 - # Find 'stop' as a function of length if 'stop' is unspecified. - stop = stop if stop is not None else myslice.stop - if length is not None: - stop_inferred = floor(start + step * length) - if stop is not None and stop < stop_inferred: - raise ValueError("'stop' ({stop}) and ".format(stop=stop) + - "'length' ({length}) ".format(length=length) + - "are not compatible.") - stop = stop_inferred - if stop is None and length is None: - raise ValueError("'stop' and 'length' cannot be both unspecified.") - myrange = np.arange(start, stop, step) - # Check the length if it was specified. - if length is not None: - assert len(myrange) == length - return myrange - - -def _unique(x): - """Faster version of np.unique(). - - This version is restricted to 1D arrays of non-negative integers. - - It is only faster if len(x) >> len(unique(x)). - - """ - if x is None or len(x) == 0: - return np.array([], dtype=np.int64) - # WARNING: only keep positive values. - # cluster=-1 means "unclustered". - x = _as_array(x) - x = x[x >= 0] - bc = np.bincount(x) - return np.nonzero(bc)[0] - - -def _ensure_unique(func): - """Apply unique() to the output of a function.""" - def wrapped(*args, **kwargs): - out = func(*args, **kwargs) - return _unique(out) - return wrapped - - -def _normalize(arr, keep_ratio=False): - """Normalize an array into [0, 1].""" - (x_min, y_min), (x_max, y_max) = arr.min(axis=0), arr.max(axis=0) - - if keep_ratio: - a = 1. / max(x_max - x_min, y_max - y_min) - ax = ay = a - bx = .5 - .5 * a * (x_max + x_min) - by = .5 - .5 * a * (y_max + y_min) - else: - ax = 1. / (x_max - x_min) - ay = 1. / (y_max - y_min) - bx = -x_min / (x_max - x_min) - by = -y_min / (y_max - y_min) - - arr_n = arr.copy() - arr_n[:, 0] *= ax - arr_n[:, 0] += bx - arr_n[:, 1] *= ay - arr_n[:, 1] += by - - return arr_n - - -def _index_of(arr, lookup): - """Replace scalars in an array by their indices in a lookup table. - - Implicitely assume that: - - * All elements of arr and lookup are non-negative integers. - * All elements or arr belong to lookup. - - This is not checked for performance reasons. - - """ - # Equivalent of np.digitize(arr, lookup) - 1, but much faster. - # TODO: assertions to disable in production for performance reasons. - m = (lookup.max() if len(lookup) else 0) + 1 - tmp = np.zeros(m + 1, dtype=np.int) - # Ensure that -1 values are kept. - tmp[-1] = -1 - if len(lookup): - tmp[lookup] = np.arange(len(lookup)) - return tmp[arr] - - -def index_of_small(arr, lookup): - """Faster on small arrays with large values.""" - return np.searchsorted(lookup, arr) - - -def _partial_shape(shape, trailing_index): - """Return the shape of a partial array.""" - if shape is None: - shape = () - if trailing_index is None: - trailing_index = () - trailing_index = _as_tuple(trailing_index) - # Length of the selection items for the partial array. - len_item = len(shape) - len(trailing_index) - # Array for the trailing dimensions. - _arr = np.empty(shape=shape[len_item:]) - try: - trailing_arr = _arr[trailing_index] - except IndexError: - raise ValueError("The partial shape index is invalid.") - return shape[:len_item] + trailing_arr.shape - - -def _pad(arr, n, dir='right'): - """Pad an array with zeros along the first axis. - - Parameters - ---------- - - n : int - Size of the returned array in the first axis. - dir : str - Direction of the padding. Must be one 'left' or 'right'. - - """ - assert dir in ('left', 'right') - if n < 0: - raise ValueError("'n' must be positive: {0}.".format(n)) - elif n == 0: - return np.zeros((0,) + arr.shape[1:], dtype=arr.dtype) - n_arr = arr.shape[0] - shape = (n,) + arr.shape[1:] - if n_arr == n: - assert arr.shape == shape - return arr - elif n_arr < n: - out = np.zeros(shape, dtype=arr.dtype) - if dir == 'left': - out[-n_arr:, ...] = arr - elif dir == 'right': - out[:n_arr, ...] = arr - assert out.shape == shape - return out - else: - if dir == 'left': - out = arr[-n:, ...] - elif dir == 'right': - out = arr[:n, ...] - assert out.shape == shape - return out - - -def _start_stop(item): - """Find the start and stop indices of a __getitem__ item. - - This is used only by ConcatenatedArrays. - - Only two cases are supported currently: - - * Single integer. - * Contiguous slice in the first dimension only. - - """ - if isinstance(item, tuple): - item = item[0] - if isinstance(item, slice): - # Slice. - if item.step not in (None, 1): - return NotImplementedError() - return item.start, item.stop - elif isinstance(item, (list, np.ndarray)): - # List or array of indices. - return np.min(item), np.max(item) - else: - # Integer. - return item, item + 1 - - -def _len_index(item, max_len=0): - """Return the expected length of the output of __getitem__(item).""" - if isinstance(item, (list, np.ndarray)): - return len(item) - elif isinstance(item, slice): - stop = item.stop or max_len - start = item.start or 0 - step = item.step or 1 - start = np.clip(start, 0, stop) - assert 0 <= start <= stop - return 1 + ((stop - 1 - start) // step) - else: - return 1 - - -def _fill_index(arr, item): - if isinstance(item, tuple): - item = (slice(None, None, None),) + item[1:] - return arr[item] - else: - return arr - - -def _in_polygon(points, polygon): - """Return the points that are inside a polygon.""" - from matplotlib.path import Path - points = _as_array(points) - polygon = _as_array(polygon) - assert points.ndim == 2 - assert polygon.ndim == 2 - path = Path(polygon, closed=True) - return path.contains_points(points) - - -def _concatenate(arrs): - if arrs is None: - return - arrs = [_as_array(arr) for arr in arrs if arr is not None] - if not arrs: - return - return np.concatenate(arrs, axis=0) - - -# ----------------------------------------------------------------------------- -# I/O functions -# ----------------------------------------------------------------------------- - -def _prod(l): - return reduce(mul, l, 1) - - -class LazyMemmap(object): - """A memmapped array that only opens the file handle when required.""" - def __init__(self, path, dtype=None, shape=None, mode='r'): - assert isinstance(path, string_types) - assert dtype - self._path = path - self._f = None - self.mode = mode - self.dtype = dtype - self.shape = shape - self.ndim = len(shape) if shape else None - - def _open_file_if_needed(self): - if self._f is None: - self._f = np.memmap(self._path, - dtype=self.dtype, - shape=self.shape, - mode=self.mode, - ) - self.shape = self._f.shape - self.ndim = self._f.ndim - - def __getitem__(self, item): - self._open_file_if_needed() - return self._f[item] - - def __len__(self): - self._open_file_if_needed() - return len(self._f) - - def __del__(self): - if self._f is not None: - del self._f - - -def _memmap(f_or_path, dtype=None, shape=None, lazy=True): - if not lazy: - return np.memmap(f_or_path, dtype=dtype, shape=shape, mode='r') - else: - return LazyMemmap(f_or_path, dtype=dtype, shape=shape, mode='r') - - -def _file_size(f_or_path): - if isinstance(f_or_path, string_types): - with open(f_or_path, 'rb') as f: - return _file_size(f) - else: - return os.fstat(f_or_path.fileno()).st_size - - -def _load_ndarray(f_or_path, dtype=None, shape=None, memmap=False, lazy=False): - if dtype is None: - return f_or_path - else: - if not memmap: - arr = np.fromfile(f_or_path, dtype=dtype) - if shape is not None: - arr = arr.reshape(shape) - else: - # memmap doesn't accept -1 in shapes, but we can compute - # the missing dimension from the file size. - if shape and shape[0] == -1: - n_bytes = _file_size(f_or_path) - n_items = n_bytes // np.dtype(dtype).itemsize - n_rows = n_items // _prod(shape[1:]) - shape = (n_rows,) + shape[1:] - assert _prod(shape) == n_items - arr = _memmap(f_or_path, dtype=dtype, shape=shape, lazy=lazy) - return arr - - -def _save_arrays(path, arrays): - """Save multiple arrays in a single file by concatenating them along - the first axis. - - A second array is stored with the offsets. - - """ - assert path.endswith('.npy') - path = op.splitext(path)[0] - offsets = np.cumsum([arr.shape[0] for arr in arrays]) - if not len(arrays): - return - concat = np.concatenate(arrays, axis=0) - np.save(path + '.npy', concat) - np.save(path + '.offsets.npy', offsets) - - -def _load_arrays(path): - assert path.endswith('.npy') - if not op.exists(path): - return [] - path = op.splitext(path)[0] - concat = np.load(path + '.npy') - offsets = np.load(path + '.offsets.npy') - return np.split(concat, offsets[:-1], axis=0) - - -# ----------------------------------------------------------------------------- -# Chunking functions -# ----------------------------------------------------------------------------- - -def _excerpt_step(n_samples, n_excerpts=None, excerpt_size=None): - """Compute the step of an excerpt set as a function of the number - of excerpts or their sizes.""" - assert n_excerpts >= 2 - step = max((n_samples - excerpt_size) // (n_excerpts - 1), - excerpt_size) - return step - - -def chunk_bounds(n_samples, chunk_size, overlap=0): - """Return chunk bounds. - - Chunks have the form: - - [ overlap/2 | chunk_size-overlap | overlap/2 ] - s_start keep_start keep_end s_end - - Except for the first and last chunks which do not have a left/right - overlap. - - This generator yields (s_start, s_end, keep_start, keep_end). - - """ - s_start = 0 - s_end = chunk_size - keep_start = s_start - keep_end = s_end - overlap // 2 - yield s_start, s_end, keep_start, keep_end - - while s_end - overlap + chunk_size < n_samples: - s_start = s_end - overlap - s_end = s_start + chunk_size - keep_start = keep_end - keep_end = s_end - overlap // 2 - if s_start < s_end: - yield s_start, s_end, keep_start, keep_end - - s_start = s_end - overlap - s_end = n_samples - keep_start = keep_end - keep_end = s_end - if s_start < s_end: - yield s_start, s_end, keep_start, keep_end - - -def excerpts(n_samples, n_excerpts=None, excerpt_size=None): - """Yield (start, end) where start is included and end is excluded.""" - assert n_excerpts >= 2 - step = _excerpt_step(n_samples, - n_excerpts=n_excerpts, - excerpt_size=excerpt_size) - for i in range(n_excerpts): - start = i * step - if start >= n_samples: - break - end = min(start + excerpt_size, n_samples) - yield start, end - - -def data_chunk(data, chunk, with_overlap=False): - """Get a data chunk.""" - assert isinstance(chunk, tuple) - if len(chunk) == 2: - i, j = chunk - elif len(chunk) == 4: - if with_overlap: - i, j = chunk[:2] - else: - i, j = chunk[2:] - else: - raise ValueError("'chunk' should have 2 or 4 elements, " - "not {0:d}".format(len(chunk))) - return data[i:j, ...] - - -def get_excerpts(data, n_excerpts=None, excerpt_size=None): - assert n_excerpts is not None - assert excerpt_size is not None - if len(data) < n_excerpts * excerpt_size: - return data - elif n_excerpts == 0: - return data[:0] - elif n_excerpts == 1: - return data[:excerpt_size] - out = np.concatenate([data_chunk(data, chunk) - for chunk in excerpts(len(data), - n_excerpts=n_excerpts, - excerpt_size=excerpt_size)]) - assert len(out) <= n_excerpts * excerpt_size - return out - - -def regular_subset(spikes=None, n_spikes_max=None): - """Prune the current selection to get at most n_spikes_max spikes.""" - assert spikes is not None - # Nothing to do if the selection already satisfies n_spikes_max. - if n_spikes_max is None or len(spikes) <= n_spikes_max: - return spikes - step = math.ceil(np.clip(1. / n_spikes_max * len(spikes), - 1, len(spikes))) - step = int(step) - # Random shift. - # start = np.random.randint(low=0, high=step) - # Note: randomly-changing selections are confusing... - start = 0 - my_spikes = spikes[start::step][:n_spikes_max] - assert len(my_spikes) <= len(spikes) - assert len(my_spikes) <= n_spikes_max - return my_spikes - - -# ----------------------------------------------------------------------------- -# Spike clusters utility functions -# ----------------------------------------------------------------------------- - -def _spikes_in_clusters(spike_clusters, clusters): - """Return the ids of all spikes belonging to the specified clusters.""" - if len(spike_clusters) == 0 or len(clusters) == 0: - return np.array([], dtype=np.int) - # spikes_per_cluster case. - if isinstance(spike_clusters, dict): - return np.sort(np.concatenate([spike_clusters[cluster] - for cluster in clusters])) - return np.nonzero(np.in1d(spike_clusters, clusters))[0] - - -def _spikes_per_cluster(spike_ids, spike_clusters): - """Return a dictionary {cluster: list_of_spikes}.""" - if not len(spike_ids): - return {} - rel_spikes = np.argsort(spike_clusters) - abs_spikes = spike_ids[rel_spikes] - spike_clusters = spike_clusters[rel_spikes] - - diff = np.empty_like(spike_clusters) - diff[0] = 1 - diff[1:] = np.diff(spike_clusters) - - idx = np.nonzero(diff > 0)[0] - clusters = spike_clusters[idx] - - spikes_in_clusters = {clusters[i]: np.sort(abs_spikes[idx[i]:idx[i + 1]]) - for i in range(len(clusters) - 1)} - spikes_in_clusters[clusters[-1]] = np.sort(abs_spikes[idx[-1]:]) - - return spikes_in_clusters - - -def _flatten_spikes_per_cluster(spikes_per_cluster): - """Convert a dictionary {cluster: list_of_spikes} to a - spike_clusters array.""" - clusters = sorted(spikes_per_cluster) - clusters_arr = np.concatenate([(cluster * - np.ones(len(spikes_per_cluster[cluster]))) - for cluster in clusters]).astype(np.int64) - spikes_arr = np.concatenate([spikes_per_cluster[cluster] - for cluster in clusters]) - spike_clusters = np.vstack((spikes_arr, clusters_arr)) - ind = np.argsort(spike_clusters[0, :]) - return spike_clusters[1, ind] - - -def _concatenate_per_cluster_arrays(spikes_per_cluster, arrays): - """Concatenate arrays from a {cluster: array} dictionary.""" - assert set(arrays) <= set(spikes_per_cluster) - clusters = sorted(arrays) - # Check the sizes of the spikes per cluster and the arrays. - n_0 = [len(spikes_per_cluster[cluster]) for cluster in clusters] - n_1 = [len(arrays[cluster]) for cluster in clusters] - assert n_0 == n_1 - - # Concatenate all spikes to find the right insertion order. - if not len(clusters): - return np.array([]) - - spikes = np.concatenate([spikes_per_cluster[cluster] - for cluster in clusters]) - idx = np.argsort(spikes) - # NOTE: concatenate all arrays along the first axis, because we assume - # that the first axis represents the spikes. - arrays = np.concatenate([_as_array(arrays[cluster]) - for cluster in clusters]) - return arrays[idx, ...] - - -def _subset_spc(spc, clusters): - return {c: s for c, s in spc.items() - if c in clusters} - - -class PerClusterData(object): - """Store data associated to every spike. - - This class provides several data structures, with per-spike data and - per-cluster data. It also defines a `subset()` method that allows to - make a subset of the data using either spikes or clusters. - - """ - def __init__(self, - spike_ids=None, array=None, spike_clusters=None, - spc=None, arrays=None): - if (array is not None and spike_ids is not None): - # From array to per-cluster arrays. - self._spike_ids = _as_array(spike_ids) - self._array = _as_array(array) - self._spike_clusters = _as_array(spike_clusters) - self._check_array() - self._split() - self._check_dict() - elif (arrays is not None and spc is not None): - # From per-cluster arrays to array. - self._spc = spc - self._arrays = arrays - self._check_dict() - self._concatenate() - self._check_array() - else: - raise ValueError() - - @property - def spike_ids(self): - """Sorted array of all spike ids.""" - return self._spike_ids - - @property - def spike_clusters(self): - """Array with the cluster id of every spike.""" - return self._spike_clusters - - @property - def array(self): - """Data array. - - The first dimension of the array corresponds to the spikes in the - cluster. - - """ - return self._array - - @property - def arrays(self): - """Dictionary of arrays `{cluster: array}`. - - The first dimension of the arrays correspond to the spikes in the - cluster. - - """ - return self._arrays - - @property - def spc(self): - """Spikes per cluster dictionary.""" - return self._spc - - @property - def cluster_ids(self): - """Sorted list of clusters.""" - return self._cluster_ids - - @property - def n_clusters(self): - return len(self._cluster_ids) - - def _check_dict(self): - assert set(self._arrays) == set(self._spc) - clusters = sorted(self._arrays) - n_0 = [len(self._spc[cluster]) for cluster in clusters] - n_1 = [len(self._arrays[cluster]) for cluster in clusters] - assert n_0 == n_1 - - def _check_array(self): - assert len(self._array) == len(self._spike_ids) - assert len(self._spike_clusters) == len(self._spike_ids) - - def _concatenate(self): - self._cluster_ids = sorted(self._spc) - n = len(self.cluster_ids) - if n == 0: - self._array = np.array([]) - self._spike_clusters = np.array([], dtype=np.int32) - self._spike_ids = np.array([], dtype=np.int64) - elif n == 1: - c = self.cluster_ids[0] - self._array = _as_array(self._arrays[c]) - self._spike_ids = self._spc[c] - self._spike_clusters = c * np.ones(len(self._spike_ids), - dtype=np.int32) - else: - # Concatenate all spikes to find the right insertion order. - spikes = np.concatenate([self._spc[cluster] - for cluster in self.cluster_ids]) - idx = np.argsort(spikes) - self._spike_ids = np.sort(spikes) - # NOTE: concatenate all arrays along the first axis, because we - # assume that the first axis represents the spikes. - # TODO OPTIM: use ConcatenatedArray and implement custom indices. - # array = ConcatenatedArrays([_as_array(self._arrays[cluster]) - # for cluster in self.cluster_ids]) - array = np.concatenate([_as_array(self._arrays[cluster]) - for cluster in self.cluster_ids]) - self._array = array[idx] - self._spike_clusters = _flatten_spikes_per_cluster(self._spc) - - def _split(self): - self._spc = _spikes_per_cluster(self._spike_ids, - self._spike_clusters) - self._cluster_ids = sorted(self._spc) - n = len(self.cluster_ids) - # Optimization for single cluster. - if n == 0: - self._arrays = {} - elif n == 1: - c = self._cluster_ids[0] - self._arrays = {c: self._array} - else: - self._arrays = {} - for cluster in sorted(self._cluster_ids): - spk = _as_array(self._spc[cluster]) - spk_rel = _index_of(spk, self._spike_ids) - self._arrays[cluster] = self._array[spk_rel] - - def subset(self, spike_ids=None, clusters=None, spc=None): - """Return a new PerClusterData instance with a subset of the data. - - There are three ways to specify the subset: - - * With a list of spikes - * With a list of clusters - * With a dictionary of `{cluster: some_spikes}` - - """ - if spike_ids is not None: - if np.array_equal(spike_ids, self._spike_ids): - return self - assert np.all(np.in1d(spike_ids, self._spike_ids)) - spike_ids_s_rel = _index_of(spike_ids, self._spike_ids) - array_s = self._array[spike_ids_s_rel] - spike_clusters_s = self._spike_clusters[spike_ids_s_rel] - return PerClusterData(spike_ids=spike_ids, - array=array_s, - spike_clusters=spike_clusters_s, - ) - elif clusters is not None: - assert set(clusters) <= set(self._cluster_ids) - spc_s = {clu: self._spc[clu] for clu in clusters} - arrays_s = {clu: self._arrays[clu] for clu in clusters} - return PerClusterData(spc=spc_s, arrays=arrays_s) - elif spc is not None: - clusters = sorted(spc) - assert set(clusters) <= set(self._cluster_ids) - arrays_s = {} - for cluster in clusters: - spk_rel = _index_of(_as_array(spc[cluster]), - _as_array(self._spc[cluster])) - arrays_s[cluster] = _as_array(self._arrays[cluster])[spk_rel] - return PerClusterData(spc=spc, arrays=arrays_s) - - -# ----------------------------------------------------------------------------- -# PartialArray -# ----------------------------------------------------------------------------- - -class PartialArray(object): - """Proxy to a view of an array, allowing selection along the first - dimensions and fixing the trailing dimensions.""" - def __init__(self, arr, trailing_index=None): - self._arr = arr - self._trailing_index = _as_tuple(trailing_index) - self.shape = _partial_shape(arr.shape, self._trailing_index) - self.dtype = arr.dtype - self.ndim = len(self.shape) - - def __getitem__(self, item): - if self._trailing_index is None: - return self._arr[item] - else: - item = _as_tuple(item) - k = len(item) - n = len(self._arr.shape) - t = len(self._trailing_index) - if k < (n - t): - item += (slice(None, None, None),) * (n - k - t) - item += self._trailing_index - if len(item) != n: - raise ValueError("The array selection is invalid: " - "{0}".format(str(item))) - return self._arr[item] - - def __len__(self): - return self.shape[0] - - -class ConcatenatedArrays(object): - """This object represents a concatenation of several memory-mapped - arrays.""" - def __init__(self, arrs): - assert isinstance(arrs, list) - self.arrs = arrs - self.offsets = np.concatenate([[0], np.cumsum([arr.shape[0] - for arr in arrs])], - axis=0) - self.dtype = arrs[0].dtype if arrs else None - self.shape = (self.offsets[-1],) + arrs[0].shape[1:] - - def _get_recording(self, index): - """Return the recording that contains a given index.""" - assert index >= 0 - recs = np.nonzero((index - self.offsets[:-1]) >= 0)[0] - if len(recs) == 0: - # If the index is greater than the total size, - # return the last recording. - return len(self.arrs) - 1 - # Return the last recording such that the index is greater than - # its offset. - return recs[-1] - - def __getitem__(self, item): - # Get the start and stop indices of the requested item. - start, stop = _start_stop(item) - # Return the concatenation of all arrays. - if start is None and stop is None: - return np.concatenate(self.arrs, axis=0) - if start is None: - start = 0 - if stop is None: - stop = self.offsets[-1] - if stop < 0: - stop = self.offsets[-1] + stop - # Get the recording indices of the first and last item. - rec_start = self._get_recording(start) - rec_stop = self._get_recording(stop) - assert 0 <= rec_start <= rec_stop < len(self.arrs) - # Find the start and stop relative to the arrays. - start_rel = start - self.offsets[rec_start] - stop_rel = stop - self.offsets[rec_stop] - # Single array case. - if rec_start == rec_stop: - # Apply the rest of the index. - return _fill_index(self.arrs[rec_start][start_rel:stop_rel], - item) - chunk_start = self.arrs[rec_start][start_rel:] - chunk_stop = self.arrs[rec_stop][:stop_rel] - # Concatenate all chunks. - l = [chunk_start] - if rec_stop - rec_start >= 2: - warn("Loading a full virtual array: this might be slow " - "and something might be wrong.") - l += [self.arrs[r][...] for r in range(rec_start + 1, - rec_stop)] - l += [chunk_stop] - # Apply the rest of the index. - return _fill_index(np.concatenate(l, axis=0), item) - - def __len__(self): - return self.shape[0] - - -class VirtualMappedArray(object): - """A virtual mapped array that yields null arrays to any selection.""" - def __init__(self, shape, dtype, fill=0): - self.shape = shape - self.dtype = dtype - self.ndim = len(self.shape) - self._fill = fill - - def __getitem__(self, item): - if isinstance(item, integer_types): - return self._fill * np.ones(self.shape[1:], dtype=self.dtype) - else: - assert not isinstance(item, tuple) - n = _len_index(item, max_len=self.shape[0]) - return self._fill * np.ones((n,) + self.shape[1:], - dtype=self.dtype) - - def __len__(self): - return self.shape[0] - - -def _concatenate_virtual_arrays(arrs): - """Return a virtual concatenate of several NumPy arrays.""" - n = len(arrs) - if n == 0: - return None - elif n == 1: - return arrs[0] - return ConcatenatedArrays(arrs) diff --git a/phy/utils/cli.py b/phy/utils/cli.py new file mode 100644 index 000000000..3841b8a1d --- /dev/null +++ b/phy/utils/cli.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# flake8: noqa + +"""CLI tool.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging +import os +import os.path as op +import sys +from traceback import format_exception + +import click + +from phy import (add_default_handler, DEBUG, _Formatter, _logger_fmt, + __version_git__, discover_plugins) +from phy.utils import _fullname + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Set up logging with the CLI tool +#------------------------------------------------------------------------------ + +add_default_handler(level='DEBUG' if DEBUG else 'INFO') + + +def exceptionHandler(exception_type, exception, traceback): # pragma: no cover + logger.error("An error has occurred (%s): %s", + exception_type.__name__, exception) + logger.debug(''.join(format_exception(exception_type, + exception, + traceback))) + +# Only show traceback in debug mode (--debug). +# if not DEBUG: +sys.excepthook = exceptionHandler + + +def _add_log_file(filename): + """Create a `phy.log` log file with DEBUG level in the + current directory.""" + handler = logging.FileHandler(filename) + + handler.setLevel(logging.DEBUG) + formatter = _Formatter(fmt=_logger_fmt, + datefmt='%Y-%m-%d %H:%M:%S') + handler.setFormatter(formatter) + logging.getLogger().addHandler(handler) + + +#------------------------------------------------------------------------------ +# CLI tool +#------------------------------------------------------------------------------ + +@click.group() +@click.version_option(version=__version_git__) +@click.help_option('-h', '--help') +@click.pass_context +def phy(ctx): + """By default, the `phy` command does nothing. Add subcommands with plugins + using `attach_to_cli()` and the `click` library.""" + + # Create a `phy.log` log file with DEBUG level in the current directory. + _add_log_file(op.join(os.getcwd(), 'phy.log')) + + +#------------------------------------------------------------------------------ +# CLI plugins +#------------------------------------------------------------------------------ + +def load_cli_plugins(cli, config_dir=None): + """Load all plugins and attach them to a CLI object.""" + from .config import load_master_config + + config = load_master_config(config_dir=config_dir) + plugins = discover_plugins(config.Plugins.dirs) + + for plugin in plugins: + if not hasattr(plugin, 'attach_to_cli'): # pragma: no cover + continue + logger.debug("Attach plugin `%s` to CLI.", _fullname(plugin)) + # NOTE: plugin is a class, so we need to instantiate it. + try: + plugin().attach_to_cli(cli) + except Exception as e: # pragma: no cover + logger.error("Error when loading plugin `%s`: %s", plugin, e) + + +# Load all plugins when importing this module. +load_cli_plugins(phy) diff --git a/phy/utils/config.py b/phy/utils/config.py new file mode 100644 index 000000000..c4b744433 --- /dev/null +++ b/phy/utils/config.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +"""Config.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import logging +import os +import os.path as op +from textwrap import dedent + +from traitlets.config import (Config, + PyFileConfigLoader, + JSONFileConfigLoader, + ) + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# Config +#------------------------------------------------------------------------------ + +def phy_config_dir(): + """Return the absolute path to the phy user directory.""" + return op.expanduser('~/.phy/') + + +def _ensure_dir_exists(path): + """Ensure a directory exists.""" + if not op.exists(path): + os.makedirs(path) + assert op.exists(path) and op.isdir(path) + + +def load_config(path): + """Load a Python or JSON config file.""" + if not op.exists(path): + return Config() + path = op.realpath(path) + dirpath, filename = op.split(path) + file_ext = op.splitext(path)[1] + logger.debug("Load config file `%s`.", path) + if file_ext == '.py': + config = PyFileConfigLoader(filename, dirpath, + log=logger).load_config() + elif file_ext == '.json': + config = JSONFileConfigLoader(filename, dirpath, + log=logger).load_config() + return config + + +def _default_config(config_dir=None): + path = op.join(config_dir or op.join('~', '.phy'), 'plugins/') + return dedent(""" + # You can also put your plugins in ~/.phy/plugins/. + + from phy import IPlugin + + class MyPlugin(IPlugin): + def attach_to_cli(self, cli): + pass + + + c = get_config() + c.Plugins.dirs = [r'{}'] + """.format(path)) + + +def load_master_config(config_dir=None): + """Load a master Config file from `~/.phy/phy_config.py`.""" + config_dir = config_dir or phy_config_dir() + path = op.join(config_dir, 'phy_config.py') + # Create a default config file if necessary. + if not op.exists(path): + _ensure_dir_exists(op.dirname(path)) + logger.debug("Creating default phy config file at `%s`.", path) + with open(path, 'w') as f: + f.write(_default_config(config_dir=config_dir)) + assert op.exists(path) + return load_config(path) + + +def save_config(path, config): + """Save a config object to a JSON file.""" + import json + config['version'] = 1 + with open(path, 'w') as f: + json.dump(config, f) diff --git a/phy/utils/event.py b/phy/utils/event.py index 47250a561..b04ca2a6a 100644 --- a/phy/utils/event.py +++ b/phy/utils/event.py @@ -11,7 +11,6 @@ import re from collections import defaultdict from functools import partial -from inspect import getargspec #------------------------------------------------------------------------------ @@ -44,9 +43,9 @@ def on_my_event(arg, key=None): """ def __init__(self): - self.reset() + self._reset() - def reset(self): + def _reset(self): """Remove all registered callbacks.""" self._callbacks = defaultdict(list) @@ -112,19 +111,13 @@ def emit(self, event, *args, **kwargs): """Call all callback functions registered with an event. Any positional and keyword arguments can be passed here, and they will - be fowarded to the callback functions. + be forwarded to the callback functions. Return the list of callback return results. """ res = [] for callback in self._callbacks.get(event, []): - argspec = getargspec(callback) - if not argspec.keywords: - # Only keep the kwargs that are part of the callback's - # arg spec, unless the callback accepts `**kwargs`. - kwargs = {n: v for n, v in kwargs.items() - if n in argspec.args} res.append(callback(*args, **kwargs)) return res @@ -153,7 +146,7 @@ def format_field(self, value, spec): def _default_on_progress(message, value, value_max, end='\r', **kwargs): - if value_max == 0: + if value_max == 0: # pragma: no cover return if value <= value_max: progress = 100 * value / float(value_max) @@ -243,7 +236,6 @@ def increment(self, **kwargs): def reset(self, value_max=None): """Reset the value to 0 and the value max to a given value.""" - super(ProgressReporter, self).reset() self._value = 0 if value_max is not None: self._value_max = value_max diff --git a/phy/utils/logging.py b/phy/utils/logging.py deleted file mode 100644 index 2608de6ce..000000000 --- a/phy/utils/logging.py +++ /dev/null @@ -1,210 +0,0 @@ -from __future__ import absolute_import - -"""Logger utility classes and functions.""" - - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -import os.path as op -import sys -import logging - -from six import iteritems, string_types - - -# ----------------------------------------------------------------------------- -# Stream classes -# ----------------------------------------------------------------------------- - -class StringStream(object): - """Logger stream used to store all logs in a string.""" - def __init__(self): - self.string = "" - - def write(self, line): - self.string += line - - def flush(self): - pass - - def __repr__(self): - return self.string - - -# ----------------------------------------------------------------------------- -# Custom formatter -# ----------------------------------------------------------------------------- - -_LONG_FORMAT = ('%(asctime)s [%(levelname)s] ' - '%(caller)s %(message)s', - '%Y-%m-%d %H:%M:%S') -_SHORT_FORMAT = ('%(asctime)s [%(levelname)s] %(message)s', - '%H:%M:%S') - - -class Formatter(logging.Formatter): - def format(self, record): - # Only keep the first character in the level name. - record.levelname = record.levelname[0] - filename = op.splitext(op.basename(record.pathname))[0] - record.caller = '{:s}:{:d}'.format(filename, record.lineno).ljust(16) - return super(Formatter, self).format(record) - - -# ----------------------------------------------------------------------------- -# Logger classes -# ----------------------------------------------------------------------------- - -class Logger(object): - """Save logging information to a stream.""" - def __init__(self, - fmt=None, - stream=None, - level=None, - name=None, - handler=None, - ): - if stream is None: - stream = sys.stdout - if name is None: - name = self.__class__.__name__ - self.name = name - if handler is None: - self.stream = stream - self.handler = logging.StreamHandler(self.stream) - else: - self.handler = handler - self.level = level - self.fmt = fmt - # Set the level and corresponding formatter. - self.set_level(level, fmt) - - def set_level(self, level=None, fmt=None): - # Default level and format. - if level is None: - level = self.level or logging.INFO - if isinstance(level, string_types): - level = getattr(logging, level.upper()) - fmt = fmt or self.fmt - # Create the Logger object. - self._logger = logging.getLogger(self.name) - # Create the formatter. - if fmt is None: - fmt, datefmt = (_LONG_FORMAT if level == logging.DEBUG - else _SHORT_FORMAT) - else: - datefmt = None - formatter = Formatter(fmt=fmt, datefmt=datefmt) - self.handler.setFormatter(formatter) - # Configure the logger. - self._logger.setLevel(level) - self._logger.propagate = False - self._logger.addHandler(self.handler) - - def close(self): - pass - - def debug(self, msg): - self._logger.debug(msg) - - def info(self, msg): - self._logger.info(msg) - - def warn(self, msg): - self._logger.warn(msg) - - -class StringLogger(Logger): - """Log to a string.""" - def __init__(self, **kwargs): - kwargs['stream'] = StringStream() - super(StringLogger, self).__init__(**kwargs) - - def __repr__(self): - return self.stream.__repr__() - - -class ConsoleLogger(Logger): - """Log to the standard output.""" - def __init__(self, **kwargs): - kwargs['stream'] = sys.stdout - super(ConsoleLogger, self).__init__(**kwargs) - - -class FileLogger(Logger): - """Log to a file.""" - def __init__(self, filename=None, **kwargs): - kwargs['handler'] = logging.FileHandler(filename) - super(FileLogger, self).__init__(**kwargs) - - def close(self): - self.handler.close() - self._logger.removeHandler(self.handler) - del self.handler - del self._logger - - -# ----------------------------------------------------------------------------- -# Global variables -# ----------------------------------------------------------------------------- - -LOGGERS = {} - - -def register(logger): - """Register a logger.""" - name = logger.name - if name not in LOGGERS: - LOGGERS[name] = logger - - -def unregister(logger): - """Unregister a logger.""" - name = logger.name - if name in LOGGERS: - LOGGERS[name].close() - del LOGGERS[name] - - -def _log(level, *msg): - if isinstance(msg, tuple): - msg = ' '.join(str(_) for _ in msg) - for name, logger in iteritems(LOGGERS): - getattr(logger, level)(msg) - - -def debug(*msg): - """Generate a debug message.""" - _log('debug', *msg) - - -def info(*msg): - """Generate an info message.""" - _log('info', *msg) - - -def warn(*msg): - """Generate a warning.""" - _log('warn', *msg) - - -def set_level(level): - """Set the level of all registered loggers. - - Parameters - ---------- - - level : str - Can be `warn`, `info`, or `debug`. - - """ - for name, logger in iteritems(LOGGERS): - logger.set_level(level) - - -def _default_logger(level='info'): - """Create a default logger in `info` mode by default.""" - register(ConsoleLogger()) - set_level(level) diff --git a/phy/utils/plugin.py b/phy/utils/plugin.py new file mode 100644 index 000000000..b38214d7d --- /dev/null +++ b/phy/utils/plugin.py @@ -0,0 +1,112 @@ +# -*- coding: utf-8 -*- + +"""Plugin system. + +Code from http://eli.thegreenplace.net/2012/08/07/fundamental-concepts-of-plugin-infrastructures # noqa + +""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import imp +import logging +import os +import os.path as op + +from six import with_metaclass + +from ._misc import _fullname + +logger = logging.getLogger(__name__) + + +#------------------------------------------------------------------------------ +# IPlugin interface +#------------------------------------------------------------------------------ + +class IPluginRegistry(type): + plugins = [] + + def __init__(cls, name, bases, attrs): + if name != 'IPlugin': + logger.debug("Register plugin `%s`.", _fullname(cls)) + if _fullname(cls) not in (_fullname(_) + for _ in IPluginRegistry.plugins): + IPluginRegistry.plugins.append(cls) + + +class IPlugin(with_metaclass(IPluginRegistry)): + """A class deriving from IPlugin can implement the following methods: + + * `attach_to_cli(cli)`: called when the CLI is created. + + """ + pass + + +def get_plugin(name): + """Get a plugin class from its name.""" + for plugin in IPluginRegistry.plugins: + if name in plugin.__name__: + return plugin + raise ValueError("The plugin %s cannot be found." % name) + + +#------------------------------------------------------------------------------ +# Plugins discovery +#------------------------------------------------------------------------------ + +def _iter_plugin_files(dirs): + for plugin_dir in dirs: + plugin_dir = op.realpath(op.expanduser(plugin_dir)) + if not op.exists(plugin_dir): + continue + for subdir, dirs, files in os.walk(plugin_dir): + # Skip test folders. + base = op.basename(subdir) + if 'test' in base or '__' in base: # pragma: no cover + continue + logger.debug("Scanning `%s`.", subdir) + for filename in files: + if (filename.startswith('__') or + not filename.endswith('.py')): + continue # pragma: no cover + logger.debug("Found plugin module `%s`.", filename) + yield op.join(subdir, filename) + + +def discover_plugins(dirs): + """Discover the plugin classes contained in Python files. + + Parameters + ---------- + + dirs : list + List of directory names to scan. + + Returns + ------- + + plugins : list + List of plugin classes. + + """ + # Scan all subdirectories recursively. + for path in _iter_plugin_files(dirs): + filename = op.basename(path) + subdir = op.dirname(path) + modname, ext = op.splitext(filename) + file, path, descr = imp.find_module(modname, [subdir]) + if file: + # Loading the module registers the plugin in + # IPluginRegistry. + try: + mod = imp.load_module(modname, file, path, descr) # noqa + except Exception as e: # pragma: no cover + logger.exception(e) + finally: + file.close() + return IPluginRegistry.plugins diff --git a/phy/utils/selector.py b/phy/utils/selector.py deleted file mode 100644 index 891fd3d21..000000000 --- a/phy/utils/selector.py +++ /dev/null @@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Selector structure.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np - -from ._types import _as_array, _as_list -from .array import (regular_subset, - get_excerpts, - _unique, - _ensure_unique, - _spikes_in_clusters, - _spikes_per_cluster, - ) - - -#------------------------------------------------------------------------------ -# Utility functions -#------------------------------------------------------------------------------ - -def _concat(l): - if not len(l): - return np.array([], dtype=np.int64) - return np.sort(np.hstack(l)) - - -#------------------------------------------------------------------------------ -# Selector class -#------------------------------------------------------------------------------ - -class Selector(object): - """Object representing a selection of spikes or clusters.""" - def __init__(self, spike_clusters, - n_spikes_max=None, - excerpt_size=None, - ): - self._spike_clusters = spike_clusters - self._n_spikes_max = n_spikes_max - self._n_spikes = (len(spike_clusters) - if spike_clusters is not None else 0) - self._excerpt_size = excerpt_size - self._selected_spikes = np.array([], dtype=np.int64) - self._selected_clusters = None - - @property - def n_spikes_max(self): - """Maximum number of spikes allowed in the selection.""" - return self._n_spikes_max - - @n_spikes_max.setter - def n_spikes_max(self, value): - """Change the maximum number of spikes allowed.""" - self._n_spikes_max = value - # Update the selected spikes accordingly. - self.selected_spikes = self.subset_spikes() - if self._n_spikes_max is not None: - assert len(self._selected_spikes) <= self._n_spikes_max - - @property - def excerpt_size(self): - """Maximum number of spikes allowed in the selection.""" - return self._excerpt_size - - @excerpt_size.setter - def excerpt_size(self, value): - """Change the excerpt size.""" - self._excerpt_size = value - # Update the selected spikes accordingly. - self.selected_spikes = self.subset_spikes() - - @_ensure_unique - def subset_spikes(self, - spikes=None, - n_spikes_max=None, - excerpt_size=None, - ): - """Prune the current selection to get at most `n_spikes_max` spikes. - - Parameters - ---------- - - spikes : array-like - Array of spike ids to subset from. By default, this is - `selector.selected_spikes`. - n_spikes_max : int or None - Maximum number of spikes allowed in the selection. - excerpt_size : int or None - If None, the method returns a regular strided selection. - Otherwise, returns a regular selection of contiguous chunks - with the specified chunk size. - - """ - # Default arguments. - if spikes is None: - spikes = self._selected_spikes - if spikes is None or len(spikes) == 0: - return spikes - if n_spikes_max is None: - n_spikes_max = self._n_spikes_max or len(spikes) - if excerpt_size is None: - excerpt_size = self._excerpt_size - # Nothing to do if there are less spikes than the maximum number. - if len(spikes) <= n_spikes_max: - return spikes - # Take a regular or chunked subset of the spikes. - if excerpt_size is None: - return regular_subset(spikes, n_spikes_max) - else: - n_excerpts = n_spikes_max // excerpt_size - return get_excerpts(spikes, - n_excerpts=n_excerpts, - excerpt_size=excerpt_size, - ) - - def subset_spikes_clusters(self, clusters, - n_spikes_max=None, - excerpt_size=None, - ): - """Take a subselection of spikes belonging to a set of clusters. - - This method ensures that the same number of spikes is chosen - for every spike. - - `n_spikes_max` is the maximum number of spikers *per cluster*. - - """ - if not len(clusters): - return {} - # Get the selection parameters. - if n_spikes_max is None: - n_spikes_max = self._n_spikes_max or self._n_spikes - if excerpt_size is None: - excerpt_size = self._excerpt_size - # Take all spikes from the selected clusters. - spikes = _spikes_in_clusters(self._spike_clusters, clusters) - if not len(spikes): - return {} - # Group the spikes per cluster. - spc = _spikes_per_cluster(spikes, self._spike_clusters[spikes]) - # Do nothing if there are less spikes than the maximum number. - if len(spikes) <= n_spikes_max: - return spc - # Take a regular or chunked subset of the spikes. - if excerpt_size is None: - return {cluster: regular_subset(spc[cluster], n_spikes_max) - for cluster in clusters} - else: - n_excerpts = n_spikes_max // excerpt_size - return {cluster: get_excerpts(spc[cluster], - n_excerpts=n_excerpts, - excerpt_size=excerpt_size, - ) - for cluster in clusters} - - @property - def selected_spikes(self): - """Ids of the selected spikes.""" - return self._selected_spikes - - @selected_spikes.setter - def selected_spikes(self, value): - """Explicitely select some spikes. - - The selection is automatically pruned to ensure that less than - `n_spikes_max` spikes are selected. - - """ - value = _as_array(value) - # Make sure there are less spikes than n_spikes_max. - self._selected_spikes = self.subset_spikes(value) - - @property - def selected_clusters(self): - """Cluster ids appearing in the current spike selection.""" - if self._selected_clusters is not None: - return self._selected_clusters - clusters = _unique(self._spike_clusters[self._selected_spikes]) - return clusters - - @selected_clusters.setter - def selected_clusters(self, value): - """Select some clusters. - - This will select less than `n_spikes_max` spikes belonging to - those clusters. - - """ - self._selected_clusters = _as_list(value) - value = _as_array(value) - # Make sure there are less spikes than n_spikes_max. - spk = self.subset_spikes_clusters(value) - self._selected_spikes = _concat(spk.values()) - - @property - def n_spikes(self): - return len(self._selected_spikes) - - @property - def n_clusters(self): - return len(self.selected_clusters) - - def on_cluster(self, up=None): - """Callback method called when the clustering has changed. - - This currently does nothing, i.e. the spike selection remains - unchanged when merges and splits occur. - - """ - # TODO - pass diff --git a/phy/utils/settings.py b/phy/utils/settings.py deleted file mode 100644 index dad69b24f..000000000 --- a/phy/utils/settings.py +++ /dev/null @@ -1,239 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Settings.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os -import os.path as op - -from .logging import debug, warn -from ._misc import _load_json, _save_json, _read_python - - -#------------------------------------------------------------------------------ -# Settings -#------------------------------------------------------------------------------ - -def _create_empty_settings(path): - """Create an empty Python file if the path doesn't exist.""" - # If the file exists of if it is an internal settings file, skip. - if op.exists(path) or op.splitext(path)[1] == '': - return - debug("Creating empty settings file: {}.".format(path)) - with open(path, 'a') as f: - f.write("# Settings file. Refer to phy's documentation " - "for more details.\n") - - -def _recursive_dirs(): - """Yield all subdirectories paths in phy's package.""" - phy_root = op.join(op.realpath(op.dirname(__file__)), '../') - for root, dirs, files in os.walk(phy_root): - root = op.realpath(root) - root = op.relpath(root, phy_root) - if ('.' in root or '_' in root or 'tests' in root or - 'static' in root or 'glsl' in root): - continue - yield op.realpath(op.join(phy_root, root)) - - -def _default_settings_paths(): - return [op.join(dir, 'default_settings.py') - for dir in _recursive_dirs()] - - -def _load_default_settings(paths=None): - """Load all default settings in phy's package.""" - if paths is None: - paths = _default_settings_paths() - settings = BaseSettings() - for path in paths: - if op.exists(path): - settings.load(path) - return settings - - -class BaseSettings(object): - """Store key-value pairs.""" - def __init__(self): - self._store = {} - self._to_save = {} - - def __getitem__(self, key): - return self._store[key] - - def __setitem__(self, key, value): - self._store[key] = value - self._to_save[key] = value - - def __contains__(self, key): - return key in self._store - - def __repr__(self): - return self._store.__repr__() - - def keys(self): - """List of settings keys.""" - return self._store.keys() - - def _update(self, d): - for k, v in d.items(): - if isinstance(v, dict) and k in self._store: - # Update instead of overwrite settings dictionaries. - self._store[k].update(v) - else: - self._store[k] = v - - def _try_load_json(self, path): - try: - self._update(_load_json(path)) - # debug("Loaded internal settings file " - # "from `{}`.".format(path)) - return True - except Exception: - return False - - def _try_load_python(self, path): - try: - self._update(_read_python(path)) - # debug("Loaded internal settings file " - # "from `{}`.".format(path)) - return True - except Exception: - return False - - def load(self, path): - """Load a settings file.""" - path = op.realpath(path) - if not op.exists(path): - debug("Settings file `{}` doesn't exist.".format(path)) - return - # Try JSON first, then Python. - has_ext = op.splitext(path)[1] != '' - if not has_ext: - if self._try_load_json(path): - return - if not self._try_load_python(path): - warn("Unable to read '{}'. ".format(path) + - "Please try to delete this file.") - - def save(self, path): - """Save the settings to a JSON file.""" - path = op.realpath(path) - try: - _save_json(path, self._to_save) - debug("Saved internal settings file " - "to `{}`.".format(path)) - except Exception as e: - warn("Unable to save the internal settings file " - "to `{}`:\n{}".format(path, str(e))) - self._to_save = {} - - -class Settings(object): - """Manage default, user-wide, and experiment-wide settings.""" - - def __init__(self, phy_user_dir=None, default_paths=None): - self.phy_user_dir = phy_user_dir - if self.phy_user_dir: - _ensure_dir_exists(self.phy_user_dir) - self._default_paths = default_paths or _default_settings_paths() - self._bs = BaseSettings() - self._load_user_settings() - - def _load_user_settings(self): - # Load phy's defaults. - if self._default_paths: - for path in self._default_paths: - if op.exists(path): - self._bs.load(path) - - if not self.phy_user_dir: - return - - # User settings. - self.user_settings_path = op.join(self.phy_user_dir, - 'user_settings.py') - - # Create empty settings path if necessary. - _create_empty_settings(self.user_settings_path) - - self._bs.load(self.user_settings_path) - - # Load the user's internal settings. - self.internal_settings_path = op.join(self.phy_user_dir, - 'internal_settings') - self._bs.load(self.internal_settings_path) - - def on_open(self, path): - """Initialize settings when loading an experiment.""" - if path is None: - debug("Unable to initialize the settings for unspecified " - "model path.") - return - # Get the experiment settings path. - path = op.realpath(op.expanduser(path)) - self.exp_path = path - self.exp_name = op.splitext(op.basename(path))[0] - self.exp_dir = op.dirname(path) - self.exp_settings_dir = op.join(self.exp_dir, self.exp_name + '.phy') - - self.exp_settings_path = op.join(self.exp_settings_dir, - 'user_settings.py') - _ensure_dir_exists(self.exp_settings_dir) - - # Create empty settings path if necessary. - _create_empty_settings(self.exp_settings_path) - - # Load experiment-wide settings. - self._load_user_settings() - self._bs.load(self.exp_settings_path) - - def save(self): - """Save settings to an internal settings file.""" - self._bs.save(self.internal_settings_path) - - def get(self, key, default=None): - """Return a settings value.""" - if key in self: - return self[key] - else: - return default - - def __getitem__(self, key): - return self._bs[key] - - def __setitem__(self, key, value): - self._bs[key] = value - - def __contains__(self, key): - return key in self._bs - - def __repr__(self): - return "".format(self._bs.__repr__()) - - def keys(self): - """Return the list of settings keys.""" - return self._bs.keys() - - -#------------------------------------------------------------------------------ -# Config -#------------------------------------------------------------------------------ - -_PHY_USER_DIR_NAME = '.phy' - - -def _phy_user_dir(): - """Return the absolute path to the phy user directory.""" - home = op.expanduser("~") - path = op.realpath(op.join(home, _PHY_USER_DIR_NAME)) - return path - - -def _ensure_dir_exists(path): - if not op.exists(path): - os.makedirs(path) diff --git a/phy/utils/testing.py b/phy/utils/testing.py index e5bd2b3d4..375b1d51a 100644 --- a/phy/utils/testing.py +++ b/phy/utils/testing.py @@ -6,13 +6,15 @@ # Imports #------------------------------------------------------------------------------ -import sys -import time from contextlib import contextmanager -from timeit import default_timer from cProfile import Profile -import os.path as op import functools +import logging +import os +import os.path as op +import sys +import time +from timeit import default_timer from numpy.testing import assert_array_equal as ae from numpy.testing import assert_allclose as ac @@ -20,8 +22,9 @@ from six.moves import builtins from ._types import _is_array_like -from .logging import info -from .settings import _ensure_dir_exists +from .config import _ensure_dir_exists + +logger = logging.getLogger(__name__) #------------------------------------------------------------------------------ @@ -39,6 +42,23 @@ def captured_output(): sys.stdout, sys.stderr = old_out, old_err +@contextmanager +def captured_logging(name=None): + buffer = StringIO() + logger = logging.getLogger(name) + handlers = logger.handlers + for handler in logger.handlers: + logger.removeHandler(handler) + handler = logging.StreamHandler(buffer) + handler.setLevel(logging.DEBUG) + logger.addHandler(handler) + yield buffer + buffer.flush() + logger.removeHandler(handler) + for handler in handlers: + logger.addHandler(handler) + + def _assert_equal(d_0, d_1): """Check that two objects are equal.""" # Compare arrays. @@ -49,10 +69,9 @@ def _assert_equal(d_0, d_1): ac(d_0, d_1) # Compare dicts recursively. elif isinstance(d_0, dict): - assert sorted(d_0) == sorted(d_1) - for (k_0, k_1) in zip(sorted(d_0), sorted(d_1)): - assert k_0 == k_1 - _assert_equal(d_0[k_0], d_1[k_1]) + assert set(d_0) == set(d_1) + for k_0 in d_0: + _assert_equal(d_0[k_0], d_1[k_0]) else: # General comparison. assert d_0 == d_1 @@ -67,10 +86,10 @@ def benchmark(name='', repeats=1): start = default_timer() yield duration = (default_timer() - start) * 1000. - info("{} took {:.6f}ms.".format(name, duration / repeats)) + logger.info("%s took %.6fms.", name, duration / repeats) -class ContextualProfile(Profile): +class ContextualProfile(Profile): # pragma: no cover def __init__(self, *args, **kwds): super(ContextualProfile, self).__init__(*args, **kwds) self.enable_count = 0 @@ -116,7 +135,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.disable_by_count() -def _enable_profiler(line_by_line=False): +def _enable_profiler(line_by_line=False): # pragma: no cover if 'profile' in builtins.__dict__: return builtins.__dict__['profile'] if line_by_line: @@ -136,13 +155,13 @@ def _profile(prof, statement, glob, loc): # Capture stdout. old_stdout = sys.stdout sys.stdout = output = StringIO() - try: + try: # pragma: no cover from line_profiler import LineProfiler if isinstance(prof, LineProfiler): prof.print_stats() else: prof.print_stats('cumulative') - except ImportError: + except ImportError: # pragma: no cover prof.print_stats('cumulative') sys.stdout = old_stdout stats = output.getvalue() @@ -156,47 +175,21 @@ def _profile(prof, statement, glob, loc): # Testing VisPy canvas #------------------------------------------------------------------------------ -def _frame(canvas): - canvas.update() - canvas.app.process_events() - time.sleep(1. / 60.) - - -def show_test(canvas, n_frames=2): +def show_test(canvas): """Show a VisPy canvas for a fraction of second.""" - with canvas as c: - show_test_run(c, n_frames) - - -def show_test_start(canvas): - """This is the __enter__ of with canvas.""" - canvas.show() - canvas._backend._vispy_warmup() - - -def show_test_run(canvas, n_frames=2): - """Display frames of a canvas.""" - if n_frames == 0: - while not canvas._closed: - _frame(canvas) - else: - for _ in range(n_frames): - _frame(canvas) - if canvas._closed: - return - - -def show_test_stop(canvas): - """This is the __exit__ of with canvas.""" - # ensure all GL calls are complete - if not canvas._closed: - canvas._backend._vispy_set_current() - canvas.context.finish() - canvas.close() - time.sleep(0.025) # ensure window is really closed/destroyed + with canvas: + # Interactive mode for tests. + if 'PYTEST_INTERACT' in os.environ: # pragma: no cover + while not canvas._closed: + canvas.update() + canvas.app.process_events() + time.sleep(1. / 60) + else: + canvas.update() + canvas.app.process_events() -def show_colored_canvas(color, n_frames=5): +def show_colored_canvas(color): """Show a transient VisPy canvas with a uniform background color.""" from vispy import app, gloo c = app.Canvas() @@ -205,4 +198,4 @@ def show_colored_canvas(color, n_frames=5): def on_draw(e): gloo.clear(color) - show_test(c, n_frames=n_frames) + show_test(c) diff --git a/phy/utils/tests/conftest.py b/phy/utils/tests/conftest.py new file mode 100644 index 000000000..25f21eac2 --- /dev/null +++ b/phy/utils/tests/conftest.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +"""py.test fixtures.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +from pytest import yield_fixture + + +#------------------------------------------------------------------------------ +# Common fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def temp_config_dir(tempdir): + """NOTE: the user directory should be loaded with: + + ```python + from .. import config + config.phy_config_dir() + ``` + + and not: + + ```python + from config import phy_config_dir + ``` + + Otherwise, the monkey patching hack in tests won't work. + + """ + from phy.utils import config + + config_dir = config.phy_config_dir + config.phy_config_dir = lambda: tempdir + yield tempdir + config.phy_config_dir = config_dir diff --git a/phy/utils/tests/test_array.py b/phy/utils/tests/test_array.py deleted file mode 100644 index c7688ed75..000000000 --- a/phy/utils/tests/test_array.py +++ /dev/null @@ -1,610 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Tests of array utility functions.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op -from itertools import product - -import numpy as np -from pytest import raises, mark - -from .._types import _as_array, _as_tuple -from ..array import (_unique, - _normalize, - _index_of, - _in_polygon, - _load_ndarray, - _len_index, - _spikes_in_clusters, - _spikes_per_cluster, - _flatten_spikes_per_cluster, - _concatenate_per_cluster_arrays, - chunk_bounds, - excerpts, - data_chunk, - get_excerpts, - PartialArray, - VirtualMappedArray, - PerClusterData, - _partial_shape, - _range_from_slice, - _pad, - _concatenate_virtual_arrays, - _load_arrays, - _save_arrays, - ) -from ..testing import _assert_equal as ae -from ...io.mock import artificial_spike_clusters - - -#------------------------------------------------------------------------------ -# Test utility functions -#------------------------------------------------------------------------------ - -def test_range_from_slice(): - """Test '_range_from_slice'.""" - - class _SliceTest(object): - """Utility class to make it more convenient to test slice objects.""" - def __init__(self, **kwargs): - self._kwargs = kwargs - - def __getitem__(self, item): - if isinstance(item, slice): - return _range_from_slice(item, **self._kwargs) - - with raises(ValueError): - _SliceTest()[:] - with raises(ValueError): - _SliceTest()[1:] - ae(_SliceTest()[:5], [0, 1, 2, 3, 4]) - ae(_SliceTest()[1:5], [1, 2, 3, 4]) - - with raises(ValueError): - _SliceTest()[::2] - with raises(ValueError): - _SliceTest()[1::2] - ae(_SliceTest()[1:5:2], [1, 3]) - - with raises(ValueError): - _SliceTest(start=0)[:] - with raises(ValueError): - _SliceTest(start=1)[:] - with raises(ValueError): - _SliceTest(step=2)[:] - - ae(_SliceTest(stop=5)[:], [0, 1, 2, 3, 4]) - ae(_SliceTest(start=1, stop=5)[:], [1, 2, 3, 4]) - ae(_SliceTest(stop=5)[1:], [1, 2, 3, 4]) - ae(_SliceTest(start=1)[:5], [1, 2, 3, 4]) - ae(_SliceTest(start=1, step=2)[:5], [1, 3]) - ae(_SliceTest(start=1)[:5:2], [1, 3]) - - ae(_SliceTest(length=5)[:], [0, 1, 2, 3, 4]) - with raises(ValueError): - _SliceTest(length=5)[:3] - ae(_SliceTest(length=5)[:10], [0, 1, 2, 3, 4]) - ae(_SliceTest(length=5)[:5], [0, 1, 2, 3, 4]) - ae(_SliceTest(start=1, length=5)[:], [1, 2, 3, 4, 5]) - ae(_SliceTest(start=1, length=5)[:6], [1, 2, 3, 4, 5]) - with raises(ValueError): - _SliceTest(start=1, length=5)[:4] - ae(_SliceTest(start=1, step=2, stop=5)[:], [1, 3]) - ae(_SliceTest(start=1, stop=5)[::2], [1, 3]) - ae(_SliceTest(stop=5)[1::2], [1, 3]) - - -def test_pad(): - arr = np.random.rand(10, 3) - - ae(_pad(arr, 0, 'right'), arr[:0, :]) - ae(_pad(arr, 3, 'right'), arr[:3, :]) - ae(_pad(arr, 9), arr[:9, :]) - ae(_pad(arr, 10), arr) - - ae(_pad(arr, 12, 'right')[:10, :], arr) - ae(_pad(arr, 12)[10:, :], np.zeros((2, 3))) - - ae(_pad(arr, 0, 'left'), arr[:0, :]) - ae(_pad(arr, 3, 'left'), arr[7:, :]) - ae(_pad(arr, 9, 'left'), arr[1:, :]) - ae(_pad(arr, 10, 'left'), arr) - - ae(_pad(arr, 12, 'left')[2:, :], arr) - ae(_pad(arr, 12, 'left')[:2, :], np.zeros((2, 3))) - - with raises(ValueError): - _pad(arr, -1) - - -def test_unique(): - """Test _unique() function""" - _unique([]) - - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - ae(_unique(spike_clusters), np.arange(n_clusters)) - - -def test_normalize(): - """Test _normalize() function.""" - - n_channels = 10 - positions = 1 + 2 * np.random.randn(n_channels, 2) - - # Keep ration is False. - positions_n = _normalize(positions) - - x_min, y_min = positions_n.min(axis=0) - x_max, y_max = positions_n.max(axis=0) - - np.allclose(x_min, 0.) - np.allclose(x_max, 1.) - np.allclose(y_min, 0.) - np.allclose(y_max, 1.) - - # Keep ratio is True. - positions_n = _normalize(positions, keep_ratio=True) - - x_min, y_min = positions_n.min(axis=0) - x_max, y_max = positions_n.max(axis=0) - - np.allclose(min(x_min, y_min), 0.) - np.allclose(max(x_max, y_max), 1.) - np.allclose(x_min + x_max, 1) - np.allclose(y_min + y_max, 1) - - -def test_index_of(): - """Test _index_of.""" - arr = [36, 42, 42, 36, 36, 2, 42] - lookup = _unique(arr) - ae(_index_of(arr, lookup), [1, 2, 2, 1, 1, 0, 2]) - - -def test_as_tuple(): - assert _as_tuple(3) == (3,) - assert _as_tuple((3,)) == (3,) - assert _as_tuple(None) is None - assert _as_tuple((None,)) == (None,) - assert _as_tuple((3, 4)) == (3, 4) - assert _as_tuple([3]) == ([3], ) - assert _as_tuple([3, 4]) == ([3, 4], ) - - -def test_len_index(): - arr = np.arange(10) - - class _Check(object): - def __getitem__(self, item): - if isinstance(item, tuple): - item, max_len = item - else: - max_len = None - assert _len_index(item, max_len) == (len(arr[item]) - if hasattr(arr[item], - '__len__') else 1) - - _check = _Check() - - for start in (0, 1, 2): - _check[start] - _check[start:1] - _check[start:2] - _check[start:3] - _check[start:3:2] - _check[start:5] - _check[start:5:2] - _check[start:, 10] - _check[start::2, 10] - _check[start::3, 10] - - -def test_virtual_mapped_array(): - shape = (10, 2) - dtype = np.float32 - arr = VirtualMappedArray(shape, dtype, 1) - arr_actual = np.ones(shape, dtype=dtype) - - class _Check(object): - def __getitem__(self, item): - ae(arr[item], arr_actual[item]) - - _check = _Check() - - for start in (0, 1, 2): - _check[start] - _check[start:1] - _check[start:2] - _check[start:3] - _check[start:3:2] - _check[start:5] - _check[start:5:2] - _check[start:] - _check[start::2] - _check[start::3] - - -def test_as_array(): - ae(_as_array(3), [3]) - ae(_as_array([3]), [3]) - ae(_as_array(3.), [3.]) - ae(_as_array([3.]), [3.]) - - with raises(ValueError): - _as_array(map) - - -def test_concatenate_virtual_arrays(): - arr1 = np.random.rand(5, 2) - arr2 = np.random.rand(4, 2) - - def _concat(*arrs): - return np.concatenate(arrs, axis=0) - - # Single array. - concat = _concatenate_virtual_arrays([arr1]) - ae(concat[:], arr1) - ae(concat[1:], arr1[1:]) - ae(concat[:3], arr1[:3]) - ae(concat[1:4], arr1[1:4]) - - # Two arrays. - concat = _concatenate_virtual_arrays([arr1, arr2]) - # First array. - ae(concat[1:], _concat(arr1[1:], arr2)) - ae(concat[:3], arr1[:3]) - ae(concat[1:4], arr1[1:4]) - # Second array. - ae(concat[5:], arr2) - ae(concat[6:], arr2[1:]) - ae(concat[5:8], arr2[:3]) - ae(concat[7:9], arr2[2:]) - ae(concat[7:12], arr2[2:]) - ae(concat[5:-1], arr2[:-1]) - # Both arrays. - ae(concat[:], _concat(arr1, arr2)) - ae(concat[1:], _concat(arr1[1:], arr2)) - ae(concat[:-1], _concat(arr1, arr2[:-1])) - ae(concat[:9], _concat(arr1, arr2)) - ae(concat[:10], _concat(arr1, arr2)) - ae(concat[:8], _concat(arr1, arr2[:-1])) - ae(concat[1:7], _concat(arr1[1:], arr2[:-2])) - ae(concat[4:7], _concat(arr1[4:], arr2[:-2])) - - # Check second axis. - for idx in (slice(None, None, None), - 0, - 1, - [0], - [1], - [0, 1], - [1, 0], - ): - # First array. - ae(concat[1:4, idx], arr1[1:4, idx]) - # Second array. - ae(concat[6:, idx], arr2[1:, idx]) - # Both arrays. - ae(concat[1:7, idx], _concat(arr1[1:, idx], arr2[:-2, idx])) - - -def test_in_polygon(): - polygon = [[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]] - points = np.random.uniform(size=(100, 2), low=-1, high=1) - idx_expected = np.nonzero((points[:, 0] > 0) & - (points[:, 1] > 0) & - (points[:, 0] < 1) & - (points[:, 1] < 1))[0] - idx = np.nonzero(_in_polygon(points, polygon))[0] - ae(idx, idx_expected) - - -#------------------------------------------------------------------------------ -# Test I/O functions -#------------------------------------------------------------------------------ - -@mark.parametrize('memmap,lazy', product([False, True], [False, True])) -def test_load_ndarray(tempdir, memmap, lazy): - n, m = 10000, 100 - dtype = np.float32 - arr = np.random.randn(n, m).astype(dtype) - path = op.join(tempdir, 'test') - with open(path, 'wb') as f: - arr.tofile(f) - arr_m = _load_ndarray(path, - dtype=dtype, - shape=(n, m), - memmap=memmap, - lazy=lazy, - ) - ae(arr, arr_m[...]) - - -@mark.parametrize('n', [20, 0]) -def test_load_save_arrays(tempdir, n): - path = op.join(tempdir, 'test.npy') - # Random arrays. - arrays = [] - for i in range(n): - size = np.random.randint(low=3, high=50) - assert size > 0 - arr = np.random.rand(size, 3).astype(np.float32) - arrays.append(arr) - _save_arrays(path, arrays) - - arrays_loaded = _load_arrays(path) - - assert len(arrays) == len(arrays_loaded) - for arr, arr_loaded in zip(arrays, arrays_loaded): - assert arr.shape == arr_loaded.shape - assert arr.dtype == arr_loaded.dtype - ae(arr, arr_loaded) - - -#------------------------------------------------------------------------------ -# Test chunking -#------------------------------------------------------------------------------ - -def test_chunk_bounds(): - chunks = chunk_bounds(200, 100, overlap=20) - - assert next(chunks) == (0, 100, 0, 90) - assert next(chunks) == (80, 180, 90, 170) - assert next(chunks) == (160, 200, 170, 200) - - -def test_chunk(): - data = np.random.randn(200, 4) - chunks = chunk_bounds(data.shape[0], 100, overlap=20) - - with raises(ValueError): - data_chunk(data, (0, 0, 0)) - - assert data_chunk(data, (0, 0)).shape == (0, 4) - - # Chunk 1. - ch = next(chunks) - d = data_chunk(data, ch) - d_o = data_chunk(data, ch, with_overlap=True) - - ae(d_o, data[0:100]) - ae(d, data[0:90]) - - # Chunk 2. - ch = next(chunks) - d = data_chunk(data, ch) - d_o = data_chunk(data, ch, with_overlap=True) - - ae(d_o, data[80:180]) - ae(d, data[90:170]) - - -def test_excerpts_1(): - bounds = [(start, end) for (start, end) in excerpts(100, - n_excerpts=3, - excerpt_size=10)] - assert bounds == [(0, 10), (45, 55), (90, 100)] - - -def test_excerpts_2(): - bounds = [(start, end) for (start, end) in excerpts(10, - n_excerpts=3, - excerpt_size=10)] - assert bounds == [(0, 10)] - - -def test_get_excerpts(): - data = np.random.rand(100, 2) - subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) - assert subdata.shape == (50, 2) - ae(subdata[:5, :], data[:5, :]) - ae(subdata[-5:, :], data[-10:-5, :]) - - data = np.random.rand(10, 2) - subdata = get_excerpts(data, n_excerpts=10, excerpt_size=5) - ae(subdata, data) - - -#------------------------------------------------------------------------------ -# Test spike clusters functions -#------------------------------------------------------------------------------ - -def test_spikes_in_clusters(): - """Test _spikes_in_clusters().""" - - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - ae(_spikes_in_clusters(spike_clusters, []), []) - - for i in range(n_clusters): - assert np.all(spike_clusters[_spikes_in_clusters(spike_clusters, - [i])] == i) - - clusters = [1, 5, 9] - assert np.all(np.in1d(spike_clusters[_spikes_in_clusters(spike_clusters, - clusters)], - clusters)) - - -def test_spikes_per_cluster(): - """Test _spikes_per_cluster().""" - - n_spikes = 1000 - spike_ids = np.arange(n_spikes).astype(np.int64) - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - spikes_per_cluster = _spikes_per_cluster(spike_ids, spike_clusters) - assert list(spikes_per_cluster.keys()) == list(range(n_clusters)) - - for i in range(10): - ae(spikes_per_cluster[i], np.sort(spikes_per_cluster[i])) - assert np.all(spike_clusters[spikes_per_cluster[i]] == i) - - sc = _flatten_spikes_per_cluster(spikes_per_cluster) - ae(spike_clusters, sc) - - -def test_concatenate_per_cluster_arrays(): - """Test _spikes_per_cluster().""" - - def _column(arr): - out = np.zeros((len(arr), 10)) - out[:, 0] = arr - return out - - # 8, 11, 12, 13, 17, 18, 20 - spikes_per_cluster = {2: [11, 13, 17], 3: [8, 12], 5: [18, 20]} - - arrays_1d = {2: [1, 3, 7], 3: [8, 2], 5: [8, 0]} - - arrays_2d = {2: _column([1, 3, 7]), - 3: _column([8, 2]), - 5: _column([8, 0])} - - concat = _concatenate_per_cluster_arrays(spikes_per_cluster, arrays_1d) - ae(concat, [8, 1, 2, 3, 7, 8, 0]) - - concat = _concatenate_per_cluster_arrays(spikes_per_cluster, arrays_2d) - ae(concat[:, 0], [8, 1, 2, 3, 7, 8, 0]) - ae(concat[:, 1:], np.zeros((7, 9))) - - -def test_per_cluster_data(): - - spike_ids = [8, 11, 12, 13, 17, 18, 20] - spc = { - 2: [11, 13, 17], - 3: [8, 12], - 5: [18, 20], - } - spike_clusters = [3, 2, 3, 2, 2, 5, 5] - arrays = { - 2: [1, 3, 7], - 3: [8, 2], - 5: [8, 0], - } - array = [8, 1, 2, 3, 7, 8, 0] - - def _check(pcd): - ae(pcd.spike_ids, spike_ids) - ae(pcd.spike_clusters, spike_clusters) - ae(pcd.array, array) - ae(pcd.spc, spc) - ae(pcd.arrays, arrays) - - # Check subset on 1 cluster. - pcd_s = pcd.subset(clusters=[2]) - ae(pcd_s.spike_ids, [11, 13, 17]) - ae(pcd_s.spike_clusters, [2, 2, 2]) - ae(pcd_s.array, [1, 3, 7]) - ae(pcd_s.spc, {2: [11, 13, 17]}) - ae(pcd_s.arrays, {2: [1, 3, 7]}) - - # Check subset on some spikes. - pcd_s = pcd.subset(spike_ids=[11, 12, 13, 17]) - ae(pcd_s.spike_ids, [11, 12, 13, 17]) - ae(pcd_s.spike_clusters, [2, 3, 2, 2]) - ae(pcd_s.array, [1, 2, 3, 7]) - ae(pcd_s.spc, {2: [11, 13, 17], 3: [12]}) - ae(pcd_s.arrays, {2: [1, 3, 7], 3: [2]}) - - # Check subset on 2 complete clusters. - pcd_s = pcd.subset(clusters=[3, 5]) - ae(pcd_s.spike_ids, [8, 12, 18, 20]) - ae(pcd_s.spike_clusters, [3, 3, 5, 5]) - ae(pcd_s.array, [8, 2, 8, 0]) - ae(pcd_s.spc, {3: [8, 12], 5: [18, 20]}) - ae(pcd_s.arrays, {3: [8, 2], 5: [8, 0]}) - - # Check subset on 2 incomplete clusters. - pcd_s = pcd.subset(spc={3: [8, 12], 5: [20]}) - ae(pcd_s.spike_ids, [8, 12, 20]) - ae(pcd_s.spike_clusters, [3, 3, 5]) - ae(pcd_s.array, [8, 2, 0]) - ae(pcd_s.spc, {3: [8, 12], 5: [20]}) - ae(pcd_s.arrays, {3: [8, 2], 5: [0]}) - - # From flat arrays. - pcd = PerClusterData(spike_ids=spike_ids, - array=array, - spike_clusters=spike_clusters, - ) - _check(pcd) - - # From dicts. - pcd = PerClusterData(spc=spc, arrays=arrays) - _check(pcd) - - -#------------------------------------------------------------------------------ -# Test PartialArray -#------------------------------------------------------------------------------ - -def test_partial_shape(): - - _partial_shape(None, ()) - _partial_shape((), None) - _partial_shape((), ()) - _partial_shape(None, None) - - assert _partial_shape((5, 3), 1) == (5,) - assert _partial_shape((5, 3), (1,)) == (5,) - assert _partial_shape((5, 10, 2), 1) == (5, 10) - with raises(ValueError): - _partial_shape((5, 10, 2), (1, 2)) - assert _partial_shape((5, 10, 3), (1, 2)) == (5,) - assert _partial_shape((5, 10, 3), (slice(None, None, None), 2)) == (5, 10) - assert _partial_shape((5, 10, 3), (slice(1, None, None), 2)) == (5, 9) - assert _partial_shape((5, 10, 3), (slice(1, 5, None), 2)) == (5, 4) - assert _partial_shape((5, 10, 3), (slice(4, None, 3), 2)) == (5, 2) - - -def test_partial_array(): - # 2D array. - arr = np.random.rand(5, 2) - - ae(PartialArray(arr)[:], arr) - - pa = PartialArray(arr, 1) - assert pa.shape == (5,) - ae(pa[0], arr[0, 1]) - ae(pa[0:2], arr[0:2, 1]) - ae(pa[[1, 2]], arr[[1, 2], 1]) - with raises(ValueError): - pa[[1, 2], 0] - - # 3D array. - arr = np.random.rand(5, 3, 2) - - pa = PartialArray(arr, (2, 1)) - assert pa.shape == (5,) - ae(pa[0], arr[0, 2, 1]) - ae(pa[0:2], arr[0:2, 2, 1]) - ae(pa[[1, 2]], arr[[1, 2], 2, 1]) - with raises(ValueError): - pa[[1, 2], 0] - - pa = PartialArray(arr, (1,)) - assert pa.shape == (5, 3) - ae(pa[0, 2], arr[0, 2, 1]) - ae(pa[0:2, 1], arr[0:2, 1, 1]) - ae(pa[[1, 2], 0], arr[[1, 2], 0, 1]) - ae(pa[[1, 2]], arr[[1, 2], :, 1]) - - # Slice and 3D. - arr = np.random.rand(5, 10, 2) - - pa = PartialArray(arr, (slice(1, None, 3), 1)) - assert pa.shape == (5, 3) - ae(pa[0], arr[0, 1::3, 1]) - ae(pa[0:2], arr[0:2, 1::3, 1]) - ae(pa[[1, 2]], arr[[1, 2], 1::3, 1]) diff --git a/phy/utils/tests/test_cli.py b/phy/utils/tests/test_cli.py new file mode 100644 index 000000000..e8f9f22bf --- /dev/null +++ b/phy/utils/tests/test_cli.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# flake8: noqa + +"""Test CLI tool.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +from click.testing import CliRunner +from pytest import yield_fixture + +from .._misc import _write_text + + +#------------------------------------------------------------------------------ +# Test CLI tool +#------------------------------------------------------------------------------ + +@yield_fixture +def runner(): + yield CliRunner() + + +def test_cli_empty(temp_config_dir, runner): + + # NOTE: make the import after the temp_config_dir fixture, to avoid + # loading any user plugin affecting the CLI. + from ..cli import phy, load_cli_plugins + load_cli_plugins(phy) + + result = runner.invoke(phy, []) + assert result.exit_code == 0 + + result = runner.invoke(phy, ['--version']) + assert result.exit_code == 0 + assert result.output.startswith('phy,') + + result = runner.invoke(phy, ['--help']) + assert result.exit_code == 0 + assert result.output.startswith('Usage: phy') + + +def test_cli_plugins(temp_config_dir, runner): + + # Write a CLI plugin. + cli_plugin = """ + import click + from phy import IPlugin + + class MyPlugin(IPlugin): + def attach_to_cli(self, cli): + @cli.command() + def hello(): + click.echo("hello world") + """ + path = op.join(temp_config_dir, 'plugins/hello.py') + _write_text(path, cli_plugin) + + # NOTE: make the import after the temp_config_dir fixture, to avoid + # loading any user plugin affecting the CLI. + from ..cli import phy, load_cli_plugins + load_cli_plugins(phy, config_dir=temp_config_dir) + + # The plugin should have added a new command. + result = runner.invoke(phy, ['--help']) + assert result.exit_code == 0 + assert 'hello' in result.output + + # The plugin should have added a new command. + result = runner.invoke(phy, ['hello']) + assert result.exit_code == 0 + assert result.output == 'hello world\n' diff --git a/phy/utils/tests/test_color.py b/phy/utils/tests/test_color.py index 94eb4ccf8..2974e8b09 100644 --- a/phy/utils/tests/test_color.py +++ b/phy/utils/tests/test_color.py @@ -6,16 +6,14 @@ # Imports #------------------------------------------------------------------------------ -from pytest import mark +import numpy as np -from .._color import _random_color, _is_bright, _random_bright_color +from .._color import (_random_color, _is_bright, _random_bright_color, + _colormap, _spike_colors, ColorSelector, + ) from ..testing import show_colored_canvas -# Skip these tests in "make test-quick". -pytestmark = mark.long - - #------------------------------------------------------------------------------ # Tests #------------------------------------------------------------------------------ @@ -24,4 +22,22 @@ def test_random_color(): color = _random_color() show_colored_canvas(color) - assert _is_bright(_random_bright_color()) + for _ in range(10): + assert _is_bright(_random_bright_color()) + + +def test_colormap(): + assert len(_colormap(0)) == 3 + assert len(_colormap(1000)) == 3 + + assert _spike_colors([0, 1, 10, 1000]).shape == (4, 4) + assert _spike_colors([0, 1, 10, 1000], + alpha=1.).shape == (4, 4) + assert _spike_colors([0, 1, 10, 1000], + masks=np.linspace(0., 1., 4)).shape == (4, 4) + + +def test_color_selector(): + sel = ColorSelector() + assert len(sel.get(0)) == 4 + assert len(sel.get(0, [1, 0])) == 4 diff --git a/phy/utils/tests/test_config.py b/phy/utils/tests/test_config.py new file mode 100644 index 000000000..a13b2d3e9 --- /dev/null +++ b/phy/utils/tests/test_config.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +"""Test config.""" + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op +from textwrap import dedent + +from pytest import yield_fixture +from traitlets import Float +from traitlets.config import Configurable + +from .. import config as _config +from .._misc import _write_text +from ..config import (_ensure_dir_exists, + load_config, + load_master_config, + save_config, + ) + + +#------------------------------------------------------------------------------ +# Test config +#------------------------------------------------------------------------------ + +def test_phy_config_dir(): + assert _config.phy_config_dir().endswith('.phy/') + + +def test_ensure_dir_exists(tempdir): + path = op.join(tempdir, 'a/b/c') + _ensure_dir_exists(path) + assert op.isdir(path) + + +def test_temp_config_dir(temp_config_dir): + assert _config.phy_config_dir() == temp_config_dir + + +#------------------------------------------------------------------------------ +# Config tests +#------------------------------------------------------------------------------ + +@yield_fixture +def py_config(tempdir): + # Create and load a config file. + config_contents = """ + c = get_config() + c.MyConfigurable.my_var = 1.0 + """ + path = op.join(tempdir, 'config.py') + _write_text(path, config_contents) + yield path + + +@yield_fixture +def json_config(tempdir): + # Create and load a config file. + config_contents = """ + { + "MyConfigurable": { + "my_var": 1.0 + } + } + """ + path = op.join(tempdir, 'config.json') + _write_text(path, config_contents) + yield path + + +@yield_fixture(params=['python', 'json']) +def config(py_config, json_config, request): + if request.param == 'python': + yield py_config + elif request.param == 'json': + yield json_config + + +def test_load_config(config): + + class MyConfigurable(Configurable): + my_var = Float(0.0, config=True) + + assert MyConfigurable().my_var == 0.0 + + c = load_config(config) + assert c.MyConfigurable.my_var == 1.0 + + # Create a new MyConfigurable instance. + configurable = MyConfigurable() + assert configurable.my_var == 0.0 + + # Load the config object. + configurable.update_config(c) + assert configurable.my_var == 1.0 + + +def test_load_master_config(temp_config_dir): + # Create a config file in the temporary user directory. + config_contents = dedent(""" + c = get_config() + c.MyConfigurable.my_var = 1.0 + """) + with open(op.join(temp_config_dir, 'phy_config.py'), 'w') as f: + f.write(config_contents) + + # Load the master config file. + c = load_master_config() + assert c.MyConfigurable.my_var == 1. + + +def test_save_config(tempdir): + c = {'A': {'b': 3.}} + path = op.join(tempdir, 'config.json') + save_config(path, c) + + c1 = load_config(path) + assert c1.A.b == 3. diff --git a/phy/utils/tests/test_event.py b/phy/utils/tests/test_event.py index 1b0d19d2f..8a2d19aa2 100644 --- a/phy/utils/tests/test_event.py +++ b/phy/utils/tests/test_event.py @@ -20,6 +20,9 @@ def test_event_system(): _list = [] + with raises(ValueError): + ev.connect(lambda x: x) + @ev.connect(set_method=True) def on_my_event(arg, kwarg=None): _list.append((arg, kwarg)) @@ -62,6 +65,7 @@ def on_complete(): pr.value_max = 10 pr.value = 0 pr.value = 5 + assert pr.value == 5 assert pr.progress == .5 assert not pr.is_complete() pr.value = 10 @@ -89,7 +93,8 @@ def on_complete(): def test_progress_message(): """Test messages with the progress reporter.""" pr = ProgressReporter() - pr.set_progress_message("The progress is {progress}%. ({hello})") + pr.reset(5) + pr.set_progress_message("The progress is {progress}%. ({hello:d})") pr.set_complete_message("Finished {hello}.") pr.value_max = 10 @@ -97,6 +102,10 @@ def test_progress_message(): print() pr.value = 5 print() - pr.increment(hello='hello world') + pr.increment() + print() + pr.increment(hello='hello') + print() + pr.increment(hello=3) print() pr.value = 10 diff --git a/phy/utils/tests/test_logging.py b/phy/utils/tests/test_logging.py deleted file mode 100644 index 49c427809..000000000 --- a/phy/utils/tests/test_logging.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Unit tests for logging module.""" - -# ----------------------------------------------------------------------------- -# Imports -# ----------------------------------------------------------------------------- - -import os - -from ..logging import (StringLogger, ConsoleLogger, debug, info, warn, - FileLogger, register, unregister, - set_level) - - -# ----------------------------------------------------------------------------- -# Tests -# ----------------------------------------------------------------------------- -def test_string_logger(): - l = StringLogger(fmt='') - l.info("test 1") - l.info("test 2") - - log = str(l) - logs = log.split('\n') - - assert "test 1" in logs[0] - assert "test 2" in logs[1] - - -def test_console_logger(): - l = ConsoleLogger(fmt='') - l.info("test 1") - l.info("test 2") - - l = ConsoleLogger(level='debug') - l.info("test 1") - l.info("test 2") - - -def test_file_logger(): - logfile = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'log.txt') - l = FileLogger(logfile, fmt='', level='debug') - l.debug("test file 1") - l.debug("test file 2") - l.info("test file info") - l.warn("test file warn") - l.close() - - with open(logfile, 'r') as f: - contents = f.read() - - assert contents.strip().startswith("test file 1\ntest file 2") - - os.remove(logfile) - - -def test_register(): - l = StringLogger(fmt='') - register(l) - - set_level('info') - debug("test D1") - info("test I1") - warn("test W1") - - set_level('Debug') - debug("test D2") - info("test I2") - warn("test W2") - assert len(str(l).strip().split('\n')) == 5 - - unregister(l) diff --git a/phy/utils/tests/test_misc.py b/phy/utils/tests/test_misc.py index a40488136..8a8c0a2f3 100644 --- a/phy/utils/tests/test_misc.py +++ b/phy/utils/tests/test_misc.py @@ -13,16 +13,41 @@ import numpy as np from numpy.testing import assert_array_equal as ae from pytest import raises +from six import string_types -from .._misc import _git_version, _load_json, _save_json +from .._misc import (_git_version, _load_json, _save_json, _read_python, + _write_text, + _encode_qbytearray, _decode_qbytearray, + ) #------------------------------------------------------------------------------ -# Tests +# Misc tests #------------------------------------------------------------------------------ +def test_qbytearray(tempdir): + + from phy.gui.qt import QByteArray + arr = QByteArray() + arr.append('1') + arr.append('2') + arr.append('3') + + encoded = _encode_qbytearray(arr) + assert isinstance(encoded, string_types) + decoded = _decode_qbytearray(encoded) + assert arr == decoded + + # Test JSON serialization of QByteArray. + d = {'arr': arr} + path = op.join(tempdir, 'test') + _save_json(path, d) + d_bis = _load_json(path) + assert d == d_bis + + def test_json_simple(tempdir): - d = {'a': 1, 'b': 'bb', 3: '33'} + d = {'a': 1, 'b': 'bb', 3: '33', 'mock': {'mock': True}} path = op.join(tempdir, 'test') _save_json(path, d) @@ -53,6 +78,23 @@ def test_json_numpy(tempdir): assert d['b'] == d_bis['b'] +def test_read_python(tempdir): + path = op.join(tempdir, 'mock.py') + with open(path, 'w') as f: + f.write("""a = {'b': 1}""") + + assert _read_python(path) == {'a': {'b': 1}} + + +def test_write_text(tempdir): + for path in (op.join(tempdir, 'test_1'), + op.join(tempdir, 'test_dir/test_2.txt'), + ): + _write_text(path, 'hello world') + with open(path, 'r') as f: + assert f.read() == 'hello world' + + def test_git_version(): v = _git_version() @@ -63,6 +105,6 @@ def test_git_version(): subprocess.check_output(['git', '-C', filedir, 'status'], stderr=fnull) assert v is not "", "git_version failed to return" - assert v[:6] == "-git-v", "Git version does not begin in -git-v" - except (OSError, subprocess.CalledProcessError): + assert v[:5] == "-git-", "Git version does not begin in -git-" + except (OSError, subprocess.CalledProcessError): # pragma: no cover assert v == "" diff --git a/phy/utils/tests/test_plugin.py b/phy/utils/tests/test_plugin.py new file mode 100644 index 000000000..b1f3994d6 --- /dev/null +++ b/phy/utils/tests/test_plugin.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +"""Test plugin system.""" + + +#------------------------------------------------------------------------------ +# Imports +#------------------------------------------------------------------------------ + +import os.path as op + +from pytest import yield_fixture, raises + +from ..plugin import (IPluginRegistry, + IPlugin, + get_plugin, + discover_plugins, + ) +from .._misc import _write_text + + +#------------------------------------------------------------------------------ +# Fixtures +#------------------------------------------------------------------------------ + +@yield_fixture +def no_native_plugins(): + # Save the plugins. + plugins = IPluginRegistry.plugins + IPluginRegistry.plugins = [] + yield + IPluginRegistry.plugins = plugins + + +#------------------------------------------------------------------------------ +# Tests +#------------------------------------------------------------------------------ + +def test_plugin_1(no_native_plugins): + class MyPlugin(IPlugin): + pass + + assert IPluginRegistry.plugins == [MyPlugin] + assert get_plugin('MyPlugin').__name__ == 'MyPlugin' + + with raises(ValueError): + get_plugin('unknown') + + +def test_discover_plugins(tempdir, no_native_plugins): + path = op.join(tempdir, 'my_plugin.py') + contents = '''from phy import IPlugin\nclass MyPlugin(IPlugin): pass''' + _write_text(path, contents) + + plugins = discover_plugins([tempdir]) + assert plugins + assert plugins[0].__name__ == 'MyPlugin' diff --git a/phy/utils/tests/test_selector.py b/phy/utils/tests/test_selector.py deleted file mode 100644 index 5b4af745e..000000000 --- a/phy/utils/tests/test_selector.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test selector.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import numpy as np -from numpy.testing import assert_array_equal as ae - -from ...io.mock import artificial_spike_clusters -from ..array import _spikes_in_clusters -from ..selector import Selector - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_selector_spikes(): - """Test selecting spikes.""" - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - selector = Selector(spike_clusters) - selector.on_cluster() - assert selector.n_spikes_max is None - selector.n_spikes_max = None - ae(selector.selected_spikes, []) - - # Select a few spikes. - my_spikes = [10, 20, 30] - selector.selected_spikes = my_spikes - ae(selector.selected_spikes, my_spikes) - - # Check selected clusters. - ae(selector.selected_clusters, np.unique(spike_clusters[my_spikes])) - - # Specify a maximum number of spikes. - selector.n_spikes_max = 3 - assert selector.n_spikes_max is 3 - my_spikes = [10, 20, 30, 40] - selector.selected_spikes = my_spikes[:3] - ae(selector.selected_spikes, my_spikes[:3]) - selector.selected_spikes = my_spikes - assert len(selector.selected_spikes) <= 3 - assert selector.n_spikes == len(selector.selected_spikes) - assert np.all(np.in1d(selector.selected_spikes, my_spikes)) - - # Check that this doesn't raise any error. - selector.selected_clusters = [100] - selector.selected_spikes = [] - - -def test_selector_clusters(): - """Test selecting clusters.""" - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - selector = Selector(spike_clusters) - selector.selected_clusters = [] - ae(selector.selected_spikes, []) - - # Select 1 cluster. - selector.selected_clusters = [0] - ae(selector.selected_spikes, _spikes_in_clusters(spike_clusters, [0])) - assert np.all(spike_clusters[selector.selected_spikes] == 0) - - # Select 2 clusters. - selector.selected_clusters = [1, 3] - ae(selector.selected_spikes, _spikes_in_clusters(spike_clusters, [1, 3])) - assert np.all(np.in1d(spike_clusters[selector.selected_spikes], (1, 3))) - assert selector.n_clusters == 2 - - # Specify a maximum number of spikes. - selector.n_spikes_max = 10 - selector.selected_clusters = [4, 2] - assert len(selector.selected_spikes) <= (10 * 2) - assert selector.selected_clusters == [4, 2] - assert np.all(np.in1d(spike_clusters[selector.selected_spikes], (2, 4))) - - # Reduce the number of maximum spikes: the selection should update - # accordingly. - selector.n_spikes_max = 5 - assert len(selector.selected_spikes) <= 5 - assert np.all(np.in1d(spike_clusters[selector.selected_spikes], (2, 4))) - - -def test_selector_subset(): - n_spikes = 1000 - n_clusters = 10 - spike_clusters = artificial_spike_clusters(n_spikes, n_clusters) - - selector = Selector(spike_clusters) - selector.subset_spikes(excerpt_size=10) - selector.subset_spikes(np.arange(n_spikes), excerpt_size=10) - - -def test_selector_subset_clusters(): - n_spikes = 100 - spike_clusters = np.zeros(n_spikes, dtype=np.int32) - spike_clusters[10:15] = 1 - spike_clusters[85:90] = 1 - - selector = Selector(spike_clusters) - spc = selector.subset_spikes_clusters([0, 1], excerpt_size=10) - counts = {_: len(spc[_]) for _ in sorted(spc.keys())} - # TODO - assert counts diff --git a/phy/utils/tests/test_settings.py b/phy/utils/tests/test_settings.py deleted file mode 100644 index f1dc7f2d0..000000000 --- a/phy/utils/tests/test_settings.py +++ /dev/null @@ -1,153 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test settings.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -import os.path as op - -from pytest import raises - -from ..settings import (BaseSettings, - Settings, - _load_default_settings, - _recursive_dirs, - ) - - -#------------------------------------------------------------------------------ -# Test settings -#------------------------------------------------------------------------------ - -def test_recursive_dirs(): - dirs = list(_recursive_dirs()) - assert len(dirs) >= 5 - root = op.join(op.realpath(op.dirname(__file__)), '../../') - for dir in dirs: - dir = op.relpath(dir, root) - assert '.' not in dir - assert '_' not in dir - - -def test_load_default_settings(): - settings = _load_default_settings() - keys = settings.keys() - assert 'log_file_level' in keys - assert 'on_open' in keys - assert 'spikedetekt' in keys - assert 'klustakwik2' in keys - assert 'traces' in keys - assert 'cluster_manual_config' in keys - - -def test_base_settings(): - s = BaseSettings() - - # Namespaces are mandatory. - with raises(KeyError): - s['a'] - - s['a'] = 3 - assert s['a'] == 3 - - -def test_user_settings(tempdir): - path = op.join(tempdir, 'test.py') - - # Create a simple settings file. - contents = '''a = 4\nb = 5\nd = {'k1': 2, 'k2': 3}\n''' - with open(path, 'w') as f: - f.write(contents) - - s = BaseSettings() - - s['a'] = 3 - s['c'] = 6 - assert s['a'] == 3 - - # Now, set the settings file. - s.load(path=path) - assert s['a'] == 4 - assert s['b'] == 5 - assert s['c'] == 6 - assert s['d'] == {'k1': 2, 'k2': 3} - - s = BaseSettings() - s['d'] = {'k2': 30, 'k3': 40} - s.load(path=path) - assert s['d'] == {'k1': 2, 'k2': 3, 'k3': 40} - - -def test_internal_settings(tempdir): - path = op.join(tempdir, 'test') - - s = BaseSettings() - - # Set the 'test' namespace. - s['a'] = 3 - s['c'] = 6 - assert s['a'] == 3 - assert s['c'] == 6 - - s.save(path) - assert s['a'] == 3 - assert s['c'] == 6 - - s = BaseSettings() - with raises(KeyError): - s['a'] - - s.load(path) - assert s['a'] == 3 - assert s['c'] == 6 - - -def test_settings_manager(tempdir, tempdir_bis): - tempdir_exp = tempdir_bis - sm = Settings(tempdir) - - # Check paths. - assert sm.phy_user_dir == tempdir - assert sm.internal_settings_path == op.join(tempdir, - 'internal_settings') - assert sm.user_settings_path == op.join(tempdir, 'user_settings.py') - - # User settings. - with raises(KeyError): - sm['a'] - # Artificially populate the user settings. - sm._bs._store['a'] = 3 - assert sm['a'] == 3 - - # Internal settings. - sm['c'] = 5 - assert sm['c'] == 5 - - # Set an experiment path. - path = op.join(tempdir_exp, 'myexperiment.dat') - sm.on_open(path) - assert op.realpath(sm.exp_path) == op.realpath(path) - assert sm.exp_name == 'myexperiment' - assert (op.realpath(sm.exp_settings_dir) == - op.realpath(op.join(tempdir_exp, 'myexperiment.phy'))) - assert (op.realpath(sm.exp_settings_path) == - op.realpath(op.join(tempdir_exp, 'myexperiment.phy/' - 'user_settings.py'))) - - # User settings. - assert sm['a'] == 3 - sm._bs._store['a'] = 30 - assert sm['a'] == 30 - - # Internal settings. - sm['c'] = 50 - assert sm['c'] == 50 - - # Check persistence. - sm.save() - sm = Settings(tempdir) - sm.on_open(path) - assert sm['c'] == 50 - assert 'a' not in sm diff --git a/phy/utils/tests/test_testing.py b/phy/utils/tests/test_testing.py index c443a3135..68a90faee 100644 --- a/phy/utils/tests/test_testing.py +++ b/phy/utils/tests/test_testing.py @@ -6,7 +6,19 @@ # Imports #------------------------------------------------------------------------------ -from ..testing import captured_output +from copy import deepcopy +import logging +import os.path as op +import time + +import numpy as np +from pytest import mark +from vispy.app import Canvas + +from ..testing import (benchmark, captured_output, captured_logging, show_test, + _assert_equal, _enable_profiler, _profile, + show_colored_canvas, + ) #------------------------------------------------------------------------------ @@ -17,3 +29,41 @@ def test_captured_output(): with captured_output() as (out, err): print('Hello world!') assert out.getvalue().strip() == 'Hello world!' + + +def test_captured_logging(): + logger = logging.getLogger() + handlers = logger.handlers + with captured_logging() as buf: + logger.debug('Hello world!') + assert 'Hello world!' in buf.getvalue() + assert logger.handlers == handlers + + +def test_assert_equal(): + d = {'a': {'b': np.random.rand(5), 3: 'c'}, 'b': 2.} + d_bis = deepcopy(d) + d_bis['a']['b'] = d_bis['a']['b'] + 1e-10 + _assert_equal(d, d_bis) + + +def test_benchmark(): + with benchmark(): + time.sleep(.002) + + +def test_canvas(): + c = Canvas(keys='interactive') + show_test(c) + + +def test_show_colored_canvas(): + show_colored_canvas((.6, 0, .8)) + + +@mark.parametrize('line_by_line', [False, True]) +def test_profile(chdir_tempdir, line_by_line): + # Remove the profile from the builtins. + prof = _enable_profiler(line_by_line=line_by_line) + _profile(prof, 'import time; time.sleep(.001)', {}, {}) + assert op.exists(op.join(chdir_tempdir, '.profile', 'stats.txt')) diff --git a/phy/utils/tests/test_types.py b/phy/utils/tests/test_types.py index bbb7fdf3d..a7d0006a9 100644 --- a/phy/utils/tests/test_types.py +++ b/phy/utils/tests/test_types.py @@ -7,8 +7,12 @@ #------------------------------------------------------------------------------ import numpy as np +from pytest import raises -from .._types import Bunch, _is_integer +from .._types import (Bunch, _bunchify, _is_integer, _is_list, _is_float, + _as_list, _is_array_like, _as_array, _as_tuple, + _as_scalar, + ) #------------------------------------------------------------------------------ @@ -21,9 +25,77 @@ def test_bunch(): assert obj.a == 1 obj.b = 2 assert obj['b'] == 2 + assert obj.copy() == obj -def test_integer(): +def test_bunchify(): + d = {'a': {'b': 0}} + b = _bunchify(d) + assert isinstance(b, Bunch) + assert isinstance(b['a'], Bunch) + + +def test_number(): + assert not _is_integer(None) + assert not _is_integer(3.) assert _is_integer(3) assert _is_integer(np.arange(1)[0]) - assert not _is_integer(3.) + + assert not _is_float(None) + assert not _is_float(3) + assert not _is_float(np.array([3])[0]) + assert _is_float(3.) + assert _is_float(np.array([3.])[0]) + + +def test_list(): + assert not _is_list(None) + assert not _is_list(()) + assert _is_list([]) + + assert _as_list(None) is None + assert _as_list(3) == [3] + assert _as_list([3]) == [3] + assert _as_list((3,)) == [3] + assert _as_list('3') == ['3'] + assert np.all(_as_list(np.array([3])) == np.array([3])) + + +def test_as_tuple(): + assert _as_tuple(3) == (3,) + assert _as_tuple((3,)) == (3,) + assert _as_tuple(None) is None + assert _as_tuple((None,)) == (None,) + assert _as_tuple((3, 4)) == (3, 4) + assert _as_tuple([3]) == ([3], ) + assert _as_tuple([3, 4]) == ([3, 4], ) + + +def test_as_scalar(): + assert _as_scalar(1) == 1 + assert _as_scalar(np.ones(1)[0]) == 1. + assert type(_as_scalar(np.ones(1)[0])) == float + + +def test_array(): + def _check(arr): + assert isinstance(arr, np.ndarray) + assert np.all(arr == [3]) + + _check(_as_array(3)) + _check(_as_array(3.)) + _check(_as_array([3])) + + _check(_as_array(3, np.float)) + _check(_as_array(3., np.float)) + _check(_as_array([3], np.float)) + _check(_as_array(np.array([3]))) + with raises(ValueError): + _check(_as_array(np.array([3]), dtype=np.object)) + _check(_as_array(np.array([3]), np.float)) + + assert _as_array(None) is None + assert not _is_array_like(None) + assert not _is_array_like(3) + assert _is_array_like([3]) + assert _is_array_like(np.array([3])) diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index cf6b0c3e4..000000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -norecursedirs = experimental diff --git a/requirements-dev.txt b/requirements-dev.txt index d616703b5..d4e0c5384 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,9 +2,12 @@ # https://github.com/pytest-dev/pytest/issues/744 git+https://github.com/pytest-dev/pytest.git +# Need development version of pytest-qt for qtbot.wait() method +git+https://github.com/pytest-dev/pytest-qt.git + flake8 -coverage +coverage==3.7.1 coveralls responses pytest-cov -pytest-qt +nose diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index ddd42871c..000000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -wheel==0.23.0 diff --git a/setup.cfg b/setup.cfg index e53fb1af0..677afbc9a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,7 @@ universal = 1 [pytest] -addopts = --cov-report term-missing --cov phy -s +norecursedirs = experimental _* + [flake8] -ignore=E265 +ignore=E265,E731 diff --git a/setup.py b/setup.py index 7b03d24c0..2554c9f38 100644 --- a/setup.py +++ b/setup.py @@ -50,11 +50,12 @@ def _package_tree(pkgroot): packages=_package_tree('phy'), package_dir={'phy': 'phy'}, package_data={ - 'phy': ['*.vert', '*.frag', '*.glsl', '*.html', '*.css', '*.prb'], + 'phy': ['*.vert', '*.frag', '*.glsl', '*.npy', '*.gz', '*.txt', + '*.html', '*.css', '*.js', '*.prb'], }, entry_points={ 'console_scripts': [ - 'phy=phy.scripts.phy_script:main', + 'phy = phy.utils.cli:phy' ], }, include_package_data=True, diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 5788a4a03..000000000 --- a/tests/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- coding: utf-8 -*- -"""Integration and functional tests for phy.""" diff --git a/tests/scripts/test_phy_spikesort.py b/tests/scripts/test_phy_spikesort.py deleted file mode 100644 index df4c5520f..000000000 --- a/tests/scripts/test_phy_spikesort.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*-1 - -"""Tests of phy spike sorting commands.""" - -#------------------------------------------------------------------------------ -# Imports -#------------------------------------------------------------------------------ - -from phy.scripts import main - - -#------------------------------------------------------------------------------ -# Tests -#------------------------------------------------------------------------------ - -def test_version(): - main('-v') - - -def test_cluster_auto_prm(chdir_tempdir): - main('download hybrid_10sec.dat') - main('download hybrid_10sec.prm') - main('detect hybrid_10sec.prm') - main('cluster-auto hybrid_10sec.prm --channel-group=0') - - -def test_quick_start(chdir_tempdir): - main('download hybrid_10sec.dat') - main('download hybrid_10sec.prm') - main('spikesort hybrid_10sec.prm') - # TODO: implement auto-close - # main('cluster-manual hybrid_10sec.kwik') - - -# def test_traces(chdir_tempdir): - # TODO: implement auto-close - # main('download hybrid_10sec.dat') - # main('traces --n-channels=32 --dtype=int16 ' - # '--sample-rate=20000 --interval=0,3 hybrid_10sec.dat')