Skip to content

Commit

Permalink
Various minor code cleanups (#9)
Browse files Browse the repository at this point in the history
* Update README

* Various minor code cleanups

---------

Co-authored-by: Reinder Vos de Wael <[email protected]>
  • Loading branch information
ReinderVosDeWael and ReinderVosDeWael authored Aug 3, 2023
1 parent 31c2012 commit 03e6db0
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 77 deletions.
42 changes: 31 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,51 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![L-GPL License](https://img.shields.io/badge/license-L--GPL-blue.svg)](LICENSE)

This is a command line interface (CLI) running BrainSpace on BIDS-compliant datasets. Gradients are computed for volumetric files in NIFTI format, or surface files in GIFTI format. For more details on BrainSpace, see the [BrainSpace documentation](https://brainspace.readthedocs.io/en/latest/).
This is a command line interface (CLI) for running BrainSpace on BIDS-compliant datasets. Gradients are computed for volumetric files in NIFTI format, or surface files in GIFTI format. For more details on BrainSpace, see the [BrainSpace documentation](https://brainspace.readthedocs.io/en/latest/).

## Installation

The recommended approaches for installing ba-timeseries-gradients is through Docker or PyPi. To build it as a Docker image, download the repository and run the following command from the repository root:
For local installation the recommended approach is through Poetry. To install through Poetry, run the following commands:

```bash
docker build -t ba_timeseries_gradients .
pip install poetry
poetry install
```

To install through PyPi, run the following command:
## Usage

The basic usage of the ba_timeseries_gradients CLI is as follows for Poetry-based installations:

```bash
pip install ba_timeseries_gradients
ba_timeseries_gradients [OPTIONS] BIDS_DIR OUTPUT_DIR ANALYSIS_LEVEL
```

## Usage
The `BIDS_DIR` is the path to the BIDS directory containing the dataset to be analyzed. The `OUTPUT_DIR` is the path to the directory where the output will be saved. The `ANALYSIS_LEVEL` is the level of analysis to be performed, which can currently only be `group`.

The basic usage of the ba_timeseries_gradients CLI is as follows, depending on whether you've installed through PyPi or Docker:
For a full list of options, see:

```bash
docker run --volume LOCAL_BIDS_DIR:BIDS_DIR --volume LOCAL_OUTPUT_DIR:OUTPUT_DIR ba_timeseries_gradients [OPTIONS] BIDS_DIR OUTPUT_DIR ANALYSIS_LEVEL
ba_timeseries_gradients [OPTIONS] BIDS_DIR OUTPUT_DIR ANALYSIS_LEVEL
ba_timeseries_gradients --help
```

The `BIDS_DIR` is the path to the BIDS directory containing the dataset to be analyzed. The `OUTPUT_DIR` is the path to the directory where the output will be saved. The `ANALYSIS_LEVEL` is the level of analysis to be performed, which can currently only be `group`.
It is highly recommended to include options to filter the dataset for specific files. See the BIDS arguments section in the help for more details.

For a complete list of options see `ba_timeseries_gradients -h`. It is highly recommended to include options to filter the dataset for specific files. See the BIDS arguments section in the help for more details.
You can also run the CLI through Docker. To do so, run the following command:

```bash
docker run \
--volume LOCAL_BIDS_DIR:BIDS_DIR \
--volume LOCAL_OUTPUT_DIR:OUTPUT_DIR \
ghcr.io/cmi-dair/ba-timeseries-gradients:main \
[OPTIONS] BIDS_DIR OUTPUT_DIR ANALYSIS_LEVEL
```

Similarly, the CLI can also be run in Singularity as follows:

```bash
singularity run \
--bind LOCAL_BIDS_DIR:BIDS_DIR \
--bind LOCAL_OUTPUT_DIR:OUTPUT_DIR \
docker://ghcr.io/cmi-dair/ba-timeseries-gradients:main \
[OPTIONS] BIDS_DIR OUTPUT_DIR ANALYSIS_LEVEL
```
23 changes: 16 additions & 7 deletions src/ba_timeseries_gradients/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def main():

logger.debug("Getting input files...")
files = _get_bids_files(args)
output_file = args.output_dir / ("gradients." + args.output_format)

if args.dry_run:
logger.info("Detected input files:\n%s", "\n".join(files))
logger.info("Output file: %s", output_file)
return

logger.debug("Checking input validity.")
_raise_invalid_input(args, files)
Expand All @@ -39,8 +45,8 @@ def main():
sparsity=args.sparsity,
)

logger.info("Saving gradient map to %s...", args.output_dir)
utils.save(output_gradients, lambdas, args)
logger.info("Saving gradient map to %s.", output_file)
utils.save(output_gradients, lambdas, output_file)


def _get_parser() -> argparse.ArgumentParser:
Expand Down Expand Up @@ -202,10 +208,16 @@ def _get_parser() -> argparse.ArgumentParser:
other_group.add_argument(
"--output_format",
required=False,
default="hdf5",
default="h5",
type=str,
help="Output file format",
choices=["hdf5", "json"],
choices=["h5", "json"],
)
other_group.add_argument(
"--dry-run",
required=False,
action="store_true",
help="Do not run the pipeline, only show what input files would be used. Note that dry run is logged at the logging.INFO level.",
)

return parser
Expand Down Expand Up @@ -238,9 +250,6 @@ def _raise_invalid_input(args: argparse.Namespace, files: list[str]) -> None:
"Must provide a parcellation if input files are volume files."
)

if args.output_format not in ("hdf5", "json"):
raise exceptions.InputError("Output format must be one of: 'hdf5', or 'json'.")


def _get_bids_files(args: argparse.Namespace) -> list[str]:
"""
Expand Down
6 changes: 1 addition & 5 deletions src/ba_timeseries_gradients/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class BaseLoggingError(Exception):

def __init__(self, message: str):
self.message = message
super().__init__(self.message)
logger.error(self.message)
super().__init__(self.message)


class InputError(BaseLoggingError):
Expand All @@ -23,7 +23,3 @@ class InputError(BaseLoggingError):

class InternalError(BaseLoggingError):
"""Exception raised when an internal error occurs. These should never happen."""


class BrainSpaceError(BaseLoggingError):
"""Exception raised when a BrainSpace error occurs."""
22 changes: 9 additions & 13 deletions src/ba_timeseries_gradients/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,14 @@ def compute_gradients(
connectivity_matrix = _get_connectivity_matrix(files, parcellation_file)

logger.info("Computing gradients...")
try:
gradient_map = gradient.GradientMaps(
n_components=n_components,
kernel=kernel,
approach=approach,
alignment=None,
random_state=0,
)
gradient_map.fit(connectivity_matrix, sparsity=sparsity)
except Exception as exc_info:
raise exceptions.BrainSpaceError(
f"An error occurred in BrainSpace: {exc_info}"
) from exc_info
gradient_map = gradient.GradientMaps(
n_components=n_components,
kernel=kernel,
approach=approach,
alignment=None,
random_state=0,
)
gradient_map.fit(connectivity_matrix, sparsity=sparsity)

return gradient_map.gradients_, gradient_map.lambdas_

Expand Down Expand Up @@ -92,6 +87,7 @@ def _get_connectivity_matrix(
for index, filename in enumerate(files):
logger.debug("Processing file %s of %s...", index + 1, len(files))
timeseries = _get_nifti_gifti_data(nib.load(filename)).squeeze()

timeseries_permuted = np.swapaxes(timeseries, 0, -1)
timeseries_2d = timeseries_permuted.reshape(timeseries_permuted.shape[0], -1)

Expand Down
17 changes: 8 additions & 9 deletions src/ba_timeseries_gradients/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Utility functions for the BrainSpace runner. """
import argparse
import json
import pathlib

Expand All @@ -10,18 +9,18 @@


def save(
output_gradients: np.ndarray, lambdas: np.ndarray, args: argparse.Namespace
output_gradients: np.ndarray, lambdas: np.ndarray, filename: str | pathlib.Path
) -> None:
"""
Saves a numpy array to a file with the given filename.
Args:
output_gradients: The numpy array to save.
lambdas: The lambdas to save.
filename: The filename to save the array to.
"""
extension = "h5" if args.output_format == "hdf5" else args.output_format
filename = args.output_dir / f"gradients.{extension}"
filename = pathlib.Path(filename)

if filename.suffix == ".h5":
save_hdf5(output_gradients, lambdas, filename)
Expand All @@ -42,9 +41,9 @@ def save_hdf5(
filename: The filename to save the array to.
"""
with h5py.File(filename, "w") as fb:
fb.create_dataset("gradients", data=output_gradients)
fb.create_dataset("lambdas", data=lambdas)
with h5py.File(filename, "w") as h5_file:
h5_file.create_dataset("gradients", data=output_gradients)
h5_file.create_dataset("lambdas", data=lambdas)


def save_json(
Expand All @@ -58,11 +57,11 @@ def save_json(
filename: The filename to save the array to.
"""
with open(filename, "w", encoding="utf-8") as fb:
with open(filename, "w", encoding="utf-8") as file_buffer:
json.dump(
{
"gradients": output_gradients.tolist(),
"lambdas": lambdas.tolist(),
},
fb,
file_buffer,
)
3 changes: 2 additions & 1 deletion tests/integration/test_integration_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ class MockParser:
extension: str = ".nii.gz"
dimensionality_reduction: str = "dm"
parcellation: str | None = None
output_format = "hdf5"
output_format = "h5"
kernel: str = "cosine"
sparsity: float = 0.1
n_components: int = 10
force: bool = False
verbose: int = 0
dry_run: bool = False

def parse_args(self, *args):
"""Return self."""
Expand Down
10 changes: 0 additions & 10 deletions tests/unit/test_unit_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,3 @@ def test_raise_invalid_input_no_parcellation(mock_args, mock_files) -> None:
cli._raise_invalid_input(mock_args, mock_files)

assert "Must provide a parcellation" in str(exc_info.value)


def test_raise_invalid_input_invalid_output_format(mock_args, mock_files) -> None:
"""Test _raise_invalid_input when an invalid output format is provided."""
mock_args.output_format = "invalid_format"

with pytest.raises(exceptions.InputError) as exc_info:
cli._raise_invalid_input(mock_args, mock_files)

assert "Output format must be" in str(exc_info.value)
5 changes: 1 addition & 4 deletions tests/unit/test_unit_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
[
exceptions.BaseLoggingError,
exceptions.InputError,
exceptions.BrainSpaceError,
exceptions.InternalError,
],
)
def test_logging_error(
mocker: pytest_mock.MockFixture,
exception_type: exceptions.BaseLoggingError
| exceptions.InputError
| exceptions.BrainSpaceError,
exception_type: exceptions.BaseLoggingError | exceptions.InputError,
):
"""
Test that a BaseLoggingError is raised with the correct message and that the error is logged.
Expand Down
16 changes: 0 additions & 16 deletions tests/unit/test_unit_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,6 @@ def test_compute_gradients(mocker: pytest_mock.MockFixture) -> None:
assert np.allclose(actual_lambdas, 0)


def test_compute_gradients_brainspace_error(mocker: pytest_mock.MockFixture) -> None:
"""Test that the compute_gradients function raises an error when
BrainSpace raises an error."""
mocker.patch(
"ba_timeseries_gradients.gradients._get_connectivity_matrix",
return_value=np.ones((3, 3)),
)
mocker.patch(
"brainspace.gradient.GradientMaps",
side_effect=Exception("Error"),
)

with pytest.raises(exceptions.BrainSpaceError):
gradients.compute_gradients(files=[])


def test_connevtivity_matrix_from_2d_success(mocker: pytest_mock.MockerFixture) -> None:
"""Test that the connectivity matrix is computed correctly from a 2D
timeseries."""
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def test_save_internal_error() -> None:
parser = MockArgparse(pathlib.Path("."), "txt")

with pytest.raises(exceptions.InternalError):
utils.save(np.array([1, 2, 3]), np.array([2, 3, 4]), parser) # type: ignore[arg-type]
utils.save(np.array([1, 2, 3]), np.array([2, 3, 4]), "wrong.extension") # type: ignore[arg-type]

0 comments on commit 03e6db0

Please sign in to comment.