Skip to content

Commit

Permalink
add additional support for transformation parsing
Browse files Browse the repository at this point in the history
Add tests and add support for:
Model jsons with transformations
Transformation jsons
Python in memory representations of the above
  • Loading branch information
leej3 committed May 14, 2021
1 parent 02cd6fc commit a9ae623
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
1 change: 0 additions & 1 deletion bids/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"reports",
"utils",
"variables",
"statsmodels_design_synthesizer",
]

due.cite(Doi("10.1038/sdata.2016.44"),
Expand Down
10 changes: 5 additions & 5 deletions bids/variables/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,13 @@ def parse_transforms(transforms_in, validate=True,level="run"):
# transformations has been obtained. This will most likely be the case since
# transformations at higher levels will no longer be required when the new
# "flow" approach is used.
if "nodes" in transforms_raw:
nodes_key = "nodes"
elif "steps" in transforms_raw:
nodes_key = "steps"
if "transformations" in transforms_raw:
transforms = transforms_raw["transformations"]
elif any(k in transforms_raw for k in ["nodes","steps"]):
nodes_key = "nodes" if "nodes" in transforms_raw else "steps"
transforms = transforms_raw[nodes_key][0]["transformations"]
else:
raise ValueError("Cannot find a key for nodes in the json input representing the model")
transforms = transforms_raw[nodes_key][0]["transformations"]
return transforms


35 changes: 35 additions & 0 deletions bids/variables/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
from bids.variables import (SparseRunVariable, SimpleVariable,
DenseRunVariable, load_variables)
from bids.variables.entities import Node, RunNode, NodeIndex
from bids.variables.io import parse_transforms
from unittest.mock import patch
import pytest
from os.path import join
from pathlib import Path
import tempfile
import json
from bids.tests import get_test_data_path
from bids.config import set_option, get_option

EXAMPLE_TRANSFORM = {
"Transformations":[{"Name":"example_trans","Inputs":["col_a","col_b"]}]
}
TRANSFORMS_JSON = join(tempfile.tempdir,"tranformations.json")
Path(TRANSFORMS_JSON).write_text(json.dumps(EXAMPLE_TRANSFORM))

@pytest.fixture
def layout1():
Expand Down Expand Up @@ -103,3 +112,29 @@ def test_load_synthetic_dataset(synthetic):
subs = index.get_nodes('subject')
assert len(subs) == 5
assert set(subs[0].variables.keys()) == {'systolic_blood_pressure'}

@pytest.mark.parametrize(
"test_case,transform_input,expected_names",
[
("raw transform json",
EXAMPLE_TRANSFORM,
["example_trans"]
),
("transform json file",
TRANSFORMS_JSON,
["example_trans"]
),
("raw model json",
{"Nodes": [EXAMPLE_TRANSFORM]},
["example_trans"]
),
("model json file",
str(Path(get_test_data_path()) / "ds005/models/ds-005_type-mfx_model.json"),
["Scale"]
),
]
)
def test_parse_transforms(test_case,transform_input,expected_names):
result = parse_transforms(transform_input)
transformation_names = [x['name'] for x in result]
assert expected_names == transformation_names

0 comments on commit a9ae623

Please sign in to comment.