From 5f8d4109ebf15f2004b396e2564417634bb7f5ca Mon Sep 17 00:00:00 2001 From: Gianluca Ficarelli Date: Tue, 30 Jan 2024 11:32:14 +0100 Subject: [PATCH] Add CLI to convert and import inferred spikes in CSV format --- CHANGELOG.rst | 9 + src/blueetl/apps/convert.py | 40 +++ src/blueetl/apps/main.py | 2 + src/blueetl/converters/__init__.py | 1 + src/blueetl/converters/convert_spikes.py | 342 +++++++++++++++++++ src/blueetl/utils.py | 22 ++ tests/unit/apps/test_convert.py | 29 ++ tests/unit/apps/test_main.py | 1 + tests/unit/converters/__init__.py | 0 tests/unit/converters/test_convert_spikes.py | 49 +++ tests/unit/test_utils.py | 29 ++ 11 files changed, 524 insertions(+) create mode 100644 src/blueetl/apps/convert.py create mode 100644 src/blueetl/converters/__init__.py create mode 100644 src/blueetl/converters/convert_spikes.py create mode 100644 tests/unit/apps/test_convert.py create mode 100644 tests/unit/converters/__init__.py create mode 100644 tests/unit/converters/test_convert_spikes.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f8ed13a..fafb74a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,15 @@ Changelog ========= +Version 0.5.0 +------------- + +New Features +~~~~~~~~~~~~ + +- Add CLI to convert and import inferred spikes in CSV format. + + Version 0.4.4 ------------- diff --git a/src/blueetl/apps/convert.py b/src/blueetl/apps/convert.py new file mode 100644 index 0000000..4b730a7 --- /dev/null +++ b/src/blueetl/apps/convert.py @@ -0,0 +1,40 @@ +"""Convert CLI.""" + +import logging +from pathlib import Path + +import click + +from blueetl.converters.convert_spikes import main +from blueetl.utils import setup_logging + + +@click.command() +@click.argument("input-file", type=click.Path(exists=True, path_type=Path)) +@click.argument("output-dir", type=click.Path(exists=False, path_type=Path)) +@click.option( + "--node-population", + help="Name of the node population to create.", + default="synthetic", + show_default=True, +) +@click.option("-v", "--verbose", count=True, help="-v for INFO, -vv for DEBUG") +def convert_spikes(input_file, output_dir, node_population, verbose): + """Convert spikes in CSV format. + + Read a CSV file containing the spikes, and write synthetic files to be used with BlueETL. + + The input file should contain: + + \b + - headers: node_ids timestamps (or: ids times) + - values: space separated + """ + loglevel = (logging.WARNING, logging.INFO, logging.DEBUG)[min(verbose, 2)] + setup_logging(loglevel=loglevel, force=True) + main( + input_file=input_file, + output_dir=output_dir, + node_population=node_population, + ) + click.secho("Conversion successful.", fg="green") diff --git a/src/blueetl/apps/main.py b/src/blueetl/apps/main.py index 3ef390d..09c3fbf 100644 --- a/src/blueetl/apps/main.py +++ b/src/blueetl/apps/main.py @@ -3,6 +3,7 @@ import click from blueetl import __version__ +from blueetl.apps.convert import convert_spikes from blueetl.apps.migrate import migrate_config from blueetl.apps.run import run from blueetl.apps.validate import validate_config @@ -25,3 +26,4 @@ def cli(): cli.add_command(run) cli.add_command(migrate_config) cli.add_command(validate_config) +cli.add_command(convert_spikes) diff --git a/src/blueetl/converters/__init__.py b/src/blueetl/converters/__init__.py new file mode 100644 index 0000000..f0e518a --- /dev/null +++ b/src/blueetl/converters/__init__.py @@ -0,0 +1 @@ +"""Conversion utilities.""" diff --git a/src/blueetl/converters/convert_spikes.py b/src/blueetl/converters/convert_spikes.py new file mode 100644 index 0000000..d7f8b01 --- /dev/null +++ b/src/blueetl/converters/convert_spikes.py @@ -0,0 +1,342 @@ +"""Convert a spikes file by generating synthetic circuit, simulation, and simulation campaign.""" + +import dataclasses +import logging +import math +from pathlib import Path + +import h5py +import numpy as np +import pandas as pd + +from blueetl.utils import dump_json, dump_yaml, import_optional_dependency, relpath, resolve_path + +L = logging.getLogger(__name__) + +TIMESTAMPS = "timestamps" +NODE_IDS = "node_ids" +DTYPES = { + TIMESTAMPS: np.float64, + NODE_IDS: np.uint64, +} + + +@dataclasses.dataclass +class DataStats: + """Statistics on the imported spikes.""" + + rows: int + unique_ids: int + min_id: int + max_id: int + min_timestamp: float + max_timestamp: float + + +@dataclasses.dataclass +class OutputPaths: + """Output paths.""" + + base: Path + + @property + def circuit_path(self) -> Path: + """Return the circuit_path.""" + return self.base / "circuit" + + @property + def simulation_path(self) -> Path: + """Return the simulation_path.""" + return self.base / "simulation" + + @property + def spikes_path(self) -> Path: + """Return the spikes_path.""" + return self.simulation_path / "synthetic_spikes.h5" + + @property + def nodes_path(self) -> Path: + """Return the nodes_path.""" + return self.circuit_path / "synthetic_nodes.h5" + + @property + def node_sets_path(self) -> Path: + """Return the node_sets_path.""" + return self.circuit_path / "node_sets.json" + + @property + def circuit_config_path(self) -> Path: + """Return the circuit_config_path.""" + return self.circuit_path / "circuit_config.json" + + @property + def simulation_config_path(self) -> Path: + """Return the simulation_config_path.""" + return self.simulation_path / "simulation_config.json" + + @property + def simulation_campaign_config_path(self) -> Path: + """Return the simulation_campaign_config_path.""" + return self.base / "simulation_campaign_config.json" + + @property + def analysis_config_path(self) -> Path: + """Return the analysis_config_path.""" + return self.base / "analysis_config.yaml" + + def mkdirs(self): + """Create the directories.""" + self.base.mkdir(exist_ok=True, parents=True) + self.circuit_path.mkdir(exist_ok=True, parents=True) + self.simulation_path.mkdir(exist_ok=True, parents=True) + + +def _load_csv(path: Path, sep: str = " ", **kwargs) -> pd.DataFrame: + """Load and sort spikes data from csv file. + + Accepted column names: timestamps or times, node_ids or ids. + """ + columns = [TIMESTAMPS, NODE_IDS] + valid_columns = {*columns, "ids", "times"} + df = pd.read_csv(path, sep=sep, usecols=lambda x: x in valid_columns, **kwargs) + df = df.rename(columns={"ids": NODE_IDS, "times": TIMESTAMPS}) + if missing := set(columns).difference(df.columns): + raise ValueError(f"Missing columns in the CSV file: {missing}") + df = df[columns].sort_values(columns) + L.info("Loaded file %s", path) + return df + + +def _get_data_stats(df: pd.DataFrame) -> DataStats: + """Calculate statistics on the imported spikes.""" + unique_ids = df[NODE_IDS].drop_duplicates() + stats = DataStats( + rows=len(df), + unique_ids=len(unique_ids), + min_id=unique_ids.min(), + max_id=unique_ids.max(), + min_timestamp=df[TIMESTAMPS].min(), + max_timestamp=df[TIMESTAMPS].max(), + ) + L.info( + "CSV stats: rows=%s, unique_ids=%s, min_id=%s, max_id=%s, min_ts=%s, max_ts=%s", + stats.rows, + stats.unique_ids, + stats.min_id, + stats.max_id, + stats.min_timestamp, + stats.max_timestamp, + ) + return stats + + +def _write_spikes( + path: Path, node_population: str, timestamps: np.ndarray, node_ids: np.ndarray +) -> None: + """Write a spikes file in SONATA format.""" + L.info("Writing %s", path) + sorting_type = h5py.enum_dtype({"none": 0, "by_id": 1, "by_time": 2}) + with h5py.File(path, "w") as h5: + root = h5.create_group("spikes") + pop = root.create_group(node_population) + pop.attrs.create("sorting", data=2, dtype=sorting_type) + ts_dataset = pop.create_dataset(TIMESTAMPS, data=timestamps, dtype=DTYPES[TIMESTAMPS]) + ts_dataset.attrs.create("units", "ms") + pop.create_dataset(NODE_IDS, data=node_ids, dtype=DTYPES[NODE_IDS]) + + +def _write_circuit(path: Path, node_population: str, size: int) -> None: + """Write a synthetic empty circuit in SONATA format.""" + L.info("Writing %s", path) + voxcell = import_optional_dependency("voxcell") + # CellCollection uses 1-based ids, since it predates SONATA + nodes = pd.DataFrame(index=range(1, size + 1)) + nodes["_"] = np.zeros(size, dtype=np.int8) + cc = voxcell.CellCollection.from_dataframe(nodes) + cc.population_name = node_population + cc.save(str(path)) + + +def _write_node_sets(path: Path, node_population: str) -> None: + """Write a partial node_sets_file.""" + L.info("Writing %s", path) + node_sets = { + "empty": { + "population": node_population, + "node_id": [], + } + } + dump_json(path, node_sets) + + +def _write_circuit_config( + path: Path, node_sets_path: Path, nodes_path: Path, node_population: str +) -> None: + """Write a partial circuit config in SONATA format.""" + L.info("Writing %s", path) + nodes_file = str(relpath(nodes_path, path.parent)) + node_sets_file = str(relpath(node_sets_path, path.parent)) + circuit_config = { + "version": "2.4", + "metadata": {"status": "partial"}, + "node_sets_file": node_sets_file, + "networks": { + "nodes": [ + { + "nodes_file": nodes_file, + "populations": {node_population: {}}, + } + ], + "edges": [], + }, + } + dump_json(path, circuit_config) + + +def _write_simulation_config( + path: Path, circuit_config_path: Path, spikes_path: Path, tstop: float +) -> None: + """Write a simulation config in SONATA format.""" + L.info("Writing %s", path) + network = str(relpath(circuit_config_path, path.parent)) + output_dir = str(relpath(spikes_path.parent, path.parent)) + spikes_file = spikes_path.name + simulation_config = { + "version": "2.4", + "network": network, + "run": { + "tstop": tstop, + "dt": 1.0, + "random_seed": 0, + }, + "output": { + "output_dir": output_dir, + "spikes_file": spikes_file, + }, + "metadata": { + "note": "Synthetic simulation", + }, + } + dump_json(path, simulation_config) + + +def _write_simulation_campaign_config( + path: Path, circuit_config_path: Path, simulation_config_path: Path +) -> None: + """Write a simulation campaign config in BlueETL format, having only one simulation.""" + L.info("Writing %s", path) + circuit_config = str(resolve_path(circuit_config_path)) + path_prefix = str(resolve_path(simulation_config_path.parents[1])) + simulation_path = str(relpath(simulation_config_path, path_prefix)) + simulation_campaign_config = { + "format": "blueetl", + "version": 1, + "name": "synthetic", + "attrs": { + "path_prefix": path_prefix, + "circuit_config": circuit_config, + "__coupled__": "coupled", + }, + "data": [ + {"simulation_path": simulation_path}, + ], + } + dump_yaml(path, simulation_campaign_config) + + +def _write_analysis_config(path: Path, simulation_campaign_config_path: Path) -> None: + """Write an analysis config in BlueETL format.""" + L.info("Writing %s", path) + simulation_campaign = str(relpath(simulation_campaign_config_path, path.parent)) + analysis_config = { + "version": 3, + "simulation_campaign": simulation_campaign, + "output": "analysis", + "analysis": { + "spikes": { + "extraction": { + "report": {"type": "spikes"}, + "neuron_classes": { + "all": {}, + }, + "limit": None, + "population": "synthetic", + "node_set": None, + "windows": { + "w1": {"bounds": [0.0, 1000.0]}, + "w2": {"bounds": [1000.0, 2000.0]}, + }, + }, + "features": [ + { + "type": "multi", + "groupby": ["simulation_id", "circuit_id", "neuron_class", "window"], + "function": ( + "blueetl.external.bnac.calculate_features.calculate_features_multi" + ), + "params": {"export_all_neurons": True}, + } + ], + } + }, + } + dump_yaml(path, analysis_config) + + +def main(input_file: Path, output_dir: Path, node_population: str) -> None: + """Read a CSV file containing the spikes, and write synthetic files to be used with BlueETL. + + Expected output files: + + ├── analysis_config.yaml + ├── circuit + │ ├── circuit_config.json + │ ├── node_sets.json + │ └── synthetic_nodes.h5 + ├── simulation + │ ├── simulation_config.json + │ └── synthetic_spikes.h5 + └── simulation_campaign_config.json + + """ + paths = OutputPaths(base=output_dir) + paths.mkdirs() + df = _load_csv(input_file) + stats = _get_data_stats(df) + + _write_circuit( + paths.nodes_path, + node_population=node_population, + size=stats.max_id + 1, + ) + _write_node_sets( + paths.node_sets_path, + node_population=node_population, + ) + _write_circuit_config( + paths.circuit_config_path, + nodes_path=paths.nodes_path, + node_sets_path=paths.node_sets_path, + node_population=node_population, + ) + _write_spikes( + paths.spikes_path, + node_population=node_population, + timestamps=df[TIMESTAMPS], + node_ids=df[NODE_IDS], + ) + _write_simulation_config( + paths.simulation_config_path, + circuit_config_path=paths.circuit_config_path, + spikes_path=paths.spikes_path, + tstop=float(math.ceil(stats.max_timestamp)), + ) + _write_simulation_campaign_config( + paths.simulation_campaign_config_path, + circuit_config_path=paths.circuit_config_path, + simulation_config_path=paths.simulation_config_path, + ) + _write_analysis_config( + paths.analysis_config_path, + simulation_campaign_config_path=paths.simulation_campaign_config_path, + ) diff --git a/src/blueetl/utils.py b/src/blueetl/utils.py index 7f853c8..679341d 100644 --- a/src/blueetl/utils.py +++ b/src/blueetl/utils.py @@ -71,6 +71,28 @@ def dump_yaml(filepath: StrOrPath, data: Any, **kwargs) -> None: yaml.dump(data, stream=f, sort_keys=False, Dumper=_get_internal_yaml_dumper(), **kwargs) +def load_json(filepath: StrOrPath, *, encoding: str = "utf-8", **kwargs) -> Any: + """Load from JSON file.""" + with open(filepath, encoding=encoding) as f: + return json.load(f, **kwargs) + + +def dump_json( + filepath: StrOrPath, data: Any, *, encoding: str = "utf-8", indent: int = 2, **kwargs +) -> None: + """Dump to JSON file.""" + with open(filepath, mode="w", encoding=encoding) as fp: + json.dump(data, fp, indent=indent, **kwargs) + + +def relpath(path: StrOrPath, start: StrOrPath) -> Path: + """Return a relative filepath to path from the start directory. + + In Python>=3.12 it would be possible to use ``Path.relative_to`` with walk_up=True. + """ + return Path(os.path.relpath(path, start=start)) + + def ensure_list(x: Any) -> list: """Always return a list from the given argument.""" if isinstance(x, list): diff --git a/tests/unit/apps/test_convert.py b/tests/unit/apps/test_convert.py new file mode 100644 index 0000000..1c56370 --- /dev/null +++ b/tests/unit/apps/test_convert.py @@ -0,0 +1,29 @@ +from pathlib import Path +from unittest.mock import patch + +from click.testing import CliRunner + +from blueetl.apps import convert as test_module + + +@patch(test_module.__name__ + ".main") +def test_validate_config_success(mock_main, tmp_path): + input_file = tmp_path / "spikes.csv" + output_dir = tmp_path / "output" + runner = CliRunner() + + with runner.isolated_filesystem(temp_dir=tmp_path) as td: + input_file.write_text("ids times\n1 0.0") + result = runner.invoke( + test_module.convert_spikes, + [ + str(input_file), + str(output_dir), + "--node-population", + "custom", + ], + ) + + assert result.output.strip() == "Conversion successful." + assert result.exit_code == 0 + assert mock_main.call_count == 1 diff --git a/tests/unit/apps/test_main.py b/tests/unit/apps/test_main.py index cf7c8c1..7a73dec 100644 --- a/tests/unit/apps/test_main.py +++ b/tests/unit/apps/test_main.py @@ -28,6 +28,7 @@ def test_help(): run Run the analysis. migrate-config Migrate a configuration file. validate-config Validate a configuration file. + convert-spikes Convert spikes in CSV format. """ runner = CliRunner() diff --git a/tests/unit/converters/__init__.py b/tests/unit/converters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/converters/test_convert_spikes.py b/tests/unit/converters/test_convert_spikes.py new file mode 100644 index 0000000..b545c78 --- /dev/null +++ b/tests/unit/converters/test_convert_spikes.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +from bluepysnap import Simulation +from pandas.testing import assert_frame_equal, assert_series_equal + +from blueetl.campaign.config import SimulationCampaign +from blueetl.converters import convert_spikes as test_module + + +def test_main(tmp_path): + population = "custom" + input_file = tmp_path / "spikes.csv" + output_dir = tmp_path / "output" + input_file.write_text("ids times\n1 10.0\n1 10.1\n2 10.1\n2 20.0\n5 5.0") + test_module.main(input_file, output_dir, node_population=population) + + assert (output_dir / "analysis_config.yaml").is_file() + assert (output_dir / "circuit" / "circuit_config.json").is_file() + assert (output_dir / "circuit" / "node_sets.json").is_file() + assert (output_dir / "circuit" / "synthetic_nodes.h5").is_file() + assert (output_dir / "simulation" / "simulation_config.json").is_file() + assert (output_dir / "simulation" / "synthetic_spikes.h5").is_file() + assert (output_dir / "simulation_campaign_config.json").is_file() + + campaign = SimulationCampaign.load(output_dir / "simulation_campaign_config.json") + assert len(campaign) == 1 + simulation = Simulation(campaign[0].path) + assert ( + Path(simulation._simulation_config_path).resolve() + == (output_dir / "simulation" / "simulation_config.json").resolve() + ) + assert ( + Path(simulation.circuit._circuit_config_path).resolve() + == (output_dir / "circuit" / "circuit_config.json").resolve() + ) + + expected_nodes = pd.DataFrame( + {"_": np.zeros(6, dtype=np.int8)}, + index=pd.RangeIndex(6, name="node_ids"), + ) + expected_spikes = pd.Series( + [5, 1, 1, 2, 2], + index=pd.Index([5.0, 10.0, 10.1, 10.1, 20.0], name="times"), + name="ids", + ) + assert_frame_equal(simulation.circuit.nodes[population].get(), expected_nodes) + assert_series_equal(simulation.spikes[population].get(), expected_spikes) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 9054637..c59cb13 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -150,6 +150,35 @@ def test_load_yaml(tmp_path): assert loaded_data == expected +def test_dump_jaon_load_json_roundtrip(tmp_path): + data = { + "dict": {"str": "mystr", "int": 123}, + "list_of_int": [1, 2, 3], + "list_of_str": ["1", "2", "3"], + "path": "/custom/path", + } + filepath = tmp_path / "test.json" + + test_module.dump_json(filepath, data, indent=None) + loaded_data = test_module.load_json(filepath) + + assert loaded_data == data + + +@pytest.mark.parametrize( + "path, start, expected", + [ + ("path/1", "path/2", "../1"), + ("/path/1", "/path/2", "../1"), + ("/path/to/1", "/path/2", "../to/1"), + ("/path/1", "/path/to/2", "../../1"), + ], +) +def test_relpath(path, start, expected): + result = test_module.relpath(path, start=start) + assert result == Path(expected) + + @pytest.mark.parametrize( "d, expected", [